Rename get_ty to typeof which is an alias of get_aval

PiperOrigin-RevId: 735946640
This commit is contained in:
Yash Katariya 2025-03-11 17:34:05 -07:00 committed by jax authors
parent c6b164dc09
commit 3a26804c68
4 changed files with 7 additions and 7 deletions

View File

@ -79,7 +79,7 @@ from jax._src.lib import xla_client as _xc
Device = _xc.Device
del _xc
from jax._src.core import get_ty as get_ty
from jax._src.core import typeof as typeof
from jax._src.api import effects_barrier as effects_barrier
from jax._src.api import block_until_ready as block_until_ready
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401

View File

@ -1576,7 +1576,7 @@ def get_aval(x):
return get_aval(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
get_ty = get_aval
typeof = get_aval
def is_concrete(x):
return to_concrete_value(x) is not None

View File

@ -216,8 +216,8 @@ class MutableArrayTest(jtu.JaxTestCase):
@jax.jit
def f(x_ref):
self.assertEqual(core.get_ty(x_ref).sharding.spec,
core.get_ty(x_ref[...]).sharding.spec)
self.assertEqual(core.typeof(x_ref).sharding.spec,
core.typeof(x_ref[...]).sharding.spec)
y = x_ref[...] + 1
return y

View File

@ -4883,11 +4883,11 @@ class ShardingInTypesTest(jtu.JaxTestCase):
arr = jax.device_put(np_inp, s)
def f(x):
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
self.assertEqual(jax.typeof(x).sharding.spec, s.spec)
x = x * 2
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
self.assertEqual(jax.typeof(x).sharding.spec, s.spec)
x = x * x
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
self.assertEqual(jax.typeof(x).sharding.spec, s.spec)
return x
# Eager mode