mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
fix legacy numpy issue with DeviceArray.__repr__
This commit is contained in:
parent
7c5d683915
commit
7fc3f3f704
@ -442,6 +442,18 @@ class DeviceArray(DeviceValue):
|
||||
def __repr__(self):
|
||||
return onp.array_repr(self)
|
||||
|
||||
def item(self):
|
||||
if onp.issubdtype(self.dtype, onp.complexfloating):
|
||||
return complex(self)
|
||||
elif onp.issubdtype(self.dtype, onp.floating):
|
||||
return float(self)
|
||||
elif onp.issubdtype(self.dtype, onp.integer):
|
||||
return int(self)
|
||||
elif onp.issubdtype(self.dtype, onp.bool_):
|
||||
return bool(self)
|
||||
else:
|
||||
raise TypeError(self.dtype)
|
||||
|
||||
def __len__(self):
|
||||
try:
|
||||
return self.shape[0]
|
||||
|
@ -555,6 +555,10 @@ class APITest(jtu.JaxTestCase):
|
||||
tup = device_put(pack((1, 2)))
|
||||
self.assertEqual(repr(tup), 'DeviceTuple(len=2)')
|
||||
|
||||
def test_legacy_devicearray_repr(self):
|
||||
dx = device_put(3.)
|
||||
str(dx.item()) # doesn't crash
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user