Merge pull request #16143 from jakevdp:fix-shape-poly

PiperOrigin-RevId: 535427698
This commit is contained in:
jax authors 2023-05-25 16:31:09 -07:00
commit 7833528765

View File

@ -2348,7 +2348,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
poly_axes=[None, (0, 1)],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_split", f"{flags_name}",
lambda key, a: jax.random.split(key, 2 * a.shape[0]),
lambda key, a: jax.random.key_data(jax.random.split(key, 2 * a.shape[0])),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 4), _f32)],
poly_axes=[None, (0,)],