jax2tf: make shape_poly_test pass with custom PRNG

This commit is contained in:
Jake VanderPlas 2023-05-25 15:16:46 -07:00
parent 2155b9181f
commit b853ce9967

View File

@ -2348,7 +2348,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
poly_axes=[None, (0, 1)],
override_jax_config_flags=override_jax_config_flags), # type: ignore
PolyHarness("random_split", f"{flags_name}",
lambda key, a: jax.random.split(key, 2 * a.shape[0]),
lambda key, a: jax.random.key_data(jax.random.split(key, 2 * a.shape[0])),
arg_descriptors=[RandArg((key_size,), np.uint32),
RandArg((3, 4), _f32)],
poly_axes=[None, (0,)],