mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Fix PRNGKey handling under jit-of-pmap
This commit is contained in:
parent
cd6e012326
commit
5150cfeeb0
@ -1319,6 +1319,8 @@ def _hlo_shard(aval, axis_env, xs, in_axis):
|
||||
if aval is core.abstract_token:
|
||||
return xs
|
||||
elif isinstance(aval, core.ShapedArray):
|
||||
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
||||
aval = aval.dtype._rules.physical_element_aval(aval.dtype)
|
||||
x, = xs
|
||||
dims = list(aval.shape)
|
||||
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
|
||||
|
@ -2207,6 +2207,15 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
y = jax.pmap(jax.scipy.linalg.expm)(jnp.array([x, x]))
|
||||
y.block_until_ready() # doesn't crash
|
||||
|
||||
def test_pmap_of_prng_key(self):
|
||||
# Regression test for https://github.com/google/jax/issues/20392
|
||||
keys = jax.random.split(jax.random.key(0), jax.device_count())
|
||||
result1 = jax.pmap(jax.random.bits)(keys)
|
||||
with jtu.ignore_warning(
|
||||
category=UserWarning, message="The jitted function foo includes a pmap"):
|
||||
result2 = jax.jit(jax.pmap(jax.random.bits))(keys)
|
||||
self.assertArraysEqual(result1, result2)
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class CppPmapTest(PythonPmapTest):
|
||||
|
Loading…
x
Reference in New Issue
Block a user