fix DeviceArray.__repr__ for complex dtypes, test

c.f. #666
This commit is contained in:
Matthew Johnson 2019-05-02 19:27:22 -07:00
parent 3e302077ae
commit ddd29e724e
2 changed files with 13 additions and 0 deletions

View File

@ -2333,6 +2333,8 @@ for method_name in _nondiff_methods + _diff_methods:
setattr(ShapedArray, "reshape", core.aval_method(_reshape))
setattr(ShapedArray, "flatten", core.aval_method(ravel))
setattr(ShapedArray, "T", core.aval_property(transpose))
setattr(ShapedArray, "real", core.aval_property(real))
setattr(ShapedArray, "imag", core.aval_property(imag))
setattr(ShapedArray, "astype", core.aval_method(lax.convert_element_type))
@ -2345,6 +2347,8 @@ for method_name in _nondiff_methods + _diff_methods:
setattr(DeviceArray, "reshape", _reshape)
setattr(DeviceArray, "flatten", ravel)
setattr(DeviceArray, "T", property(transpose))
setattr(DeviceArray, "real", property(real))
setattr(DeviceArray, "imag", property(imag))
setattr(DeviceArray, "astype", lax.convert_element_type)

View File

@ -537,6 +537,15 @@ class APITest(jtu.JaxTestCase):
self.assertAllClose(grad_ans, 3. * 4. + onp.cos(onp.sin(3. * 4)),
check_dtypes=False)
def test_devicearray_repr(self):
x = device_put(np.zeros(3))
self.assertIsInstance(x, DeviceArray)
repr(x) # doesn't crash
x = device_put(np.ones(3) + 1j * np.ones(3))
self.assertIsInstance(x, DeviceArray)
repr(x) # doesn't crash
if __name__ == '__main__':
absltest.main()