mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #16143 from jakevdp:fix-shape-poly
PiperOrigin-RevId: 535427698
This commit is contained in:
commit
7833528765
@ -2348,7 +2348,7 @@ _POLY_SHAPE_TEST_HARNESSES = [
|
||||
poly_axes=[None, (0, 1)],
|
||||
override_jax_config_flags=override_jax_config_flags), # type: ignore
|
||||
PolyHarness("random_split", f"{flags_name}",
|
||||
lambda key, a: jax.random.split(key, 2 * a.shape[0]),
|
||||
lambda key, a: jax.random.key_data(jax.random.split(key, 2 * a.shape[0])),
|
||||
arg_descriptors=[RandArg((key_size,), np.uint32),
|
||||
RandArg((3, 4), _f32)],
|
||||
poly_axes=[None, (0,)],
|
||||
|
Loading…
x
Reference in New Issue
Block a user