Add JAX unit test for Shardy which causes the compiler to introduce the mlir::tensor::TensorDialect. This was causing the compiler to crash.

PiperOrigin-RevId: 714896947
This commit is contained in:
Bart Chrzaszcz 2025-01-13 03:07:54 -08:00 committed by jax authors
parent 91ffb640a8
commit c14e5b4332

View File

@ -6813,6 +6813,19 @@ class ShardyTest(jtu.JaxTestCase):
self.assertEqual(repr(dim_sharding),
"SdyDimSharding({'data', 'model', ?}p2)")
def test_tensor_dialect(self):
# While this doesn't emit any `mlir::TensorDialect` ops, some pass in the
# compiler pipeline is temporarily introducing it before then discarding it
# again. Make sure this doesn't crash.
mesh = jtu.create_mesh((2,), ('x'))
in_sds = jax.ShapeDtypeStruct((4, 8), jnp.float32)
@partial(jax.jit, out_shardings=NamedSharding(mesh, P('x')))
def gen_dummy_inputs():
return tuple(jax.random.normal(jax.random.key(42), shape=in_sds.shape
).astype(in_sds.dtype))
gen_dummy_inputs() # doesn't crash
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())