Allow get_aval to work on ShapeDtypeStruct

This is necessary to be able to call jit(f).lower(ShapeDtypeStruct(...) when
--jax_dynamic_shapes is on. The code in partial_eval.infer_lambda_input_type
calls get_aval.
This commit is contained in:
George Necula 2022-09-03 08:17:38 +03:00
parent 15fbf22715
commit fe055d06ba
2 changed files with 13 additions and 0 deletions

View File

@ -2958,6 +2958,10 @@ class ShapeDtypeStruct:
named = frozenset(self.named_shape.items())
return hash((self.shape, self.dtype, named))
core.pytype_aval_mappings[ShapeDtypeStruct] = (
lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype),
weak_type=False, named_shape=x.named_shape))
def eval_shape(fun: Callable, *args, **kwargs):
"""Compute the shape/dtype of ``fun`` without any FLOPs.

View File

@ -1273,6 +1273,15 @@ class DynamicShapeTest(jtu.JaxTestCase):
mhlo = f_lowered.compiler_ir('mhlo')
self.assertIn('tensor<?xi32>', str(mhlo))
def test_lower_abstracted_axes_shapedtypestruct(self):
@partial(jax.jit, abstracted_axes=('n',))
def f(x):
return x.sum()
f_lowered = f.lower(jax.ShapeDtypeStruct((3,), np.int32))
mhlo = f_lowered.compiler_ir('mhlo')
self.assertIn('tensor<?xi32>', str(mhlo))
def test_vmap_abstracted_axis(self):
def foo(x, y):
z = jax.vmap(jnp.sin)(x) * y