mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Expose get_ty
aka get_aval from jax namespace
PiperOrigin-RevId: 728490205
This commit is contained in:
parent
c825241ccc
commit
b35083331c
@ -79,6 +79,7 @@ from jax._src.lib import xla_client as _xc
|
||||
Device = _xc.Device
|
||||
del _xc
|
||||
|
||||
from jax._src.core import get_ty as get_ty
|
||||
from jax._src.api import effects_barrier as effects_barrier
|
||||
from jax._src.api import block_until_ready as block_until_ready
|
||||
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401
|
||||
|
@ -1552,6 +1552,7 @@ def get_aval(x):
|
||||
return get_aval(x.__jax_array__())
|
||||
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
|
||||
|
||||
get_ty = get_aval
|
||||
|
||||
def is_concrete(x):
|
||||
return to_concrete_value(x) is not None
|
||||
|
@ -4795,11 +4795,11 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
def f(x):
|
||||
self.assertEqual(x.aval.sharding.spec, s.spec)
|
||||
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
|
||||
x = x * 2
|
||||
self.assertEqual(x.aval.sharding.spec, s.spec)
|
||||
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
|
||||
x = x * x
|
||||
self.assertEqual(x.aval.sharding.spec, s.spec)
|
||||
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
|
||||
return x
|
||||
|
||||
# Eager mode
|
||||
|
Loading…
x
Reference in New Issue
Block a user