api_util: make shaped_abstractify respect raise_to_shaped

This commit is contained in:
Jake VanderPlas 2022-05-05 17:20:00 -07:00
parent 212edd66d6
commit 5d45458c7b
2 changed files with 4 additions and 2 deletions

View File

@ -318,7 +318,8 @@ def _dtype(x):
def shaped_abstractify(x):
try:
return core.raise_to_shaped(core.get_aval(x))
return core.raise_to_shaped(
x if isinstance(x, core.AbstractValue) else core.get_aval(x))
except TypeError:
pass

View File

@ -27,6 +27,7 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax import dtypes
from jax import stages
from jax.errors import JAXTypeError
from jax import lax
@ -827,7 +828,7 @@ class PJitTest(jtu.BufferDonationTestCase):
return x @ y
shape = (8, 8)
aval = jax.ShapedArray(shape, jnp.int64)
aval = jax.ShapedArray(shape, dtypes.canonicalize_dtype(jnp.int64))
x = jnp.arange(np.prod(shape)).reshape(shape)
exe = f.lower(aval, x, _global_avals=True).compile()
self.assertIsInstance(exe, stages.Compiled)