mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[XLA:Client] Expose HloSharding pybind factories for iota tile/partial tile, replicated and manual sharding,
PiperOrigin-RevId: 534600886
This commit is contained in:
parent
7f7f995bf4
commit
4f1f5e4516
@ -55,6 +55,7 @@ from jax._src.interpreters import pxla
|
||||
from jax.interpreters import mlir
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import curry, unzip2, safe_zip
|
||||
|
||||
@ -3890,6 +3891,10 @@ class UtilTest(jtu.JaxTestCase):
|
||||
hs2 = xc.HloSharding.from_proto(op2)
|
||||
hs3 = xc.HloSharding.from_proto(op3)
|
||||
|
||||
if xla_extension_version >= 155:
|
||||
self.assertEqual(hs1, xc.HloSharding.iota_tile((2, 2)))
|
||||
self.assertEqual(hs2, xc.HloSharding.iota_tile((2, 2)))
|
||||
self.assertEqual(hs3, xc.HloSharding.iota_tile((4, 2)))
|
||||
self.assertEqual(hash(hs1), hash(hs2))
|
||||
self.assertNotEqual(hash(hs1), hash(hs3))
|
||||
self.assertNotEqual(hash(hs2), hash(hs3))
|
||||
@ -3898,19 +3903,38 @@ class UtilTest(jtu.JaxTestCase):
|
||||
op1 = xc.OpSharding()
|
||||
op1.type = xc.OpSharding.Type.OTHER
|
||||
op1.tile_assignment_dimensions = [4, 1]
|
||||
op1.tile_assignment_devices = [0, 1, 2, 3]
|
||||
op1.tile_assignment_devices = [0, 2, 1, 3]
|
||||
op1.last_tile_dims = [xc.OpSharding.Type.REPLICATED]
|
||||
|
||||
op2 = xc.OpSharding()
|
||||
op2.type = xc.OpSharding.Type.OTHER
|
||||
op2.tile_assignment_dimensions = [4, 1]
|
||||
op2.tile_assignment_devices = [0, 1, 2, 3]
|
||||
op2.tile_assignment_devices = [0, 2, 1, 3]
|
||||
op2.last_tile_dims = [xc.OpSharding.Type.REPLICATED]
|
||||
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
|
||||
|
||||
hs1 = xc.HloSharding.from_proto(op1)
|
||||
hs2 = xc.HloSharding.from_proto(op2)
|
||||
if xla_extension_version >= 155:
|
||||
self.assertEqual(
|
||||
hs1,
|
||||
xc.HloSharding.iota_tile(
|
||||
(4, 1),
|
||||
reshape_dims=(2, 2),
|
||||
transpose_perm=(1, 0),
|
||||
replicate_last_dim=True,
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
hs2,
|
||||
xc.HloSharding.iota_tile(
|
||||
(4, 1),
|
||||
reshape_dims=(2, 2),
|
||||
transpose_perm=(1, 0),
|
||||
replicate_last_dim=True,
|
||||
),
|
||||
)
|
||||
self.assertEqual(hash(hs1), hash(hs2))
|
||||
|
||||
def test_op_sharding_tuple_shardings(self):
|
||||
@ -3940,6 +3964,33 @@ class UtilTest(jtu.JaxTestCase):
|
||||
hs2 = xc.HloSharding.from_proto(op2)
|
||||
self.assertNotEqual(hash(hs1), hash(hs2))
|
||||
|
||||
def test_hlo_sharding_iota_tile_error(self):
|
||||
if xla_extension_version < 155:
|
||||
self.skipTest('Requires xla_extension_version >= 155')
|
||||
self.assertRaisesRegex(
|
||||
xla_extension.XlaRuntimeError,
|
||||
'INVALID_ARGUMENT: `dims` should not be empty.',
|
||||
lambda: xc.HloSharding.iota_tile(())
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
xla_extension.XlaRuntimeError,
|
||||
'INVALID_ARGUMENT: Cannot reshape from',
|
||||
lambda: xc.HloSharding.iota_tile(
|
||||
(2, 2),
|
||||
reshape_dims=(2, 4),
|
||||
transpose_perm=(1, 0),
|
||||
),
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
xla_extension.XlaRuntimeError,
|
||||
'INVALID_ARGUMENT: `reshape_dims` and `transpose_perm` should have the'
|
||||
' same size',
|
||||
lambda: xc.HloSharding.iota_tile(
|
||||
(2, 2),
|
||||
transpose_perm=(1, 0),
|
||||
),
|
||||
)
|
||||
|
||||
def test_device_indices_cache(self):
|
||||
op1 = xc.OpSharding()
|
||||
op1.type = xc.OpSharding.Type.OTHER
|
||||
@ -4020,6 +4071,16 @@ class UtilTest(jtu.JaxTestCase):
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op2))
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(op1, op3))
|
||||
|
||||
def test_hlo_sharding_manual_replicated(self):
|
||||
if xla_extension_version < 155:
|
||||
self.skipTest('Requires xla_extension_version >= 155')
|
||||
|
||||
hs1 = xc.HloSharding.manual()
|
||||
self.assertTrue(hs1.is_manual())
|
||||
|
||||
hs2 = xc.HloSharding.replicate()
|
||||
self.assertTrue(hs2.is_replicated())
|
||||
|
||||
def test_op_sharding_cache_on_mesh_pspec_sharding(self):
|
||||
ndim = 2
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
|
Loading…
x
Reference in New Issue
Block a user