mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Return PositionalSharding instead of GSPMDSharding in custom_partitioning when mesh is not defined
PiperOrigin-RevId: 539719517
This commit is contained in:
parent
79a1bc9a3e
commit
4d698c30b9
@ -2439,7 +2439,7 @@ orig_out_sharding_handlers[sharding_impls.NamedSharding] = _gspmd_to_named_shard
|
||||
def _gspmd_to_positional_sharding(
|
||||
op_sharding: xc.OpSharding,
|
||||
self: sharding_impls.PositionalSharding) -> sharding_impls.PositionalSharding:
|
||||
return sharding_impls._from_op_sharding_to_pos_sharding(
|
||||
return sharding_impls._op_sharding_to_pos_sharding(
|
||||
op_sharding, self._device_assignment)
|
||||
orig_out_sharding_handlers[sharding_impls.PositionalSharding] = _gspmd_to_positional_sharding
|
||||
|
||||
|
@ -495,7 +495,7 @@ class PmapSharding(XLACompatibleSharding):
|
||||
return global_shape[:sharded_dim] + global_shape[sharded_dim+1:]
|
||||
|
||||
|
||||
def _from_op_sharding_to_pos_sharding(
|
||||
def _op_sharding_to_pos_sharding(
|
||||
op_sharding: Union[xc.OpSharding, xc.HloSharding],
|
||||
device_assignment: Sequence[xc.Device]) -> PositionalSharding:
|
||||
if isinstance(op_sharding, xc.HloSharding):
|
||||
|
@ -25,6 +25,7 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.sharding_impls import _op_sharding_to_pos_sharding
|
||||
from jax._src import custom_api_util
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.api_util import flatten_fun_nokwargs
|
||||
@ -475,15 +476,14 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
return mlir.lower_fun(
|
||||
core.jaxpr_as_fun(call), multiple_results=True)(ctx, *values)
|
||||
|
||||
def to_mesh_pspec_sharding(op_sharding: Optional[xc.OpSharding]):
|
||||
if op_sharding is None:
|
||||
return op_sharding
|
||||
def to_mesh_pspec_sharding(hlo_sharding: Optional[xc.HloSharding]):
|
||||
if hlo_sharding is None:
|
||||
return hlo_sharding
|
||||
if mesh.empty or not decode_shardings:
|
||||
from jax._src.sharding_impls import GSPMDSharding
|
||||
assert devices is not None
|
||||
return GSPMDSharding(devices, op_sharding.to_proto())
|
||||
return _op_sharding_to_pos_sharding(hlo_sharding, devices)
|
||||
pspec = sharding_impls.parse_flatten_op_sharding(
|
||||
op_sharding.to_proto(), mesh)[0].get_partition_spec()
|
||||
hlo_sharding, mesh)[0].get_partition_spec()
|
||||
return jax.sharding.NamedSharding(mesh, pspec)
|
||||
|
||||
sharding_callback_info = _ShardingCallbackInfo(propagate_user_sharding,
|
||||
|
@ -31,7 +31,7 @@ from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.sharding_impls import _from_op_sharding_to_pos_sharding
|
||||
from jax._src.sharding_impls import _op_sharding_to_pos_sharding
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental import multihost_utils
|
||||
from jax.sharding import PartitionSpec as P
|
||||
@ -922,7 +922,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z'))
|
||||
mps = jax.sharding.NamedSharding(mesh, pspec)
|
||||
original_op_sharding = mps._to_xla_hlo_sharding(ndim)
|
||||
ps = _from_op_sharding_to_pos_sharding(original_op_sharding,
|
||||
ps = _op_sharding_to_pos_sharding(original_op_sharding,
|
||||
mps._device_assignment)
|
||||
out_op_sharding = ps._to_xla_hlo_sharding(ndim)
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(
|
||||
@ -958,7 +958,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
ops_ifr = op_shardings.is_op_sharding_replicated(mps_op_sharding)
|
||||
self.assertEqual(mps.is_fully_replicated, ops_ifr)
|
||||
|
||||
ps = _from_op_sharding_to_pos_sharding(mps_op_sharding,
|
||||
ps = _op_sharding_to_pos_sharding(mps_op_sharding,
|
||||
mps._device_assignment)
|
||||
self.assertEqual(ps.is_fully_replicated,
|
||||
op_shardings.is_op_sharding_replicated(
|
||||
|
Loading…
x
Reference in New Issue
Block a user