Add is_fully_replicated method to Shardings. This allows to scrub the usage of is_op_sharding_replicated from JAX because we can just query it on Shardings and save an expensive round trip to OpSharding creation.

PiperOrigin-RevId: 524379122
This commit is contained in:
Yash Katariya 2023-04-14 13:55:52 -07:00 committed by jax authors
parent 88a5ffb2e8
commit 673730c065
7 changed files with 118 additions and 45 deletions

View File

@ -330,7 +330,7 @@ class ArrayImpl(basearray.Array):
@property
def is_fully_replicated(self) -> bool:
return self.shape == self._arrays[0].shape
return self.sharding.is_fully_replicated
def __repr__(self):
prefix = 'Array('

View File

@ -34,7 +34,6 @@ from jax._src import core
from jax._src import dtypes
from jax._src import effects as effects_lib
from jax._src import linear_util as lu
from jax._src import op_shardings
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import util
@ -47,6 +46,7 @@ from jax._src.lib.mlir import dialects
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.sharding_impls import XLACompatibleSharding
map, unsafe_map = util.safe_map, map
@ -491,7 +491,7 @@ def flatten_lowering_ir_args(
_module_name_regex = re.compile(r"[^\w.-]")
def sharded_aval(aval: core.AbstractValue,
sharding: Optional[xc.OpSharding]) -> core.AbstractValue:
sharding: Optional[XLACompatibleSharding]) -> core.AbstractValue:
"""Returns the new aval sharded based on sharding proto."""
if sharding is None:
return aval
@ -499,18 +499,7 @@ def sharded_aval(aval: core.AbstractValue,
return aval
if not isinstance(aval, core.ShapedArray):
raise NotImplementedError
if (op_shardings.is_op_sharding_replicated(sharding) or
sharding.type == xc.OpSharding.Type.MANUAL):
return aval
partitions, _ = op_shardings.get_num_ways_dim_sharded(sharding)
out = []
for s, p in zip(aval.shape, partitions):
quotient, remainder = divmod(s, p)
assert remainder == 0
out.append(quotient)
return aval.update(tuple(out))
return aval.update(sharding.shard_shape(aval.shape))
def eval_dynamic_shape(ctx: LoweringRuleContext,
@ -537,6 +526,16 @@ class LoweringResult(NamedTuple):
_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"]
def _to_logical_op_sharding(
aval: core.AbstractValue, sharding: Optional[XLACompatibleSharding],
) -> Optional[xc.OpSharding]:
if sharding is None:
return None
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
assert isinstance(aval, core.ShapedArray)
return sharding._to_xla_op_sharding(aval.ndim)
def lower_jaxpr_to_module(
module_name: str,
jaxpr: core.ClosedJaxpr,
@ -547,8 +546,8 @@ def lower_jaxpr_to_module(
name_stack: source_info_util.NameStack,
donated_args: Sequence[bool],
replicated_args: Optional[Sequence[bool]] = None,
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
arg_shardings: Optional[Sequence[Optional[XLACompatibleSharding]]] = None,
result_shardings: Optional[Sequence[Optional[XLACompatibleSharding]]] = None,
arg_names: Optional[Sequence[Optional[str]]] = None,
result_names: Optional[Sequence[Optional[str]]] = None,
num_replicas: int = 1,
@ -596,6 +595,13 @@ def lower_jaxpr_to_module(
else:
dim_vars = ()
arg_op_shardings = (
map(_to_logical_op_sharding, jaxpr.in_avals, arg_shardings)
if arg_shardings is not None else arg_shardings)
result_op_shardings = (
map(_to_logical_op_sharding, jaxpr.out_avals, result_shardings)
if result_shardings is not None else result_shardings)
ctx = ModuleContext(backend_or_name, platform, axis_context, name_stack,
keepalives, channel_iter, host_callbacks, dim_vars=dim_vars)
with ctx.context, ir.Location.unknown(ctx.context):
@ -611,9 +617,11 @@ def lower_jaxpr_to_module(
replace_tokens_with_dummy=True,
num_output_tokens=0,
replicated_args=replicated_args,
arg_shardings=arg_shardings, result_shardings=result_shardings,
arg_shardings=arg_op_shardings,
result_shardings=result_op_shardings,
input_output_aliases=input_output_aliases,
arg_names=arg_names, result_names=result_names)
arg_names=arg_names,
result_names=result_names)
if not ctx.module.operation.verify():
module_string = module_to_string(ctx.module)

View File

@ -1892,21 +1892,21 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
dispatch.raise_warnings_or_errors_for_jit_of_pmap(
nreps, backend, fun_name, jaxpr)
in_op_shardings: Optional[List[Optional[xc.OpSharding]]]
out_op_shardings: Optional[List[Optional[xc.OpSharding]]]
in_mlir_shardings: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
out_mlir_shardings: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
axis_ctx: mlir.AxisContext
if nreps == 1:
in_op_shardings = map(_to_logical_op_sharding, global_in_avals, in_shardings)
out_op_shardings = map(_to_logical_op_sharding, global_out_avals, out_shardings)
in_mlir_shardings = map(_to_logical_sharding, global_in_avals, in_shardings)
out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(global_in_avals)
axis_ctx = sharding_impls.ShardingContext(device_assignment)
num_partitions = len(device_assignment)
else:
# This path is triggered for `jit(pmap)` cases.
replicated_args = None
in_op_shardings = None
out_op_shardings = None
in_mlir_shardings = None
out_mlir_shardings = None
axis_env = sharding_impls.AxisEnv(nreps, (), ())
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
num_partitions = 1
@ -1929,8 +1929,8 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
name_stack,
donated_invars,
replicated_args=replicated_args,
arg_shardings=in_op_shardings,
result_shardings=out_op_shardings,
arg_shardings=in_mlir_shardings,
result_shardings=out_mlir_shardings,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths,
num_replicas=nreps,
@ -2113,14 +2113,14 @@ def lower_sharding_computation(
pmap_nreps=nreps)
def _to_logical_op_sharding(
def _to_logical_sharding(
aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTOAxisResource]
) -> Optional[xc.OpSharding]:
) -> Optional[sharding_impls.XLACompatibleSharding]:
if is_unspecified(sharding) or is_auto(sharding):
return None
elif isinstance(aval, ShapedArray):
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
return sharding._to_xla_op_sharding(aval.ndim)
return sharding
elif isinstance(aval, core.AbstractToken):
return None
else:
@ -2219,12 +2219,12 @@ def lower_mesh_computation(
# 2. Build up the HLO
tuple_args = dispatch.should_tuple_args(len(in_jaxpr_avals), backend.platform)
in_partitions: Optional[List[Optional[xc.OpSharding]]]
out_partitions: Optional[List[Optional[xc.OpSharding]]]
in_partitions: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
out_partitions: Optional[List[Optional[sharding_impls.XLACompatibleSharding]]]
axis_ctx: mlir.AxisContext
if spmd_lowering:
in_partitions = map(_to_logical_op_sharding, global_in_avals, in_shardings)
out_partitions = map(_to_logical_op_sharding, global_out_avals, out_shardings)
in_partitions = map(_to_logical_sharding, global_in_avals, in_shardings)
out_partitions = map(_to_logical_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(in_jaxpr_avals)
axis_ctx = sharding_impls.SPMDAxisContext(mesh, manual_axes)
num_replicas = 1
@ -2370,12 +2370,7 @@ def _get_input_indices(
index = tuple(
(slice(None),) for _ in range(num_addressable_devices))
else:
# We special case this logic to support fully replicated values because
# the mesh is global mesh and the indices returned by `spec_to_indices` will
# represent index for each device in the global mesh. But here we want
# indices for the local devices of the global mesh.
proto = sharding._to_xla_op_sharding(aval.ndim)
if op_shardings.is_op_sharding_replicated(proto):
if sharding.is_fully_replicated:
index = tuple(
(slice(None),) * aval.ndim for _ in range(num_addressable_devices)) # type: ignore
else:

View File

@ -1029,9 +1029,8 @@ def _resolve_in_shardings(
'multiple devices is not supported.')
else:
if (isinstance(arg, np.ndarray) and
not op_shardings.is_op_sharding_replicated(
pjit_in_s._to_xla_op_sharding(arg.ndim)) # type: ignore
and xb.process_count() > 1):
not pjit_in_s.is_fully_replicated and # type: ignore
xb.process_count() > 1):
raise ValueError(
'Passing non-trivial shardings for numpy '
'inputs is not allowed. To fix this error, either specify a '

View File

@ -80,6 +80,11 @@ class Sharding:
"""
raise NotImplementedError('Subclasses should implement this method.')
@property
def is_fully_replicated(self) -> bool:
"""Returns if a sharding is fully replicated on all the devices."""
raise NotImplementedError('Subclasses should implement this method.')
#############################################################################
# Default implementations below that all subclasses will inherit.

View File

@ -191,7 +191,7 @@ class NamedSharding(XLACompatibleSharding):
mesh: mesh_lib.Mesh
spec: PartitionSpec
_parsed_pspec: Optional[Any]
_parsed_pspec: ParsedPartitionSpec
@use_cpp_method()
def __init__(
@ -269,6 +269,17 @@ class NamedSharding(XLACompatibleSharding):
# across multiple NamedSharding objects will be the same.
return self.mesh._local_devices_set
@functools.cached_property
def is_fully_replicated(self) -> bool:
if self.mesh.size == 1:
return True
array_mapping = cast(ParsedPartitionSpec, get_array_mapping(self._parsed_pspec))
mesh_shape = self.mesh.shape
num_partitions = 1
for name in array_mapping:
num_partitions *= mesh_shape[name]
return num_partitions == 1
@functools.lru_cache(maxsize=4096)
def _to_xla_op_sharding(
self,
@ -350,6 +361,10 @@ class SingleDeviceSharding(XLACompatibleSharding):
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
return get_replicated_op_sharding()
@property
def is_fully_replicated(self) -> bool:
return True
@use_cpp_class(xc.PmapSharding)
class PmapSharding(XLACompatibleSharding):
@ -447,6 +462,13 @@ class PmapSharding(XLACompatibleSharding):
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
raise NotImplementedError("pmap doesn't use OpSharding.")
@functools.cached_property
def is_fully_replicated(self) -> bool:
for s in self.sharding_spec.sharding:
if isinstance(s, sharding_specs.Unstacked):
return False
return True
@functools.lru_cache(maxsize=4096)
def shard_shape(self, global_shape: Shape) -> Shape:
sharded_dim = None
@ -554,6 +576,10 @@ class PositionalSharding(XLACompatibleSharding):
def device_set(self) -> set[xc.Device]:
return set(self._devices)
@functools.cached_property
def is_fully_replicated(self) -> bool:
return self.shape == (1,) * self.ndim
# XLACompatibleSharding interface
@property
@ -670,6 +696,10 @@ class GSPMDSharding(XLACompatibleSharding):
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
return self._op_sharding
@functools.cached_property
def is_fully_replicated(self) -> bool:
return is_op_sharding_replicated(self._op_sharding)
@classmethod
def get_replicated(cls, device_assignment):
proto = get_replicated_op_sharding()

View File

@ -860,7 +860,7 @@ class ShardingTest(jtu.JaxTestCase):
("3d_mesh_none_y_none", (2, 2, 2), P(None, 'y', None)),
("3d_mesh_x_y_none", (2, 2, 2), P('x', 'y', None)),
("3d_mesh_none_yz", (2, 2, 2), P(None, ('y', 'z'))),
("3d_mesh2_none_y_none", (1, 2, 4), P(None, None, 'z')),
("3d_mesh2_none_none_z", (1, 2, 4), P(None, None, 'z')),
("3d_mesh2_x_none_none", (1, 2, 4), P('x', None, None)),
)
def test_positional_sharding_from_op_sharding(self, mesh_shape, pspec):
@ -875,6 +875,42 @@ class ShardingTest(jtu.JaxTestCase):
self.assertTrue(op_shardings.are_op_shardings_equal(
original_op_sharding, out_op_sharding))
@parameterized.named_parameters(
("2d_mesh_x", (1, 1), P("x", "y")),
("2d_mesh_x_y", (4, 2), P("x", "y")),
("2d_mesh_empty", (2, 1), P()),
("2d_mesh_p_none", (2, 1), P(None)),
("2d_mesh_none_none", (2, 1), P(None, None)),
("2d_mesh_tuple_empty", (2, 1), P((),)),
("2d_mesh_x_none", (2, 1), P(('x',), None)),
("2d_mesh_xy_none", (2, 1), P(('x', 'y'), None)),
("2d_mesh_none", (2, 1), None),
("2d_mesh_x_tuple_empty", (2, 1), P('x', (), (), ())),
("2d_mesh_3_tuple_empty", (2, 1), P((), (), ())),
("3d_mesh2_x_none_none", (1, 2, 4), P('x', None, None)),
("3d_mesh2_x_y_none", (1, 1, 4), P('x', 'y', None)),
("3d_mesh2_xy_none", (1, 1, 4), P(('x', 'y'), None)),
)
def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec):
if len(mesh_shape) == 2:
axis_names = ('x', 'y')
elif len(mesh_shape) == 3:
axis_names = ('x', 'y', 'z')
else:
axis_names = ('x',)
mesh = jtu.create_global_mesh(mesh_shape, axis_names)
mps = jax.sharding.NamedSharding(mesh, pspec)
shape = (8, 2, 4)
mps_op_sharding = mps._to_xla_op_sharding(len(shape))
ops_ifr = op_shardings.is_op_sharding_replicated(mps_op_sharding)
self.assertEqual(mps.is_fully_replicated, ops_ifr)
ps = _from_op_sharding_to_pos_sharding(mps_op_sharding,
mps._device_assignment)
self.assertEqual(ps.is_fully_replicated,
op_shardings.is_op_sharding_replicated(
ps._to_xla_op_sharding(len(shape))))
def test_devices_sharding_respects_init_mesh_shape(self):
value_shape = (8, 4)