diff --git a/jax/__init__.py b/jax/__init__.py index 950c3ed4b..ae3bac4ad 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -79,7 +79,7 @@ from jax._src.lib import xla_client as _xc Device = _xc.Device del _xc -from jax._src.core import get_ty as get_ty +from jax._src.core import typeof as typeof from jax._src.api import effects_barrier as effects_barrier from jax._src.api import block_until_ready as block_until_ready from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401 diff --git a/jax/_src/core.py b/jax/_src/core.py index 9d8edeb8b..b17e26255 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1576,7 +1576,7 @@ def get_aval(x): return get_aval(x.__jax_array__()) raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type") -get_ty = get_aval +typeof = get_aval def is_concrete(x): return to_concrete_value(x) is not None diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index c510c2cfa..4c6a8eb7a 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -216,8 +216,8 @@ class MutableArrayTest(jtu.JaxTestCase): @jax.jit def f(x_ref): - self.assertEqual(core.get_ty(x_ref).sharding.spec, - core.get_ty(x_ref[...]).sharding.spec) + self.assertEqual(core.typeof(x_ref).sharding.spec, + core.typeof(x_ref[...]).sharding.spec) y = x_ref[...] + 1 return y diff --git a/tests/pjit_test.py b/tests/pjit_test.py index bd7954d60..4cd1af9d3 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4883,11 +4883,11 @@ class ShardingInTypesTest(jtu.JaxTestCase): arr = jax.device_put(np_inp, s) def f(x): - self.assertEqual(jax.get_ty(x).sharding.spec, s.spec) + self.assertEqual(jax.typeof(x).sharding.spec, s.spec) x = x * 2 - self.assertEqual(jax.get_ty(x).sharding.spec, s.spec) + self.assertEqual(jax.typeof(x).sharding.spec, s.spec) x = x * x - self.assertEqual(jax.get_ty(x).sharding.spec, s.spec) + self.assertEqual(jax.typeof(x).sharding.spec, s.spec) return x # Eager mode