mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
api_util: make shaped_abstractify respect raise_to_shaped
This commit is contained in:
parent
212edd66d6
commit
5d45458c7b
@ -318,7 +318,8 @@ def _dtype(x):
|
||||
|
||||
def shaped_abstractify(x):
|
||||
try:
|
||||
return core.raise_to_shaped(core.get_aval(x))
|
||||
return core.raise_to_shaped(
|
||||
x if isinstance(x, core.AbstractValue) else core.get_aval(x))
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
|
@ -27,6 +27,7 @@ import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax import dtypes
|
||||
from jax import stages
|
||||
from jax.errors import JAXTypeError
|
||||
from jax import lax
|
||||
@ -827,7 +828,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
return x @ y
|
||||
|
||||
shape = (8, 8)
|
||||
aval = jax.ShapedArray(shape, jnp.int64)
|
||||
aval = jax.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64))
|
||||
x = jnp.arange(np.prod(shape)).reshape(shape)
|
||||
exe = f.lower(aval, x, _global_avals=True).compile()
|
||||
self.assertIsInstance(exe, stages.Compiled)
|
||||
|
Loading…
x
Reference in New Issue
Block a user