Export HloSharding via pybind which is a C++ wrapper around OpSharding proto.

PiperOrigin-RevId: 463992136
This commit is contained in:
Yash Katariya 2022-07-28 21:00:33 -07:00 committed by jax authors
parent 560c936a46
commit 47623264db
3 changed files with 54 additions and 12 deletions

View File

@ -21,6 +21,8 @@ import jax
from jax import lax
from jax.experimental import sparse
from jax._src.api_util import shaped_abstractify # technically not an api fn
from jax._src.lib import xla_client as xc
from jax.interpreters import pxla
import jax.numpy as jnp
import numpy as np
@ -471,6 +473,23 @@ def bench_shaped_abstractify(state):
_ = [shaped_abstractify(x) for x in args]
@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMicrosecond)
def bench_are_op_shardings_equal(state):
op1 = xc.OpSharding()
op1.type = xc.OpSharding.Type.OTHER
op1.tile_assignment_dimensions = [4, 192, 16]
op1.tile_assignment_devices = list(range(12288))
op2 = xc.OpSharding()
op2.type = xc.OpSharding.Type.OTHER
op2.tile_assignment_dimensions = [4, 192, 16]
op2.tile_assignment_devices = list(range(12288))
while state:
pxla.are_op_shardings_equal(op1, op2)
def swap(a, b):
return b, a

View File

@ -70,6 +70,7 @@ from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import pmap_lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
@ -2713,18 +2714,20 @@ def _get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
def are_op_shardings_equal(op1, op2):
# TODO(yashkatariya): Use HloSharding class to check for equality.
if id(op1) == id(op2):
return True
if op1.type != op2.type:
return False
if op1.type == xc.OpSharding.Type.TUPLE:
return all(are_op_shardings_equal(i, j)
for i, j in safe_zip(op1.tuple_shardings, op2.tuple_shardings))
return (op1.tile_assignment_dimensions == op2.tile_assignment_dimensions and
op1.tile_assignment_devices == op2.tile_assignment_devices and
op1.last_tile_dims == op2.last_tile_dims and
op1.replicate_on_last_tile_dim == op2.replicate_on_last_tile_dim)
if xla_extension_version >= 81:
return xc.HloSharding.from_proto(op1) == xc.HloSharding.from_proto(op2)
else:
if op1.type == xc.OpSharding.Type.TUPLE:
return all(are_op_shardings_equal(i, j)
for i, j in safe_zip(op1.tuple_shardings, op2.tuple_shardings))
return (op1.tile_assignment_dimensions == op2.tile_assignment_dimensions and
op1.tile_assignment_devices == op2.tile_assignment_devices and
op1.last_tile_dims == op2.last_tile_dims and
op1.replicate_on_last_tile_dim == op2.replicate_on_last_tile_dim)
_forbidden_primitives = {

View File

@ -45,6 +45,7 @@ from jax.experimental.pjit import (pjit, pjit_p, with_sharding_constraint,
from jax.interpreters import pxla
from jax.interpreters import mlir
from jax._src.lib import xla_client as xc, xla_bridge
from jax._src.lib import xla_extension_version
from jax._src.util import prod, curry, unzip2, safe_zip
from jax.config import config
@ -2056,15 +2057,15 @@ class UtilTest(jtu.JaxTestCase):
else:
pjit_lib._check_all_or_none_unspecified(entries, 'test axis resources')
def test_op_sharding_equality(self):
def test_op_sharding_equality_and_hash_equality(self):
op1 = xc.OpSharding()
op1.type = xc.OpSharding.Type.OTHER
op1.tile_assignment_dimensions = [4, 2]
op1.tile_assignment_dimensions = [2, 2]
op1.tile_assignment_devices = [0, 1, 2, 3]
op2 = xc.OpSharding()
op2.type = xc.OpSharding.Type.OTHER
op2.tile_assignment_dimensions = [4, 2]
op2.tile_assignment_dimensions = [2, 2]
op2.tile_assignment_devices = [0, 1, 2, 3]
op3 = xc.OpSharding()
@ -2076,6 +2077,15 @@ class UtilTest(jtu.JaxTestCase):
self.assertFalse(pxla.are_op_shardings_equal(op1, op3))
self.assertFalse(pxla.are_op_shardings_equal(op2, op3))
if xla_extension_version >= 81:
hs1 = xc.HloSharding.from_proto(op1)
hs2 = xc.HloSharding.from_proto(op2)
hs3 = xc.HloSharding.from_proto(op3)
self.assertEqual(hash(hs1), hash(hs2))
self.assertNotEqual(hash(hs1), hash(hs3))
self.assertNotEqual(hash(hs2), hash(hs3))
def test_op_sharding_partial_sharding(self):
op1 = xc.OpSharding()
op1.type = xc.OpSharding.Type.OTHER
@ -2091,6 +2101,11 @@ class UtilTest(jtu.JaxTestCase):
self.assertTrue(pxla.are_op_shardings_equal(op1, op2))
if xla_extension_version >= 81:
hs1 = xc.HloSharding.from_proto(op1)
hs2 = xc.HloSharding.from_proto(op2)
self.assertEqual(hash(hs1), hash(hs2))
def test_op_sharding_tuple_shardings(self):
top1 = xc.OpSharding()
top1.type = xc.OpSharding.Type.OTHER
@ -2100,7 +2115,7 @@ class UtilTest(jtu.JaxTestCase):
top2 = xc.OpSharding()
top2.type = xc.OpSharding.Type.OTHER
top2.tile_assignment_dimensions = [2, 1]
top2.tile_assignment_dimensions = [2, 2]
top2.tile_assignment_devices = [0, 1, 2, 3]
top2.replicate_on_last_tile_dim = True
@ -2114,6 +2129,11 @@ class UtilTest(jtu.JaxTestCase):
self.assertFalse(pxla.are_op_shardings_equal(op1, op2))
if xla_extension_version >= 81:
hs1 = xc.HloSharding.from_proto(op1)
hs2 = xc.HloSharding.from_proto(op2)
self.assertNotEqual(hash(hs1), hash(hs2))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())