Expose get_ty aka get_aval from jax namespace

PiperOrigin-RevId: 728490205
This commit is contained in:
Yash Katariya 2025-02-18 21:21:37 -08:00 committed by jax authors
parent c825241ccc
commit b35083331c
3 changed files with 5 additions and 3 deletions

View File

@ -79,6 +79,7 @@ from jax._src.lib import xla_client as _xc
Device = _xc.Device
del _xc
from jax._src.core import get_ty as get_ty
from jax._src.api import effects_barrier as effects_barrier
from jax._src.api import block_until_ready as block_until_ready
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401

View File

@ -1552,6 +1552,7 @@ def get_aval(x):
return get_aval(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
get_ty = get_aval
def is_concrete(x):
return to_concrete_value(x) is not None

View File

@ -4795,11 +4795,11 @@ class ShardingInTypesTest(jtu.JaxTestCase):
arr = jax.device_put(np_inp, s)
def f(x):
self.assertEqual(x.aval.sharding.spec, s.spec)
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
x = x * 2
self.assertEqual(x.aval.sharding.spec, s.spec)
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
x = x * x
self.assertEqual(x.aval.sharding.spec, s.spec)
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
return x
# Eager mode