mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Add is_fully_replicated method to Shardings. This allows to scrub the usage of is_op_sharding_replicated from JAX because we can just query it on Shardings and save an expensive round trip to OpSharding creation.
PiperOrigin-RevId: 524379122
This commit is contained in:
parent
88a5ffb2e8
commit
673730c065
@ -330,7 +330,7 @@ class ArrayImpl(basearray.Array):
|
||||
|
||||
@property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
return self.shape == self._arrays[0].shape
|
||||
return self.sharding.is_fully_replicated
|
||||
|
||||
def __repr__(self):
|
||||
prefix = 'Array('
|
||||
|
@ -34,7 +34,6 @@ from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects as effects_lib
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
@ -47,6 +46,7 @@ from jax._src.lib.mlir import dialects
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.sharding_impls import XLACompatibleSharding
|
||||
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
@ -491,7 +491,7 @@ def flatten_lowering_ir_args(
|
||||
_module_name_regex = re.compile(r"[^\w.-]")
|
||||
|
||||
def sharded_aval(aval: core.AbstractValue,
|
||||
sharding: Optional[xc.OpSharding]) -> core.AbstractValue:
|
||||
sharding: Optional[XLACompatibleSharding]) -> core.AbstractValue:
|
||||
"""Returns the new aval sharded based on sharding proto."""
|
||||
if sharding is None:
|
||||
return aval
|
||||
@ -499,18 +499,7 @@ def sharded_aval(aval: core.AbstractValue,
|
||||
return aval
|
||||
if not isinstance(aval, core.ShapedArray):
|
||||
raise NotImplementedError
|
||||
|
||||
if (op_shardings.is_op_sharding_replicated(sharding) or
|
||||
sharding.type == xc.OpSharding.Type.MANUAL):
|
||||
return aval
|
||||
|
||||
partitions, _ = op_shardings.get_num_ways_dim_sharded(sharding)
|
||||
out = []
|
||||
for s, p in zip(aval.shape, partitions):
|
||||
quotient, remainder = divmod(s, p)
|
||||
assert remainder == 0
|
||||
out.append(quotient)
|
||||
return aval.update(tuple(out))
|
||||
return aval.update(sharding.shard_shape(aval.shape))
|
||||
|
||||
|
||||
def eval_dynamic_shape(ctx: LoweringRuleContext,
|
||||
@ -537,6 +526,16 @@ class LoweringResult(NamedTuple):
|
||||
_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
|
||||
|
||||
|
||||
def _to_logical_op_sharding(
|
||||
aval: core.AbstractValue, sharding: Optional[XLACompatibleSharding],
|
||||
) -> Optional[xc.OpSharding]:
|
||||
if sharding is None:
|
||||
return None
|
||||
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
|
||||
assert isinstance(aval, core.ShapedArray)
|
||||
return sharding._to_xla_op_sharding(aval.ndim)
|
||||
|
||||
|
||||
def lower_jaxpr_to_module(
|
||||
module_name: str,
|
||||
jaxpr: core.ClosedJaxpr,
|
||||
@ -547,8 +546,8 @@ def lower_jaxpr_to_module(
|
||||
name_stack: source_info_util.NameStack,
|
||||
donated_args: Sequence[bool],
|
||||
replicated_args: Optional[Sequence[bool]] = None,
|
||||
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
|
||||
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
|
||||
arg_shardings: Optional[Sequence[Optional[XLACompatibleSharding]]] = None,
|
||||
result_shardings: Optional[Sequence[Optional[XLACompatibleSharding]]] = None,
|
||||
arg_names: Optional[Sequence[Optional[str]]] = None,
|
||||
result_names: Optional[Sequence[Optional[str]]] = None,
|
||||
num_replicas: int = 1,
|
||||
@ -596,6 +595,13 @@ def lower_jaxpr_to_module(
|
||||
else:
|
||||
dim_vars = ()
|
||||
|
||||
arg_op_shardings = (
|
||||
map(_to_logical_op_sharding, jaxpr.in_avals, arg_shardings)
|
||||
if arg_shardings is not None else arg_shardings)
|
||||
result_op_shardings = (
|
||||
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
|
||||
if result_shardings is not None else result_shardings)
|
||||
|
||||
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
|
||||
keepalives, channel_iter, host_callbacks, dim_vars=dim_vars)
|
||||
with ctx.context, ir.Location.unknown(ctx.context):
|
||||
@ -611,9 +617,11 @@ def lower_jaxpr_to_module(
|
||||
replace_tokens_with_dummy=True,
|
||||
num_output_tokens=0,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=arg_shardings, result_shardings=result_shardings,
|
||||
arg_shardings=arg_op_shardings,
|
||||
result_shardings=result_op_shardings,
|
||||
input_output_aliases=input_output_aliases,
|
||||
arg_names=arg_names, result_names=result_names)
|
||||
arg_names=arg_names,
|
||||
result_names=result_names)
|
||||
|
||||
if not ctx.module.operation.verify():
|
||||
module_string = module_to_string(ctx.module)
|
||||
|
@ -1892,21 +1892,21 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
dispatch.raise_warnings_or_errors_for_jit_of_pmap(
|
||||
nreps, backend, fun_name, jaxpr)
|
||||
|
||||
in_op_shardings: Optional[List[Optional[xc.OpSharding]]]
|
||||
out_op_shardings: Optional[List[Optional[xc.OpSharding]]]
|
||||
in_mlir_shardings: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
|
||||
out_mlir_shardings: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
|
||||
axis_ctx: mlir.AxisContext
|
||||
|
||||
if nreps == 1:
|
||||
in_op_shardings = map(_to_logical_op_sharding, global_in_avals, in_shardings)
|
||||
out_op_shardings = map(_to_logical_op_sharding, global_out_avals, out_shardings)
|
||||
in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings)
|
||||
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
|
||||
replicated_args = [False] * len(global_in_avals)
|
||||
axis_ctx = sharding_impls.ShardingContext(device_assignment)
|
||||
num_partitions = len(device_assignment)
|
||||
else:
|
||||
# This path is triggered for `jit(pmap)` cases.
|
||||
replicated_args = None
|
||||
in_op_shardings = None
|
||||
out_op_shardings = None
|
||||
in_mlir_shardings = None
|
||||
out_mlir_shardings = None
|
||||
axis_env = sharding_impls.AxisEnv(nreps, (), ())
|
||||
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
|
||||
num_partitions = 1
|
||||
@ -1929,8 +1929,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
name_stack,
|
||||
donated_invars,
|
||||
replicated_args=replicated_args,
|
||||
arg_shardings=in_op_shardings,
|
||||
result_shardings=out_op_shardings,
|
||||
arg_shardings=in_mlir_shardings,
|
||||
result_shardings=out_mlir_shardings,
|
||||
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
|
||||
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
|
||||
num_replicas=nreps,
|
||||
@ -2113,14 +2113,14 @@ def lower_sharding_computation(
|
||||
pmap_nreps=nreps)
|
||||
|
||||
|
||||
def _to_logical_op_sharding(
|
||||
def _to_logical_sharding(
|
||||
aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTOAxisResource]
|
||||
) -> Optional[xc.OpSharding]:
|
||||
) -> Optional[sharding_impls.XLACompatibleSharding]:
|
||||
if is_unspecified(sharding) or is_auto(sharding):
|
||||
return None
|
||||
elif isinstance(aval, ShapedArray):
|
||||
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
|
||||
return sharding._to_xla_op_sharding(aval.ndim)
|
||||
return sharding
|
||||
elif isinstance(aval, core.AbstractToken):
|
||||
return None
|
||||
else:
|
||||
@ -2219,12 +2219,12 @@ def lower_mesh_computation(
|
||||
# 2. Build up the HLO
|
||||
tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform)
|
||||
|
||||
in_partitions: Optional[List[Optional[xc.OpSharding]]]
|
||||
out_partitions: Optional[List[Optional[xc.OpSharding]]]
|
||||
in_partitions: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
|
||||
out_partitions: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
|
||||
axis_ctx: mlir.AxisContext
|
||||
if spmd_lowering:
|
||||
in_partitions = map(_to_logical_op_sharding, global_in_avals, in_shardings)
|
||||
out_partitions = map(_to_logical_op_sharding, global_out_avals, out_shardings)
|
||||
in_partitions = map(_to_logical_sharding, global_in_avals, in_shardings)
|
||||
out_partitions = map(_to_logical_sharding, global_out_avals, out_shardings)
|
||||
replicated_args = [False] * len(in_jaxpr_avals)
|
||||
axis_ctx = sharding_impls.SPMDAxisContext(mesh, manual_axes)
|
||||
num_replicas = 1
|
||||
@ -2370,12 +2370,7 @@ def _get_input_indices(
|
||||
index = tuple(
|
||||
(slice(None),) for _ in range(num_addressable_devices))
|
||||
else:
|
||||
# We special case this logic to support fully replicated values because
|
||||
# the mesh is global mesh and the indices returned by `spec_to_indices` will
|
||||
# represent index for each device in the global mesh. But here we want
|
||||
# indices for the local devices of the global mesh.
|
||||
proto = sharding._to_xla_op_sharding(aval.ndim)
|
||||
if op_shardings.is_op_sharding_replicated(proto):
|
||||
if sharding.is_fully_replicated:
|
||||
index = tuple(
|
||||
(slice(None),) * aval.ndim for _ in range(num_addressable_devices)) # type: ignore
|
||||
else:
|
||||
|
@ -1029,9 +1029,8 @@ def _resolve_in_shardings(
|
||||
'multiple devices is not supported.')
|
||||
else:
|
||||
if (isinstance(arg, np.ndarray) and
|
||||
not op_shardings.is_op_sharding_replicated(
|
||||
pjit_in_s._to_xla_op_sharding(arg.ndim)) # type: ignore
|
||||
and xb.process_count() > 1):
|
||||
not pjit_in_s.is_fully_replicated and # type: ignore
|
||||
xb.process_count() > 1):
|
||||
raise ValueError(
|
||||
'Passing non-trivial shardings for numpy '
|
||||
'inputs is not allowed. To fix this error, either specify a '
|
||||
|
@ -80,6 +80,11 @@ class Sharding:
|
||||
"""
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
@property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
"""Returns if a sharding is fully replicated on all the devices."""
|
||||
raise NotImplementedError('Subclasses should implement this method.')
|
||||
|
||||
#############################################################################
|
||||
# Default implementations below that all subclasses will inherit.
|
||||
|
||||
|
@ -191,7 +191,7 @@ class NamedSharding(XLACompatibleSharding):
|
||||
|
||||
mesh: mesh_lib.Mesh
|
||||
spec: PartitionSpec
|
||||
_parsed_pspec: Optional[Any]
|
||||
_parsed_pspec: ParsedPartitionSpec
|
||||
|
||||
@use_cpp_method()
|
||||
def __init__(
|
||||
@ -269,6 +269,17 @@ class NamedSharding(XLACompatibleSharding):
|
||||
# across multiple NamedSharding objects will be the same.
|
||||
return self.mesh._local_devices_set
|
||||
|
||||
@functools.cached_property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
if self.mesh.size == 1:
|
||||
return True
|
||||
array_mapping = cast(ParsedPartitionSpec, get_array_mapping(self._parsed_pspec))
|
||||
mesh_shape = self.mesh.shape
|
||||
num_partitions = 1
|
||||
for name in array_mapping:
|
||||
num_partitions *= mesh_shape[name]
|
||||
return num_partitions == 1
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def _to_xla_op_sharding(
|
||||
self,
|
||||
@ -350,6 +361,10 @@ class SingleDeviceSharding(XLACompatibleSharding):
|
||||
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
|
||||
return get_replicated_op_sharding()
|
||||
|
||||
@property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@use_cpp_class(xc.PmapSharding)
|
||||
class PmapSharding(XLACompatibleSharding):
|
||||
@ -447,6 +462,13 @@ class PmapSharding(XLACompatibleSharding):
|
||||
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
|
||||
raise NotImplementedError("pmap doesn't use OpSharding.")
|
||||
|
||||
@functools.cached_property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
for s in self.sharding_spec.sharding:
|
||||
if isinstance(s, sharding_specs.Unstacked):
|
||||
return False
|
||||
return True
|
||||
|
||||
@functools.lru_cache(maxsize=4096)
|
||||
def shard_shape(self, global_shape: Shape) -> Shape:
|
||||
sharded_dim = None
|
||||
@ -554,6 +576,10 @@ class PositionalSharding(XLACompatibleSharding):
|
||||
def device_set(self) -> set[xc.Device]:
|
||||
return set(self._devices)
|
||||
|
||||
@functools.cached_property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
return self.shape == (1,) * self.ndim
|
||||
|
||||
# XLACompatibleSharding interface
|
||||
|
||||
@property
|
||||
@ -670,6 +696,10 @@ class GSPMDSharding(XLACompatibleSharding):
|
||||
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
|
||||
return self._op_sharding
|
||||
|
||||
@functools.cached_property
|
||||
def is_fully_replicated(self) -> bool:
|
||||
return is_op_sharding_replicated(self._op_sharding)
|
||||
|
||||
@classmethod
|
||||
def get_replicated(cls, device_assignment):
|
||||
proto = get_replicated_op_sharding()
|
||||
|
@ -860,7 +860,7 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
("3d_mesh_none_y_none", (2, 2, 2), P(None, 'y', None)),
|
||||
("3d_mesh_x_y_none", (2, 2, 2), P('x', 'y', None)),
|
||||
("3d_mesh_none_yz", (2, 2, 2), P(None, ('y', 'z'))),
|
||||
("3d_mesh2_none_y_none", (1, 2, 4), P(None, None, 'z')),
|
||||
("3d_mesh2_none_none_z", (1, 2, 4), P(None, None, 'z')),
|
||||
("3d_mesh2_x_none_none", (1, 2, 4), P('x', None, None)),
|
||||
)
|
||||
def test_positional_sharding_from_op_sharding(self, mesh_shape, pspec):
|
||||
@ -875,6 +875,42 @@ class ShardingTest(jtu.JaxTestCase):
|
||||
self.assertTrue(op_shardings.are_op_shardings_equal(
|
||||
original_op_sharding, out_op_sharding))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("2d_mesh_x", (1, 1), P("x", "y")),
|
||||
("2d_mesh_x_y", (4, 2), P("x", "y")),
|
||||
("2d_mesh_empty", (2, 1), P()),
|
||||
("2d_mesh_p_none", (2, 1), P(None)),
|
||||
("2d_mesh_none_none", (2, 1), P(None, None)),
|
||||
("2d_mesh_tuple_empty", (2, 1), P((),)),
|
||||
("2d_mesh_x_none", (2, 1), P(('x',), None)),
|
||||
("2d_mesh_xy_none", (2, 1), P(('x', 'y'), None)),
|
||||
("2d_mesh_none", (2, 1), None),
|
||||
("2d_mesh_x_tuple_empty", (2, 1), P('x', (), (), ())),
|
||||
("2d_mesh_3_tuple_empty", (2, 1), P((), (), ())),
|
||||
("3d_mesh2_x_none_none", (1, 2, 4), P('x', None, None)),
|
||||
("3d_mesh2_x_y_none", (1, 1, 4), P('x', 'y', None)),
|
||||
("3d_mesh2_xy_none", (1, 1, 4), P(('x', 'y'), None)),
|
||||
)
|
||||
def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec):
|
||||
if len(mesh_shape) == 2:
|
||||
axis_names = ('x', 'y')
|
||||
elif len(mesh_shape) == 3:
|
||||
axis_names = ('x', 'y', 'z')
|
||||
else:
|
||||
axis_names = ('x',)
|
||||
mesh = jtu.create_global_mesh(mesh_shape, axis_names)
|
||||
mps = jax.sharding.NamedSharding(mesh, pspec)
|
||||
shape = (8, 2, 4)
|
||||
mps_op_sharding = mps._to_xla_op_sharding(len(shape))
|
||||
ops_ifr = op_shardings.is_op_sharding_replicated(mps_op_sharding)
|
||||
self.assertEqual(mps.is_fully_replicated, ops_ifr)
|
||||
|
||||
ps = _from_op_sharding_to_pos_sharding(mps_op_sharding,
|
||||
mps._device_assignment)
|
||||
self.assertEqual(ps.is_fully_replicated,
|
||||
op_shardings.is_op_sharding_replicated(
|
||||
ps._to_xla_op_sharding(len(shape))))
|
||||
|
||||
def test_devices_sharding_respects_init_mesh_shape(self):
|
||||
value_shape = (8, 4)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user