mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Implement device_get for typed PRNG keys
This commit is contained in:
parent
c3b4b76080
commit
b9ad519a29
@ -2445,6 +2445,13 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
||||
def _device_get(x):
|
||||
if isinstance(x, core.Tracer):
|
||||
return x
|
||||
if dtypes.issubdtype(getattr(x, "dtype", None), dtypes.extended):
|
||||
try:
|
||||
to_device = x.dtype._rules.device_get
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
return to_device(x)
|
||||
try:
|
||||
toarray = x.__array__
|
||||
except AttributeError:
|
||||
|
@ -400,6 +400,11 @@ class KeyTyRules:
|
||||
phys_result = phys_handler(phys_arrays)
|
||||
return PRNGKeyArray(aval.dtype._impl, phys_result)
|
||||
|
||||
@staticmethod
|
||||
def device_get(val):
|
||||
buffer = api.device_get(random_unwrap(val))
|
||||
return random_wrap(buffer, impl=val.dtype._impl)
|
||||
|
||||
@staticmethod
|
||||
def device_put_sharded(vals, aval, sharding, devices):
|
||||
physical_aval = core.physical_aval(aval)
|
||||
|
@ -936,6 +936,11 @@ class KeyArrayTest(jtu.JaxTestCase):
|
||||
x = jnp.array([True, False, False])
|
||||
f(x) # doesn't crash
|
||||
|
||||
def test_device_get(self):
|
||||
keys = self.make_keys(4)
|
||||
keys_on_host = jax.device_get(keys)
|
||||
self.assertKeysEqual(keys, keys_on_host)
|
||||
|
||||
def test_device_put(self):
|
||||
device = jax.devices()[0]
|
||||
keys = self.make_keys(4)
|
||||
|
Loading…
x
Reference in New Issue
Block a user