fix legacy numpy issue with DeviceArray.__repr__

This commit is contained in:
Matthew Johnson 2019-05-03 08:14:03 -07:00
parent 7c5d683915
commit 7fc3f3f704
2 changed files with 16 additions and 0 deletions

View File

@ -442,6 +442,18 @@ class DeviceArray(DeviceValue):
def __repr__(self):
return onp.array_repr(self)
def item(self):
if onp.issubdtype(self.dtype, onp.complexfloating):
return complex(self)
elif onp.issubdtype(self.dtype, onp.floating):
return float(self)
elif onp.issubdtype(self.dtype, onp.integer):
return int(self)
elif onp.issubdtype(self.dtype, onp.bool_):
return bool(self)
else:
raise TypeError(self.dtype)
def __len__(self):
try:
return self.shape[0]

View File

@ -555,6 +555,10 @@ class APITest(jtu.JaxTestCase):
tup = device_put(pack((1, 2)))
self.assertEqual(repr(tup), 'DeviceTuple(len=2)')
def test_legacy_devicearray_repr(self):
dx = device_put(3.)
str(dx.item()) # doesn't crash
if __name__ == '__main__':
absltest.main()