mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Set ArrayImpl.__name__ to ArrayImpl
Fixes https://github.com/google/jax/issues/14768 PiperOrigin-RevId: 515097907
This commit is contained in:
parent
2c2dfbe89b
commit
44082be103
@ -28,6 +28,7 @@ from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_zip
|
||||
from jax.interpreters import pxla
|
||||
from jax.experimental.pjit import pjit
|
||||
@ -82,6 +83,10 @@ def create_array(shape, sharding, global_data=None):
|
||||
@jtu.with_config(jax_array=True)
|
||||
class JaxArrayTest(jtu.JaxTestCase):
|
||||
|
||||
def test_array_impl_name(self):
|
||||
expected = "Array" if xla_extension_version < 135 else "ArrayImpl"
|
||||
self.assertEqual(array.ArrayImpl.__name__, expected)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("mesh_x_y", P("x", "y")),
|
||||
("mesh_x", P("x")),
|
||||
|
@ -32,6 +32,7 @@ from jax import numpy as jnp
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
@ -498,7 +499,10 @@ class JaxNumpyOperatorTests(jtu.JaxTestCase):
|
||||
other = othertype(data)
|
||||
|
||||
if config.jax_array:
|
||||
val_str = 'Array'
|
||||
if xla_extension_version < 135:
|
||||
val_str = 'Array'
|
||||
else:
|
||||
val_str = 'ArrayImpl'
|
||||
else:
|
||||
val_str = 'DeviceArray'
|
||||
msg = f"unsupported operand type.* '{val_str}' and '{othertype.__name__}'"
|
||||
@ -517,7 +521,10 @@ class JaxNumpyOperatorTests(jtu.JaxTestCase):
|
||||
other = othertype(data)
|
||||
|
||||
if config.jax_array:
|
||||
val_str = 'Array'
|
||||
if xla_extension_version < 135:
|
||||
val_str = 'Array'
|
||||
else:
|
||||
val_str = 'ArrayImpl'
|
||||
else:
|
||||
val_str = 'DeviceArray'
|
||||
msg = f"unsupported operand type.* '{othertype.__name__}' and '{val_str}'"
|
||||
|
Loading…
x
Reference in New Issue
Block a user