[XLA:Client] Expose HloSharding pybind factories for iota tile/partial tile, replicated and manual sharding,

PiperOrigin-RevId: 534600886
This commit is contained in:
Ce Zheng 2023-05-23 16:37:02 -07:00 committed by jax authors
parent 7f7f995bf4
commit 4f1f5e4516

View File

@ -55,6 +55,7 @@ from jax._src.interpreters import pxla
from jax.interpreters import mlir
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.util import curry, unzip2, safe_zip
@ -3890,6 +3891,10 @@ class UtilTest(jtu.JaxTestCase):
hs2 = xc.HloSharding.from_proto(op2)
hs3 = xc.HloSharding.from_proto(op3)
if xla_extension_version >= 155:
self.assertEqual(hs1, xc.HloSharding.iota_tile((2, 2)))
self.assertEqual(hs2, xc.HloSharding.iota_tile((2, 2)))
self.assertEqual(hs3, xc.HloSharding.iota_tile((4, 2)))
self.assertEqual(hash(hs1), hash(hs2))
self.assertNotEqual(hash(hs1), hash(hs3))
self.assertNotEqual(hash(hs2), hash(hs3))
@ -3898,19 +3903,38 @@ class UtilTest(jtu.JaxTestCase):
op1 = xc.OpSharding()
op1.type = xc.OpSharding.Type.OTHER
op1.tile_assignment_dimensions = [4, 1]
op1.tile_assignment_devices = [0, 1, 2, 3]
op1.tile_assignment_devices = [0, 2, 1, 3]
op1.last_tile_dims = [xc.OpSharding.Type.REPLICATED]
op2 = xc.OpSharding()
op2.type = xc.OpSharding.Type.OTHER
op2.tile_assignment_dimensions = [4, 1]
op2.tile_assignment_devices = [0, 1, 2, 3]
op2.tile_assignment_devices = [0, 2, 1, 3]
op2.last_tile_dims = [xc.OpSharding.Type.REPLICATED]
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
hs1 = xc.HloSharding.from_proto(op1)
hs2 = xc.HloSharding.from_proto(op2)
if xla_extension_version >= 155:
self.assertEqual(
hs1,
xc.HloSharding.iota_tile(
(4, 1),
reshape_dims=(2, 2),
transpose_perm=(1, 0),
replicate_last_dim=True,
),
)
self.assertEqual(
hs2,
xc.HloSharding.iota_tile(
(4, 1),
reshape_dims=(2, 2),
transpose_perm=(1, 0),
replicate_last_dim=True,
),
)
self.assertEqual(hash(hs1), hash(hs2))
def test_op_sharding_tuple_shardings(self):
@ -3940,6 +3964,33 @@ class UtilTest(jtu.JaxTestCase):
hs2 = xc.HloSharding.from_proto(op2)
self.assertNotEqual(hash(hs1), hash(hs2))
def test_hlo_sharding_iota_tile_error(self):
if xla_extension_version < 155:
self.skipTest('Requires xla_extension_version >= 155')
self.assertRaisesRegex(
xla_extension.XlaRuntimeError,
'INVALID_ARGUMENT: `dims` should not be empty.',
lambda: xc.HloSharding.iota_tile(())
)
self.assertRaisesRegex(
xla_extension.XlaRuntimeError,
'INVALID_ARGUMENT: Cannot reshape from',
lambda: xc.HloSharding.iota_tile(
(2, 2),
reshape_dims=(2, 4),
transpose_perm=(1, 0),
),
)
self.assertRaisesRegex(
xla_extension.XlaRuntimeError,
'INVALID_ARGUMENT: `reshape_dims` and `transpose_perm` should have the'
' same size',
lambda: xc.HloSharding.iota_tile(
(2, 2),
transpose_perm=(1, 0),
),
)
def test_device_indices_cache(self):
op1 = xc.OpSharding()
op1.type = xc.OpSharding.Type.OTHER
@ -4020,6 +4071,16 @@ class UtilTest(jtu.JaxTestCase):
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op3))
def test_hlo_sharding_manual_replicated(self):
if xla_extension_version < 155:
self.skipTest('Requires xla_extension_version >= 155')
hs1 = xc.HloSharding.manual()
self.assertTrue(hs1.is_manual())
hs2 = xc.HloSharding.replicate()
self.assertTrue(hs2.is_replicated())
def test_op_sharding_cache_on_mesh_pspec_sharding(self):
ndim = 2
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))