mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Allow get_aval to work on ShapeDtypeStruct
This is necessary to be able to call jit(f).lower(ShapeDtypeStruct(...) when --jax_dynamic_shapes is on. The code in partial_eval.infer_lambda_input_type calls get_aval.
This commit is contained in:
parent
15fbf22715
commit
fe055d06ba
@ -2958,6 +2958,10 @@ class ShapeDtypeStruct:
|
||||
named = frozenset(self.named_shape.items())
|
||||
return hash((self.shape, self.dtype, named))
|
||||
|
||||
core.pytype_aval_mappings[ShapeDtypeStruct] = (
|
||||
lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype),
|
||||
weak_type=False, named_shape=x.named_shape))
|
||||
|
||||
def eval_shape(fun: Callable, *args, **kwargs):
|
||||
"""Compute the shape/dtype of ``fun`` without any FLOPs.
|
||||
|
||||
|
@ -1273,6 +1273,15 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
mhlo = f_lowered.compiler_ir('mhlo')
|
||||
self.assertIn('tensor<?xi32>', str(mhlo))
|
||||
|
||||
def test_lower_abstracted_axes_shapedtypestruct(self):
|
||||
@partial(jax.jit, abstracted_axes=('n',))
|
||||
def f(x):
|
||||
return x.sum()
|
||||
|
||||
f_lowered = f.lower(jax.ShapeDtypeStruct((3,), np.int32))
|
||||
mhlo = f_lowered.compiler_ir('mhlo')
|
||||
self.assertIn('tensor<?xi32>', str(mhlo))
|
||||
|
||||
def test_vmap_abstracted_axis(self):
|
||||
def foo(x, y):
|
||||
z = jax.vmap(jnp.sin)(x) * y
|
||||
|
Loading…
x
Reference in New Issue
Block a user