mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add JAX unit test for Shardy which causes the compiler to introduce the mlir::tensor::TensorDialect
. This was causing the compiler to crash.
PiperOrigin-RevId: 714896947
This commit is contained in:
parent
91ffb640a8
commit
c14e5b4332
@ -6813,6 +6813,19 @@ class ShardyTest(jtu.JaxTestCase):
|
||||
self.assertEqual(repr(dim_sharding),
|
||||
"SdyDimSharding({'data', 'model', ?}p2)")
|
||||
|
||||
def test_tensor_dialect(self):
|
||||
# While this doesn't emit any `mlir::TensorDialect` ops, some pass in the
|
||||
# compiler pipeline is temporarily introducing it before then discarding it
|
||||
# again. Make sure this doesn't crash.
|
||||
mesh = jtu.create_mesh((2,), ('x'))
|
||||
in_sds = jax.ShapeDtypeStruct((4, 8), jnp.float32)
|
||||
|
||||
@partial(jax.jit, out_shardings=NamedSharding(mesh, P('x')))
|
||||
def gen_dummy_inputs():
|
||||
return tuple(jax.random.normal(jax.random.key(42), shape=in_sds.shape
|
||||
).astype(in_sds.dtype))
|
||||
gen_dummy_inputs() # doesn't crash
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user