Re-land #24589 with fixes to handle dtype that is not compatible with NumPy.

Previously, this change did not account for that fact that `device_get` may be called on objects that have a non-NumPy-compatible `dtype` attribute, such as tensorflow tensors. This change adds new dtype handling aimed at being robust to this case.

Reverts 2bed1e88e4276558e4dd5e6a6d5afe6f2396a25d

PiperOrigin-RevId: 691568933
This commit is contained in:
Jake VanderPlas 2024-10-30 15:12:04 -07:00 committed by jax authors
parent 242e6634ff
commit 0181cb396d
3 changed files with 21 additions and 0 deletions

View File

@ -2440,6 +2440,17 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
def _device_get(x):
if isinstance(x, core.Tracer):
return x
# Extended dtypes dispatch via their device_get rule.
if isinstance(x, basearray.Array) and dtypes.issubdtype(x.dtype, dtypes.extended):
try:
to_device = x.dtype._rules.device_get
except AttributeError:
pass
else:
return to_device(x)
# Other types dispatch via their __array__ method.
try:
toarray = x.__array__
except AttributeError:

View File

@ -400,6 +400,11 @@ class KeyTyRules:
phys_result = phys_handler(phys_arrays)
return PRNGKeyArray(aval.dtype._impl, phys_result)
@staticmethod
def device_get(val):
buffer = api.device_get(random_unwrap(val))
return random_wrap(buffer, impl=val.dtype._impl)
@staticmethod
def device_put_sharded(vals, aval, sharding, devices):
physical_aval = core.physical_aval(aval)

View File

@ -936,6 +936,11 @@ class KeyArrayTest(jtu.JaxTestCase):
x = jnp.array([True, False, False])
f(x) # doesn't crash
def test_device_get(self):
keys = self.make_keys(4)
keys_on_host = jax.device_get(keys)
self.assertKeysEqual(keys, keys_on_host)
def test_device_put(self):
device = jax.devices()[0]
keys = self.make_keys(4)