mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Use Array
in __repr__
instead of the class name which is ArrayImpl
.
PiperOrigin-RevId: 477465432
This commit is contained in:
parent
0282b4bfad
commit
c89cb5d8a4
@ -310,7 +310,7 @@ class ArrayImpl(basearray.Array):
|
||||
return self.shape == self._arrays[0].shape
|
||||
|
||||
def __repr__(self):
|
||||
prefix = '{}('.format(self.__class__.__name__.lstrip('_'))
|
||||
prefix = 'Array('
|
||||
if self.aval is not None and self.aval.weak_type:
|
||||
dtype_str = f'dtype={self.dtype.name}, weak_type=True)'
|
||||
else:
|
||||
|
@ -187,7 +187,7 @@ class JaxArrayTest(jtu.JaxTestCase):
|
||||
input_shape = (8, 2)
|
||||
arr, _ = create_array(
|
||||
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
|
||||
repr(arr) # doesn't crash
|
||||
self.assertStartsWith(repr(arr), "Array(")
|
||||
|
||||
def test_jnp_array(self):
|
||||
arr = jnp.array([1, 2, 3])
|
||||
|
Loading…
x
Reference in New Issue
Block a user