mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
3e302077ae
commit
ddd29e724e
@ -2333,6 +2333,8 @@ for method_name in _nondiff_methods + _diff_methods:
|
||||
setattr(ShapedArray, "reshape", core.aval_method(_reshape))
|
||||
setattr(ShapedArray, "flatten", core.aval_method(ravel))
|
||||
setattr(ShapedArray, "T", core.aval_property(transpose))
|
||||
setattr(ShapedArray, "real", core.aval_property(real))
|
||||
setattr(ShapedArray, "imag", core.aval_property(imag))
|
||||
setattr(ShapedArray, "astype", core.aval_method(lax.convert_element_type))
|
||||
|
||||
|
||||
@ -2345,6 +2347,8 @@ for method_name in _nondiff_methods + _diff_methods:
|
||||
setattr(DeviceArray, "reshape", _reshape)
|
||||
setattr(DeviceArray, "flatten", ravel)
|
||||
setattr(DeviceArray, "T", property(transpose))
|
||||
setattr(DeviceArray, "real", property(real))
|
||||
setattr(DeviceArray, "imag", property(imag))
|
||||
setattr(DeviceArray, "astype", lax.convert_element_type)
|
||||
|
||||
|
||||
|
@ -537,6 +537,15 @@ class APITest(jtu.JaxTestCase):
|
||||
self.assertAllClose(grad_ans, 3. * 4. + onp.cos(onp.sin(3. * 4)),
|
||||
check_dtypes=False)
|
||||
|
||||
def test_devicearray_repr(self):
|
||||
x = device_put(np.zeros(3))
|
||||
self.assertIsInstance(x, DeviceArray)
|
||||
repr(x) # doesn't crash
|
||||
|
||||
x = device_put(np.ones(3) + 1j * np.ones(3))
|
||||
self.assertIsInstance(x, DeviceArray)
|
||||
repr(x) # doesn't crash
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user