Use Array in __repr__ instead of the class name which is ArrayImpl.

PiperOrigin-RevId: 477465432
This commit is contained in:
Yash Katariya 2022-09-28 08:57:07 -07:00 committed by jax authors
parent 0282b4bfad
commit c89cb5d8a4
2 changed files with 2 additions and 2 deletions

View File

@ -310,7 +310,7 @@ class ArrayImpl(basearray.Array):
return self.shape == self._arrays[0].shape
def __repr__(self):
prefix = '{}('.format(self.__class__.__name__.lstrip('_'))
prefix = 'Array('
if self.aval is not None and self.aval.weak_type:
dtype_str = f'dtype={self.dtype.name}, weak_type=True)'
else:

View File

@ -187,7 +187,7 @@ class JaxArrayTest(jtu.JaxTestCase):
input_shape = (8, 2)
arr, _ = create_array(
input_shape, sharding.MeshPspecSharding(global_mesh, P('x', 'y')))
repr(arr) # doesn't crash
self.assertStartsWith(repr(arr), "Array(")
def test_jnp_array(self):
arr = jnp.array([1, 2, 3])