mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Rename get_ty
to typeof
which is an alias of get_aval
PiperOrigin-RevId: 735946640
This commit is contained in:
parent
c6b164dc09
commit
3a26804c68
@ -79,7 +79,7 @@ from jax._src.lib import xla_client as _xc
|
||||
Device = _xc.Device
|
||||
del _xc
|
||||
|
||||
from jax._src.core import get_ty as get_ty
|
||||
from jax._src.core import typeof as typeof
|
||||
from jax._src.api import effects_barrier as effects_barrier
|
||||
from jax._src.api import block_until_ready as block_until_ready
|
||||
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint # noqa: F401
|
||||
|
@ -1576,7 +1576,7 @@ def get_aval(x):
|
||||
return get_aval(x.__jax_array__())
|
||||
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")
|
||||
|
||||
get_ty = get_aval
|
||||
typeof = get_aval
|
||||
|
||||
def is_concrete(x):
|
||||
return to_concrete_value(x) is not None
|
||||
|
@ -216,8 +216,8 @@ class MutableArrayTest(jtu.JaxTestCase):
|
||||
|
||||
@jax.jit
|
||||
def f(x_ref):
|
||||
self.assertEqual(core.get_ty(x_ref).sharding.spec,
|
||||
core.get_ty(x_ref[...]).sharding.spec)
|
||||
self.assertEqual(core.typeof(x_ref).sharding.spec,
|
||||
core.typeof(x_ref[...]).sharding.spec)
|
||||
y = x_ref[...] + 1
|
||||
return y
|
||||
|
||||
|
@ -4883,11 +4883,11 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
def f(x):
|
||||
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
|
||||
self.assertEqual(jax.typeof(x).sharding.spec, s.spec)
|
||||
x = x * 2
|
||||
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
|
||||
self.assertEqual(jax.typeof(x).sharding.spec, s.spec)
|
||||
x = x * x
|
||||
self.assertEqual(jax.get_ty(x).sharding.spec, s.spec)
|
||||
self.assertEqual(jax.typeof(x).sharding.spec, s.spec)
|
||||
return x
|
||||
|
||||
# Eager mode
|
||||
|
Loading…
x
Reference in New Issue
Block a user