mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Export HloSharding
via pybind which is a C++ wrapper around OpSharding proto.
PiperOrigin-RevId: 463992136
This commit is contained in:
parent
560c936a46
commit
47623264db
@ -21,6 +21,8 @@ import jax
|
||||
from jax import lax
|
||||
from jax.experimental import sparse
|
||||
from jax._src.api_util import shaped_abstractify # technically not an api fn
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import pxla
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
@ -471,6 +473,23 @@ def bench_shaped_abstractify(state):
|
||||
_ = [shaped_abstractify(x) for x in args]
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.unit(google_benchmark.kMicrosecond)
|
||||
def bench_are_op_shardings_equal(state):
|
||||
op1 = xc.OpSharding()
|
||||
op1.type = xc.OpSharding.Type.OTHER
|
||||
op1.tile_assignment_dimensions = [4, 192, 16]
|
||||
op1.tile_assignment_devices = list(range(12288))
|
||||
|
||||
op2 = xc.OpSharding()
|
||||
op2.type = xc.OpSharding.Type.OTHER
|
||||
op2.tile_assignment_dimensions = [4, 192, 16]
|
||||
op2.tile_assignment_devices = list(range(12288))
|
||||
|
||||
while state:
|
||||
pxla.are_op_shardings_equal(op1, op2)
|
||||
|
||||
|
||||
def swap(a, b):
|
||||
return b, a
|
||||
|
||||
|
@ -70,6 +70,7 @@ from jax._src.abstract_arrays import array_types
|
||||
from jax._src.config import config
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
@ -2713,18 +2714,20 @@ def _get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
|
||||
|
||||
|
||||
def are_op_shardings_equal(op1, op2):
|
||||
# TODO(yashkatariya): Use HloSharding class to check for equality.
|
||||
if id(op1) == id(op2):
|
||||
return True
|
||||
if op1.type != op2.type:
|
||||
return False
|
||||
if op1.type == xc.OpSharding.Type.TUPLE:
|
||||
return all(are_op_shardings_equal(i, j)
|
||||
for i, j in safe_zip(op1.tuple_shardings, op2.tuple_shardings))
|
||||
return (op1.tile_assignment_dimensions == op2.tile_assignment_dimensions and
|
||||
op1.tile_assignment_devices == op2.tile_assignment_devices and
|
||||
op1.last_tile_dims == op2.last_tile_dims and
|
||||
op1.replicate_on_last_tile_dim == op2.replicate_on_last_tile_dim)
|
||||
if xla_extension_version >= 81:
|
||||
return xc.HloSharding.from_proto(op1) == xc.HloSharding.from_proto(op2)
|
||||
else:
|
||||
if op1.type == xc.OpSharding.Type.TUPLE:
|
||||
return all(are_op_shardings_equal(i, j)
|
||||
for i, j in safe_zip(op1.tuple_shardings, op2.tuple_shardings))
|
||||
return (op1.tile_assignment_dimensions == op2.tile_assignment_dimensions and
|
||||
op1.tile_assignment_devices == op2.tile_assignment_devices and
|
||||
op1.last_tile_dims == op2.last_tile_dims and
|
||||
op1.replicate_on_last_tile_dim == op2.replicate_on_last_tile_dim)
|
||||
|
||||
|
||||
_forbidden_primitives = {
|
||||
|
@ -45,6 +45,7 @@ from jax.experimental.pjit import (pjit, pjit_p, with_sharding_constraint,
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.lib import xla_client as xc, xla_bridge
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import prod, curry, unzip2, safe_zip
|
||||
|
||||
from jax.config import config
|
||||
@ -2056,15 +2057,15 @@ class UtilTest(jtu.JaxTestCase):
|
||||
else:
|
||||
pjit_lib._check_all_or_none_unspecified(entries, 'test axis resources')
|
||||
|
||||
def test_op_sharding_equality(self):
|
||||
def test_op_sharding_equality_and_hash_equality(self):
|
||||
op1 = xc.OpSharding()
|
||||
op1.type = xc.OpSharding.Type.OTHER
|
||||
op1.tile_assignment_dimensions = [4, 2]
|
||||
op1.tile_assignment_dimensions = [2, 2]
|
||||
op1.tile_assignment_devices = [0, 1, 2, 3]
|
||||
|
||||
op2 = xc.OpSharding()
|
||||
op2.type = xc.OpSharding.Type.OTHER
|
||||
op2.tile_assignment_dimensions = [4, 2]
|
||||
op2.tile_assignment_dimensions = [2, 2]
|
||||
op2.tile_assignment_devices = [0, 1, 2, 3]
|
||||
|
||||
op3 = xc.OpSharding()
|
||||
@ -2076,6 +2077,15 @@ class UtilTest(jtu.JaxTestCase):
|
||||
self.assertFalse(pxla.are_op_shardings_equal(op1, op3))
|
||||
self.assertFalse(pxla.are_op_shardings_equal(op2, op3))
|
||||
|
||||
if xla_extension_version >= 81:
|
||||
hs1 = xc.HloSharding.from_proto(op1)
|
||||
hs2 = xc.HloSharding.from_proto(op2)
|
||||
hs3 = xc.HloSharding.from_proto(op3)
|
||||
|
||||
self.assertEqual(hash(hs1), hash(hs2))
|
||||
self.assertNotEqual(hash(hs1), hash(hs3))
|
||||
self.assertNotEqual(hash(hs2), hash(hs3))
|
||||
|
||||
def test_op_sharding_partial_sharding(self):
|
||||
op1 = xc.OpSharding()
|
||||
op1.type = xc.OpSharding.Type.OTHER
|
||||
@ -2091,6 +2101,11 @@ class UtilTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertTrue(pxla.are_op_shardings_equal(op1, op2))
|
||||
|
||||
if xla_extension_version >= 81:
|
||||
hs1 = xc.HloSharding.from_proto(op1)
|
||||
hs2 = xc.HloSharding.from_proto(op2)
|
||||
self.assertEqual(hash(hs1), hash(hs2))
|
||||
|
||||
def test_op_sharding_tuple_shardings(self):
|
||||
top1 = xc.OpSharding()
|
||||
top1.type = xc.OpSharding.Type.OTHER
|
||||
@ -2100,7 +2115,7 @@ class UtilTest(jtu.JaxTestCase):
|
||||
|
||||
top2 = xc.OpSharding()
|
||||
top2.type = xc.OpSharding.Type.OTHER
|
||||
top2.tile_assignment_dimensions = [2, 1]
|
||||
top2.tile_assignment_dimensions = [2, 2]
|
||||
top2.tile_assignment_devices = [0, 1, 2, 3]
|
||||
top2.replicate_on_last_tile_dim = True
|
||||
|
||||
@ -2114,6 +2129,11 @@ class UtilTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertFalse(pxla.are_op_shardings_equal(op1, op2))
|
||||
|
||||
if xla_extension_version >= 81:
|
||||
hs1 = xc.HloSharding.from_proto(op1)
|
||||
hs2 = xc.HloSharding.from_proto(op2)
|
||||
self.assertNotEqual(hash(hs1), hash(hs2))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user