From be1cf46a49d260e979c68a78348ce7ce931e8151 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 10 Apr 2023 10:15:08 -0700 Subject: [PATCH] Split sharding_impls into its own Bazel target. * Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies. * Fix a handful of new pytype errors. PiperOrigin-RevId: 523146076 --- jax/BUILD | 20 +- jax/_src/api.py | 11 +- jax/_src/callback.py | 7 +- jax/_src/checkify.py | 3 +- jax/_src/debugging.py | 12 +- jax/_src/dispatch.py | 9 +- jax/_src/interpreters/mlir.py | 78 +----- jax/_src/interpreters/pxla.py | 141 ++++------ jax/_src/interpreters/xla.py | 12 +- jax/_src/lax/parallel.py | 36 ++- jax/_src/lax/windowed_reductions.py | 1 + jax/_src/maps.py | 51 ++-- jax/_src/pjit.py | 340 +++++------------------ jax/_src/sharding_impls.py | 350 +++++++++++++++++++++++- jax/_src/stages.py | 3 +- jax/experimental/custom_partitioning.py | 7 +- jax/experimental/host_callback.py | 24 +- jax/experimental/jax2tf/jax2tf.py | 3 +- jax/experimental/jax2tf/jax_export.py | 17 +- jax/experimental/jet.py | 9 +- jax/experimental/pjit.py | 24 +- jax/experimental/shard_map.py | 16 +- jax/experimental/sparse/transform.py | 12 +- jax/interpreters/mlir.py | 13 +- jax/interpreters/pxla.py | 16 +- jax/interpreters/xla.py | 5 +- tests/pjit_test.py | 33 ++- tests/pmap_test.py | 3 +- 28 files changed, 673 insertions(+), 583 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index 9ab5074c6..89ca58c9d 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -110,7 +110,6 @@ py_library_providing_imports_info( "_src/prng.py", "_src/public_test_util.py", "_src/random.py", - "_src/sharding_impls.py", "_src/stages.py", ] + glob( [ @@ -183,6 +182,7 @@ py_library_providing_imports_info( ":pretty_printer", ":profiler", ":sharding", + ":sharding_impls", ":sharding_specs", ":source_info_util", ":traceback_util", @@ -345,6 +345,7 @@ pytype_strict_library( ":effects", ":op_shardings", ":partial_eval", + ":sharding_impls", ":source_info_util", ":util", ":xla", @@ -418,6 +419,22 @@ pytype_strict_library( ], ) +pytype_strict_library( + name = "sharding_impls", + srcs = ["_src/sharding_impls.py"], + deps = [ + ":mesh", + ":op_shardings", + ":partition_spec", + ":sharding", + ":sharding_specs", + ":tree_util", + ":util", + ":xla_bridge", + "//jax/_src/lib", + ] + py_deps("numpy"), +) + pytype_strict_library( name = "sharding_specs", srcs = ["_src/sharding_specs.py"], @@ -504,6 +521,7 @@ pytype_strict_library( ":abstract_arrays", ":config", ":core", + ":sharding_impls", ":source_info_util", ":typing", ":util", diff --git a/jax/_src/api.py b/jax/_src/api.py index f01d28321..cb93dcb8a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -46,6 +46,7 @@ from jax._src import dispatch from jax._src import effects from jax._src import array from jax._src import dtypes +from jax._src import sharding_impls from jax._src import sharding_specs from jax._src import source_info_util from jax._src import traceback_util @@ -146,8 +147,8 @@ float0 = dtypes.float0 def jit( fun: Callable, - in_shardings=pxla._UNSPECIFIED, - out_shardings=pxla._UNSPECIFIED, + in_shardings=sharding_impls.UNSPECIFIED, + out_shardings=sharding_impls.UNSPECIFIED, static_argnums: Union[int, Sequence[int], None] = None, static_argnames: Union[str, Iterable[str], None] = None, donate_argnums: Union[int, Sequence[int]] = (), @@ -503,11 +504,11 @@ def xla_computation(fun: Callable, def make_axis_env(nreps): if axis_env is None: - return xla.AxisEnv(nreps, (), ()) + return sharding_impls.AxisEnv(nreps, (), ()) else: nreps = nreps * math.prod(size for name, size in axis_env) names, sizes = unzip2(axis_env) - return xla.AxisEnv(nreps, names, sizes) + return sharding_impls.AxisEnv(nreps, names, sizes) @wraps(fun) @api_boundary @@ -553,7 +554,7 @@ def xla_computation(fun: Callable, ordered_effects=ordered_effects, backend_or_name=backend, platform=platform, - axis_context=mlir.ReplicaAxisContext(axis_env_), + axis_context=sharding_impls.ReplicaAxisContext(axis_env_), name_stack=source_info_util.new_name_stack( wrap_name(fun_name, "xla_computation")), donated_args=donated_invars, diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 9dcff38a8..091e7a1ce 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -23,6 +23,7 @@ from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import effects +from jax._src import sharding_impls from jax._src import tree_util from jax._src import util from jax._src.interpreters import ad @@ -107,13 +108,13 @@ def pure_callback_lowering(ctx, *args, callback, **params): sharding = None axis_context = ctx.module_context.axis_context - if isinstance(axis_context, mlir.ShardingContext): + if isinstance(axis_context, sharding_impls.ShardingContext): if len(axis_context.device_assignment) > 1: raise NotImplementedError( "pure_callback is only supported in spmd computations when all mesh" " axes are partitioned manually (no partial automatic sharding)." ) - if isinstance(axis_context, mlir.SPMDAxisContext): + if isinstance(axis_context, sharding_impls.SPMDAxisContext): if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names): raise NotImplementedError( "pure_callback is only supported in spmd computations when all mesh" @@ -272,7 +273,7 @@ def io_callback_lowering(ctx, *args, callback, ordered, **params): # can only safely maximally shard. Should we allow device_index to be passed # in like host_callback? if isinstance(ctx.module_context.axis_context, - (mlir.SPMDAxisContext, mlir.ShardingContext)): + (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext)): # Apply maximal sharding so pjit only executes the callback on device 0. sharding = xc.OpSharding() sharding.type = xc.OpSharding.Type.MAXIMAL diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 829907058..69a8c7566 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -31,6 +31,7 @@ from jax._src import custom_derivatives from jax._src import effects from jax._src import pjit from jax._src import prng +from jax._src import sharding_impls from jax._src import source_info_util from jax._src import traceback_util from jax._src import tree_util as jtu @@ -869,7 +870,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, # Update pjit params to account for extra error values. num_error_vals = len(err_vals) num_out_error_vals = out_tree.num_leaves - len(out_shardings) - sharding = pjit._UNSPECIFIED + sharding = sharding_impls.UNSPECIFIED new_in_shardings = (*[sharding] * num_error_vals, *in_shardings) new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index e78bdade5..9eb3c26b1 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -29,6 +29,7 @@ from jax._src import effects from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import pjit +from jax._src import sharding_impls from jax._src import tree_util from jax._src import util from jax._src.interpreters import ad @@ -123,13 +124,16 @@ ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule def debug_callback_lowering(ctx, *args, effect, callback, **params): axis_context = ctx.module_context.axis_context - if (isinstance(axis_context, mlir.SPMDAxisContext) and + if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)): # If we have fully manual sharding during lowering, that means the JAX # program has per-device semantics, so we run the callback on each device. sharding = xc.OpSharding() sharding.type = xc.OpSharding.Type.MANUAL - elif isinstance(axis_context, (mlir.ShardingContext, mlir.SPMDAxisContext)): + elif isinstance( + axis_context, + (sharding_impls.ShardingContext, sharding_impls.SPMDAxisContext), + ): # If we have fully automatic sharding during lowering, that means the JAX # program has bulk array semantics, so we run the callback with a MAXIMAL # sharding and hence execute it only once on the full logical value). @@ -312,9 +316,9 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *, mesh = mesh_lib.thread_resources.env.physical_mesh axis_context = ctx.module_context.axis_context - if isinstance(axis_context, mlir.ShardingContext): + if isinstance(axis_context, sharding_impls.ShardingContext): devices = axis_context.device_assignment - elif isinstance(axis_context, mlir.SPMDAxisContext): + elif isinstance(axis_context, sharding_impls.SPMDAxisContext): devices = list(axis_context.mesh.devices.flat) else: raise NotImplementedError(type(axis_context)) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 2b2d0ebaf..f8e65873b 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -53,7 +53,8 @@ from jax._src.monitoring import record_event_duration_secs from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding from jax._src.sharding_impls import ( - PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding) + PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding, + UNSPECIFIED) JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration" @@ -211,13 +212,13 @@ def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params): def sharded_lowering(fun, name, donated_invars, keep_unused, *arg_specs, lowering_platform: Optional[str]): in_avals, in_shardings = util.unzip2(arg_specs) - in_shardings = [pxla._UNSPECIFIED if i is None else i for i in in_shardings] # type: ignore + in_shardings = [UNSPECIFIED if i is None else i for i in in_shardings] # type: ignore - # Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know + # Pass in a singleton `UNSPECIFIED` for out_shardings because we don't know # the number of output avals at this stage. lower_sharding_computation will # apply it to all out_avals. return pxla.lower_sharding_computation( - fun, 'jit', name, in_shardings, pxla._UNSPECIFIED, donated_invars, + fun, 'jit', name, in_shardings, UNSPECIFIED, donated_invars, tuple(in_avals), keep_unused=keep_unused, always_lower=False, devices_from_context=None, lowering_platform=lowering_platform) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index e7fa85c50..c4a660634 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -24,7 +24,7 @@ import itertools import re import typing from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional, - Protocol, Sequence, Set, Tuple, Type, Union, FrozenSet) + Protocol, Sequence, Set, Tuple, Type, Union) import warnings import numpy as np @@ -35,6 +35,7 @@ from jax._src import dtypes from jax._src import effects as effects_lib from jax._src import linear_util as lu from jax._src import op_shardings +from jax._src import sharding_impls from jax._src import source_info_util from jax._src import util from jax._src import xla_bridge as xb @@ -339,68 +340,11 @@ def make_ir_context() -> ir.Context: return context -Mesh = Any -MeshAxisName = Any - -@dataclasses.dataclass(frozen=True) -class SPMDAxisContext: - """A hardware axis context for parallel computations that use the GSPMD partitioner. - - This includes the mesh that will later by used to execute this computation, - as well as a set of mesh axes that are currently (e.g. because the current lowering - is invoked inside an xmap) lowered in the MANUAL sharding mode. - """ - mesh: Mesh - manual_axes: FrozenSet[MeshAxisName] = frozenset() - - @property - def axis_env(self): - # All collectives that touch axis_env should remember to set use_global_device_ids - # when this context is enabled! - if self.manual_axes != frozenset(self.mesh.axis_names): - raise NotImplementedError( - "Collectives in manually partitioned computations are only supported " - "when all mesh axes are partitioned manually (no partial automatic sharding). " - "Make sure that you mention all mesh axes in axis_resources!") - return self.unsafe_axis_env - - @property - def unsafe_axis_env(self): - return xla.AxisEnv( - nreps=self.mesh.size, - names=self.mesh.axis_names, - sizes=tuple(self.mesh.shape.values())) - - def extend_manual(self, axes: FrozenSet[MeshAxisName]) -> SPMDAxisContext: - return SPMDAxisContext(self.mesh, self.manual_axes | axes) - - -@dataclasses.dataclass(frozen=True) -class ReplicaAxisContext: - """A hardware axis context for parallel computations that are partitioned by JAX. - - Unlike in the SPMDAxisContext, this means that JAX might need to emit calls to - explicit collectives. - """ - axis_env: xla.AxisEnv - - -@dataclasses.dataclass(frozen=True) -class ShardingContext: - """A hardware axis context for parallel computations that use the sharding - interface. - - This context also uses the GSPMD partitioner. - """ - device_assignment: Sequence[xc.Device] - - # Similar to SPMDContext as ShardingContext also uses the GSPMD partitioner. - @property - def axis_env(self): - return xla.AxisEnv(nreps=1, names=(), sizes=()) - - -AxisContext = Union[SPMDAxisContext, ReplicaAxisContext, ShardingContext] +AxisContext = Union[ + sharding_impls.SPMDAxisContext, + sharding_impls.ReplicaAxisContext, + sharding_impls.ShardingContext, +] @dataclasses.dataclass class ModuleContext: @@ -428,7 +372,7 @@ class ModuleContext: @property - def axis_env(self) -> xla.AxisEnv: + def axis_env(self) -> sharding_impls.AxisEnv: return self.axis_context.axis_env def __init__( @@ -1539,7 +1483,7 @@ def xla_fallback_lowering(prim: core.Primitive): def fallback(ctx: LoweringRuleContext, *args, **params): module_ctx = ctx.module_context axis_ctx = module_ctx.axis_context - if isinstance(axis_ctx, SPMDAxisContext): + if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): axis_env = axis_ctx.unsafe_axis_env else: axis_env = module_ctx.axis_env @@ -1645,7 +1589,7 @@ def _emit_tpu_python_callback( ctx: LoweringRuleContext, callback, token: Optional[Any], - operands: List[ir.Value], + operands: Sequence[ir.Value], operand_avals: List[core.ShapedArray], operand_shapes: List[xc.Shape], result_avals: List[core.ShapedArray], @@ -1716,7 +1660,7 @@ def _aval_to_default_layout(aval): def emit_python_callback( ctx: LoweringRuleContext, callback, token: Optional[Any], - operands: List[ir.Value], operand_avals: List[core.ShapedArray], + operands: Sequence[ir.Value], operand_avals: List[core.ShapedArray], result_avals: List[core.ShapedArray], has_side_effect: bool, *, sharding: Optional[xc.OpSharding] = None, operand_layouts: Optional[Sequence[Optional[Sequence[int]]]] = None, diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 7edb69d3e..ae4020bd8 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -17,13 +17,12 @@ from __future__ import annotations import enum from contextlib import contextmanager -from collections import defaultdict, OrderedDict, namedtuple +from collections import defaultdict, namedtuple import dataclasses from functools import partial, lru_cache, cached_property import itertools as it import logging import math -import sys from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet, Sequence, Set, Tuple, Type, Union, Iterable, TYPE_CHECKING, cast) @@ -40,7 +39,7 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import effects from jax._src import linear_util as lu -from jax._src import mesh +from jax._src import mesh as mesh_lib from jax._src import op_shardings from jax._src import sharding_specs from jax._src import profiler @@ -61,6 +60,11 @@ from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec +from jax._src.sharding_impls import ( + ArrayMapping, ArrayMappingOrAutoOrUnspecified, + AUTOAxisResource, UnspecifiedValue, UNSPECIFIED, + get_array_mapping as _get_array_mapping, is_auto, is_unspecified +) from jax._src.util import (unzip3, safe_map, safe_zip, partition_list, wrap_name, tuple_delete, distributed_debug_log, unzip2, HashableFunction, weakref_lru_cache) @@ -71,11 +75,6 @@ class WeakRefList(list): pass -if sys.version_info >= (3, 9): - OrderedDictType = OrderedDict -else: - OrderedDictType = Dict - xe = xc._xla unsafe_map, map = map, safe_map # type: ignore @@ -92,8 +91,8 @@ ShardedAxis = sharding_specs.ShardedAxis Replicated = sharding_specs.Replicated AvalDimSharding = Union[Unstacked, Chunked, NoSharding] -Mesh = jax._src.mesh.Mesh -MeshAxisName = mesh.MeshAxisName +Mesh = mesh_lib.Mesh +MeshAxisName = sharding_impls.MeshAxisName MeshDimAssignment = Union[ShardedAxis, Replicated] ShardingSpec = sharding_specs.ShardingSpec @@ -233,56 +232,6 @@ def _shard_abstract_array(size, axis: int, x): shard_aval_handlers[ShapedArray] = _shard_abstract_array -class AUTOAxisResource: - pass -AUTO = AUTOAxisResource() - -def is_auto(x): - return isinstance(x, AUTOAxisResource) - - -class UnspecifiedValue: - def __repr__(self): - return "UnspecifiedValue" -_UNSPECIFIED = UnspecifiedValue() - -def _is_unspecified(x): - return isinstance(x, UnspecifiedValue) - -""" -ArrayMapping specifies how an ndarray should map to mesh axes. - -Note that the ordering is crucial for the cases when this mapping is non-injective -(i.e. when multiple mesh axes map to the same positional axis). Then, the -order of entries of the mapping determines a major-to-minor order on mesh axes, -according to which chunks of the value along the repeated dimension will be assigned. - -For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}. -The second dimension of the value would get chunked into 6 pieces, and assigned to the -mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case, -that would mean that a flat list of chunks would get assigned to a flattened list of -mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the -mesh devices ndarray would have to be transposed before flattening and assignment. -""" -ArrayMapping = OrderedDictType[mesh.MeshAxisName, int] -ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTOAxisResource, - UnspecifiedValue] - - -def array_mapping_to_axis_resources(array_mapping: ArrayMapping): - if not array_mapping: - return PartitionSpec() - max_index = -1 - reverse_map = defaultdict(list) - for axis, index in array_mapping.items(): - reverse_map[index].append(axis) - if index > max_index: - max_index = index - partitions = tuple(tuple(reverse_map[i]) if reverse_map[i] else None - for i in range(max_index + 1)) - return PartitionSpec(*partitions) - - def local_aval_to_result_handler( aval: core.AbstractValue, sharding: sharding_impls.XLACompatibleSharding, @@ -370,7 +319,7 @@ def make_sharded_device_array( if sharding_spec is None: sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape) - mesh = jax._src.mesh.thread_resources.env.physical_mesh + mesh = mesh_lib.thread_resources.env.physical_mesh sharding: sharding_impls.XLACompatibleSharding if mesh.empty: sharding = sharding_impls.PmapSharding( @@ -930,7 +879,7 @@ def lower_parallel_callable( shards.num_global_shards, avals, replicas.num_global_replicas, parts.num_partitions) - axis_env = xla.AxisEnv( + axis_env = sharding_impls.AxisEnv( replicas.num_global_replicas, (axis_name,), (global_axis_size,)) name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap')) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) @@ -952,7 +901,7 @@ def lower_parallel_callable( ordered_effects, backend, lowering_platform or backend.platform, - mlir.ReplicaAxisContext(axis_env), + sharding_impls.ReplicaAxisContext(axis_env), name_stack, donated_invars, replicated_args=replicated_args, @@ -1161,6 +1110,7 @@ class UnloadedPmapExecutable: input_indices = [] for aval, spec in safe_zip(self.local_input_avals, self.input_shardings): assert isinstance(spec, sharding_impls.PmapSharding), spec + assert isinstance(aval, core.ShapedArray), aval input_indices.append( sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec) if spec.sharding_spec is not None else None) @@ -1723,7 +1673,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore sub_ctx = ctx.module_context.replace( - axis_context=mlir.ReplicaAxisContext(new_env), + axis_context=sharding_impls.ReplicaAxisContext(new_env), name_stack=ctx.module_context.name_stack.extend( util.wrap_name(name, 'pmap'))) sharded_outs, _ = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, mlir.TokenSet(), (), @@ -1805,7 +1755,9 @@ def _full_to_shard_abstract_eval(x, axes, mesh, **_): # TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes! return tile_aval_nd(mesh.shape, axes, x) -def manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[mesh.MeshAxisName], mesh: Mesh): +def manual_proto( + aval: core.ShapedArray, + manual_axes_set: FrozenSet[sharding_impls.MeshAxisName], mesh: Mesh): """Create an OpSharding proto that declares all mesh axes from `axes` as manual and all others as replicated. """ @@ -1832,7 +1784,7 @@ def manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[mesh.MeshAxi @partial(mlir.register_lowering, full_to_shard_p) def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, - manual_axes: FrozenSet[mesh.MeshAxisName]): + manual_axes: FrozenSet[sharding_impls.MeshAxisName]): # TODO: Can we short-circuit for replicated values? Probably not. aval_in, = ctx.avals_in aval_out, = ctx.avals_out @@ -1852,7 +1804,7 @@ def _shard_to_full_abstract_eval(x, axes, mesh, **_): @partial(mlir.register_lowering, shard_to_full_p) def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, - manual_axes: FrozenSet[mesh.MeshAxisName]): + manual_axes: FrozenSet[sharding_impls.MeshAxisName]): aval_in, = ctx.avals_in aval_out, = ctx.avals_out proto = manual_proto(aval_in, manual_axes, mesh) @@ -1863,7 +1815,7 @@ def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh, return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, unspecified_dims), @lu.transformation -def vtile_manual(manual_axes: FrozenSet[mesh.MeshAxisName], +def vtile_manual(manual_axes: FrozenSet[sharding_impls.MeshAxisName], mesh: Mesh, in_axes: Sequence[ArrayMapping], out_axes: Sequence[ArrayMapping], @@ -1882,7 +1834,7 @@ class TileVectorize: @dataclasses.dataclass(frozen=True) class TileManual: - manual_axes: FrozenSet[mesh.MeshAxisName] + manual_axes: FrozenSet[sharding_impls.MeshAxisName] TilingMethod = Union[TileVectorize, TileManual] @@ -1971,7 +1923,7 @@ def _get_and_check_device_assignment( devices = list(devices) for i, s_type, source_info in shardings: - if is_auto(i) or _is_unspecified(i): + if is_auto(i) or is_unspecified(i): continue # Assign `first_sharding_info` after `AUTO` and `UNSPECIFIED` have been # skipped. @@ -2096,14 +2048,14 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, in_op_shardings = map(_to_logical_op_sharding, global_in_avals, in_shardings) out_op_shardings = map(_to_logical_op_sharding, global_out_avals, out_shardings) replicated_args = [False] * len(global_in_avals) - axis_ctx = mlir.ShardingContext(device_assignment) + axis_ctx = sharding_impls.ShardingContext(device_assignment) else: # This path is triggered for `jit(pmap)` cases. replicated_args = None in_op_shardings = None out_op_shardings = None - axis_env = xla.AxisEnv(nreps, (), ()) - axis_ctx = mlir.ReplicaAxisContext(axis_env) + axis_env = sharding_impls.AxisEnv(nreps, (), ()) + axis_ctx = sharding_impls.ReplicaAxisContext(axis_env) module_name = f"{api_name}_{fun_name}" @@ -2189,10 +2141,10 @@ def lower_sharding_computation( ) -> MeshComputation: """Lowers a computation to XLA. It can take arbitrary shardings as input. - The caller of this code can pass in a singleton _UNSPECIFIED because the + The caller of this code can pass in a singleton UNSPECIFIED because the number of out_avals might not be known at that time and lower_sharding_computation calculates the number of out_avals so it can apply - the singleton _UNSPECIFIED to all out_avals. + the singleton UNSPECIFIED to all out_avals. """ # 1. Trace to jaxpr and preprocess/verify it (closed_jaxpr, global_in_avals, global_out_avals, donated_invars, @@ -2202,8 +2154,8 @@ def lower_sharding_computation( jaxpr = closed_jaxpr.jaxpr in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx) - if _is_unspecified(out_shardings): - out_shardings = (_UNSPECIFIED,) * len(global_out_avals) + if is_unspecified(out_shardings): + out_shardings = (UNSPECIFIED,) * len(global_out_avals) assert isinstance(out_shardings, tuple) assert len(out_shardings) == len(global_out_avals), ( len(out_shardings), len(global_out_avals)) @@ -2221,12 +2173,12 @@ def lower_sharding_computation( committed = bool( devices_from_context or len(device_assignment) > 1 or - any(not _is_unspecified(i) for i in in_shardings) or - any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or - any(not _is_unspecified(o) for o in out_shardings)) + any(not is_unspecified(i) for i in in_shardings) or + any(not is_unspecified(js) for js, _ in jaxpr_sharding) or + any(not is_unspecified(o) for o in out_shardings)) in_shardings = tuple(sharding_impls.GSPMDSharding.get_replicated(device_assignment) - if _is_unspecified(i) else i for i in in_shardings) + if is_unspecified(i) else i for i in in_shardings) da_object = _create_da_object(tuple(device_assignment)) @@ -2260,7 +2212,7 @@ def lower_sharding_computation( # and don't need to evaluate their arguments. if (not always_lower and not (jaxpr.effects or has_outfeed) and (not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and - all(_is_unspecified(o) for o in out_shardings)): + all(is_unspecified(o) for o in out_shardings)): return MeshComputation( str(name_stack), None, True, donated_invars, jaxpr=jaxpr, consts=closed_jaxpr.consts, global_in_avals=global_in_avals, @@ -2309,7 +2261,7 @@ def lower_sharding_computation( def _to_logical_op_sharding( aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTOAxisResource] ) -> Optional[xc.OpSharding]: - if _is_unspecified(sharding) or is_auto(sharding): + if is_unspecified(sharding) or is_auto(sharding): return None elif isinstance(aval, ShapedArray): assert isinstance(sharding, sharding_impls.XLACompatibleSharding) @@ -2419,15 +2371,16 @@ def lower_mesh_computation( in_partitions = map(_to_logical_op_sharding, global_in_avals, in_shardings) out_partitions = map(_to_logical_op_sharding, global_out_avals, out_shardings) replicated_args = [False] * len(in_jaxpr_avals) - axis_ctx = mlir.SPMDAxisContext(mesh, manual_axes) + axis_ctx = sharding_impls.SPMDAxisContext(mesh, manual_axes) else: replicated_args = [not get_array_mapping(i.spec) for i in in_shardings] # type: ignore in_partitions = None out_partitions = None - axis_env = xla.AxisEnv(nreps=mesh.size, - names=tuple(global_axis_sizes.keys()), - sizes=tuple(global_axis_sizes.values())) - axis_ctx = mlir.ReplicaAxisContext(axis_env) + axis_env = sharding_impls.AxisEnv( + nreps=mesh.size, + names=tuple(global_axis_sizes.keys()), + sizes=tuple(global_axis_sizes.values())) + axis_ctx = sharding_impls.ReplicaAxisContext(axis_env) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) module: Union[str, xc.XlaComputation] module_name = f"{api_name}_{fun_name}" @@ -2834,7 +2787,7 @@ class UnloadedMeshExecutable: for x, o in safe_zip(out_shardings_xla, out_shardings) ] out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple) - elif (out_shardings and any(_is_unspecified(o) for o in out_shardings) + elif (out_shardings and any(is_unspecified(o) for o in out_shardings) and pmap_nreps == 1): assert mesh is None _, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore @@ -2844,7 +2797,7 @@ class UnloadedMeshExecutable: out_shardings, are_out_shardings_from_xla = [], [] # type: ignore for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings, global_out_avals): - if _is_unspecified(orig): + if is_unspecified(orig): out_shardings.append(xla_s) are_out_shardings_from_xla.append(True) else: @@ -3133,7 +3086,7 @@ def create_mesh_pspec_sharding( def check_device_backend_on_shardings(shardings) -> bool: for i in shardings: - if _is_unspecified(i) or is_auto(i): + if is_unspecified(i) or is_auto(i): continue if hasattr(i, '_original_sharding') and getattr( i._original_sharding, '_device_backend', False): @@ -3163,10 +3116,8 @@ def check_gda_or_array_xla_sharding_match( def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified: - # Import here to avoid cyclic import error when importing gda in pjit.py. - from jax.experimental.pjit import get_array_mapping as _get_array_mapping, _prepare_axis_resources - - parsed_pspec, _, _ = _prepare_axis_resources(pspec, "pspec to array_mapping") + parsed_pspec, _, _ = sharding_impls.prepare_axis_resources( + pspec, "pspec to array_mapping") return _get_array_mapping(parsed_pspec) diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 46bf34138..f0924f524 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -22,7 +22,7 @@ import itertools as it import math import operator import re -from typing import (Any, Callable, Dict, NamedTuple, Optional, Protocol, +from typing import (Any, Callable, Dict, Optional, Protocol, Sequence, Set, Type, Tuple, Union, TYPE_CHECKING) import numpy as np @@ -34,6 +34,7 @@ from jax._src import dtypes from jax._src import source_info_util from jax._src.abstract_arrays import numpy_scalar_types from jax._src.core import ConcreteArray, ShapedArray +from jax._src.sharding_impls import AxisEnv from jax._src.util import safe_zip, safe_map from jax._src.typing import Shape @@ -254,13 +255,6 @@ def primitive_subcomputation(platform: str, axis_env: 'AxisEnv', ### compiling jaxprs - -class AxisEnv(NamedTuple): - """Represents a pmap mesh (only along the replica axes).""" - nreps: int - names: Tuple[Any, ...] - sizes: Tuple[int, ...] - @dataclasses.dataclass class TranslationContext: builder: xc.XlaBuilder @@ -272,8 +266,6 @@ class TranslationContext: def replace(self, **kw): return dataclasses.replace(self, **kw) - - def xla_destructure(c, ans): num_elements = len(c.get_shape(ans).tuple_shapes()) return [xops.GetTupleElement(ans, i) for i in range(num_elements)] diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index a9fc31884..dc141bf07 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -28,6 +28,7 @@ from jax import tree_util from jax._src import core from jax._src import dtypes +from jax._src import sharding_impls from jax._src import util from jax._src.core import ShapedArray, AxisName, raise_to_shaped from jax._src.interpreters import ad @@ -750,8 +751,10 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups): _replica_groups(ctx.module_context.axis_env, named_axes, axis_index_groups)) axis_context = ctx.module_context.axis_context - is_spmd = isinstance(axis_context, - (mlir.SPMDAxisContext, mlir.ShardingContext)) + is_spmd = isinstance( + axis_context, + (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + ) def all_reduce(aval, x): if is_spmd: @@ -880,7 +883,10 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm): full_perm = full_perm.reshape((-1, 2)) axis_context = ctx.module_context.axis_context - is_manual = isinstance(axis_context, mlir.SPMDAxisContext) and axis_context.manual_axes + is_manual = ( + isinstance(axis_context, sharding_impls.SPMDAxisContext) + and axis_context.manual_axes + ) if is_manual: channel = ctx.module_context.new_channel() other_args = dict( @@ -978,8 +984,10 @@ def _all_to_all_lowering(ctx, x, *, split_count = len(replica_groups[0]) if not all(split_count == len(g) for g in replica_groups): raise ValueError('Replica groups must be equally sized') - is_spmd = isinstance(ctx.module_context.axis_context, - (mlir.SPMDAxisContext, mlir.ShardingContext)) + is_spmd = isinstance( + ctx.module_context.axis_context, + (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + ) if is_spmd: # We want to emit the all-gather with global device IDs and a unique # channel ID, as otherwise it interprets the devices as replicas instead @@ -1209,8 +1217,10 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, x_aval, = ctx.avals_in out_aval, = ctx.avals_out axis_context = ctx.module_context.axis_context - is_spmd = isinstance(axis_context, - (mlir.SPMDAxisContext, mlir.ShardingContext)) + is_spmd = isinstance( + axis_context, + (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + ) if (ctx.module_context.platform == 'tpu' or ctx.module_context.platform in ('cuda', 'rocm') and all_gather_dimension == 0): @@ -1354,8 +1364,10 @@ def _reduce_scatter_lowering(prim, reducer, ctx, x, scatter_out_shape = list(x_aval.shape) scatter_out_shape[scatter_dimension] //= axis_size axis_context = ctx.module_context.axis_context - is_spmd = isinstance(axis_context, - (mlir.SPMDAxisContext, mlir.ShardingContext)) + is_spmd = isinstance( + axis_context, + (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + ) if is_spmd: # We want to emit the all-gather with global device IDs and a unique # channel ID, as otherwise it interprets the devices as replicas instead @@ -1572,8 +1584,10 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): ) mod = mlir.ir_constant(np.array(axis_env.sizes[axis_pos], dtype=np.uint32)) axis_context = ctx.module_context.axis_context - is_spmd = isinstance(axis_context, - (mlir.SPMDAxisContext, mlir.ShardingContext)) + is_spmd = isinstance( + axis_context, + (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + ) if is_spmd: device_id = hlo.PartitionIdOp() else: diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 4ef575a70..d61d1f472 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -662,6 +662,7 @@ def _select_and_gather_add_lowering( max_bits=64): _, operand_aval, = ctx.avals_in out_aval, = ctx.avals_out + assert isinstance(operand_aval, core.ShapedArray), operand_aval dtype = operand_aval.dtype etype = mlir.dtype_to_ir_type(dtype) nbits = dtypes.finfo(dtype).bits diff --git a/jax/_src/maps.py b/jax/_src/maps.py index 8da5a840e..c339816e9 100644 --- a/jax/_src/maps.py +++ b/jax/_src/maps.py @@ -26,9 +26,10 @@ from jax import numpy as jnp from jax._src import core from jax._src import dispatch from jax._src import effects -from jax._src import mesh +from jax._src import mesh as mesh_lib from jax._src import linear_util as lu from jax._src import op_shardings +from jax._src import sharding_impls from jax._src import source_info_util from jax._src import stages from jax._src import traceback_util @@ -47,10 +48,11 @@ from jax._src.interpreters.partial_eval import ( convert_constvars_jaxpr, new_jaxpr_eqn) from jax._src.interpreters import pxla from jax._src.interpreters import xla -from jax._src.pjit import ( - sharding_constraint_p, ParsedPartitionSpec, get_unconstrained_dims, - GSPMDSharding) -from jax._src.sharding_impls import NamedSharding +from jax._src.pjit import (sharding_constraint_p, get_unconstrained_dims, + GSPMDSharding) +from jax._src.sharding_impls import ( + ArrayMapping, NamedSharding, ParsedPartitionSpec, + array_mapping_to_axis_resources) from jax._src.tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map, treedef_tuple) from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3, @@ -90,12 +92,12 @@ class FrozenDict(abc.Mapping): # Multi-dimensional generalized map AxisName = core.AxisName -ResourceAxisName = mesh.ResourceAxisName # Different name just for documentation purposes -Mesh = mesh.Mesh -MeshAxisName = mesh.MeshAxisName -ResourceEnv = mesh.ResourceEnv -EMPTY_ENV = mesh.EMPTY_ENV -thread_resources = mesh.thread_resources +ResourceAxisName = mesh_lib.ResourceAxisName # Different name just for documentation purposes +Mesh = mesh_lib.Mesh +MeshAxisName = mesh_lib.MeshAxisName +ResourceEnv = mesh_lib.ResourceEnv +EMPTY_ENV = mesh_lib.EMPTY_ENV +thread_resources = mesh_lib.thread_resources class SerialLoop: @@ -165,7 +167,7 @@ def serial_loop(name: ResourceAxisName, length: int): axis_resources={'i': 'l'})(x) """ old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV) - thread_resources.env = old_env.with_extra_loop(mesh.Loop(name, length)) + thread_resources.env = old_env.with_extra_loop(mesh_lib.Loop(name, length)) try: yield finally: @@ -677,9 +679,9 @@ def make_xmap_callable(fun: lu.WrappedFun, tiling_method = pxla.TileManual(manual_mesh_axes) else: tiling_method = pxla.TileVectorize() - in_shardings = [NamedSharding(mesh, pxla.array_mapping_to_axis_resources(i)) + in_shardings = [NamedSharding(mesh, array_mapping_to_axis_resources(i)) for i in mesh_in_axes] - out_shardings = [NamedSharding(mesh, pxla.array_mapping_to_axis_resources(o)) + out_shardings = [NamedSharding(mesh, array_mapping_to_axis_resources(o)) for o in mesh_out_axes] return pxla.lower_mesh_computation( f, 'xmap', name, mesh, @@ -937,7 +939,7 @@ def _resource_typing_xmap(avals, raise JAXTypeError( f"Detected disallowed xmap axis name shadowing at " f"{source_info_util.summarize(source_info)} " - f"(shadowed axes: {mesh.show_axes(overlap)})") + f"(shadowed axes: {mesh_lib.show_axes(overlap)})") if resource_env.physical_mesh != params['resource_env'].physical_mesh: raise RuntimeError("Changing the physical mesh is not allowed inside xmap.") @@ -965,9 +967,9 @@ def _resource_typing_xmap(avals, raise JAXTypeError( f"One of xmapped function ({params['name']}) outputs is broadcast " f"along axis `{baxis}` which is assigned to resources " - f"{mesh.show_axes(baxis_resources)}, but the output is already " - f"partitioned along {mesh.show_axes(overlap)}, because its " - f"named shape contains {mesh.show_axes(partitioning_axes)}") + f"{mesh_lib.show_axes(baxis_resources)}, but the output is already " + f"partitioned along {mesh_lib.show_axes(overlap)}, because its " + f"named shape contains {mesh_lib.show_axes(partitioning_axes)}") pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap @@ -1269,7 +1271,7 @@ batching.BatchTrace.post_process_xmap = _batch_trace_post_process_xmap # -------- nested xmap handling -------- def _xmap_lowering_rule(ctx, *args, **kwargs): - if isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext): + if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext): if config.experimental_xmap_spmd_lowering_manual: return _xmap_lowering_rule_spmd_manual(ctx, *args, **kwargs) else: @@ -1277,9 +1279,9 @@ def _xmap_lowering_rule(ctx, *args, **kwargs): # Here ShardingContext is used in place of ReplicaAxisContext because when # axis_resources and mesh is not used with xmap, `make_xmap_callable` will # go via `dispatch.sharded_lowering` path which sets the context to - # ShardingContext. mlir.ShardingContext is not used for SPMD. + # ShardingContext. sharding_impls.ShardingContext is not used for SPMD. elif isinstance(ctx.module_context.axis_context, - (mlir.ReplicaAxisContext, mlir.ShardingContext)): + (sharding_impls.ReplicaAxisContext, sharding_impls.ShardingContext)): return _xmap_lowering_rule_replica(ctx, *args, **kwargs) else: raise AssertionError("Unrecognized axis context type!") @@ -1382,7 +1384,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes, # XXX: We modify mesh_in_axes and mesh_out_axes here def add_spmd_axes( - flat_mesh_axes: Sequence[pxla.ArrayMapping], + flat_mesh_axes: Sequence[ArrayMapping], flat_extra_axes: Optional[Sequence[Sequence[Sequence[MeshAxisName]]]]): if flat_extra_axes is None: return @@ -1456,7 +1458,8 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes, # We in-line here rather than generating a Call HLO as in the xla_call # translation rule just because the extra tuple stuff is a pain. - assert isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext) + assert isinstance(ctx.module_context.axis_context, + sharding_impls.SPMDAxisContext) sub_ctx = ctx.module_context.replace( name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')), axis_context=ctx.module_context.axis_context.extend_manual(manual_mesh_axes)) @@ -1755,7 +1758,7 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env, s = arg.sharding xmap_sharding = pxla.create_mesh_pspec_sharding( - mesh, pxla.array_mapping_to_axis_resources(xmap_array_mapping)) + mesh, array_mapping_to_axis_resources(xmap_array_mapping)) # This check is cached because comparing OpSharding is expensive during # dispatch and if the shardings are the same, then there is no need to # compare twice. diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 395a5ca3d..a33e2f21c 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -13,10 +13,8 @@ # limitations under the License. import dataclasses -from enum import IntEnum import inspect import numpy as np -from collections import OrderedDict, Counter from typing import (Callable, Sequence, Tuple, Union, cast, List, Optional, Iterable, NamedTuple, Any) import itertools as it @@ -30,6 +28,7 @@ from jax._src import dispatch from jax._src import mesh as mesh_lib from jax._src import linear_util as lu from jax._src import op_shardings +from jax._src import sharding_impls from jax._src import source_info_util from jax._src import traceback_util from jax._src import util @@ -51,10 +50,13 @@ from jax._src.interpreters import pxla from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import xla_client as xc -from jax._src.sharding import Sharding from jax._src.sharding_impls import ( NamedSharding, XLACompatibleSharding, GSPMDSharding, - XLADeviceAssignment, SingleDeviceSharding, PmapSharding) + XLADeviceAssignment, SingleDeviceSharding, PmapSharding, + AUTOAxisResource, UNSPECIFIED, UnspecifiedValue, + CanonicalizedParsedPartitionSpec, ParsedPartitionSpec, + SpecSync, get_single_pspec, is_auto, is_unspecified, is_unspecified_or_auto, + prepare_axis_resources) from jax._src.traceback_util import api_boundary from jax._src.tree_util import ( tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, @@ -62,7 +64,7 @@ from jax._src.tree_util import ( prefix_errors, generate_key_paths) from jax._src.util import ( HashableFunction, safe_map, safe_zip, wraps, - distributed_debug_log, split_list, tuple_insert, weakref_lru_cache, + distributed_debug_log, split_list, weakref_lru_cache, merge_lists) map, unsafe_map = safe_map, map @@ -70,41 +72,10 @@ zip, unsafe_zip = safe_zip, zip traceback_util.register_exclusion(__file__) - -_AUTOAxisResource = pxla.AUTOAxisResource -AUTO = pxla.AUTO # type: ignore -is_auto = pxla.is_auto - -_UnspecifiedValue = pxla.UnspecifiedValue -_UNSPECIFIED = pxla._UNSPECIFIED # type: ignore -_is_unspecified = pxla._is_unspecified - -def _is_unspecified_or_auto(x): - return is_auto(x) or _is_unspecified(x) - - -PjitSharding = Union[GSPMDSharding, _UnspecifiedValue, _AUTOAxisResource] -PjitShardingMinusUnspecified = Union[GSPMDSharding, _AUTOAxisResource] -MeshSharding = Union[NamedSharding, _UnspecifiedValue, _AUTOAxisResource] -MeshShardingMinusUnspecified = Union[NamedSharding, _AUTOAxisResource] - - -def _check_all_or_none_unspecified(axis_resources, name): - if not axis_resources: - return False - unspecified_count = 0 - unspecified = _is_unspecified(axis_resources[0]) - for resource in axis_resources: - current_is_unspecified = _is_unspecified(resource) - if current_is_unspecified: - unspecified_count += 1 - assert unspecified_count == 1 - if current_is_unspecified != unspecified: - raise ValueError(f'`pjit._UNSPECIFIED` exists in {name}. ' - f'Make sure that every entry in {name} is ' - '`pjit._UNSPECIFIED`.') - return unspecified - +PjitSharding = Union[GSPMDSharding, UnspecifiedValue, AUTOAxisResource] +PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTOAxisResource] +MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTOAxisResource] +MeshShardingMinusUnspecified = Union[NamedSharding, AUTOAxisResource] def _try_infer_args(f, tree): dummy_args = tree_unflatten(tree, [False] * tree.num_leaves) @@ -281,27 +252,27 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames, def _resolve_axis_resources_and_shardings_arg( in_shardings, out_shardings, in_axis_resources, out_axis_resources): - if not _is_unspecified(in_shardings) and not _is_unspecified(in_axis_resources): + if not is_unspecified(in_shardings) and not is_unspecified(in_axis_resources): raise ValueError( 'Setting both in_shardings and in_axis_resources is not ' 'allowed. in_axis_resources is deprecated. Please use in_shardings.') - if not _is_unspecified(out_shardings) and not _is_unspecified(out_axis_resources): + if not is_unspecified(out_shardings) and not is_unspecified(out_axis_resources): raise ValueError( 'Setting both out_shardings and out_axis_resources is not ' 'allowed. out_axis_resources is deprecated. Please use out_shardings.') - if (not _is_unspecified(in_axis_resources) or - not _is_unspecified(out_axis_resources)): + if (not is_unspecified(in_axis_resources) or + not is_unspecified(out_axis_resources)): warnings.warn( 'in_axis_resources and out_axis_resources are deprecated. Please use ' 'in_shardings and out_shardings as their replacement.', DeprecationWarning) - if not _is_unspecified(in_axis_resources): + if not is_unspecified(in_axis_resources): final_in_shardings = in_axis_resources else: final_in_shardings = in_shardings - if not _is_unspecified(out_axis_resources): + if not is_unspecified(out_axis_resources): final_out_shardings = out_axis_resources else: final_out_shardings = out_shardings @@ -326,10 +297,10 @@ def pre_infer_params(fun, in_shardings, out_shardings, if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " f"got {device=} and {backend=}") - if not _is_unspecified(in_shardings): + if not is_unspecified(in_shardings): raise ValueError('If backend or device is specified on jit, then ' 'in_shardings should not be specified.') - if not _is_unspecified(out_shardings): + if not is_unspecified(out_shardings): raise ValueError('If backend or device is specified on jit, then ' 'out_shardings should not be specified.') @@ -341,8 +312,8 @@ def pre_infer_params(fun, in_shardings, out_shardings, # rather than raising an error. https://github.com/google/jax/issues/2367 in_shardings = tuple(in_shardings) - in_shardings, _, _ = _prepare_axis_resources(in_shardings, 'in_shardings') - out_shardings, _, _ = _prepare_axis_resources(out_shardings, 'out_shardings') + in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings') + out_shardings, _, _ = prepare_axis_resources(out_shardings, 'out_shardings') donate_argnums, static_argnums, static_argnames = resolve_argnums( fun, donate_argnums, static_argnums, static_argnames) @@ -394,8 +365,8 @@ def _pjit_explicit_sharding(in_shardings, out_shardings, device, out_shardings_flat, _ = tree_flatten(out_shardings) return (device is not None or backend is not None or - any(not _is_unspecified(i) for i in in_shardings_flat) or - any(not _is_unspecified(i) for i in out_shardings_flat)) + any(not is_unspecified(i) for i in in_shardings_flat) or + any(not is_unspecified(i) for i in out_shardings_flat)) class PjitInfo(NamedTuple): @@ -418,7 +389,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs): donate_argnums, device, backend, keep_unused, inline, resource_env, abstracted_axes) = pjit_info_args - if kwargs and not _is_unspecified(user_in_shardings): + if kwargs and not is_unspecified(user_in_shardings): raise ValueError( "pjit does not support kwargs when in_shardings is specified.") @@ -511,7 +482,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs): num_extra_args = len(implicit_args) + len(consts) canonicalized_in_shardings_flat = \ - (_UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat + (UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat donated_invars = (False,) * num_extra_args + donated_invars assert (len(canonicalized_in_shardings_flat) == len(donated_invars) == len(consts) + len(args_flat)) @@ -574,10 +545,10 @@ def _flat_axes_specs(abstracted_axes, *args, **kwargs # because `None` means that the input is fully replicated. def pjit( fun: Callable, - in_shardings=_UNSPECIFIED, - out_shardings=_UNSPECIFIED, - in_axis_resources=_UNSPECIFIED, - out_axis_resources=_UNSPECIFIED, + in_shardings=UNSPECIFIED, + out_shardings=UNSPECIFIED, + in_axis_resources=UNSPECIFIED, + out_axis_resources=UNSPECIFIED, static_argnums: Union[int, Sequence[int], None] = None, static_argnames: Union[str, Iterable[str], None] = None, donate_argnums: Union[int, Sequence[int]] = (), @@ -775,13 +746,13 @@ def hashable_pytree(pytree): @lru_cache(maxsize=4096) def _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x): - if _is_unspecified_or_auto(x): + if is_unspecified_or_auto(x): return x return pxla.create_mesh_pspec_sharding(mesh, x.user_spec, x) def _create_sharding_for_array(mesh, x, name): - if isinstance(x, XLACompatibleSharding) or _is_unspecified_or_auto(x): + if isinstance(x, XLACompatibleSharding) or is_unspecified_or_auto(x): return x if mesh is None: msg = ('jax.jit only supports `XLACompatibleSharding`s being passed to' @@ -803,7 +774,7 @@ def _create_sharding_for_array(mesh, x, name): ' call site? Alternatively, provide `XLACompatibleSharding`s to' ' `in_shardings` and `out_shardings` and then the mesh context manager' ' is not required.') - # A nice user error is raised in _prepare_axis_resources. + # A nice user error is raised in prepare_axis_resources. assert isinstance(x, ParsedPartitionSpec), x return _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x) @@ -884,7 +855,7 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree, orig_in_shardings = in_shardings_thunk() # Only do this if original in_shardings are unspecified. If it is AUTO, go # via flatten_axis_resources. - if _is_unspecified(orig_in_shardings): + if is_unspecified(orig_in_shardings): in_shardings_flat = (orig_in_shardings,) * len(in_avals) else: in_shardings_flat = flatten_axis_resources( @@ -895,7 +866,7 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree, pjit_check_aval_sharding(in_shardings_flat, in_avals, "pjit arguments", allow_uneven_sharding=False) canonicalized_shardings = tuple( - i if _is_unspecified_or_auto(i) else to_gspmd_sharding(i, aval.ndim) + i if is_unspecified_or_auto(i) else to_gspmd_sharding(i, aval.ndim) for i, aval in zip(in_shardings_flat, in_avals)) return canonicalized_shardings @@ -933,7 +904,7 @@ def _check_and_canonicalize_out_shardings( # instead. This condition exists because flatten_axis_resources passes in an # `object()` while unflattening which breaks assertion is user defined # pytrees (which shouldn't exist but they do). - if (_is_unspecified(orig_out_shardings) or + if (is_unspecified(orig_out_shardings) or isinstance(orig_out_shardings, XLACompatibleSharding)): out_shardings_flat = (orig_out_shardings,) * len(out_type) else: @@ -946,7 +917,7 @@ def _check_and_canonicalize_out_shardings( allow_uneven_sharding=False) canonicalized_out_shardings_flat = tuple( - o if _is_unspecified(o) or is_auto(o) else to_gspmd_sharding(o, aval.ndim) + o if is_unspecified(o) or is_auto(o) else to_gspmd_sharding(o, aval.ndim) for o, aval in zip(out_shardings_flat, out_type) ) return canonicalized_out_shardings_flat @@ -965,7 +936,7 @@ def _pjit_jaxpr(fun, out_shardings_thunk, in_type, debug_info, out_tree, def pjit_check_aval_sharding( shardings, flat_avals, what_aval: str, allow_uneven_sharding: bool): for aval, s in zip(flat_avals, shardings): - if _is_unspecified_or_auto(s): + if is_unspecified_or_auto(s): continue shape = aval.shape try: @@ -994,170 +965,6 @@ def pjit_check_aval_sharding( f"(full shape: {shape}) ") -class SpecSync(IntEnum): - """Encodes how much out of sync the real value of partitions is compared to the user specified one. - - We use this to make sure we don't show garbage modified values while claiming - that the users have specified them like that. - """ - OUT_OF_SYNC = 0 # Arbitrary changes, including new axes inserted - DIM_PERMUTE = 1 # Dimensions permuted, but no new sharding axes - IN_SYNC = 2 # Entirely in sync - -class ParsedPartitionSpec: - __slots__ = ('unsafe_user_spec', 'partitions', 'sync') - - def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC): - self.unsafe_user_spec = user_spec - # None in partitions represents unconstrained dim. - # TODO(yashkatariya): May use a sentinel value. - self.partitions = tuple(partitions) - self.sync = sync - - @property - def user_spec(self): - return self.unsynced_user_spec(SpecSync.IN_SYNC) - - def get_partition_spec(self) -> PartitionSpec: - if self.sync < SpecSync.IN_SYNC: - return _get_single_pspec(self) - else: - if isinstance(self.unsafe_user_spec, PartitionSpec): - return self.unsafe_user_spec - else: - return _get_single_pspec(self) - - def unsynced_user_spec(self, min_sync): - if self.sync < min_sync: - raise AssertionError(f"Please open a bug report! ({self.sync} >= {min_sync})") - return self.unsafe_user_spec - - def insert_axis_partitions(self, dim, val): - parts = self.partitions - too_short = dim - len(parts) - if too_short > 0: - parts += ((),) * too_short - new_partitions = tuple_insert(parts, dim, val) - new_sync = SpecSync.DIM_PERMUTE if (val == () or val is None) else SpecSync.OUT_OF_SYNC - return ParsedPartitionSpec(self.unsafe_user_spec, new_partitions, sync=new_sync) - - @classmethod - def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False): - if entry is None: - return cls(entry, ()) - if not isinstance(entry, PartitionSpec): - raise TypeError(f"{arg_name} are expected to be " - f"PartitionSpec instances or None, but got {entry}") - axis_specs = [] - for axis_spec in entry: - if axis_spec is None: - axis_spec = () - elif isinstance(axis_spec, (list, tuple)): - axis_spec = tuple(axis_spec) - elif axis_spec == PartitionSpec.UNCONSTRAINED: - if not allow_unconstrained_dims: - raise ValueError(f"Unconstrained dims are not allowed: {entry}") - axis_spec = None - else: - axis_spec = (axis_spec,) - axis_specs.append(axis_spec) - return cls(entry, axis_specs) - - def __hash__(self): - return hash((self.partitions, self.sync)) - - def __eq__(self, other): - return (self.partitions == other.partitions and - self.sync == other.sync) - - def __len__(self): - return len(self.partitions) - - def __getitem__(self, i): - return self.partitions[i] - - def __iter__(self): - return iter(self.partitions) - - def __repr__(self): - return (f"ParsedPartitionSpec(partitions={self.partitions}, " - f"unsafe_user_spec={self.unsafe_user_spec}, " - f"sync={self.sync})") - -class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec): - """ParsedPartitionSpecs that are canonicalized. - - ParsedPartitionSpecs may contain trailing empty tuples, that make them - semantically different in general, and yet in some situations we prefer - to regard them as equivalent. For example, partitions of () and ((),) - cannot be always considered equivalent, since the first one is a valid - spec for a scalar value, while the second is not! However, when either of - those are applied to a 2D array, they both mean that the array is fully - replicated. - - So CanonicalizedParsedPartitionSpecs removes the trailing empty tuples from - partitions. - """ - - def __init__(self, parsed_pspec: ParsedPartitionSpec): - partitions = list(parsed_pspec.partitions) - while partitions and partitions[-1] == (): - partitions.pop() - - super().__init__(parsed_pspec.unsafe_user_spec, partitions, - parsed_pspec.sync) - - def __repr__(self): - return (f"CanonicalizedParsedPartitionSpec(partitions={self.partitions}, " - f"unsafe_user_spec={self.unsafe_user_spec}, " - f"sync={self.sync})") - - -def _prepare_axis_resources(axis_resources, - arg_name, - allow_unconstrained_dims=False): - # PyTrees don't treat None values as leaves, so we use an is_leaf function. - entries, treedef = tree_flatten(axis_resources, is_leaf=lambda x: x is None) - what = f"{arg_name} leaf specifications" - # All entries should be specified or if unspecified then there should only - # be 1 entry for that since _UNSPECIFIED is a private API. - _check_all_or_none_unspecified(entries, arg_name) - - new_entries = [] - for entry in entries: - if _is_unspecified_or_auto(entry): - new_entries.append(entry) - elif isinstance(entry, Sharding): - if isinstance(entry, PmapSharding): - raise ValueError(f'One of {what} got sharding {entry} which is not ' - 'allowed.') - if not isinstance(entry, XLACompatibleSharding): - raise ValueError(f'One of {what} got sharding {entry} which is not a ' - 'subclass of XLACompatibleSharding.') - new_entries.append(entry) - else: - new_entries.append(ParsedPartitionSpec.from_user_input( - entry, what, allow_unconstrained_dims=allow_unconstrained_dims)) - - _check_unique_resources(new_entries, arg_name) - return tree_unflatten(treedef, new_entries), new_entries, treedef - - -def _check_unique_resources(axis_resources, arg_name): - for arg_axis_resources in axis_resources: - if not arg_axis_resources: continue - if (_is_unspecified_or_auto(arg_axis_resources) or - isinstance(arg_axis_resources, XLACompatibleSharding)): - continue - constrained_dims = [d for d in arg_axis_resources if d is not None] - resource_counts = Counter(it.chain.from_iterable(constrained_dims)) - if not resource_counts: continue - if resource_counts.most_common(1)[0][1] > 1: - multiple_uses = [r for r, c in resource_counts.items() if c > 1] - if multiple_uses: - raise ValueError(f"A single {arg_name} specification can map every mesh axis " - f"to at most one positional dimension, but {arg_axis_resources.user_spec} " - f"has duplicate entries for {mesh_lib.show_axes(multiple_uses)}") # -------------------- pjit rules -------------------- @@ -1203,21 +1010,21 @@ def _resolve_in_shardings( resolved_in_shardings = [] for arg, pjit_in_s in zip(args, pjit_in_shardings): arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True)) - if hasattr(arg, 'sharding') else (_UNSPECIFIED, False)) - if _is_unspecified(pjit_in_s): - if _is_unspecified(arg_s): + if hasattr(arg, 'sharding') else (UNSPECIFIED, False)) + if is_unspecified(pjit_in_s): + if is_unspecified(arg_s): resolved_in_shardings.append(arg_s) else: if committed: # If the arg has a PmapSharding, then reshard it unconditionally. if isinstance(arg_s, PmapSharding): - resolved_in_shardings.append(_UNSPECIFIED) + resolved_in_shardings.append(UNSPECIFIED) else: resolved_in_shardings.append(to_gspmd_sharding( cast(XLACompatibleSharding, arg_s), arg.ndim)) else: if dispatch.is_single_device_sharding(arg_s): - resolved_in_shardings.append(_UNSPECIFIED) + resolved_in_shardings.append(UNSPECIFIED) else: raise NotImplementedError('Having uncommitted Array sharded on ' 'multiple devices is not supported.') @@ -1239,7 +1046,7 @@ def _resolve_in_shardings( 'Please see the jax.Array migration guide for more information ' 'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. ' f'Got arg shape: {arg.shape}, arg value: {arg}') - if not _is_unspecified(arg_s): + if not is_unspecified(arg_s): if (committed and not isinstance(arg_s, PmapSharding) and not op_shardings.are_op_shardings_equal( @@ -1265,7 +1072,7 @@ def _pjit_call_impl(*args, jaxpr, args, in_shardings, out_shardings, resource_env.physical_mesh if resource_env is not None else None) - _allow_propagation_to_outputs = [_is_unspecified(o) for o in out_shardings] + _allow_propagation_to_outputs = [is_unspecified(o) for o in out_shardings] compiled = _pjit_lower( jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, keep_unused, @@ -1428,8 +1235,8 @@ def _pjit_lower_cached( def pjit_staging_rule(trace, *args, **params): if (params["inline"] and - all(_is_unspecified(i) for i in params["in_shardings"]) and - all(_is_unspecified(o) for o in params["out_shardings"])): + all(is_unspecified(i) for i in params["in_shardings"]) and + all(is_unspecified(o) for o in params["out_shardings"])): jaxpr = params['jaxpr'] return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) elif config.jax_dynamic_shapes: @@ -1488,9 +1295,9 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, output_types = [mlir.token_type()] * len(effects) + output_types flat_output_types = util.flatten(output_types) - arg_shardings = [None if _is_unspecified(i) else i._to_xla_op_sharding(aval.ndim) + arg_shardings = [None if is_unspecified(i) else i._to_xla_op_sharding(aval.ndim) for aval, i in zip(ctx.avals_in, in_shardings)] - result_shardings = [None if _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim) + result_shardings = [None if is_unspecified(o) else o._to_xla_op_sharding(aval.ndim) for aval, o in zip(ctx.avals_out, out_shardings)] # TODO(b/228598865): inlined calls cannot have shardings set directly on the @@ -1557,9 +1364,9 @@ batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False, None) pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None) def _pjit_batcher_for_sharding( - s: Union[GSPMDSharding, _UnspecifiedValue], + s: Union[GSPMDSharding, UnspecifiedValue], dim: int, val: Tuple[str, ...], mesh, ndim: int): - if _is_unspecified(s): + if is_unspecified(s): return s if not val: new_op = s._op_sharding.clone() # type: ignore @@ -1631,7 +1438,7 @@ def _pjit_partial_eval(trace, *in_tracers, def keep_where(l, should_keep): return tuple(x for x, keep in unsafe_zip(l, should_keep) if keep) - residual_shardings = (_UNSPECIFIED,) * num_residuals + residual_shardings = (UNSPECIFIED,) * num_residuals # Compute the known outputs known_params = dict( jaxpr=known_jaxpr, @@ -1649,7 +1456,7 @@ def _pjit_partial_eval(trace, *in_tracers, # Only forward the outvars where the out_sharding is UNSPECIFIED. known_user_out_shardings = keep_where(known_params['out_shardings'], known_outs) fwds_known_user = [ - fwd if _is_unspecified(os) else None + fwd if is_unspecified(os) else None for os, fwd in zip(known_user_out_shardings, fwds_known[:len(known_user_out_shardings)])] fwds_known = fwds_known_user + fwds_known[len(known_user_out_shardings):] @@ -1728,7 +1535,7 @@ def _pjit_partial_eval_custom_params_updater( if num_res == 0: residual_shardings = [] else: - residual_shardings = [_UNSPECIFIED] * num_res + residual_shardings = [UNSPECIFIED] * num_res _, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings']) new_params_known = dict(params_known, in_shardings=tuple(in_shardings_known), @@ -1862,7 +1669,7 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r raise RuntimeError("Changing the physical mesh is not allowed inside pjit.") for aval, s in zip(jaxpr.in_avals, params['in_shardings']): - if _is_unspecified(s) or is_auto(s): + if is_unspecified(s) or is_auto(s): continue elif hasattr(s, '_original_sharding') and hasattr( s._original_sharding, '_parsed_pspec'): @@ -1884,7 +1691,7 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r what = "pjit output" for aval, s in zip(jaxpr.out_avals, params['out_shardings']): - if _is_unspecified(s) or is_auto(s): + if is_unspecified(s) or is_auto(s): continue elif hasattr(s, '_original_sharding') and hasattr( s._original_sharding, '_parsed_pspec'): @@ -1907,9 +1714,9 @@ def _pjit_pp_rule(eqn, context, settings): del params['inline'] if not any(params['donated_invars']): del params['donated_invars'] - if all(pxla._is_unspecified(s) for s in params['in_shardings']): + if all(is_unspecified(s) for s in params['in_shardings']): del params['in_shardings'] - if all(pxla._is_unspecified(s) for s in params['out_shardings']): + if all(is_unspecified(s) for s in params['out_shardings']): del params['out_shardings'] if not params['keep_unused']: del params['keep_unused'] @@ -1923,17 +1730,17 @@ core.pp_eqn_rules[pjit_p] = _pjit_pp_rule # -------------------- with_sharding_constraint -------------------- def _resolve_wsc_args(axis_resources, shardings): - if not _is_unspecified(axis_resources) and not _is_unspecified(shardings): + if not is_unspecified(axis_resources) and not is_unspecified(shardings): raise ValueError( 'Setting both axis_resources and shardings is not ' 'allowed. axis_resources is deprecated. Please use shardings.') - if _is_unspecified(axis_resources) and _is_unspecified(shardings): + if is_unspecified(axis_resources) and is_unspecified(shardings): raise ValueError( 'Not specifying shardings to `with_sharding_constraint` is not allowed. ' 'Please specify the shardings argument with a concrete sharding. Note ' 'that axis_resources is deprecated, so use the shardings argument.') - if not _is_unspecified(axis_resources): + if not is_unspecified(axis_resources): warnings.warn( 'axis_resources is deprecated. Please use shardings argument instead.', DeprecationWarning) @@ -1946,11 +1753,11 @@ def _resolve_wsc_args(axis_resources, shardings): # TODO(yashkatariya): Remove the axis_resources argument and make the signature # `with_sharding_constraint(x, shardings)` with no defaults after deprecation # period is finished. The deprecation period expires 3 months from Feb 13, 2023. -def with_sharding_constraint(x, shardings=_UNSPECIFIED, - axis_resources=_UNSPECIFIED): +def with_sharding_constraint(x, shardings=UNSPECIFIED, + axis_resources=UNSPECIFIED): final_shardings = _resolve_wsc_args(axis_resources, shardings) x_flat, tree = tree_flatten(x) - user_shardings, _, _ = _prepare_axis_resources( + user_shardings, _, _ = prepare_axis_resources( final_shardings, "shardings", allow_unconstrained_dims=True) del final_shardings @@ -1996,7 +1803,7 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, # axis_ctx and manual_axes is *only used with xmap* and xmap only works with # NamedSharding. So convert the GSPMDSharding to NamedSharding # and then convert it back with the added special axes. - if isinstance(axis_ctx, mlir.SPMDAxisContext): + if isinstance(axis_ctx, sharding_impls.SPMDAxisContext): mesh = resource_env.physical_mesh parsed_pspec = parse_flatten_op_sharding(sharding._op_sharding, mesh)[0] mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec) @@ -2054,18 +1861,6 @@ pxla.custom_resource_typing_rules[sharding_constraint_p] = \ # -------------------- helpers -------------------- -def get_array_mapping( - axis_resources: Union[ParsedPartitionSpec, _AUTOAxisResource, _UnspecifiedValue] -) -> pxla.ArrayMappingOrAutoOrUnspecified: - # TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported. - # Don't use `is_auto` here to satisfy pytype and mypy. - if isinstance(axis_resources, (_AUTOAxisResource, _UnspecifiedValue)): - return axis_resources - return OrderedDict((axis, i) - for i, axes in enumerate(axis_resources) - if axes is not None for axis in axes) - - def to_gspmd_sharding(s: XLACompatibleSharding, ndim: int) -> GSPMDSharding: if isinstance(s, GSPMDSharding): return s @@ -2085,7 +1880,7 @@ def _fast_path_get_device_assignment( shardings: Iterable[PjitSharding]) -> Optional[XLADeviceAssignment]: da = None for i in shardings: - if is_auto(i) or _is_unspecified(i): + if is_auto(i) or is_unspecified(i): continue da = i._device_assignment # type: ignore break @@ -2227,11 +2022,8 @@ def parse_flatten_op_sharding(op_sharding: xc.OpSharding, raise AssertionError("Unhandled OpSharding type. Please open a bug report!") -_get_single_pspec = lambda p: pxla.array_mapping_to_axis_resources( - cast(pxla.ArrayMapping, get_array_mapping(p))) - def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[PartitionSpec]: - return [_get_single_pspec(p) for p in ppspec] + return [get_single_pspec(p) for p in ppspec] def _get_op_sharding_from_executable( diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index d86ca08cd..a26f7cd19 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -14,24 +14,35 @@ from __future__ import annotations +import collections +import dataclasses +import enum import functools -from collections import Counter +import itertools import operator as op -from typing import (Any, Sequence, List, Tuple, Optional, Mapping, Dict, Set, - FrozenSet, Union, cast) +import sys +from typing import (Any, Dict, FrozenSet, List, Mapping, Optional, OrderedDict, + NamedTuple, Sequence, Set, Tuple, Union, cast) from jax._src import mesh as mesh_lib from jax._src import op_shardings from jax._src import sharding from jax._src import sharding_specs +from jax._src import tree_util +from jax._src import util from jax._src import xla_bridge from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method from jax._src.lib import xla_client as xc -from jax._src.interpreters import mlir from jax._src.partition_spec import PartitionSpec import numpy as np +if sys.version_info >= (3, 9): + OrderedDictType = OrderedDict +else: + OrderedDictType = Dict + + Shape = Tuple[int, ...] Device = xc.Device Index = Tuple[slice, ...] @@ -134,7 +145,7 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int] 'create a device to index mapping for your sharding from which replica ' 'ids will be calculated.') from None - index_to_replica: Dict[int, int] = Counter() + index_to_replica: Dict[int, int] = collections.Counter() out = {} for device, index in device_indices_map_fn(global_shape).items(): h_index = hashed_index(index) @@ -204,8 +215,7 @@ class NamedSharding(XLACompatibleSharding): # TODO(yaskatariya): Remove this and replace this with a normalized # representation of Parsed Pspec if self._parsed_pspec is None: - from jax._src import pjit - self._parsed_pspec, _, _ = pjit._prepare_axis_resources( + self._parsed_pspec, _, _ = prepare_axis_resources( self.spec, "NamedSharding spec") _check_mesh_resource_axis(self.mesh, self._parsed_pspec) @@ -263,9 +273,8 @@ class NamedSharding(XLACompatibleSharding): def _to_xla_op_sharding( self, num_dimensions: int, - axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None + axis_ctx: Optional[Union[SPMDAxisContext, ShardingContext]] = None ) -> xc.OpSharding: - from jax._src.pjit import get_array_mapping assert self._parsed_pspec is not None array_mapping = get_array_mapping(self._parsed_pspec) # TODO(yashkatariya): Move away from sharding spec in NamedSharding @@ -275,7 +284,7 @@ class NamedSharding(XLACompatibleSharding): # Used in `with_sharding_constraint`. special_axes = {} # Manual axes is only used with xmap. - if axis_ctx is not None and isinstance(axis_ctx, mlir.SPMDAxisContext): + if axis_ctx is not None and isinstance(axis_ctx, SPMDAxisContext): axis_names = self.mesh.axis_names # Ignore type because mypy doesn't recognize the `hasattr` check above. for manual_axis in axis_ctx.manual_axes: # type: ignore @@ -634,3 +643,324 @@ class GSPMDSharding(XLACompatibleSharding): def get_replicated(cls, device_assignment): proto = get_replicated_op_sharding() return cls(device_assignment, proto) + + +class AUTOAxisResource: + pass +AUTO = AUTOAxisResource() + +def is_auto(x): + return isinstance(x, AUTOAxisResource) + + +class UnspecifiedValue: + def __repr__(self): + return "UnspecifiedValue" +UNSPECIFIED = UnspecifiedValue() + +def is_unspecified(x): + return isinstance(x, UnspecifiedValue) + +def is_unspecified_or_auto(x): + return is_auto(x) or is_unspecified(x) + + +MeshAxisName = Any + +""" +ArrayMapping specifies how an ndarray should map to mesh axes. + +Note that the ordering is crucial for the cases when this mapping is non-injective +(i.e. when multiple mesh axes map to the same positional axis). Then, the +order of entries of the mapping determines a major-to-minor order on mesh axes, +according to which chunks of the value along the repeated dimension will be assigned. + +For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}. +The second dimension of the value would get chunked into 6 pieces, and assigned to the +mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case, +that would mean that a flat list of chunks would get assigned to a flattened list of +mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the +mesh devices ndarray would have to be transposed before flattening and assignment. +""" +ArrayMapping = OrderedDictType[MeshAxisName, int] +ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTOAxisResource, + UnspecifiedValue] + +def array_mapping_to_axis_resources(array_mapping: ArrayMapping): + if not array_mapping: + return PartitionSpec() + max_index = -1 + reverse_map = collections.defaultdict(list) + for axis, index in array_mapping.items(): + reverse_map[index].append(axis) + if index > max_index: + max_index = index + partitions = tuple(tuple(reverse_map[i]) if reverse_map[i] else None + for i in range(max_index + 1)) + return PartitionSpec(*partitions) + +def get_array_mapping( + axis_resources: Union[ParsedPartitionSpec, AUTOAxisResource, UnspecifiedValue] +) -> ArrayMappingOrAutoOrUnspecified: + # TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported. + # Don't use `is_auto` here to satisfy pytype and mypy. + if isinstance(axis_resources, (AUTOAxisResource, UnspecifiedValue)): + return axis_resources + return OrderedDict((axis, i) + for i, axes in enumerate(axis_resources) + if axes is not None for axis in axes) + + +get_single_pspec = lambda p: array_mapping_to_axis_resources( + cast(ArrayMapping, get_array_mapping(p))) + + +class SpecSync(enum.IntEnum): + """Encodes how much out of sync the real value of partitions is compared to the user specified one. + + We use this to make sure we don't show garbage modified values while claiming + that the users have specified them like that. + """ + OUT_OF_SYNC = 0 # Arbitrary changes, including new axes inserted + DIM_PERMUTE = 1 # Dimensions permuted, but no new sharding axes + IN_SYNC = 2 # Entirely in sync + +class ParsedPartitionSpec: + __slots__ = ('unsafe_user_spec', 'partitions', 'sync') + + def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC): + self.unsafe_user_spec = user_spec + # None in partitions represents unconstrained dim. + # TODO(yashkatariya): May use a sentinel value. + self.partitions = tuple(partitions) + self.sync = sync + + @property + def user_spec(self): + return self.unsynced_user_spec(SpecSync.IN_SYNC) + + def get_partition_spec(self) -> PartitionSpec: + if self.sync < SpecSync.IN_SYNC: + return get_single_pspec(self) + else: + if isinstance(self.unsafe_user_spec, PartitionSpec): + return self.unsafe_user_spec + else: + return get_single_pspec(self) + + def unsynced_user_spec(self, min_sync): + if self.sync < min_sync: + raise AssertionError(f"Please open a bug report! ({self.sync} >= {min_sync})") + return self.unsafe_user_spec + + def insert_axis_partitions(self, dim, val): + parts = self.partitions + too_short = dim - len(parts) + if too_short > 0: + parts += ((),) * too_short + new_partitions = util.tuple_insert(parts, dim, val) + new_sync = SpecSync.DIM_PERMUTE if (val == () or val is None) else SpecSync.OUT_OF_SYNC + return ParsedPartitionSpec(self.unsafe_user_spec, new_partitions, sync=new_sync) + + @classmethod + def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False): + if entry is None: + return cls(entry, ()) + if not isinstance(entry, PartitionSpec): + raise TypeError(f"{arg_name} are expected to be " + f"PartitionSpec instances or None, but got {entry}") + axis_specs = [] + for axis_spec in entry: + if axis_spec is None: + axis_spec = () + elif isinstance(axis_spec, (list, tuple)): + axis_spec = tuple(axis_spec) + elif axis_spec == PartitionSpec.UNCONSTRAINED: + if not allow_unconstrained_dims: + raise ValueError(f"Unconstrained dims are not allowed: {entry}") + axis_spec = None + else: + axis_spec = (axis_spec,) + axis_specs.append(axis_spec) + return cls(entry, axis_specs) + + def __hash__(self): + return hash((self.partitions, self.sync)) + + def __eq__(self, other): + return (self.partitions == other.partitions and + self.sync == other.sync) + + def __len__(self): + return len(self.partitions) + + def __getitem__(self, i): + return self.partitions[i] + + def __iter__(self): + return iter(self.partitions) + + def __repr__(self): + return (f"ParsedPartitionSpec(partitions={self.partitions}, " + f"unsafe_user_spec={self.unsafe_user_spec}, " + f"sync={self.sync})") + +class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec): + """ParsedPartitionSpecs that are canonicalized. + + ParsedPartitionSpecs may contain trailing empty tuples, that make them + semantically different in general, and yet in some situations we prefer + to regard them as equivalent. For example, partitions of () and ((),) + cannot be always considered equivalent, since the first one is a valid + spec for a scalar value, while the second is not! However, when either of + those are applied to a 2D array, they both mean that the array is fully + replicated. + + So CanonicalizedParsedPartitionSpecs removes the trailing empty tuples from + partitions. + """ + + def __init__(self, parsed_pspec: ParsedPartitionSpec): + partitions = list(parsed_pspec.partitions) + while partitions and partitions[-1] == (): + partitions.pop() + + super().__init__(parsed_pspec.unsafe_user_spec, partitions, + parsed_pspec.sync) + + def __repr__(self): + return (f"CanonicalizedParsedPartitionSpec(partitions={self.partitions}, " + f"unsafe_user_spec={self.unsafe_user_spec}, " + f"sync={self.sync})") + + +def check_all_or_none_unspecified(axis_resources, name): + if not axis_resources: + return False + unspecified_count = 0 + unspecified = is_unspecified(axis_resources[0]) + for resource in axis_resources: + current_is_unspecified = is_unspecified(resource) + if current_is_unspecified: + unspecified_count += 1 + assert unspecified_count == 1 + if current_is_unspecified != unspecified: + raise ValueError(f'`pjit.UNSPECIFIED` exists in {name}. ' + f'Make sure that every entry in {name} is ' + '`pjit.UNSPECIFIED`.') + return unspecified + + +def prepare_axis_resources(axis_resources, + arg_name, + allow_unconstrained_dims=False): + # PyTrees don't treat None values as leaves, so we use an is_leaf function. + entries, treedef = tree_util.tree_flatten( + axis_resources, is_leaf=lambda x: x is None) + what = f"{arg_name} leaf specifications" + # All entries should be specified or if unspecified then there should only + # be 1 entry for that since UNSPECIFIED is a private API. + check_all_or_none_unspecified(entries, arg_name) + + new_entries = [] + for entry in entries: + if is_unspecified_or_auto(entry): + new_entries.append(entry) + elif isinstance(entry, sharding.Sharding): + if isinstance(entry, PmapSharding): + raise ValueError(f'One of {what} got sharding {entry} which is not ' + 'allowed.') + if not isinstance(entry, XLACompatibleSharding): + raise ValueError(f'One of {what} got sharding {entry} which is not a ' + 'subclass of XLACompatibleSharding.') + new_entries.append(entry) + else: + new_entries.append(ParsedPartitionSpec.from_user_input( + entry, what, allow_unconstrained_dims=allow_unconstrained_dims)) + + _check_unique_resources(new_entries, arg_name) + return tree_util.tree_unflatten(treedef, new_entries), new_entries, treedef + + +def _check_unique_resources(axis_resources, arg_name): + for arg_axis_resources in axis_resources: + if not arg_axis_resources: continue + if (is_unspecified_or_auto(arg_axis_resources) or + isinstance(arg_axis_resources, XLACompatibleSharding)): + continue + constrained_dims = [d for d in arg_axis_resources if d is not None] + resource_counts = collections.Counter( + itertools.chain.from_iterable(constrained_dims)) + if not resource_counts: continue + if resource_counts.most_common(1)[0][1] > 1: + multiple_uses = [r for r, c in resource_counts.items() if c > 1] + if multiple_uses: + raise ValueError(f"A single {arg_name} specification can map every mesh axis " + f"to at most one positional dimension, but {arg_axis_resources.user_spec} " + f"has duplicate entries for {mesh_lib.show_axes(multiple_uses)}") + +# Axis environments + +class AxisEnv(NamedTuple): + """Represents a pmap mesh (only along the replica axes).""" + nreps: int + names: Tuple[Any, ...] + sizes: Tuple[int, ...] + + +@dataclasses.dataclass(frozen=True) +class SPMDAxisContext: + """A hardware axis context for parallel computations that use the GSPMD partitioner. + + This includes the mesh that will later by used to execute this computation, + as well as a set of mesh axes that are currently (e.g. because the current lowering + is invoked inside an xmap) lowered in the MANUAL sharding mode. + """ + mesh: mesh_lib.Mesh + manual_axes: FrozenSet[MeshAxisName] = frozenset() + + @property + def axis_env(self): + # All collectives that touch axis_env should remember to set use_global_device_ids + # when this context is enabled! + if self.manual_axes != frozenset(self.mesh.axis_names): + raise NotImplementedError( + "Collectives in manually partitioned computations are only supported " + "when all mesh axes are partitioned manually (no partial automatic sharding). " + "Make sure that you mention all mesh axes in axis_resources!") + return self.unsafe_axis_env + + @property + def unsafe_axis_env(self): + return AxisEnv( + nreps=self.mesh.size, + names=self.mesh.axis_names, + sizes=tuple(self.mesh.shape.values())) + + def extend_manual(self, axes: FrozenSet[MeshAxisName]) -> SPMDAxisContext: + return SPMDAxisContext(self.mesh, self.manual_axes | axes) + + +@dataclasses.dataclass(frozen=True) +class ReplicaAxisContext: + """A hardware axis context for parallel computations that are partitioned by JAX. + + Unlike in the SPMDAxisContext, this means that JAX might need to emit calls to + explicit collectives. + """ + axis_env: AxisEnv + + +@dataclasses.dataclass(frozen=True) +class ShardingContext: + """A hardware axis context for parallel computations that use the sharding + interface. + + This context also uses the GSPMD partitioner. + """ + device_assignment: Sequence[xc.Device] + + # Similar to SPMDContext as ShardingContext also uses the GSPMD partitioner. + @property + def axis_env(self): + return AxisEnv(nreps=1, names=(), sizes=()) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index bc66f6648..9d60b4d75 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -38,6 +38,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tu import jax from jax._src import core +from jax._src import sharding_impls from jax._src import source_info_util from jax._src import traceback_util from jax._src import tree_util @@ -596,7 +597,7 @@ class Lowered(Stage): if isinstance(self._lowering, pxla.MeshComputation): kw.update( _allow_propagation_to_outputs=[ - pxla._is_unspecified(o) + sharding_impls.is_unspecified(o) for o in self._lowering.compile_args["out_shardings"] ] ) diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index 169e2ebd4..4155565e6 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -17,6 +17,7 @@ import inspect from jax._src import core from jax import tree_util from jax._src import linear_util as lu +from jax._src import sharding_impls from jax.experimental import pjit from jax.errors import UnexpectedTracerError from jax._src import mesh as mesh_lib @@ -97,7 +98,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape, ] closed_jaxpr = jax.make_jaxpr( lower_fn, axis_env=list(info.mesh.shape.items()))(*tiled_args) - axis_context = mlir.SPMDAxisContext(info.mesh) + axis_context = sharding_impls.SPMDAxisContext(info.mesh) built = mlir.build_xla_computation_helper( closed_jaxpr, name="tmp_xla_computation", @@ -373,9 +374,9 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values, mesh = mesh_lib.thread_resources.env.physical_mesh axis_context = ctx.module_context.axis_context - if isinstance(axis_context, mlir.ShardingContext): + if isinstance(axis_context, sharding_impls.ShardingContext): devices = axis_context.device_assignment - elif isinstance(axis_context, mlir.SPMDAxisContext): + elif isinstance(axis_context, sharding_impls.SPMDAxisContext): devices = list(axis_context.mesh.devices.flat) else: devices = None diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index e69b2a638..8cb396339 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -518,9 +518,9 @@ from jax._src.interpreters import xla from jax._src import ad_checkpoint from jax._src import dispatch from jax._src import pretty_printer as pp +from jax._src import sharding_impls from jax._src import source_info_util from jax._src import util -from jax._src import lib as jaxlib from jax._src.lib import pytree from jax._src import xla_bridge as xb from jax._src.lib import xla_client @@ -1245,8 +1245,10 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext, result_arrays = () return result_arrays - if isinstance(ctx.module_context.axis_context, - (mlir.SPMDAxisContext, mlir.ShardingContext)): + if isinstance( + ctx.module_context.axis_context, + (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), + ): # Apply maximal sharding so pjit only executes the callback on device device_index. sharding = xla_client.OpSharding() sharding.type = xla_client.OpSharding.Type.MAXIMAL @@ -1715,11 +1717,17 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn], eqn.params, jaxpr=_rewrite_closed_jaxpr(jaxpr, True, True), donated_invars=eqn.params["donated_invars"] + (False, False), - in_shardings=(eqn.params["in_shardings"] + - (pjit._UNSPECIFIED, pjit._UNSPECIFIED)), - out_shardings=(eqn.params["out_shardings"] + - (pjit._UNSPECIFIED, pjit._UNSPECIFIED)), - ))) + in_shardings=( + eqn.params["in_shardings"] + + (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED) + ), + out_shardings=( + eqn.params["out_shardings"] + + (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED) + ), + ), + ) + ) elif eqn.primitive is ad_checkpoint.remat_p: jaxpr_ = cast(core.Jaxpr, eqn.params["jaxpr"]) eqns.append( diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index e4416325d..1bff7a504 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -49,6 +49,7 @@ from jax._src import dispatch from jax._src import dtypes from jax._src import linear_util as lu from jax._src import op_shardings +from jax._src import sharding_impls from jax._src import pjit from jax._src import prng from jax._src import random as random_internal @@ -3096,7 +3097,7 @@ def _shard_value(val: TfVal, sd: sharding.XLACompatibleSharding, *, skip_replicated_sharding: bool) -> TfVal: """Apply sharding to a TfVal.""" - if pxla._is_unspecified(sd): + if sharding_impls.is_unspecified(sd): return val sharding_proto: xla_client.OpSharding = cast( diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 91d4cef15..3e5d17596 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -28,6 +28,7 @@ from jax import sharding from jax._src import core from jax._src import pjit +from jax._src import sharding_impls from jax._src import source_info_util from jax._src import util from jax._src import xla_bridge as xb @@ -270,7 +271,7 @@ def _add_dim_arg_computation(module: mlir.ir.Module, with ir.InsertionPoint(entry_block): orig_main_args: List[mlir.ir.Value] = [] module_context = mlir.ModuleContext( - "cpu", "cpu", mlir.ShardingContext([]), + "cpu", "cpu", sharding_impls.ShardingContext([]), source_info_util.new_name_stack(), [], itertools.count(1), [], module=new_module, context=context) ctx = mlir.LoweringRuleContext(module_context=module_context, @@ -467,7 +468,7 @@ def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported: map(lambda a: a.at_least_vspace(), primal.out_avals))) # Expand in_shardings to all in_avals even not kept ones. - all_in_shardings = [pxla._UNSPECIFIED] * len(primal.in_avals) + all_in_shardings = [sharding_impls.UNSPECIFIED] * len(primal.in_avals) for idx, in_s in zip(sorted(primal.module_kept_var_idx), primal.in_shardings): all_in_shardings[idx] = in_s # type: ignore @@ -475,13 +476,13 @@ def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported: # Cannot mix unspecified and specified shardings. Make the unspecified # ones replicated. specified_shardings = [ - s for s in all_shardings if not pxla._is_unspecified(s)] + s for s in all_shardings if not sharding_impls.is_unspecified(s)] vjp_in_shardings: Any # The primal inputs followed by output cotangents vjp_out_shardings: Any # The primal output cotangents if 0 == len(specified_shardings): - vjp_in_shardings = pxla._UNSPECIFIED - vjp_out_shardings = pxla._UNSPECIFIED + vjp_in_shardings = sharding_impls.UNSPECIFIED + vjp_out_shardings = sharding_impls.UNSPECIFIED else: if len(specified_shardings) < len(all_shardings): # There are some specified, but not all; pjit front-end does not liwk @@ -489,13 +490,13 @@ def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported: assert isinstance(in_s, sharding.XLACompatibleSharding) replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment) all_shardings = [ - s if not pxla._is_unspecified(s) else replicated_s + s if not sharding_impls.is_unspecified(s) else replicated_s for s in all_shardings] vjp_in_shardings = tuple(all_shardings) vjp_out_shardings = tuple(all_shardings[:len(primal.in_avals)]) - if all(pxla._is_unspecified(s) for s in vjp_out_shardings): - vjp_out_shardings = pxla._UNSPECIFIED + if all(sharding_impls.is_unspecified(s) for s in vjp_out_shardings): + vjp_out_shardings = sharding_impls.UNSPECIFIED fun_vjp_jax = pjit.pjit(fun_vjp_jax, in_shardings=vjp_in_shardings, diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index bc50f767e..c622b1b82 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -68,6 +68,7 @@ from jax._src import ad_util from jax._src import core from jax._src import dispatch from jax._src import linear_util as lu +from jax._src import sharding_impls from jax._src.api_util import shaped_abstractify from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal @@ -734,8 +735,12 @@ def _pjit_jet_rule(primals_in, series_in, **params): **params, 'jaxpr': jaxpr_jet, 'in_shardings': ( - params['in_shardings'] + (pjit._UNSPECIFIED,) * num_series_in), - 'out_shardings': params['out_shardings'] + (pjit._UNSPECIFIED,) * num_series_out, + params['in_shardings'] + (sharding_impls.UNSPECIFIED,) * num_series_in + ), + 'out_shardings': ( + params['out_shardings'] + + (sharding_impls.UNSPECIFIED,) * num_series_out + ), 'donated_invars': params['donated_invars'] + (False,) * num_series_in, } result = pjit.pjit_p.bind(*primals_and_series, **new_params) diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index f379da1d5..7fad15796 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -15,34 +15,36 @@ # flake8: noqa from jax._src.pjit import ( - AUTO as AUTO, - ParsedPartitionSpec as ParsedPartitionSpec, - get_array_mapping as get_array_mapping, hashable_pytree as hashable_pytree, parse_flatten_op_sharding as parse_flatten_op_sharding, pjit as pjit, pjit_p as pjit_p, with_sharding_constraint as with_sharding_constraint, ) +from jax._src.sharding_impls import ( + AUTO as AUTO, + UNSPECIFIED as _UNSPECIFIED, + ParsedPartitionSpec as ParsedPartitionSpec, + get_array_mapping as get_array_mapping, + prepare_axis_resources as _prepare_axis_resources +) -from jax._src.pjit import (_UNSPECIFIED, _prepare_axis_resources, - _get_op_sharding_from_executable, +from jax._src.pjit import (_get_op_sharding_from_executable, _get_pspec_from_executable, _pjit_lower_cached, _pjit_lower, _pjit_jaxpr, _process_in_axis_resources) - -from jax._src.pjit import ( +from jax._src.sharding_impls import ( NamedSharding as _deprecated_NamedSharding, +) +from jax._src.partition_spec import ( PartitionSpec as _deprecated_PartitionSpec, ) import typing if typing.TYPE_CHECKING: - from jax._src.pjit import ( - NamedSharding as NamedSharding, - PartitionSpec as PartitionSpec, - ) + from jax._src.sharding_impls import NamedSharding as NamedSharding + from jax._src.partition_spec import PartitionSpec as PartitionSpec del typing _deprecations = { diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 8e3b54697..7538ec6d4 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -34,6 +34,7 @@ from jax._src import linear_util as lu from jax._src import ops from jax._src import pjit from jax._src import prng +from jax._src import sharding_impls from jax._src import source_info_util from jax._src import traceback_util from jax._src import util @@ -47,7 +48,6 @@ from jax.api_util import flatten_fun_nokwargs, shaped_abstractify from jax._src.interpreters import batching from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe -from jax._src.interpreters import xla from jax._src.interpreters import pxla from jax.interpreters import ad from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, @@ -468,7 +468,9 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, sharded_avals = [v.aval for v in jaxpr.invars] in_nodes_ = map(partial(_xla_shard, mesh), in_names, ctx.avals_in, sharded_avals, in_nodes) - new_axis_context = mlir.SPMDAxisContext(mesh, frozenset(mesh.axis_names)) + new_axis_context = sharding_impls.SPMDAxisContext( + mesh, frozenset(mesh.axis_names) + ) sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) with core.extend_axis_env_nd(tuple(mesh.shape.items())): out_nodes_, _ = mlir.jaxpr_subcomp(sub_ctx, jaxpr, mlir.TokenSet(), @@ -483,8 +485,9 @@ def _xla_shard(mesh, names, aval_in, aval_out, x): manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names), mesh) result_type, = mlir.aval_to_ir_types(aval_out) axes = {name: i for i, ns in names.items() for name in ns} - shard_proto = NamedSharding(mesh, pxla.array_mapping_to_axis_resources(axes) # type: ignore - )._to_xla_op_sharding(aval_in.ndim) + shard_proto = NamedSharding( + mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore + )._to_xla_op_sharding(aval_in.ndim) if core.is_opaque_dtype(aval_in.dtype): shard_proto = aval_in.dtype._rules.physical_op_sharding(aval_in, shard_proto) sx = mlir.wrap_with_sharding_op(x, shard_proto, unspecified_dims=set()) @@ -496,8 +499,9 @@ def _xla_unshard(mesh, names, aval_in, aval_out, xs): result_type, = mlir.aval_to_ir_types(aval_out) sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=set()) axes = {name: i for i, ns in names.items() for name in ns} - shard_proto = NamedSharding(mesh, pxla.array_mapping_to_axis_resources(axes) # type: ignore - )._to_xla_op_sharding(aval_out.ndim) + shard_proto = NamedSharding( + mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore + )._to_xla_op_sharding(aval_out.ndim) if core.is_opaque_dtype(aval_out.dtype): shard_proto = aval_out.dtype._rules.physical_op_sharding(aval_out, shard_proto) return mlir.wrap_with_shard_to_full_op(result_type, sx, shard_proto, set()) diff --git a/jax/experimental/sparse/transform.py b/jax/experimental/sparse/transform.py index 993e7bb11..a233145fc 100644 --- a/jax/experimental/sparse/transform.py +++ b/jax/experimental/sparse/transform.py @@ -53,18 +53,16 @@ from typing import ( import numpy as np -import jax from jax import lax from jax._src import core from jax._src import linear_util as lu from jax._src import pjit +from jax._src import sharding_impls from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_sparse import jax.numpy as jnp from jax._src.api_util import flatten_fun_nokwargs from jax._src.lib import pytree from jax._src.interpreters import partial_eval as pe -from jax._src.interpreters import xla -from jax._src.interpreters import pxla from jax.tree_util import tree_flatten, tree_map, tree_unflatten from jax.util import safe_map, safe_zip, split_list from jax._src.config import config @@ -777,9 +775,13 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings, # TODO(yashkatariya, vanderplas): Flatten twice and set the correct sharding # for data and indices. in_shardings = in_shardings + tuple( - pxla._UNSPECIFIED for _ in range(len(args_flat) - len(in_shardings))) + sharding_impls.UNSPECIFIED + for _ in range(len(args_flat) - len(in_shardings)) + ) out_shardings = out_shardings + tuple( - pxla._UNSPECIFIED for _ in range(len(sp_call_jaxpr.out_avals) - len(out_shardings))) + sharding_impls.UNSPECIFIED + for _ in range(len(sp_call_jaxpr.out_avals) - len(out_shardings)) + ) out_flat = pjit.pjit_p.bind( *args_flat, diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index ff6edca99..0d31aab5a 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -19,14 +19,9 @@ from jax._src.interpreters.mlir import ( LoweringResult as LoweringResult, LoweringRule as LoweringRule, LoweringRuleContext as LoweringRuleContext, - Mesh as Mesh, - MeshAxisName as MeshAxisName, ModuleContext as ModuleContext, RECV_FROM_HOST_TYPE as RECV_FROM_HOST_TYPE, - ReplicaAxisContext as ReplicaAxisContext, SEND_TO_HOST_TYPE as SEND_TO_HOST_TYPE, - SPMDAxisContext as SPMDAxisContext, - ShardingContext as ShardingContext, Token as Token, TokenSet as TokenSet, Value as Value, @@ -64,3 +59,11 @@ from jax._src.interpreters.mlir import ( token_type as token_type, xla_computation_to_mlir_module as xla_computation_to_mlir_module, ) + +from jax._src.mesh import Mesh as Mesh +from jax._src.sharding_impls import ( + MeshAxisName as MeshAxisName, + ReplicaAxisContext as ReplicaAxisContext, + SPMDAxisContext as SPMDAxisContext, + ShardingContext as ShardingContext, +) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 38ef041e4..7ebaf0675 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -13,9 +13,6 @@ # limitations under the License. from jax._src.interpreters.pxla import ( - AUTO as AUTO, - ArrayMapping as ArrayMapping, - ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified, AvalDimSharding as AvalDimSharding, EmapInfo as EmapInfo, ExecuteReplicated as ExecuteReplicated, @@ -26,7 +23,6 @@ from jax._src.interpreters.pxla import ( MeshComputation as MeshComputation, MeshDimAssignment as MeshDimAssignment, MeshExecutable as MeshExecutable, - OrderedDictType as OrderedDictType, ParallelCallableInfo as ParallelCallableInfo, PartitionInfo as PartitionInfo, PartitionsOrReplicated as PartitionsOrReplicated, @@ -43,12 +39,9 @@ from jax._src.interpreters.pxla import ( UnloadedMeshExecutable as UnloadedMeshExecutable, UnloadedPmapExecutable as UnloadedPmapExecutable, WeakRefList as WeakRefList, - _UNSPECIFIED as _UNSPECIFIED, _create_pmap_sharding_spec as _create_pmap_sharding_spec, _get_and_check_device_assignment as _get_and_check_device_assignment, - _is_unspecified as _is_unspecified, _pmap_sharding_spec as _pmap_sharding_spec, - array_mapping_to_axis_resources as array_mapping_to_axis_resources, array_types as array_types, custom_resource_typing_rules as custom_resource_typing_rules, device_put as _deprecated_device_put, @@ -103,6 +96,15 @@ from jax._src.op_shardings import ( op_sharding_to_indices as op_sharding_to_indices, ) +from jax._src.sharding_impls import ( + ArrayMapping as ArrayMapping, + ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified, + AUTO as AUTO, + UNSPECIFIED as _UNSPECIFIED, + array_mapping_to_axis_resources as array_mapping_to_axis_resources, + is_unspecified as _is_unspecified, +) + from jax._src.sharding_specs import ( Chunked as Chunked, NoSharding as NoSharding, diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 30789c1d9..e751a8318 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -13,7 +13,6 @@ # limitations under the License. from jax._src.interpreters.xla import ( - AxisEnv as AxisEnv, DeviceArray as _deprecated_DeviceArray, TranslationContext as TranslationContext, TranslationRule as TranslationRule, @@ -46,6 +45,10 @@ from jax._src.dispatch import ( backend_compile as backend_compile, ) +from jax._src.sharding_impls import ( + AxisEnv as AxisEnv, +) + from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc # type: ignore diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6a7aa43fb..724c8ba94 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -44,10 +44,12 @@ from jax.experimental.custom_partitioning import custom_partitioning from jax._src import array from jax._src.sharding import Sharding from jax._src import op_shardings -from jax._src.sharding_impls import (NamedSharding, GSPMDSharding, - PositionalSharding, SingleDeviceSharding) +from jax._src import sharding_impls +from jax._src.sharding_impls import ( + AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding, + SingleDeviceSharding) import jax._src.pjit as pjit_lib -from jax._src.pjit import (pjit, pjit_p, AUTO) +from jax._src.pjit import pjit, pjit_p from jax._src import mesh from jax._src.interpreters import pxla from jax.interpreters import mlir @@ -3519,7 +3521,9 @@ class UtilTest(jtu.JaxTestCase): ("multi_skip", {'x': 0, 'y': 1, 'z': 3}, P(('x',), ('y',), None, ('z',))), ) def test_array_mapping_to_axis_resources(self, inp, expected_out): - self.assertEqual(pxla.array_mapping_to_axis_resources(inp), expected_out) + self.assertEqual( + sharding_impls.array_mapping_to_axis_resources(inp), expected_out + ) def test_get_input_indices_fully_replicated(self): global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) @@ -3552,22 +3556,22 @@ class UtilTest(jtu.JaxTestCase): pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval, array_mapping) @parameterized.named_parameters( - ("all_unspecified", (pjit_lib._UNSPECIFIED, pjit_lib._UNSPECIFIED), AssertionError), - ("only_unspecified", pjit_lib._UNSPECIFIED), + ("all_unspecified", (UNSPECIFIED, UNSPECIFIED), AssertionError), + ("only_unspecified", UNSPECIFIED), ("all_specified", (P('x'), P('y'))), ("only_specified", P('x')), - ("mix_1", (P('x'), pjit_lib._UNSPECIFIED), ValueError), - ("mix_2", (P('x'), pjit_lib._UNSPECIFIED, P('y')), ValueError), - ("mix_3", (pjit_lib._UNSPECIFIED, P('x'), P('y')), ValueError), - ("mix_4", (pjit_lib._UNSPECIFIED, P('x'), pjit_lib._UNSPECIFIED), ValueError), + ("mix_1", (P('x'), UNSPECIFIED), ValueError), + ("mix_2", (P('x'), UNSPECIFIED, P('y')), ValueError), + ("mix_3", (UNSPECIFIED, P('x'), P('y')), ValueError), + ("mix_4", (UNSPECIFIED, P('x'), UNSPECIFIED), ValueError), ) def test_all_or_non_unspecified(self, axis_resources, error=None): entries, _ = jax.tree_util.tree_flatten(axis_resources, is_leaf=lambda x: x is None) if error is not None: with self.assertRaises(error): - pjit_lib._check_all_or_none_unspecified(entries, 'test axis resources') + sharding_impls.check_all_or_none_unspecified(entries, 'test axis resources') else: - pjit_lib._check_all_or_none_unspecified(entries, 'test axis resources') + sharding_impls.check_all_or_none_unspecified(entries, 'test axis resources') def test_op_sharding_equality_and_hash_equality(self): op1 = xc.OpSharding() @@ -3673,7 +3677,6 @@ class UtilTest(jtu.JaxTestCase): cache_info4 = GSPMDSharding.devices_indices_map.cache_info() self.assertEqual(cache_info4.hits, cache_info3.hits + 1) - def test_op_sharding_semantically_replicated(self): op1 = xc.OpSharding() op1.type = xc.OpSharding.Type.OTHER @@ -3769,8 +3772,8 @@ class UtilTest(jtu.JaxTestCase): self.assertEqual(recovered_parsed_pspec[0].get_partition_spec(), P(('x',), ('y',))) - out_of_sync_parsed_pspec = pjit_lib.ParsedPartitionSpec( - P('x', 'y'), ('x', 'y'), pjit_lib.SpecSync.OUT_OF_SYNC) + out_of_sync_parsed_pspec = sharding_impls.ParsedPartitionSpec( + P('x', 'y'), ('x', 'y'), sharding_impls.SpecSync.OUT_OF_SYNC) self.assertEqual(out_of_sync_parsed_pspec.get_partition_spec(), P(('x',), ('y',))) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 01d2feb4e..8be2fa526 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -43,6 +43,7 @@ from jax._src import core from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr, linearize, device_put) from jax._src import config as jax_config +from jax._src import sharding_impls from jax._src import sharding_specs from jax._src import xla_bridge from jax._src.lib import xla_extension @@ -1091,7 +1092,7 @@ class PythonPmapTest(jtu.JaxTestCase): self.assertAllClose(ans, expected) def testAxisGroups(self): - axis_env = xla.AxisEnv(8, ('i', 'j'), (4, 2)) + axis_env = sharding_impls.AxisEnv(8, ('i', 'j'), (4, 2)) groups = xla.axis_groups(axis_env, 'i') self.assertEqual(groups, ((0, 2, 4, 6), (1, 3, 5, 7)))