mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
jax2tf: make shape_poly_test pass with custom PRNG
This commit is contained in:
parent
2155b9181f
commit
b853ce9967
@ -2348,7 +2348,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
poly_axes=[None, (0, 1)],
|
||||
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
||||
PolyHarness("random_split", f"{flags_name}",
|
||||
lambda key, a: jax.random.split(key, 2 * a.shape[0]),
|
||||
lambda key, a: jax.random.key_data(jax.random.split(key, 2 * a.shape[0])),
|
||||
arg_descriptors=[RandArg((key_size,), np.uint32),
|
||||
RandArg((3, 4), _f32)],
|
||||
poly_axes=[None, (0,)],
|
||||
|
Loading…
x
Reference in New Issue
Block a user