mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Split sharding_impls into its own Bazel target.
* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies. * Fix a handful of new pytype errors. PiperOrigin-RevId: 523146076
This commit is contained in:
parent
a1797170af
commit
be1cf46a49
20
jax/BUILD
20
jax/BUILD
@ -110,7 +110,6 @@ py_library_providing_imports_info(
|
||||
"_src/prng.py",
|
||||
"_src/public_test_util.py",
|
||||
"_src/random.py",
|
||||
"_src/sharding_impls.py",
|
||||
"_src/stages.py",
|
||||
] + glob(
|
||||
[
|
||||
@ -183,6 +182,7 @@ py_library_providing_imports_info(
|
||||
":pretty_printer",
|
||||
":profiler",
|
||||
":sharding",
|
||||
":sharding_impls",
|
||||
":sharding_specs",
|
||||
":source_info_util",
|
||||
":traceback_util",
|
||||
@ -345,6 +345,7 @@ pytype_strict_library(
|
||||
":effects",
|
||||
":op_shardings",
|
||||
":partial_eval",
|
||||
":sharding_impls",
|
||||
":source_info_util",
|
||||
":util",
|
||||
":xla",
|
||||
@ -418,6 +419,22 @@ pytype_strict_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "sharding_impls",
|
||||
srcs = ["_src/sharding_impls.py"],
|
||||
deps = [
|
||||
":mesh",
|
||||
":op_shardings",
|
||||
":partition_spec",
|
||||
":sharding",
|
||||
":sharding_specs",
|
||||
":tree_util",
|
||||
":util",
|
||||
":xla_bridge",
|
||||
"//jax/_src/lib",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "sharding_specs",
|
||||
srcs = ["_src/sharding_specs.py"],
|
||||
@ -504,6 +521,7 @@ pytype_strict_library(
|
||||
":abstract_arrays",
|
||||
":config",
|
||||
":core",
|
||||
":sharding_impls",
|
||||
":source_info_util",
|
||||
":typing",
|
||||
":util",
|
||||
|
@ -46,6 +46,7 @@ from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
from jax._src import array
|
||||
from jax._src import dtypes
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
@ -146,8 +147,8 @@ float0 = dtypes.float0
|
||||
|
||||
def jit(
|
||||
fun: Callable,
|
||||
in_shardings=pxla._UNSPECIFIED,
|
||||
out_shardings=pxla._UNSPECIFIED,
|
||||
in_shardings=sharding_impls.UNSPECIFIED,
|
||||
out_shardings=sharding_impls.UNSPECIFIED,
|
||||
static_argnums: Union[int, Sequence[int], None] = None,
|
||||
static_argnames: Union[str, Iterable[str], None] = None,
|
||||
donate_argnums: Union[int, Sequence[int]] = (),
|
||||
@ -503,11 +504,11 @@ def xla_computation(fun: Callable,
|
||||
|
||||
def make_axis_env(nreps):
|
||||
if axis_env is None:
|
||||
return xla.AxisEnv(nreps, (), ())
|
||||
return sharding_impls.AxisEnv(nreps, (), ())
|
||||
else:
|
||||
nreps = nreps * math.prod(size for name, size in axis_env)
|
||||
names, sizes = unzip2(axis_env)
|
||||
return xla.AxisEnv(nreps, names, sizes)
|
||||
return sharding_impls.AxisEnv(nreps, names, sizes)
|
||||
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
@ -553,7 +554,7 @@ def xla_computation(fun: Callable,
|
||||
ordered_effects=ordered_effects,
|
||||
backend_or_name=backend,
|
||||
platform=platform,
|
||||
axis_context=mlir.ReplicaAxisContext(axis_env_),
|
||||
axis_context=sharding_impls.ReplicaAxisContext(axis_env_),
|
||||
name_stack=source_info_util.new_name_stack(
|
||||
wrap_name(fun_name, "xla_computation")),
|
||||
donated_args=donated_invars,
|
||||
|
@ -23,6 +23,7 @@ from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
@ -107,13 +108,13 @@ def pure_callback_lowering(ctx, *args, callback, **params):
|
||||
|
||||
sharding = None
|
||||
axis_context = ctx.module_context.axis_context
|
||||
if isinstance(axis_context, mlir.ShardingContext):
|
||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
if len(axis_context.device_assignment) > 1:
|
||||
raise NotImplementedError(
|
||||
"pure_callback is only supported in spmd computations when all mesh"
|
||||
" axes are partitioned manually (no partial automatic sharding)."
|
||||
)
|
||||
if isinstance(axis_context, mlir.SPMDAxisContext):
|
||||
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names):
|
||||
raise NotImplementedError(
|
||||
"pure_callback is only supported in spmd computations when all mesh"
|
||||
@ -272,7 +273,7 @@ def io_callback_lowering(ctx, *args, callback, ordered, **params):
|
||||
# can only safely maximally shard. Should we allow device_index to be passed
|
||||
# in like host_callback?
|
||||
if isinstance(ctx.module_context.axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext)):
|
||||
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext)):
|
||||
# Apply maximal sharding so pjit only executes the callback on device 0.
|
||||
sharding = xc.OpSharding()
|
||||
sharding.type = xc.OpSharding.Type.MAXIMAL
|
||||
|
@ -31,6 +31,7 @@ from jax._src import custom_derivatives
|
||||
from jax._src import effects
|
||||
from jax._src import pjit
|
||||
from jax._src import prng
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import tree_util as jtu
|
||||
@ -869,7 +870,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
# Update pjit params to account for extra error values.
|
||||
num_error_vals = len(err_vals)
|
||||
num_out_error_vals = out_tree.num_leaves - len(out_shardings)
|
||||
sharding = pjit._UNSPECIFIED
|
||||
sharding = sharding_impls.UNSPECIFIED
|
||||
|
||||
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
|
||||
new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings)
|
||||
|
@ -29,6 +29,7 @@ from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import pjit
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
@ -123,13 +124,16 @@ ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule
|
||||
def debug_callback_lowering(ctx, *args, effect, callback, **params):
|
||||
|
||||
axis_context = ctx.module_context.axis_context
|
||||
if (isinstance(axis_context, mlir.SPMDAxisContext) and
|
||||
if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and
|
||||
set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)):
|
||||
# If we have fully manual sharding during lowering, that means the JAX
|
||||
# program has per-device semantics, so we run the callback on each device.
|
||||
sharding = xc.OpSharding()
|
||||
sharding.type = xc.OpSharding.Type.MANUAL
|
||||
elif isinstance(axis_context, (mlir.ShardingContext, mlir.SPMDAxisContext)):
|
||||
elif isinstance(
|
||||
axis_context,
|
||||
(sharding_impls.ShardingContext, sharding_impls.SPMDAxisContext),
|
||||
):
|
||||
# If we have fully automatic sharding during lowering, that means the JAX
|
||||
# program has bulk array semantics, so we run the callback with a MAXIMAL
|
||||
# sharding and hence execute it only once on the full logical value).
|
||||
@ -312,9 +316,9 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
axis_context = ctx.module_context.axis_context
|
||||
|
||||
if isinstance(axis_context, mlir.ShardingContext):
|
||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
devices = axis_context.device_assignment
|
||||
elif isinstance(axis_context, mlir.SPMDAxisContext):
|
||||
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
devices = list(axis_context.mesh.devices.flat)
|
||||
else:
|
||||
raise NotImplementedError(type(axis_context))
|
||||
|
@ -53,7 +53,8 @@ from jax._src.monitoring import record_event_duration_secs
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding)
|
||||
PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding,
|
||||
UNSPECIFIED)
|
||||
|
||||
|
||||
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
|
||||
@ -211,13 +212,13 @@ def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params):
|
||||
def sharded_lowering(fun, name, donated_invars, keep_unused,
|
||||
*arg_specs, lowering_platform: Optional[str]):
|
||||
in_avals, in_shardings = util.unzip2(arg_specs)
|
||||
in_shardings = [pxla._UNSPECIFIED if i is None else i for i in in_shardings] # type: ignore
|
||||
in_shardings = [UNSPECIFIED if i is None else i for i in in_shardings] # type: ignore
|
||||
|
||||
# Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know
|
||||
# Pass in a singleton `UNSPECIFIED` for out_shardings because we don't know
|
||||
# the number of output avals at this stage. lower_sharding_computation will
|
||||
# apply it to all out_avals.
|
||||
return pxla.lower_sharding_computation(
|
||||
fun, 'jit', name, in_shardings, pxla._UNSPECIFIED, donated_invars,
|
||||
fun, 'jit', name, in_shardings, UNSPECIFIED, donated_invars,
|
||||
tuple(in_avals), keep_unused=keep_unused, always_lower=False,
|
||||
devices_from_context=None, lowering_platform=lowering_platform)
|
||||
|
||||
|
@ -24,7 +24,7 @@ import itertools
|
||||
import re
|
||||
import typing
|
||||
from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional,
|
||||
Protocol, Sequence, Set, Tuple, Type, Union, FrozenSet)
|
||||
Protocol, Sequence, Set, Tuple, Type, Union)
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
@ -35,6 +35,7 @@ from jax._src import dtypes
|
||||
from jax._src import effects as effects_lib
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
@ -339,68 +340,11 @@ def make_ir_context() -> ir.Context:
|
||||
return context
|
||||
|
||||
|
||||
Mesh = Any
|
||||
MeshAxisName = Any
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SPMDAxisContext:
|
||||
"""A hardware axis context for parallel computations that use the GSPMD partitioner.
|
||||
|
||||
This includes the mesh that will later by used to execute this computation,
|
||||
as well as a set of mesh axes that are currently (e.g. because the current lowering
|
||||
is invoked inside an xmap) lowered in the MANUAL sharding mode.
|
||||
"""
|
||||
mesh: Mesh
|
||||
manual_axes: FrozenSet[MeshAxisName] = frozenset()
|
||||
|
||||
@property
|
||||
def axis_env(self):
|
||||
# All collectives that touch axis_env should remember to set use_global_device_ids
|
||||
# when this context is enabled!
|
||||
if self.manual_axes != frozenset(self.mesh.axis_names):
|
||||
raise NotImplementedError(
|
||||
"Collectives in manually partitioned computations are only supported "
|
||||
"when all mesh axes are partitioned manually (no partial automatic sharding). "
|
||||
"Make sure that you mention all mesh axes in axis_resources!")
|
||||
return self.unsafe_axis_env
|
||||
|
||||
@property
|
||||
def unsafe_axis_env(self):
|
||||
return xla.AxisEnv(
|
||||
nreps=self.mesh.size,
|
||||
names=self.mesh.axis_names,
|
||||
sizes=tuple(self.mesh.shape.values()))
|
||||
|
||||
def extend_manual(self, axes: FrozenSet[MeshAxisName]) -> SPMDAxisContext:
|
||||
return SPMDAxisContext(self.mesh, self.manual_axes | axes)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ReplicaAxisContext:
|
||||
"""A hardware axis context for parallel computations that are partitioned by JAX.
|
||||
|
||||
Unlike in the SPMDAxisContext, this means that JAX might need to emit calls to
|
||||
explicit collectives.
|
||||
"""
|
||||
axis_env: xla.AxisEnv
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ShardingContext:
|
||||
"""A hardware axis context for parallel computations that use the sharding
|
||||
interface.
|
||||
|
||||
This context also uses the GSPMD partitioner.
|
||||
"""
|
||||
device_assignment: Sequence[xc.Device]
|
||||
|
||||
# Similar to SPMDContext as ShardingContext also uses the GSPMD partitioner.
|
||||
@property
|
||||
def axis_env(self):
|
||||
return xla.AxisEnv(nreps=1, names=(), sizes=())
|
||||
|
||||
|
||||
AxisContext = Union[SPMDAxisContext, ReplicaAxisContext, ShardingContext]
|
||||
AxisContext = Union[
|
||||
sharding_impls.SPMDAxisContext,
|
||||
sharding_impls.ReplicaAxisContext,
|
||||
sharding_impls.ShardingContext,
|
||||
]
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModuleContext:
|
||||
@ -428,7 +372,7 @@ class ModuleContext:
|
||||
|
||||
|
||||
@property
|
||||
def axis_env(self) -> xla.AxisEnv:
|
||||
def axis_env(self) -> sharding_impls.AxisEnv:
|
||||
return self.axis_context.axis_env
|
||||
|
||||
def __init__(
|
||||
@ -1539,7 +1483,7 @@ def xla_fallback_lowering(prim: core.Primitive):
|
||||
def fallback(ctx: LoweringRuleContext, *args, **params):
|
||||
module_ctx = ctx.module_context
|
||||
axis_ctx = module_ctx.axis_context
|
||||
if isinstance(axis_ctx, SPMDAxisContext):
|
||||
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
|
||||
axis_env = axis_ctx.unsafe_axis_env
|
||||
else:
|
||||
axis_env = module_ctx.axis_env
|
||||
@ -1645,7 +1589,7 @@ def _emit_tpu_python_callback(
|
||||
ctx: LoweringRuleContext,
|
||||
callback,
|
||||
token: Optional[Any],
|
||||
operands: List[ir.Value],
|
||||
operands: Sequence[ir.Value],
|
||||
operand_avals: List[core.ShapedArray],
|
||||
operand_shapes: List[xc.Shape],
|
||||
result_avals: List[core.ShapedArray],
|
||||
@ -1716,7 +1660,7 @@ def _aval_to_default_layout(aval):
|
||||
|
||||
def emit_python_callback(
|
||||
ctx: LoweringRuleContext, callback, token: Optional[Any],
|
||||
operands: List[ir.Value], operand_avals: List[core.ShapedArray],
|
||||
operands: Sequence[ir.Value], operand_avals: List[core.ShapedArray],
|
||||
result_avals: List[core.ShapedArray],
|
||||
has_side_effect: bool, *, sharding: Optional[xc.OpSharding] = None,
|
||||
operand_layouts: Optional[Sequence[Optional[Sequence[int]]]] = None,
|
||||
|
@ -17,13 +17,12 @@ from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from contextlib import contextmanager
|
||||
from collections import defaultdict, OrderedDict, namedtuple
|
||||
from collections import defaultdict, namedtuple
|
||||
import dataclasses
|
||||
from functools import partial, lru_cache, cached_property
|
||||
import itertools as it
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
|
||||
Sequence, Set, Tuple, Type, Union, Iterable,
|
||||
TYPE_CHECKING, cast)
|
||||
@ -40,7 +39,7 @@ from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import mesh
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import profiler
|
||||
@ -61,6 +60,11 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
|
||||
AUTOAxisResource, UnspecifiedValue, UNSPECIFIED,
|
||||
get_array_mapping as _get_array_mapping, is_auto, is_unspecified
|
||||
)
|
||||
from jax._src.util import (unzip3, safe_map, safe_zip, partition_list,
|
||||
wrap_name, tuple_delete, distributed_debug_log,
|
||||
unzip2, HashableFunction, weakref_lru_cache)
|
||||
@ -71,11 +75,6 @@ class WeakRefList(list):
|
||||
pass
|
||||
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
OrderedDictType = OrderedDict
|
||||
else:
|
||||
OrderedDictType = Dict
|
||||
|
||||
xe = xc._xla
|
||||
|
||||
unsafe_map, map = map, safe_map # type: ignore
|
||||
@ -92,8 +91,8 @@ ShardedAxis = sharding_specs.ShardedAxis
|
||||
Replicated = sharding_specs.Replicated
|
||||
|
||||
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
|
||||
Mesh = jax._src.mesh.Mesh
|
||||
MeshAxisName = mesh.MeshAxisName
|
||||
Mesh = mesh_lib.Mesh
|
||||
MeshAxisName = sharding_impls.MeshAxisName
|
||||
MeshDimAssignment = Union[ShardedAxis, Replicated]
|
||||
ShardingSpec = sharding_specs.ShardingSpec
|
||||
|
||||
@ -233,56 +232,6 @@ def _shard_abstract_array(size, axis: int, x):
|
||||
shard_aval_handlers[ShapedArray] = _shard_abstract_array
|
||||
|
||||
|
||||
class AUTOAxisResource:
|
||||
pass
|
||||
AUTO = AUTOAxisResource()
|
||||
|
||||
def is_auto(x):
|
||||
return isinstance(x, AUTOAxisResource)
|
||||
|
||||
|
||||
class UnspecifiedValue:
|
||||
def __repr__(self):
|
||||
return "UnspecifiedValue"
|
||||
_UNSPECIFIED = UnspecifiedValue()
|
||||
|
||||
def _is_unspecified(x):
|
||||
return isinstance(x, UnspecifiedValue)
|
||||
|
||||
"""
|
||||
ArrayMapping specifies how an ndarray should map to mesh axes.
|
||||
|
||||
Note that the ordering is crucial for the cases when this mapping is non-injective
|
||||
(i.e. when multiple mesh axes map to the same positional axis). Then, the
|
||||
order of entries of the mapping determines a major-to-minor order on mesh axes,
|
||||
according to which chunks of the value along the repeated dimension will be assigned.
|
||||
|
||||
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
|
||||
The second dimension of the value would get chunked into 6 pieces, and assigned to the
|
||||
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
|
||||
that would mean that a flat list of chunks would get assigned to a flattened list of
|
||||
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
|
||||
mesh devices ndarray would have to be transposed before flattening and assignment.
|
||||
"""
|
||||
ArrayMapping = OrderedDictType[mesh.MeshAxisName, int]
|
||||
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTOAxisResource,
|
||||
UnspecifiedValue]
|
||||
|
||||
|
||||
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
|
||||
if not array_mapping:
|
||||
return PartitionSpec()
|
||||
max_index = -1
|
||||
reverse_map = defaultdict(list)
|
||||
for axis, index in array_mapping.items():
|
||||
reverse_map[index].append(axis)
|
||||
if index > max_index:
|
||||
max_index = index
|
||||
partitions = tuple(tuple(reverse_map[i]) if reverse_map[i] else None
|
||||
for i in range(max_index + 1))
|
||||
return PartitionSpec(*partitions)
|
||||
|
||||
|
||||
def local_aval_to_result_handler(
|
||||
aval: core.AbstractValue,
|
||||
sharding: sharding_impls.XLACompatibleSharding,
|
||||
@ -370,7 +319,7 @@ def make_sharded_device_array(
|
||||
if sharding_spec is None:
|
||||
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
|
||||
|
||||
mesh = jax._src.mesh.thread_resources.env.physical_mesh
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
sharding: sharding_impls.XLACompatibleSharding
|
||||
if mesh.empty:
|
||||
sharding = sharding_impls.PmapSharding(
|
||||
@ -930,7 +879,7 @@ def lower_parallel_callable(
|
||||
shards.num_global_shards, avals, replicas.num_global_replicas,
|
||||
parts.num_partitions)
|
||||
|
||||
axis_env = xla.AxisEnv(
|
||||
axis_env = sharding_impls.AxisEnv(
|
||||
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
|
||||
name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap'))
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
@ -952,7 +901,7 @@ def lower_parallel_callable(
|
||||
ordered_effects,
|
||||
backend,
|
||||
lowering_platform or backend.platform,
|
||||
mlir.ReplicaAxisContext(axis_env),
|
||||
sharding_impls.ReplicaAxisContext(axis_env),
|
||||
name_stack,
|
||||
donated_invars,
|
||||
replicated_args=replicated_args,
|
||||
@ -1161,6 +1110,7 @@ class UnloadedPmapExecutable:
|
||||
input_indices = []
|
||||
for aval, spec in safe_zip(self.local_input_avals, self.input_shardings):
|
||||
assert isinstance(spec, sharding_impls.PmapSharding), spec
|
||||
assert isinstance(aval, core.ShapedArray), aval
|
||||
input_indices.append(
|
||||
sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec)
|
||||
if spec.sharding_spec is not None else None)
|
||||
@ -1723,7 +1673,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
|
||||
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
axis_context=mlir.ReplicaAxisContext(new_env),
|
||||
axis_context=sharding_impls.ReplicaAxisContext(new_env),
|
||||
name_stack=ctx.module_context.name_stack.extend(
|
||||
util.wrap_name(name, 'pmap')))
|
||||
sharded_outs, _ = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, mlir.TokenSet(), (),
|
||||
@ -1805,7 +1755,9 @@ def _full_to_shard_abstract_eval(x, axes, mesh, **_):
|
||||
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
|
||||
return tile_aval_nd(mesh.shape, axes, x)
|
||||
|
||||
def manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[mesh.MeshAxisName], mesh: Mesh):
|
||||
def manual_proto(
|
||||
aval: core.ShapedArray,
|
||||
manual_axes_set: FrozenSet[sharding_impls.MeshAxisName], mesh: Mesh):
|
||||
"""Create an OpSharding proto that declares all mesh axes from `axes` as manual
|
||||
and all others as replicated.
|
||||
"""
|
||||
@ -1832,7 +1784,7 @@ def manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[mesh.MeshAxi
|
||||
|
||||
@partial(mlir.register_lowering, full_to_shard_p)
|
||||
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
|
||||
manual_axes: FrozenSet[mesh.MeshAxisName]):
|
||||
manual_axes: FrozenSet[sharding_impls.MeshAxisName]):
|
||||
# TODO: Can we short-circuit for replicated values? Probably not.
|
||||
aval_in, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
@ -1852,7 +1804,7 @@ def _shard_to_full_abstract_eval(x, axes, mesh, **_):
|
||||
|
||||
@partial(mlir.register_lowering, shard_to_full_p)
|
||||
def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
|
||||
manual_axes: FrozenSet[mesh.MeshAxisName]):
|
||||
manual_axes: FrozenSet[sharding_impls.MeshAxisName]):
|
||||
aval_in, = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
proto = manual_proto(aval_in, manual_axes, mesh)
|
||||
@ -1863,7 +1815,7 @@ def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
|
||||
return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, unspecified_dims),
|
||||
|
||||
@lu.transformation
|
||||
def vtile_manual(manual_axes: FrozenSet[mesh.MeshAxisName],
|
||||
def vtile_manual(manual_axes: FrozenSet[sharding_impls.MeshAxisName],
|
||||
mesh: Mesh,
|
||||
in_axes: Sequence[ArrayMapping],
|
||||
out_axes: Sequence[ArrayMapping],
|
||||
@ -1882,7 +1834,7 @@ class TileVectorize:
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TileManual:
|
||||
manual_axes: FrozenSet[mesh.MeshAxisName]
|
||||
manual_axes: FrozenSet[sharding_impls.MeshAxisName]
|
||||
|
||||
TilingMethod = Union[TileVectorize, TileManual]
|
||||
|
||||
@ -1971,7 +1923,7 @@ def _get_and_check_device_assignment(
|
||||
devices = list(devices)
|
||||
|
||||
for i, s_type, source_info in shardings:
|
||||
if is_auto(i) or _is_unspecified(i):
|
||||
if is_auto(i) or is_unspecified(i):
|
||||
continue
|
||||
# Assign `first_sharding_info` after `AUTO` and `UNSPECIFIED` have been
|
||||
# skipped.
|
||||
@ -2096,14 +2048,14 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
|
||||
in_op_shardings = map(_to_logical_op_sharding, global_in_avals, in_shardings)
|
||||
out_op_shardings = map(_to_logical_op_sharding, global_out_avals, out_shardings)
|
||||
replicated_args = [False] * len(global_in_avals)
|
||||
axis_ctx = mlir.ShardingContext(device_assignment)
|
||||
axis_ctx = sharding_impls.ShardingContext(device_assignment)
|
||||
else:
|
||||
# This path is triggered for `jit(pmap)` cases.
|
||||
replicated_args = None
|
||||
in_op_shardings = None
|
||||
out_op_shardings = None
|
||||
axis_env = xla.AxisEnv(nreps, (), ())
|
||||
axis_ctx = mlir.ReplicaAxisContext(axis_env)
|
||||
axis_env = sharding_impls.AxisEnv(nreps, (), ())
|
||||
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
|
||||
|
||||
module_name = f"{api_name}_{fun_name}"
|
||||
|
||||
@ -2189,10 +2141,10 @@ def lower_sharding_computation(
|
||||
) -> MeshComputation:
|
||||
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
|
||||
|
||||
The caller of this code can pass in a singleton _UNSPECIFIED because the
|
||||
The caller of this code can pass in a singleton UNSPECIFIED because the
|
||||
number of out_avals might not be known at that time and
|
||||
lower_sharding_computation calculates the number of out_avals so it can apply
|
||||
the singleton _UNSPECIFIED to all out_avals.
|
||||
the singleton UNSPECIFIED to all out_avals.
|
||||
"""
|
||||
# 1. Trace to jaxpr and preprocess/verify it
|
||||
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
|
||||
@ -2202,8 +2154,8 @@ def lower_sharding_computation(
|
||||
jaxpr = closed_jaxpr.jaxpr
|
||||
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
|
||||
|
||||
if _is_unspecified(out_shardings):
|
||||
out_shardings = (_UNSPECIFIED,) * len(global_out_avals)
|
||||
if is_unspecified(out_shardings):
|
||||
out_shardings = (UNSPECIFIED,) * len(global_out_avals)
|
||||
assert isinstance(out_shardings, tuple)
|
||||
assert len(out_shardings) == len(global_out_avals), (
|
||||
len(out_shardings), len(global_out_avals))
|
||||
@ -2221,12 +2173,12 @@ def lower_sharding_computation(
|
||||
committed = bool(
|
||||
devices_from_context or
|
||||
len(device_assignment) > 1 or
|
||||
any(not _is_unspecified(i) for i in in_shardings) or
|
||||
any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or
|
||||
any(not _is_unspecified(o) for o in out_shardings))
|
||||
any(not is_unspecified(i) for i in in_shardings) or
|
||||
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
|
||||
any(not is_unspecified(o) for o in out_shardings))
|
||||
|
||||
in_shardings = tuple(sharding_impls.GSPMDSharding.get_replicated(device_assignment)
|
||||
if _is_unspecified(i) else i for i in in_shardings)
|
||||
if is_unspecified(i) else i for i in in_shardings)
|
||||
|
||||
da_object = _create_da_object(tuple(device_assignment))
|
||||
|
||||
@ -2260,7 +2212,7 @@ def lower_sharding_computation(
|
||||
# and don't need to evaluate their arguments.
|
||||
if (not always_lower and not (jaxpr.effects or has_outfeed) and
|
||||
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and
|
||||
all(_is_unspecified(o) for o in out_shardings)):
|
||||
all(is_unspecified(o) for o in out_shardings)):
|
||||
return MeshComputation(
|
||||
str(name_stack), None, True, donated_invars, jaxpr=jaxpr,
|
||||
consts=closed_jaxpr.consts, global_in_avals=global_in_avals,
|
||||
@ -2309,7 +2261,7 @@ def lower_sharding_computation(
|
||||
def _to_logical_op_sharding(
|
||||
aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTOAxisResource]
|
||||
) -> Optional[xc.OpSharding]:
|
||||
if _is_unspecified(sharding) or is_auto(sharding):
|
||||
if is_unspecified(sharding) or is_auto(sharding):
|
||||
return None
|
||||
elif isinstance(aval, ShapedArray):
|
||||
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
|
||||
@ -2419,15 +2371,16 @@ def lower_mesh_computation(
|
||||
in_partitions = map(_to_logical_op_sharding, global_in_avals, in_shardings)
|
||||
out_partitions = map(_to_logical_op_sharding, global_out_avals, out_shardings)
|
||||
replicated_args = [False] * len(in_jaxpr_avals)
|
||||
axis_ctx = mlir.SPMDAxisContext(mesh, manual_axes)
|
||||
axis_ctx = sharding_impls.SPMDAxisContext(mesh, manual_axes)
|
||||
else:
|
||||
replicated_args = [not get_array_mapping(i.spec) for i in in_shardings] # type: ignore
|
||||
in_partitions = None
|
||||
out_partitions = None
|
||||
axis_env = xla.AxisEnv(nreps=mesh.size,
|
||||
names=tuple(global_axis_sizes.keys()),
|
||||
sizes=tuple(global_axis_sizes.values()))
|
||||
axis_ctx = mlir.ReplicaAxisContext(axis_env)
|
||||
axis_env = sharding_impls.AxisEnv(
|
||||
nreps=mesh.size,
|
||||
names=tuple(global_axis_sizes.keys()),
|
||||
sizes=tuple(global_axis_sizes.values()))
|
||||
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
|
||||
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
module: Union[str, xc.XlaComputation]
|
||||
module_name = f"{api_name}_{fun_name}"
|
||||
@ -2834,7 +2787,7 @@ class UnloadedMeshExecutable:
|
||||
for x, o in safe_zip(out_shardings_xla, out_shardings)
|
||||
]
|
||||
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
|
||||
elif (out_shardings and any(_is_unspecified(o) for o in out_shardings)
|
||||
elif (out_shardings and any(is_unspecified(o) for o in out_shardings)
|
||||
and pmap_nreps == 1):
|
||||
assert mesh is None
|
||||
_, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
|
||||
@ -2844,7 +2797,7 @@ class UnloadedMeshExecutable:
|
||||
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
|
||||
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
|
||||
global_out_avals):
|
||||
if _is_unspecified(orig):
|
||||
if is_unspecified(orig):
|
||||
out_shardings.append(xla_s)
|
||||
are_out_shardings_from_xla.append(True)
|
||||
else:
|
||||
@ -3133,7 +3086,7 @@ def create_mesh_pspec_sharding(
|
||||
|
||||
def check_device_backend_on_shardings(shardings) -> bool:
|
||||
for i in shardings:
|
||||
if _is_unspecified(i) or is_auto(i):
|
||||
if is_unspecified(i) or is_auto(i):
|
||||
continue
|
||||
if hasattr(i, '_original_sharding') and getattr(
|
||||
i._original_sharding, '_device_backend', False):
|
||||
@ -3163,10 +3116,8 @@ def check_gda_or_array_xla_sharding_match(
|
||||
|
||||
|
||||
def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
|
||||
# Import here to avoid cyclic import error when importing gda in pjit.py.
|
||||
from jax.experimental.pjit import get_array_mapping as _get_array_mapping, _prepare_axis_resources
|
||||
|
||||
parsed_pspec, _, _ = _prepare_axis_resources(pspec, "pspec to array_mapping")
|
||||
parsed_pspec, _, _ = sharding_impls.prepare_axis_resources(
|
||||
pspec, "pspec to array_mapping")
|
||||
return _get_array_mapping(parsed_pspec)
|
||||
|
||||
|
||||
|
@ -22,7 +22,7 @@ import itertools as it
|
||||
import math
|
||||
import operator
|
||||
import re
|
||||
from typing import (Any, Callable, Dict, NamedTuple, Optional, Protocol,
|
||||
from typing import (Any, Callable, Dict, Optional, Protocol,
|
||||
Sequence, Set, Type, Tuple, Union, TYPE_CHECKING)
|
||||
|
||||
import numpy as np
|
||||
@ -34,6 +34,7 @@ from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src.abstract_arrays import numpy_scalar_types
|
||||
from jax._src.core import ConcreteArray, ShapedArray
|
||||
from jax._src.sharding_impls import AxisEnv
|
||||
from jax._src.util import safe_zip, safe_map
|
||||
|
||||
from jax._src.typing import Shape
|
||||
@ -254,13 +255,6 @@ def primitive_subcomputation(platform: str, axis_env: 'AxisEnv',
|
||||
|
||||
### compiling jaxprs
|
||||
|
||||
|
||||
class AxisEnv(NamedTuple):
|
||||
"""Represents a pmap mesh (only along the replica axes)."""
|
||||
nreps: int
|
||||
names: Tuple[Any, ...]
|
||||
sizes: Tuple[int, ...]
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TranslationContext:
|
||||
builder: xc.XlaBuilder
|
||||
@ -272,8 +266,6 @@ class TranslationContext:
|
||||
|
||||
def replace(self, **kw): return dataclasses.replace(self, **kw)
|
||||
|
||||
|
||||
|
||||
def xla_destructure(c, ans):
|
||||
num_elements = len(c.get_shape(ans).tuple_shapes())
|
||||
return [xops.GetTupleElement(ans, i) for i in range(num_elements)]
|
||||
|
@ -28,6 +28,7 @@ from jax import tree_util
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import util
|
||||
from jax._src.core import ShapedArray, AxisName, raise_to_shaped
|
||||
from jax._src.interpreters import ad
|
||||
@ -750,8 +751,10 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
_replica_groups(ctx.module_context.axis_env, named_axes,
|
||||
axis_index_groups))
|
||||
axis_context = ctx.module_context.axis_context
|
||||
is_spmd = isinstance(axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext))
|
||||
is_spmd = isinstance(
|
||||
axis_context,
|
||||
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
|
||||
)
|
||||
|
||||
def all_reduce(aval, x):
|
||||
if is_spmd:
|
||||
@ -880,7 +883,10 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm):
|
||||
full_perm = full_perm.reshape((-1, 2))
|
||||
|
||||
axis_context = ctx.module_context.axis_context
|
||||
is_manual = isinstance(axis_context, mlir.SPMDAxisContext) and axis_context.manual_axes
|
||||
is_manual = (
|
||||
isinstance(axis_context, sharding_impls.SPMDAxisContext)
|
||||
and axis_context.manual_axes
|
||||
)
|
||||
if is_manual:
|
||||
channel = ctx.module_context.new_channel()
|
||||
other_args = dict(
|
||||
@ -978,8 +984,10 @@ def _all_to_all_lowering(ctx, x, *,
|
||||
split_count = len(replica_groups[0])
|
||||
if not all(split_count == len(g) for g in replica_groups):
|
||||
raise ValueError('Replica groups must be equally sized')
|
||||
is_spmd = isinstance(ctx.module_context.axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext))
|
||||
is_spmd = isinstance(
|
||||
ctx.module_context.axis_context,
|
||||
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
|
||||
)
|
||||
if is_spmd:
|
||||
# We want to emit the all-gather with global device IDs and a unique
|
||||
# channel ID, as otherwise it interprets the devices as replicas instead
|
||||
@ -1209,8 +1217,10 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
|
||||
x_aval, = ctx.avals_in
|
||||
out_aval, = ctx.avals_out
|
||||
axis_context = ctx.module_context.axis_context
|
||||
is_spmd = isinstance(axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext))
|
||||
is_spmd = isinstance(
|
||||
axis_context,
|
||||
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
|
||||
)
|
||||
if (ctx.module_context.platform == 'tpu' or
|
||||
ctx.module_context.platform in ('cuda', 'rocm')
|
||||
and all_gather_dimension == 0):
|
||||
@ -1354,8 +1364,10 @@ def _reduce_scatter_lowering(prim, reducer, ctx, x,
|
||||
scatter_out_shape = list(x_aval.shape)
|
||||
scatter_out_shape[scatter_dimension] //= axis_size
|
||||
axis_context = ctx.module_context.axis_context
|
||||
is_spmd = isinstance(axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext))
|
||||
is_spmd = isinstance(
|
||||
axis_context,
|
||||
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
|
||||
)
|
||||
if is_spmd:
|
||||
# We want to emit the all-gather with global device IDs and a unique
|
||||
# channel ID, as otherwise it interprets the devices as replicas instead
|
||||
@ -1572,8 +1584,10 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
|
||||
)
|
||||
mod = mlir.ir_constant(np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
|
||||
axis_context = ctx.module_context.axis_context
|
||||
is_spmd = isinstance(axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext))
|
||||
is_spmd = isinstance(
|
||||
axis_context,
|
||||
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
|
||||
)
|
||||
if is_spmd:
|
||||
device_id = hlo.PartitionIdOp()
|
||||
else:
|
||||
|
@ -662,6 +662,7 @@ def _select_and_gather_add_lowering(
|
||||
max_bits=64):
|
||||
_, operand_aval, = ctx.avals_in
|
||||
out_aval, = ctx.avals_out
|
||||
assert isinstance(operand_aval, core.ShapedArray), operand_aval
|
||||
dtype = operand_aval.dtype
|
||||
etype = mlir.dtype_to_ir_type(dtype)
|
||||
nbits = dtypes.finfo(dtype).bits
|
||||
|
@ -26,9 +26,10 @@ from jax import numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
from jax._src import mesh
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import stages
|
||||
from jax._src import traceback_util
|
||||
@ -47,10 +48,11 @@ from jax._src.interpreters.partial_eval import (
|
||||
convert_constvars_jaxpr, new_jaxpr_eqn)
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.pjit import (
|
||||
sharding_constraint_p, ParsedPartitionSpec, get_unconstrained_dims,
|
||||
GSPMDSharding)
|
||||
from jax._src.sharding_impls import NamedSharding
|
||||
from jax._src.pjit import (sharding_constraint_p, get_unconstrained_dims,
|
||||
GSPMDSharding)
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping, NamedSharding, ParsedPartitionSpec,
|
||||
array_mapping_to_axis_resources)
|
||||
from jax._src.tree_util import (tree_flatten, tree_unflatten, all_leaves,
|
||||
tree_map, treedef_tuple)
|
||||
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3,
|
||||
@ -90,12 +92,12 @@ class FrozenDict(abc.Mapping):
|
||||
# Multi-dimensional generalized map
|
||||
|
||||
AxisName = core.AxisName
|
||||
ResourceAxisName = mesh.ResourceAxisName # Different name just for documentation purposes
|
||||
Mesh = mesh.Mesh
|
||||
MeshAxisName = mesh.MeshAxisName
|
||||
ResourceEnv = mesh.ResourceEnv
|
||||
EMPTY_ENV = mesh.EMPTY_ENV
|
||||
thread_resources = mesh.thread_resources
|
||||
ResourceAxisName = mesh_lib.ResourceAxisName # Different name just for documentation purposes
|
||||
Mesh = mesh_lib.Mesh
|
||||
MeshAxisName = mesh_lib.MeshAxisName
|
||||
ResourceEnv = mesh_lib.ResourceEnv
|
||||
EMPTY_ENV = mesh_lib.EMPTY_ENV
|
||||
thread_resources = mesh_lib.thread_resources
|
||||
|
||||
|
||||
class SerialLoop:
|
||||
@ -165,7 +167,7 @@ def serial_loop(name: ResourceAxisName, length: int):
|
||||
axis_resources={'i': 'l'})(x)
|
||||
"""
|
||||
old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV)
|
||||
thread_resources.env = old_env.with_extra_loop(mesh.Loop(name, length))
|
||||
thread_resources.env = old_env.with_extra_loop(mesh_lib.Loop(name, length))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@ -677,9 +679,9 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
tiling_method = pxla.TileManual(manual_mesh_axes)
|
||||
else:
|
||||
tiling_method = pxla.TileVectorize()
|
||||
in_shardings = [NamedSharding(mesh, pxla.array_mapping_to_axis_resources(i))
|
||||
in_shardings = [NamedSharding(mesh, array_mapping_to_axis_resources(i))
|
||||
for i in mesh_in_axes]
|
||||
out_shardings = [NamedSharding(mesh, pxla.array_mapping_to_axis_resources(o))
|
||||
out_shardings = [NamedSharding(mesh, array_mapping_to_axis_resources(o))
|
||||
for o in mesh_out_axes]
|
||||
return pxla.lower_mesh_computation(
|
||||
f, 'xmap', name, mesh,
|
||||
@ -937,7 +939,7 @@ def _resource_typing_xmap(avals,
|
||||
raise JAXTypeError(
|
||||
f"Detected disallowed xmap axis name shadowing at "
|
||||
f"{source_info_util.summarize(source_info)} "
|
||||
f"(shadowed axes: {mesh.show_axes(overlap)})")
|
||||
f"(shadowed axes: {mesh_lib.show_axes(overlap)})")
|
||||
|
||||
if resource_env.physical_mesh != params['resource_env'].physical_mesh:
|
||||
raise RuntimeError("Changing the physical mesh is not allowed inside xmap.")
|
||||
@ -965,9 +967,9 @@ def _resource_typing_xmap(avals,
|
||||
raise JAXTypeError(
|
||||
f"One of xmapped function ({params['name']}) outputs is broadcast "
|
||||
f"along axis `{baxis}` which is assigned to resources "
|
||||
f"{mesh.show_axes(baxis_resources)}, but the output is already "
|
||||
f"partitioned along {mesh.show_axes(overlap)}, because its "
|
||||
f"named shape contains {mesh.show_axes(partitioning_axes)}")
|
||||
f"{mesh_lib.show_axes(baxis_resources)}, but the output is already "
|
||||
f"partitioned along {mesh_lib.show_axes(overlap)}, because its "
|
||||
f"named shape contains {mesh_lib.show_axes(partitioning_axes)}")
|
||||
pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap
|
||||
|
||||
|
||||
@ -1269,7 +1271,7 @@ batching.BatchTrace.post_process_xmap = _batch_trace_post_process_xmap
|
||||
# -------- nested xmap handling --------
|
||||
|
||||
def _xmap_lowering_rule(ctx, *args, **kwargs):
|
||||
if isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext):
|
||||
if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext):
|
||||
if config.experimental_xmap_spmd_lowering_manual:
|
||||
return _xmap_lowering_rule_spmd_manual(ctx, *args, **kwargs)
|
||||
else:
|
||||
@ -1277,9 +1279,9 @@ def _xmap_lowering_rule(ctx, *args, **kwargs):
|
||||
# Here ShardingContext is used in place of ReplicaAxisContext because when
|
||||
# axis_resources and mesh is not used with xmap, `make_xmap_callable` will
|
||||
# go via `dispatch.sharded_lowering` path which sets the context to
|
||||
# ShardingContext. mlir.ShardingContext is not used for SPMD.
|
||||
# ShardingContext. sharding_impls.ShardingContext is not used for SPMD.
|
||||
elif isinstance(ctx.module_context.axis_context,
|
||||
(mlir.ReplicaAxisContext, mlir.ShardingContext)):
|
||||
(sharding_impls.ReplicaAxisContext, sharding_impls.ShardingContext)):
|
||||
return _xmap_lowering_rule_replica(ctx, *args, **kwargs)
|
||||
else:
|
||||
raise AssertionError("Unrecognized axis context type!")
|
||||
@ -1382,7 +1384,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
|
||||
# XXX: We modify mesh_in_axes and mesh_out_axes here
|
||||
def add_spmd_axes(
|
||||
flat_mesh_axes: Sequence[pxla.ArrayMapping],
|
||||
flat_mesh_axes: Sequence[ArrayMapping],
|
||||
flat_extra_axes: Optional[Sequence[Sequence[Sequence[MeshAxisName]]]]):
|
||||
if flat_extra_axes is None:
|
||||
return
|
||||
@ -1456,7 +1458,8 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
|
||||
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
assert isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext)
|
||||
assert isinstance(ctx.module_context.axis_context,
|
||||
sharding_impls.SPMDAxisContext)
|
||||
sub_ctx = ctx.module_context.replace(
|
||||
name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')),
|
||||
axis_context=ctx.module_context.axis_context.extend_manual(manual_mesh_axes))
|
||||
@ -1755,7 +1758,7 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
|
||||
|
||||
s = arg.sharding
|
||||
xmap_sharding = pxla.create_mesh_pspec_sharding(
|
||||
mesh, pxla.array_mapping_to_axis_resources(xmap_array_mapping))
|
||||
mesh, array_mapping_to_axis_resources(xmap_array_mapping))
|
||||
# This check is cached because comparing OpSharding is expensive during
|
||||
# dispatch and if the shardings are the same, then there is no need to
|
||||
# compare twice.
|
||||
|
340
jax/_src/pjit.py
340
jax/_src/pjit.py
@ -13,10 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import dataclasses
|
||||
from enum import IntEnum
|
||||
import inspect
|
||||
import numpy as np
|
||||
from collections import OrderedDict, Counter
|
||||
from typing import (Callable, Sequence, Tuple, Union, cast, List, Optional,
|
||||
Iterable, NamedTuple, Any)
|
||||
import itertools as it
|
||||
@ -30,6 +28,7 @@ from jax._src import dispatch
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
@ -51,10 +50,13 @@ from jax._src.interpreters import pxla
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, XLACompatibleSharding, GSPMDSharding,
|
||||
XLADeviceAssignment, SingleDeviceSharding, PmapSharding)
|
||||
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
|
||||
AUTOAxisResource, UNSPECIFIED, UnspecifiedValue,
|
||||
CanonicalizedParsedPartitionSpec, ParsedPartitionSpec,
|
||||
SpecSync, get_single_pspec, is_auto, is_unspecified, is_unspecified_or_auto,
|
||||
prepare_axis_resources)
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import (
|
||||
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
|
||||
@ -62,7 +64,7 @@ from jax._src.tree_util import (
|
||||
prefix_errors, generate_key_paths)
|
||||
from jax._src.util import (
|
||||
HashableFunction, safe_map, safe_zip, wraps,
|
||||
distributed_debug_log, split_list, tuple_insert, weakref_lru_cache,
|
||||
distributed_debug_log, split_list, weakref_lru_cache,
|
||||
merge_lists)
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
@ -70,41 +72,10 @@ zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
|
||||
_AUTOAxisResource = pxla.AUTOAxisResource
|
||||
AUTO = pxla.AUTO # type: ignore
|
||||
is_auto = pxla.is_auto
|
||||
|
||||
_UnspecifiedValue = pxla.UnspecifiedValue
|
||||
_UNSPECIFIED = pxla._UNSPECIFIED # type: ignore
|
||||
_is_unspecified = pxla._is_unspecified
|
||||
|
||||
def _is_unspecified_or_auto(x):
|
||||
return is_auto(x) or _is_unspecified(x)
|
||||
|
||||
|
||||
PjitSharding = Union[GSPMDSharding, _UnspecifiedValue, _AUTOAxisResource]
|
||||
PjitShardingMinusUnspecified = Union[GSPMDSharding, _AUTOAxisResource]
|
||||
MeshSharding = Union[NamedSharding, _UnspecifiedValue, _AUTOAxisResource]
|
||||
MeshShardingMinusUnspecified = Union[NamedSharding, _AUTOAxisResource]
|
||||
|
||||
|
||||
def _check_all_or_none_unspecified(axis_resources, name):
|
||||
if not axis_resources:
|
||||
return False
|
||||
unspecified_count = 0
|
||||
unspecified = _is_unspecified(axis_resources[0])
|
||||
for resource in axis_resources:
|
||||
current_is_unspecified = _is_unspecified(resource)
|
||||
if current_is_unspecified:
|
||||
unspecified_count += 1
|
||||
assert unspecified_count == 1
|
||||
if current_is_unspecified != unspecified:
|
||||
raise ValueError(f'`pjit._UNSPECIFIED` exists in {name}. '
|
||||
f'Make sure that every entry in {name} is '
|
||||
'`pjit._UNSPECIFIED`.')
|
||||
return unspecified
|
||||
|
||||
PjitSharding = Union[GSPMDSharding, UnspecifiedValue, AUTOAxisResource]
|
||||
PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTOAxisResource]
|
||||
MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTOAxisResource]
|
||||
MeshShardingMinusUnspecified = Union[NamedSharding, AUTOAxisResource]
|
||||
|
||||
def _try_infer_args(f, tree):
|
||||
dummy_args = tree_unflatten(tree, [False] * tree.num_leaves)
|
||||
@ -281,27 +252,27 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
||||
|
||||
def _resolve_axis_resources_and_shardings_arg(
|
||||
in_shardings, out_shardings, in_axis_resources, out_axis_resources):
|
||||
if not _is_unspecified(in_shardings) and not _is_unspecified(in_axis_resources):
|
||||
if not is_unspecified(in_shardings) and not is_unspecified(in_axis_resources):
|
||||
raise ValueError(
|
||||
'Setting both in_shardings and in_axis_resources is not '
|
||||
'allowed. in_axis_resources is deprecated. Please use in_shardings.')
|
||||
if not _is_unspecified(out_shardings) and not _is_unspecified(out_axis_resources):
|
||||
if not is_unspecified(out_shardings) and not is_unspecified(out_axis_resources):
|
||||
raise ValueError(
|
||||
'Setting both out_shardings and out_axis_resources is not '
|
||||
'allowed. out_axis_resources is deprecated. Please use out_shardings.')
|
||||
if (not _is_unspecified(in_axis_resources) or
|
||||
not _is_unspecified(out_axis_resources)):
|
||||
if (not is_unspecified(in_axis_resources) or
|
||||
not is_unspecified(out_axis_resources)):
|
||||
warnings.warn(
|
||||
'in_axis_resources and out_axis_resources are deprecated. Please use '
|
||||
'in_shardings and out_shardings as their replacement.',
|
||||
DeprecationWarning)
|
||||
|
||||
if not _is_unspecified(in_axis_resources):
|
||||
if not is_unspecified(in_axis_resources):
|
||||
final_in_shardings = in_axis_resources
|
||||
else:
|
||||
final_in_shardings = in_shardings
|
||||
|
||||
if not _is_unspecified(out_axis_resources):
|
||||
if not is_unspecified(out_axis_resources):
|
||||
final_out_shardings = out_axis_resources
|
||||
else:
|
||||
final_out_shardings = out_shardings
|
||||
@ -326,10 +297,10 @@ def pre_infer_params(fun, in_shardings, out_shardings,
|
||||
if device is not None and backend is not None:
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
f"got {device=} and {backend=}")
|
||||
if not _is_unspecified(in_shardings):
|
||||
if not is_unspecified(in_shardings):
|
||||
raise ValueError('If backend or device is specified on jit, then '
|
||||
'in_shardings should not be specified.')
|
||||
if not _is_unspecified(out_shardings):
|
||||
if not is_unspecified(out_shardings):
|
||||
raise ValueError('If backend or device is specified on jit, then '
|
||||
'out_shardings should not be specified.')
|
||||
|
||||
@ -341,8 +312,8 @@ def pre_infer_params(fun, in_shardings, out_shardings,
|
||||
# rather than raising an error. https://github.com/google/jax/issues/2367
|
||||
in_shardings = tuple(in_shardings)
|
||||
|
||||
in_shardings, _, _ = _prepare_axis_resources(in_shardings, 'in_shardings')
|
||||
out_shardings, _, _ = _prepare_axis_resources(out_shardings, 'out_shardings')
|
||||
in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
|
||||
out_shardings, _, _ = prepare_axis_resources(out_shardings, 'out_shardings')
|
||||
|
||||
donate_argnums, static_argnums, static_argnames = resolve_argnums(
|
||||
fun, donate_argnums, static_argnums, static_argnames)
|
||||
@ -394,8 +365,8 @@ def _pjit_explicit_sharding(in_shardings, out_shardings, device,
|
||||
out_shardings_flat, _ = tree_flatten(out_shardings)
|
||||
return (device is not None or
|
||||
backend is not None or
|
||||
any(not _is_unspecified(i) for i in in_shardings_flat) or
|
||||
any(not _is_unspecified(i) for i in out_shardings_flat))
|
||||
any(not is_unspecified(i) for i in in_shardings_flat) or
|
||||
any(not is_unspecified(i) for i in out_shardings_flat))
|
||||
|
||||
|
||||
class PjitInfo(NamedTuple):
|
||||
@ -418,7 +389,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
donate_argnums, device, backend, keep_unused, inline,
|
||||
resource_env, abstracted_axes) = pjit_info_args
|
||||
|
||||
if kwargs and not _is_unspecified(user_in_shardings):
|
||||
if kwargs and not is_unspecified(user_in_shardings):
|
||||
raise ValueError(
|
||||
"pjit does not support kwargs when in_shardings is specified.")
|
||||
|
||||
@ -511,7 +482,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
|
||||
|
||||
num_extra_args = len(implicit_args) + len(consts)
|
||||
canonicalized_in_shardings_flat = \
|
||||
(_UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat
|
||||
(UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat
|
||||
donated_invars = (False,) * num_extra_args + donated_invars
|
||||
assert (len(canonicalized_in_shardings_flat) == len(donated_invars) ==
|
||||
len(consts) + len(args_flat))
|
||||
@ -574,10 +545,10 @@ def _flat_axes_specs(abstracted_axes, *args, **kwargs
|
||||
# because `None` means that the input is fully replicated.
|
||||
def pjit(
|
||||
fun: Callable,
|
||||
in_shardings=_UNSPECIFIED,
|
||||
out_shardings=_UNSPECIFIED,
|
||||
in_axis_resources=_UNSPECIFIED,
|
||||
out_axis_resources=_UNSPECIFIED,
|
||||
in_shardings=UNSPECIFIED,
|
||||
out_shardings=UNSPECIFIED,
|
||||
in_axis_resources=UNSPECIFIED,
|
||||
out_axis_resources=UNSPECIFIED,
|
||||
static_argnums: Union[int, Sequence[int], None] = None,
|
||||
static_argnames: Union[str, Iterable[str], None] = None,
|
||||
donate_argnums: Union[int, Sequence[int]] = (),
|
||||
@ -775,13 +746,13 @@ def hashable_pytree(pytree):
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x):
|
||||
if _is_unspecified_or_auto(x):
|
||||
if is_unspecified_or_auto(x):
|
||||
return x
|
||||
return pxla.create_mesh_pspec_sharding(mesh, x.user_spec, x)
|
||||
|
||||
|
||||
def _create_sharding_for_array(mesh, x, name):
|
||||
if isinstance(x, XLACompatibleSharding) or _is_unspecified_or_auto(x):
|
||||
if isinstance(x, XLACompatibleSharding) or is_unspecified_or_auto(x):
|
||||
return x
|
||||
if mesh is None:
|
||||
msg = ('jax.jit only supports `XLACompatibleSharding`s being passed to'
|
||||
@ -803,7 +774,7 @@ def _create_sharding_for_array(mesh, x, name):
|
||||
' call site? Alternatively, provide `XLACompatibleSharding`s to'
|
||||
' `in_shardings` and `out_shardings` and then the mesh context manager'
|
||||
' is not required.')
|
||||
# A nice user error is raised in _prepare_axis_resources.
|
||||
# A nice user error is raised in prepare_axis_resources.
|
||||
assert isinstance(x, ParsedPartitionSpec), x
|
||||
return _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x)
|
||||
|
||||
@ -884,7 +855,7 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
|
||||
orig_in_shardings = in_shardings_thunk()
|
||||
# Only do this if original in_shardings are unspecified. If it is AUTO, go
|
||||
# via flatten_axis_resources.
|
||||
if _is_unspecified(orig_in_shardings):
|
||||
if is_unspecified(orig_in_shardings):
|
||||
in_shardings_flat = (orig_in_shardings,) * len(in_avals)
|
||||
else:
|
||||
in_shardings_flat = flatten_axis_resources(
|
||||
@ -895,7 +866,7 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
|
||||
pjit_check_aval_sharding(in_shardings_flat, in_avals,
|
||||
"pjit arguments", allow_uneven_sharding=False)
|
||||
canonicalized_shardings = tuple(
|
||||
i if _is_unspecified_or_auto(i) else to_gspmd_sharding(i, aval.ndim)
|
||||
i if is_unspecified_or_auto(i) else to_gspmd_sharding(i, aval.ndim)
|
||||
for i, aval in zip(in_shardings_flat, in_avals))
|
||||
return canonicalized_shardings
|
||||
|
||||
@ -933,7 +904,7 @@ def _check_and_canonicalize_out_shardings(
|
||||
# instead. This condition exists because flatten_axis_resources passes in an
|
||||
# `object()` while unflattening which breaks assertion is user defined
|
||||
# pytrees (which shouldn't exist but they do).
|
||||
if (_is_unspecified(orig_out_shardings) or
|
||||
if (is_unspecified(orig_out_shardings) or
|
||||
isinstance(orig_out_shardings, XLACompatibleSharding)):
|
||||
out_shardings_flat = (orig_out_shardings,) * len(out_type)
|
||||
else:
|
||||
@ -946,7 +917,7 @@ def _check_and_canonicalize_out_shardings(
|
||||
allow_uneven_sharding=False)
|
||||
|
||||
canonicalized_out_shardings_flat = tuple(
|
||||
o if _is_unspecified(o) or is_auto(o) else to_gspmd_sharding(o, aval.ndim)
|
||||
o if is_unspecified(o) or is_auto(o) else to_gspmd_sharding(o, aval.ndim)
|
||||
for o, aval in zip(out_shardings_flat, out_type)
|
||||
)
|
||||
return canonicalized_out_shardings_flat
|
||||
@ -965,7 +936,7 @@ def _pjit_jaxpr(fun, out_shardings_thunk, in_type, debug_info, out_tree,
|
||||
def pjit_check_aval_sharding(
|
||||
shardings, flat_avals, what_aval: str, allow_uneven_sharding: bool):
|
||||
for aval, s in zip(flat_avals, shardings):
|
||||
if _is_unspecified_or_auto(s):
|
||||
if is_unspecified_or_auto(s):
|
||||
continue
|
||||
shape = aval.shape
|
||||
try:
|
||||
@ -994,170 +965,6 @@ def pjit_check_aval_sharding(
|
||||
f"(full shape: {shape}) ")
|
||||
|
||||
|
||||
class SpecSync(IntEnum):
|
||||
"""Encodes how much out of sync the real value of partitions is compared to the user specified one.
|
||||
|
||||
We use this to make sure we don't show garbage modified values while claiming
|
||||
that the users have specified them like that.
|
||||
"""
|
||||
OUT_OF_SYNC = 0 # Arbitrary changes, including new axes inserted
|
||||
DIM_PERMUTE = 1 # Dimensions permuted, but no new sharding axes
|
||||
IN_SYNC = 2 # Entirely in sync
|
||||
|
||||
class ParsedPartitionSpec:
|
||||
__slots__ = ('unsafe_user_spec', 'partitions', 'sync')
|
||||
|
||||
def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC):
|
||||
self.unsafe_user_spec = user_spec
|
||||
# None in partitions represents unconstrained dim.
|
||||
# TODO(yashkatariya): May use a sentinel value.
|
||||
self.partitions = tuple(partitions)
|
||||
self.sync = sync
|
||||
|
||||
@property
|
||||
def user_spec(self):
|
||||
return self.unsynced_user_spec(SpecSync.IN_SYNC)
|
||||
|
||||
def get_partition_spec(self) -> PartitionSpec:
|
||||
if self.sync < SpecSync.IN_SYNC:
|
||||
return _get_single_pspec(self)
|
||||
else:
|
||||
if isinstance(self.unsafe_user_spec, PartitionSpec):
|
||||
return self.unsafe_user_spec
|
||||
else:
|
||||
return _get_single_pspec(self)
|
||||
|
||||
def unsynced_user_spec(self, min_sync):
|
||||
if self.sync < min_sync:
|
||||
raise AssertionError(f"Please open a bug report! ({self.sync} >= {min_sync})")
|
||||
return self.unsafe_user_spec
|
||||
|
||||
def insert_axis_partitions(self, dim, val):
|
||||
parts = self.partitions
|
||||
too_short = dim - len(parts)
|
||||
if too_short > 0:
|
||||
parts += ((),) * too_short
|
||||
new_partitions = tuple_insert(parts, dim, val)
|
||||
new_sync = SpecSync.DIM_PERMUTE if (val == () or val is None) else SpecSync.OUT_OF_SYNC
|
||||
return ParsedPartitionSpec(self.unsafe_user_spec, new_partitions, sync=new_sync)
|
||||
|
||||
@classmethod
|
||||
def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False):
|
||||
if entry is None:
|
||||
return cls(entry, ())
|
||||
if not isinstance(entry, PartitionSpec):
|
||||
raise TypeError(f"{arg_name} are expected to be "
|
||||
f"PartitionSpec instances or None, but got {entry}")
|
||||
axis_specs = []
|
||||
for axis_spec in entry:
|
||||
if axis_spec is None:
|
||||
axis_spec = ()
|
||||
elif isinstance(axis_spec, (list, tuple)):
|
||||
axis_spec = tuple(axis_spec)
|
||||
elif axis_spec == PartitionSpec.UNCONSTRAINED:
|
||||
if not allow_unconstrained_dims:
|
||||
raise ValueError(f"Unconstrained dims are not allowed: {entry}")
|
||||
axis_spec = None
|
||||
else:
|
||||
axis_spec = (axis_spec,)
|
||||
axis_specs.append(axis_spec)
|
||||
return cls(entry, axis_specs)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.partitions, self.sync))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.partitions == other.partitions and
|
||||
self.sync == other.sync)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.partitions)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.partitions[i]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.partitions)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"ParsedPartitionSpec(partitions={self.partitions}, "
|
||||
f"unsafe_user_spec={self.unsafe_user_spec}, "
|
||||
f"sync={self.sync})")
|
||||
|
||||
class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
|
||||
"""ParsedPartitionSpecs that are canonicalized.
|
||||
|
||||
ParsedPartitionSpecs may contain trailing empty tuples, that make them
|
||||
semantically different in general, and yet in some situations we prefer
|
||||
to regard them as equivalent. For example, partitions of () and ((),)
|
||||
cannot be always considered equivalent, since the first one is a valid
|
||||
spec for a scalar value, while the second is not! However, when either of
|
||||
those are applied to a 2D array, they both mean that the array is fully
|
||||
replicated.
|
||||
|
||||
So CanonicalizedParsedPartitionSpecs removes the trailing empty tuples from
|
||||
partitions.
|
||||
"""
|
||||
|
||||
def __init__(self, parsed_pspec: ParsedPartitionSpec):
|
||||
partitions = list(parsed_pspec.partitions)
|
||||
while partitions and partitions[-1] == ():
|
||||
partitions.pop()
|
||||
|
||||
super().__init__(parsed_pspec.unsafe_user_spec, partitions,
|
||||
parsed_pspec.sync)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"CanonicalizedParsedPartitionSpec(partitions={self.partitions}, "
|
||||
f"unsafe_user_spec={self.unsafe_user_spec}, "
|
||||
f"sync={self.sync})")
|
||||
|
||||
|
||||
def _prepare_axis_resources(axis_resources,
|
||||
arg_name,
|
||||
allow_unconstrained_dims=False):
|
||||
# PyTrees don't treat None values as leaves, so we use an is_leaf function.
|
||||
entries, treedef = tree_flatten(axis_resources, is_leaf=lambda x: x is None)
|
||||
what = f"{arg_name} leaf specifications"
|
||||
# All entries should be specified or if unspecified then there should only
|
||||
# be 1 entry for that since _UNSPECIFIED is a private API.
|
||||
_check_all_or_none_unspecified(entries, arg_name)
|
||||
|
||||
new_entries = []
|
||||
for entry in entries:
|
||||
if _is_unspecified_or_auto(entry):
|
||||
new_entries.append(entry)
|
||||
elif isinstance(entry, Sharding):
|
||||
if isinstance(entry, PmapSharding):
|
||||
raise ValueError(f'One of {what} got sharding {entry} which is not '
|
||||
'allowed.')
|
||||
if not isinstance(entry, XLACompatibleSharding):
|
||||
raise ValueError(f'One of {what} got sharding {entry} which is not a '
|
||||
'subclass of XLACompatibleSharding.')
|
||||
new_entries.append(entry)
|
||||
else:
|
||||
new_entries.append(ParsedPartitionSpec.from_user_input(
|
||||
entry, what, allow_unconstrained_dims=allow_unconstrained_dims))
|
||||
|
||||
_check_unique_resources(new_entries, arg_name)
|
||||
return tree_unflatten(treedef, new_entries), new_entries, treedef
|
||||
|
||||
|
||||
def _check_unique_resources(axis_resources, arg_name):
|
||||
for arg_axis_resources in axis_resources:
|
||||
if not arg_axis_resources: continue
|
||||
if (_is_unspecified_or_auto(arg_axis_resources) or
|
||||
isinstance(arg_axis_resources, XLACompatibleSharding)):
|
||||
continue
|
||||
constrained_dims = [d for d in arg_axis_resources if d is not None]
|
||||
resource_counts = Counter(it.chain.from_iterable(constrained_dims))
|
||||
if not resource_counts: continue
|
||||
if resource_counts.most_common(1)[0][1] > 1:
|
||||
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
|
||||
if multiple_uses:
|
||||
raise ValueError(f"A single {arg_name} specification can map every mesh axis "
|
||||
f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
|
||||
f"has duplicate entries for {mesh_lib.show_axes(multiple_uses)}")
|
||||
|
||||
# -------------------- pjit rules --------------------
|
||||
|
||||
@ -1203,21 +1010,21 @@ def _resolve_in_shardings(
|
||||
resolved_in_shardings = []
|
||||
for arg, pjit_in_s in zip(args, pjit_in_shardings):
|
||||
arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True))
|
||||
if hasattr(arg, 'sharding') else (_UNSPECIFIED, False))
|
||||
if _is_unspecified(pjit_in_s):
|
||||
if _is_unspecified(arg_s):
|
||||
if hasattr(arg, 'sharding') else (UNSPECIFIED, False))
|
||||
if is_unspecified(pjit_in_s):
|
||||
if is_unspecified(arg_s):
|
||||
resolved_in_shardings.append(arg_s)
|
||||
else:
|
||||
if committed:
|
||||
# If the arg has a PmapSharding, then reshard it unconditionally.
|
||||
if isinstance(arg_s, PmapSharding):
|
||||
resolved_in_shardings.append(_UNSPECIFIED)
|
||||
resolved_in_shardings.append(UNSPECIFIED)
|
||||
else:
|
||||
resolved_in_shardings.append(to_gspmd_sharding(
|
||||
cast(XLACompatibleSharding, arg_s), arg.ndim))
|
||||
else:
|
||||
if dispatch.is_single_device_sharding(arg_s):
|
||||
resolved_in_shardings.append(_UNSPECIFIED)
|
||||
resolved_in_shardings.append(UNSPECIFIED)
|
||||
else:
|
||||
raise NotImplementedError('Having uncommitted Array sharded on '
|
||||
'multiple devices is not supported.')
|
||||
@ -1239,7 +1046,7 @@ def _resolve_in_shardings(
|
||||
'Please see the jax.Array migration guide for more information '
|
||||
'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. '
|
||||
f'Got arg shape: {arg.shape}, arg value: {arg}')
|
||||
if not _is_unspecified(arg_s):
|
||||
if not is_unspecified(arg_s):
|
||||
if (committed and
|
||||
not isinstance(arg_s, PmapSharding) and
|
||||
not op_shardings.are_op_shardings_equal(
|
||||
@ -1265,7 +1072,7 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
args, in_shardings, out_shardings,
|
||||
resource_env.physical_mesh if resource_env is not None else None)
|
||||
|
||||
_allow_propagation_to_outputs = [_is_unspecified(o) for o in out_shardings]
|
||||
_allow_propagation_to_outputs = [is_unspecified(o) for o in out_shardings]
|
||||
compiled = _pjit_lower(
|
||||
jaxpr, in_shardings, out_shardings, resource_env,
|
||||
donated_invars, name, keep_unused,
|
||||
@ -1428,8 +1235,8 @@ def _pjit_lower_cached(
|
||||
|
||||
def pjit_staging_rule(trace, *args, **params):
|
||||
if (params["inline"] and
|
||||
all(_is_unspecified(i) for i in params["in_shardings"]) and
|
||||
all(_is_unspecified(o) for o in params["out_shardings"])):
|
||||
all(is_unspecified(i) for i in params["in_shardings"]) and
|
||||
all(is_unspecified(o) for o in params["out_shardings"])):
|
||||
jaxpr = params['jaxpr']
|
||||
return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
|
||||
elif config.jax_dynamic_shapes:
|
||||
@ -1488,9 +1295,9 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
output_types = [mlir.token_type()] * len(effects) + output_types
|
||||
flat_output_types = util.flatten(output_types)
|
||||
|
||||
arg_shardings = [None if _is_unspecified(i) else i._to_xla_op_sharding(aval.ndim)
|
||||
arg_shardings = [None if is_unspecified(i) else i._to_xla_op_sharding(aval.ndim)
|
||||
for aval, i in zip(ctx.avals_in, in_shardings)]
|
||||
result_shardings = [None if _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim)
|
||||
result_shardings = [None if is_unspecified(o) else o._to_xla_op_sharding(aval.ndim)
|
||||
for aval, o in zip(ctx.avals_out, out_shardings)]
|
||||
|
||||
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
|
||||
@ -1557,9 +1364,9 @@ batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False, None)
|
||||
pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None)
|
||||
|
||||
def _pjit_batcher_for_sharding(
|
||||
s: Union[GSPMDSharding, _UnspecifiedValue],
|
||||
s: Union[GSPMDSharding, UnspecifiedValue],
|
||||
dim: int, val: Tuple[str, ...], mesh, ndim: int):
|
||||
if _is_unspecified(s):
|
||||
if is_unspecified(s):
|
||||
return s
|
||||
if not val:
|
||||
new_op = s._op_sharding.clone() # type: ignore
|
||||
@ -1631,7 +1438,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
def keep_where(l, should_keep):
|
||||
return tuple(x for x, keep in unsafe_zip(l, should_keep) if keep)
|
||||
|
||||
residual_shardings = (_UNSPECIFIED,) * num_residuals
|
||||
residual_shardings = (UNSPECIFIED,) * num_residuals
|
||||
# Compute the known outputs
|
||||
known_params = dict(
|
||||
jaxpr=known_jaxpr,
|
||||
@ -1649,7 +1456,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
# Only forward the outvars where the out_sharding is UNSPECIFIED.
|
||||
known_user_out_shardings = keep_where(known_params['out_shardings'], known_outs)
|
||||
fwds_known_user = [
|
||||
fwd if _is_unspecified(os) else None
|
||||
fwd if is_unspecified(os) else None
|
||||
for os, fwd in zip(known_user_out_shardings,
|
||||
fwds_known[:len(known_user_out_shardings)])]
|
||||
fwds_known = fwds_known_user + fwds_known[len(known_user_out_shardings):]
|
||||
@ -1728,7 +1535,7 @@ def _pjit_partial_eval_custom_params_updater(
|
||||
if num_res == 0:
|
||||
residual_shardings = []
|
||||
else:
|
||||
residual_shardings = [_UNSPECIFIED] * num_res
|
||||
residual_shardings = [UNSPECIFIED] * num_res
|
||||
_, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings'])
|
||||
new_params_known = dict(params_known,
|
||||
in_shardings=tuple(in_shardings_known),
|
||||
@ -1862,7 +1669,7 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
|
||||
raise RuntimeError("Changing the physical mesh is not allowed inside pjit.")
|
||||
|
||||
for aval, s in zip(jaxpr.in_avals, params['in_shardings']):
|
||||
if _is_unspecified(s) or is_auto(s):
|
||||
if is_unspecified(s) or is_auto(s):
|
||||
continue
|
||||
elif hasattr(s, '_original_sharding') and hasattr(
|
||||
s._original_sharding, '_parsed_pspec'):
|
||||
@ -1884,7 +1691,7 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
|
||||
|
||||
what = "pjit output"
|
||||
for aval, s in zip(jaxpr.out_avals, params['out_shardings']):
|
||||
if _is_unspecified(s) or is_auto(s):
|
||||
if is_unspecified(s) or is_auto(s):
|
||||
continue
|
||||
elif hasattr(s, '_original_sharding') and hasattr(
|
||||
s._original_sharding, '_parsed_pspec'):
|
||||
@ -1907,9 +1714,9 @@ def _pjit_pp_rule(eqn, context, settings):
|
||||
del params['inline']
|
||||
if not any(params['donated_invars']):
|
||||
del params['donated_invars']
|
||||
if all(pxla._is_unspecified(s) for s in params['in_shardings']):
|
||||
if all(is_unspecified(s) for s in params['in_shardings']):
|
||||
del params['in_shardings']
|
||||
if all(pxla._is_unspecified(s) for s in params['out_shardings']):
|
||||
if all(is_unspecified(s) for s in params['out_shardings']):
|
||||
del params['out_shardings']
|
||||
if not params['keep_unused']:
|
||||
del params['keep_unused']
|
||||
@ -1923,17 +1730,17 @@ core.pp_eqn_rules[pjit_p] = _pjit_pp_rule
|
||||
# -------------------- with_sharding_constraint --------------------
|
||||
|
||||
def _resolve_wsc_args(axis_resources, shardings):
|
||||
if not _is_unspecified(axis_resources) and not _is_unspecified(shardings):
|
||||
if not is_unspecified(axis_resources) and not is_unspecified(shardings):
|
||||
raise ValueError(
|
||||
'Setting both axis_resources and shardings is not '
|
||||
'allowed. axis_resources is deprecated. Please use shardings.')
|
||||
if _is_unspecified(axis_resources) and _is_unspecified(shardings):
|
||||
if is_unspecified(axis_resources) and is_unspecified(shardings):
|
||||
raise ValueError(
|
||||
'Not specifying shardings to `with_sharding_constraint` is not allowed. '
|
||||
'Please specify the shardings argument with a concrete sharding. Note '
|
||||
'that axis_resources is deprecated, so use the shardings argument.')
|
||||
|
||||
if not _is_unspecified(axis_resources):
|
||||
if not is_unspecified(axis_resources):
|
||||
warnings.warn(
|
||||
'axis_resources is deprecated. Please use shardings argument instead.',
|
||||
DeprecationWarning)
|
||||
@ -1946,11 +1753,11 @@ def _resolve_wsc_args(axis_resources, shardings):
|
||||
# TODO(yashkatariya): Remove the axis_resources argument and make the signature
|
||||
# `with_sharding_constraint(x, shardings)` with no defaults after deprecation
|
||||
# period is finished. The deprecation period expires 3 months from Feb 13, 2023.
|
||||
def with_sharding_constraint(x, shardings=_UNSPECIFIED,
|
||||
axis_resources=_UNSPECIFIED):
|
||||
def with_sharding_constraint(x, shardings=UNSPECIFIED,
|
||||
axis_resources=UNSPECIFIED):
|
||||
final_shardings = _resolve_wsc_args(axis_resources, shardings)
|
||||
x_flat, tree = tree_flatten(x)
|
||||
user_shardings, _, _ = _prepare_axis_resources(
|
||||
user_shardings, _, _ = prepare_axis_resources(
|
||||
final_shardings, "shardings", allow_unconstrained_dims=True)
|
||||
del final_shardings
|
||||
|
||||
@ -1996,7 +1803,7 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
|
||||
# axis_ctx and manual_axes is *only used with xmap* and xmap only works with
|
||||
# NamedSharding. So convert the GSPMDSharding to NamedSharding
|
||||
# and then convert it back with the added special axes.
|
||||
if isinstance(axis_ctx, mlir.SPMDAxisContext):
|
||||
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
|
||||
mesh = resource_env.physical_mesh
|
||||
parsed_pspec = parse_flatten_op_sharding(sharding._op_sharding, mesh)[0]
|
||||
mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec)
|
||||
@ -2054,18 +1861,6 @@ pxla.custom_resource_typing_rules[sharding_constraint_p] = \
|
||||
|
||||
# -------------------- helpers --------------------
|
||||
|
||||
def get_array_mapping(
|
||||
axis_resources: Union[ParsedPartitionSpec, _AUTOAxisResource, _UnspecifiedValue]
|
||||
) -> pxla.ArrayMappingOrAutoOrUnspecified:
|
||||
# TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported.
|
||||
# Don't use `is_auto` here to satisfy pytype and mypy.
|
||||
if isinstance(axis_resources, (_AUTOAxisResource, _UnspecifiedValue)):
|
||||
return axis_resources
|
||||
return OrderedDict((axis, i)
|
||||
for i, axes in enumerate(axis_resources)
|
||||
if axes is not None for axis in axes)
|
||||
|
||||
|
||||
def to_gspmd_sharding(s: XLACompatibleSharding, ndim: int) -> GSPMDSharding:
|
||||
if isinstance(s, GSPMDSharding):
|
||||
return s
|
||||
@ -2085,7 +1880,7 @@ def _fast_path_get_device_assignment(
|
||||
shardings: Iterable[PjitSharding]) -> Optional[XLADeviceAssignment]:
|
||||
da = None
|
||||
for i in shardings:
|
||||
if is_auto(i) or _is_unspecified(i):
|
||||
if is_auto(i) or is_unspecified(i):
|
||||
continue
|
||||
da = i._device_assignment # type: ignore
|
||||
break
|
||||
@ -2227,11 +2022,8 @@ def parse_flatten_op_sharding(op_sharding: xc.OpSharding,
|
||||
raise AssertionError("Unhandled OpSharding type. Please open a bug report!")
|
||||
|
||||
|
||||
_get_single_pspec = lambda p: pxla.array_mapping_to_axis_resources(
|
||||
cast(pxla.ArrayMapping, get_array_mapping(p)))
|
||||
|
||||
def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[PartitionSpec]:
|
||||
return [_get_single_pspec(p) for p in ppspec]
|
||||
return [get_single_pspec(p) for p in ppspec]
|
||||
|
||||
|
||||
def _get_op_sharding_from_executable(
|
||||
|
@ -14,24 +14,35 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
from collections import Counter
|
||||
import itertools
|
||||
import operator as op
|
||||
from typing import (Any, Sequence, List, Tuple, Optional, Mapping, Dict, Set,
|
||||
FrozenSet, Union, cast)
|
||||
import sys
|
||||
from typing import (Any, Dict, FrozenSet, List, Mapping, Optional, OrderedDict,
|
||||
NamedTuple, Sequence, Set, Tuple, Union, cast)
|
||||
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
|
||||
import numpy as np
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
OrderedDictType = OrderedDict
|
||||
else:
|
||||
OrderedDictType = Dict
|
||||
|
||||
|
||||
Shape = Tuple[int, ...]
|
||||
Device = xc.Device
|
||||
Index = Tuple[slice, ...]
|
||||
@ -134,7 +145,7 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
|
||||
'create a device to index mapping for your sharding from which replica '
|
||||
'ids will be calculated.') from None
|
||||
|
||||
index_to_replica: Dict[int, int] = Counter()
|
||||
index_to_replica: Dict[int, int] = collections.Counter()
|
||||
out = {}
|
||||
for device, index in device_indices_map_fn(global_shape).items():
|
||||
h_index = hashed_index(index)
|
||||
@ -204,8 +215,7 @@ class NamedSharding(XLACompatibleSharding):
|
||||
# TODO(yaskatariya): Remove this and replace this with a normalized
|
||||
# representation of Parsed Pspec
|
||||
if self._parsed_pspec is None:
|
||||
from jax._src import pjit
|
||||
self._parsed_pspec, _, _ = pjit._prepare_axis_resources(
|
||||
self._parsed_pspec, _, _ = prepare_axis_resources(
|
||||
self.spec, "NamedSharding spec")
|
||||
|
||||
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
|
||||
@ -263,9 +273,8 @@ class NamedSharding(XLACompatibleSharding):
|
||||
def _to_xla_op_sharding(
|
||||
self,
|
||||
num_dimensions: int,
|
||||
axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None
|
||||
axis_ctx: Optional[Union[SPMDAxisContext, ShardingContext]] = None
|
||||
) -> xc.OpSharding:
|
||||
from jax._src.pjit import get_array_mapping
|
||||
assert self._parsed_pspec is not None
|
||||
array_mapping = get_array_mapping(self._parsed_pspec)
|
||||
# TODO(yashkatariya): Move away from sharding spec in NamedSharding
|
||||
@ -275,7 +284,7 @@ class NamedSharding(XLACompatibleSharding):
|
||||
# Used in `with_sharding_constraint`.
|
||||
special_axes = {}
|
||||
# Manual axes is only used with xmap.
|
||||
if axis_ctx is not None and isinstance(axis_ctx, mlir.SPMDAxisContext):
|
||||
if axis_ctx is not None and isinstance(axis_ctx, SPMDAxisContext):
|
||||
axis_names = self.mesh.axis_names
|
||||
# Ignore type because mypy doesn't recognize the `hasattr` check above.
|
||||
for manual_axis in axis_ctx.manual_axes: # type: ignore
|
||||
@ -634,3 +643,324 @@ class GSPMDSharding(XLACompatibleSharding):
|
||||
def get_replicated(cls, device_assignment):
|
||||
proto = get_replicated_op_sharding()
|
||||
return cls(device_assignment, proto)
|
||||
|
||||
|
||||
class AUTOAxisResource:
|
||||
pass
|
||||
AUTO = AUTOAxisResource()
|
||||
|
||||
def is_auto(x):
|
||||
return isinstance(x, AUTOAxisResource)
|
||||
|
||||
|
||||
class UnspecifiedValue:
|
||||
def __repr__(self):
|
||||
return "UnspecifiedValue"
|
||||
UNSPECIFIED = UnspecifiedValue()
|
||||
|
||||
def is_unspecified(x):
|
||||
return isinstance(x, UnspecifiedValue)
|
||||
|
||||
def is_unspecified_or_auto(x):
|
||||
return is_auto(x) or is_unspecified(x)
|
||||
|
||||
|
||||
MeshAxisName = Any
|
||||
|
||||
"""
|
||||
ArrayMapping specifies how an ndarray should map to mesh axes.
|
||||
|
||||
Note that the ordering is crucial for the cases when this mapping is non-injective
|
||||
(i.e. when multiple mesh axes map to the same positional axis). Then, the
|
||||
order of entries of the mapping determines a major-to-minor order on mesh axes,
|
||||
according to which chunks of the value along the repeated dimension will be assigned.
|
||||
|
||||
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
|
||||
The second dimension of the value would get chunked into 6 pieces, and assigned to the
|
||||
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
|
||||
that would mean that a flat list of chunks would get assigned to a flattened list of
|
||||
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
|
||||
mesh devices ndarray would have to be transposed before flattening and assignment.
|
||||
"""
|
||||
ArrayMapping = OrderedDictType[MeshAxisName, int]
|
||||
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTOAxisResource,
|
||||
UnspecifiedValue]
|
||||
|
||||
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
|
||||
if not array_mapping:
|
||||
return PartitionSpec()
|
||||
max_index = -1
|
||||
reverse_map = collections.defaultdict(list)
|
||||
for axis, index in array_mapping.items():
|
||||
reverse_map[index].append(axis)
|
||||
if index > max_index:
|
||||
max_index = index
|
||||
partitions = tuple(tuple(reverse_map[i]) if reverse_map[i] else None
|
||||
for i in range(max_index + 1))
|
||||
return PartitionSpec(*partitions)
|
||||
|
||||
def get_array_mapping(
|
||||
axis_resources: Union[ParsedPartitionSpec, AUTOAxisResource, UnspecifiedValue]
|
||||
) -> ArrayMappingOrAutoOrUnspecified:
|
||||
# TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported.
|
||||
# Don't use `is_auto` here to satisfy pytype and mypy.
|
||||
if isinstance(axis_resources, (AUTOAxisResource, UnspecifiedValue)):
|
||||
return axis_resources
|
||||
return OrderedDict((axis, i)
|
||||
for i, axes in enumerate(axis_resources)
|
||||
if axes is not None for axis in axes)
|
||||
|
||||
|
||||
get_single_pspec = lambda p: array_mapping_to_axis_resources(
|
||||
cast(ArrayMapping, get_array_mapping(p)))
|
||||
|
||||
|
||||
class SpecSync(enum.IntEnum):
|
||||
"""Encodes how much out of sync the real value of partitions is compared to the user specified one.
|
||||
|
||||
We use this to make sure we don't show garbage modified values while claiming
|
||||
that the users have specified them like that.
|
||||
"""
|
||||
OUT_OF_SYNC = 0 # Arbitrary changes, including new axes inserted
|
||||
DIM_PERMUTE = 1 # Dimensions permuted, but no new sharding axes
|
||||
IN_SYNC = 2 # Entirely in sync
|
||||
|
||||
class ParsedPartitionSpec:
|
||||
__slots__ = ('unsafe_user_spec', 'partitions', 'sync')
|
||||
|
||||
def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC):
|
||||
self.unsafe_user_spec = user_spec
|
||||
# None in partitions represents unconstrained dim.
|
||||
# TODO(yashkatariya): May use a sentinel value.
|
||||
self.partitions = tuple(partitions)
|
||||
self.sync = sync
|
||||
|
||||
@property
|
||||
def user_spec(self):
|
||||
return self.unsynced_user_spec(SpecSync.IN_SYNC)
|
||||
|
||||
def get_partition_spec(self) -> PartitionSpec:
|
||||
if self.sync < SpecSync.IN_SYNC:
|
||||
return get_single_pspec(self)
|
||||
else:
|
||||
if isinstance(self.unsafe_user_spec, PartitionSpec):
|
||||
return self.unsafe_user_spec
|
||||
else:
|
||||
return get_single_pspec(self)
|
||||
|
||||
def unsynced_user_spec(self, min_sync):
|
||||
if self.sync < min_sync:
|
||||
raise AssertionError(f"Please open a bug report! ({self.sync} >= {min_sync})")
|
||||
return self.unsafe_user_spec
|
||||
|
||||
def insert_axis_partitions(self, dim, val):
|
||||
parts = self.partitions
|
||||
too_short = dim - len(parts)
|
||||
if too_short > 0:
|
||||
parts += ((),) * too_short
|
||||
new_partitions = util.tuple_insert(parts, dim, val)
|
||||
new_sync = SpecSync.DIM_PERMUTE if (val == () or val is None) else SpecSync.OUT_OF_SYNC
|
||||
return ParsedPartitionSpec(self.unsafe_user_spec, new_partitions, sync=new_sync)
|
||||
|
||||
@classmethod
|
||||
def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False):
|
||||
if entry is None:
|
||||
return cls(entry, ())
|
||||
if not isinstance(entry, PartitionSpec):
|
||||
raise TypeError(f"{arg_name} are expected to be "
|
||||
f"PartitionSpec instances or None, but got {entry}")
|
||||
axis_specs = []
|
||||
for axis_spec in entry:
|
||||
if axis_spec is None:
|
||||
axis_spec = ()
|
||||
elif isinstance(axis_spec, (list, tuple)):
|
||||
axis_spec = tuple(axis_spec)
|
||||
elif axis_spec == PartitionSpec.UNCONSTRAINED:
|
||||
if not allow_unconstrained_dims:
|
||||
raise ValueError(f"Unconstrained dims are not allowed: {entry}")
|
||||
axis_spec = None
|
||||
else:
|
||||
axis_spec = (axis_spec,)
|
||||
axis_specs.append(axis_spec)
|
||||
return cls(entry, axis_specs)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.partitions, self.sync))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.partitions == other.partitions and
|
||||
self.sync == other.sync)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.partitions)
|
||||
|
||||
def __getitem__(self, i):
|
||||
return self.partitions[i]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.partitions)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"ParsedPartitionSpec(partitions={self.partitions}, "
|
||||
f"unsafe_user_spec={self.unsafe_user_spec}, "
|
||||
f"sync={self.sync})")
|
||||
|
||||
class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
|
||||
"""ParsedPartitionSpecs that are canonicalized.
|
||||
|
||||
ParsedPartitionSpecs may contain trailing empty tuples, that make them
|
||||
semantically different in general, and yet in some situations we prefer
|
||||
to regard them as equivalent. For example, partitions of () and ((),)
|
||||
cannot be always considered equivalent, since the first one is a valid
|
||||
spec for a scalar value, while the second is not! However, when either of
|
||||
those are applied to a 2D array, they both mean that the array is fully
|
||||
replicated.
|
||||
|
||||
So CanonicalizedParsedPartitionSpecs removes the trailing empty tuples from
|
||||
partitions.
|
||||
"""
|
||||
|
||||
def __init__(self, parsed_pspec: ParsedPartitionSpec):
|
||||
partitions = list(parsed_pspec.partitions)
|
||||
while partitions and partitions[-1] == ():
|
||||
partitions.pop()
|
||||
|
||||
super().__init__(parsed_pspec.unsafe_user_spec, partitions,
|
||||
parsed_pspec.sync)
|
||||
|
||||
def __repr__(self):
|
||||
return (f"CanonicalizedParsedPartitionSpec(partitions={self.partitions}, "
|
||||
f"unsafe_user_spec={self.unsafe_user_spec}, "
|
||||
f"sync={self.sync})")
|
||||
|
||||
|
||||
def check_all_or_none_unspecified(axis_resources, name):
|
||||
if not axis_resources:
|
||||
return False
|
||||
unspecified_count = 0
|
||||
unspecified = is_unspecified(axis_resources[0])
|
||||
for resource in axis_resources:
|
||||
current_is_unspecified = is_unspecified(resource)
|
||||
if current_is_unspecified:
|
||||
unspecified_count += 1
|
||||
assert unspecified_count == 1
|
||||
if current_is_unspecified != unspecified:
|
||||
raise ValueError(f'`pjit.UNSPECIFIED` exists in {name}. '
|
||||
f'Make sure that every entry in {name} is '
|
||||
'`pjit.UNSPECIFIED`.')
|
||||
return unspecified
|
||||
|
||||
|
||||
def prepare_axis_resources(axis_resources,
|
||||
arg_name,
|
||||
allow_unconstrained_dims=False):
|
||||
# PyTrees don't treat None values as leaves, so we use an is_leaf function.
|
||||
entries, treedef = tree_util.tree_flatten(
|
||||
axis_resources, is_leaf=lambda x: x is None)
|
||||
what = f"{arg_name} leaf specifications"
|
||||
# All entries should be specified or if unspecified then there should only
|
||||
# be 1 entry for that since UNSPECIFIED is a private API.
|
||||
check_all_or_none_unspecified(entries, arg_name)
|
||||
|
||||
new_entries = []
|
||||
for entry in entries:
|
||||
if is_unspecified_or_auto(entry):
|
||||
new_entries.append(entry)
|
||||
elif isinstance(entry, sharding.Sharding):
|
||||
if isinstance(entry, PmapSharding):
|
||||
raise ValueError(f'One of {what} got sharding {entry} which is not '
|
||||
'allowed.')
|
||||
if not isinstance(entry, XLACompatibleSharding):
|
||||
raise ValueError(f'One of {what} got sharding {entry} which is not a '
|
||||
'subclass of XLACompatibleSharding.')
|
||||
new_entries.append(entry)
|
||||
else:
|
||||
new_entries.append(ParsedPartitionSpec.from_user_input(
|
||||
entry, what, allow_unconstrained_dims=allow_unconstrained_dims))
|
||||
|
||||
_check_unique_resources(new_entries, arg_name)
|
||||
return tree_util.tree_unflatten(treedef, new_entries), new_entries, treedef
|
||||
|
||||
|
||||
def _check_unique_resources(axis_resources, arg_name):
|
||||
for arg_axis_resources in axis_resources:
|
||||
if not arg_axis_resources: continue
|
||||
if (is_unspecified_or_auto(arg_axis_resources) or
|
||||
isinstance(arg_axis_resources, XLACompatibleSharding)):
|
||||
continue
|
||||
constrained_dims = [d for d in arg_axis_resources if d is not None]
|
||||
resource_counts = collections.Counter(
|
||||
itertools.chain.from_iterable(constrained_dims))
|
||||
if not resource_counts: continue
|
||||
if resource_counts.most_common(1)[0][1] > 1:
|
||||
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
|
||||
if multiple_uses:
|
||||
raise ValueError(f"A single {arg_name} specification can map every mesh axis "
|
||||
f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
|
||||
f"has duplicate entries for {mesh_lib.show_axes(multiple_uses)}")
|
||||
|
||||
# Axis environments
|
||||
|
||||
class AxisEnv(NamedTuple):
|
||||
"""Represents a pmap mesh (only along the replica axes)."""
|
||||
nreps: int
|
||||
names: Tuple[Any, ...]
|
||||
sizes: Tuple[int, ...]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SPMDAxisContext:
|
||||
"""A hardware axis context for parallel computations that use the GSPMD partitioner.
|
||||
|
||||
This includes the mesh that will later by used to execute this computation,
|
||||
as well as a set of mesh axes that are currently (e.g. because the current lowering
|
||||
is invoked inside an xmap) lowered in the MANUAL sharding mode.
|
||||
"""
|
||||
mesh: mesh_lib.Mesh
|
||||
manual_axes: FrozenSet[MeshAxisName] = frozenset()
|
||||
|
||||
@property
|
||||
def axis_env(self):
|
||||
# All collectives that touch axis_env should remember to set use_global_device_ids
|
||||
# when this context is enabled!
|
||||
if self.manual_axes != frozenset(self.mesh.axis_names):
|
||||
raise NotImplementedError(
|
||||
"Collectives in manually partitioned computations are only supported "
|
||||
"when all mesh axes are partitioned manually (no partial automatic sharding). "
|
||||
"Make sure that you mention all mesh axes in axis_resources!")
|
||||
return self.unsafe_axis_env
|
||||
|
||||
@property
|
||||
def unsafe_axis_env(self):
|
||||
return AxisEnv(
|
||||
nreps=self.mesh.size,
|
||||
names=self.mesh.axis_names,
|
||||
sizes=tuple(self.mesh.shape.values()))
|
||||
|
||||
def extend_manual(self, axes: FrozenSet[MeshAxisName]) -> SPMDAxisContext:
|
||||
return SPMDAxisContext(self.mesh, self.manual_axes | axes)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ReplicaAxisContext:
|
||||
"""A hardware axis context for parallel computations that are partitioned by JAX.
|
||||
|
||||
Unlike in the SPMDAxisContext, this means that JAX might need to emit calls to
|
||||
explicit collectives.
|
||||
"""
|
||||
axis_env: AxisEnv
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ShardingContext:
|
||||
"""A hardware axis context for parallel computations that use the sharding
|
||||
interface.
|
||||
|
||||
This context also uses the GSPMD partitioner.
|
||||
"""
|
||||
device_assignment: Sequence[xc.Device]
|
||||
|
||||
# Similar to SPMDContext as ShardingContext also uses the GSPMD partitioner.
|
||||
@property
|
||||
def axis_env(self):
|
||||
return AxisEnv(nreps=1, names=(), sizes=())
|
||||
|
@ -38,6 +38,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tu
|
||||
import jax
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import tree_util
|
||||
@ -596,7 +597,7 @@ class Lowered(Stage):
|
||||
if isinstance(self._lowering, pxla.MeshComputation):
|
||||
kw.update(
|
||||
_allow_propagation_to_outputs=[
|
||||
pxla._is_unspecified(o)
|
||||
sharding_impls.is_unspecified(o)
|
||||
for o in self._lowering.compile_args["out_shardings"]
|
||||
]
|
||||
)
|
||||
|
@ -17,6 +17,7 @@ import inspect
|
||||
from jax._src import core
|
||||
from jax import tree_util
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import sharding_impls
|
||||
from jax.experimental import pjit
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax._src import mesh as mesh_lib
|
||||
@ -97,7 +98,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
|
||||
]
|
||||
closed_jaxpr = jax.make_jaxpr(
|
||||
lower_fn, axis_env=list(info.mesh.shape.items()))(*tiled_args)
|
||||
axis_context = mlir.SPMDAxisContext(info.mesh)
|
||||
axis_context = sharding_impls.SPMDAxisContext(info.mesh)
|
||||
built = mlir.build_xla_computation_helper(
|
||||
closed_jaxpr,
|
||||
name="tmp_xla_computation",
|
||||
@ -373,9 +374,9 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
axis_context = ctx.module_context.axis_context
|
||||
|
||||
if isinstance(axis_context, mlir.ShardingContext):
|
||||
if isinstance(axis_context, sharding_impls.ShardingContext):
|
||||
devices = axis_context.device_assignment
|
||||
elif isinstance(axis_context, mlir.SPMDAxisContext):
|
||||
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
|
||||
devices = list(axis_context.mesh.devices.flat)
|
||||
else:
|
||||
devices = None
|
||||
|
@ -518,9 +518,9 @@ from jax._src.interpreters import xla
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import dispatch
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src import lib as jaxlib
|
||||
from jax._src.lib import pytree
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client
|
||||
@ -1245,8 +1245,10 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext,
|
||||
result_arrays = ()
|
||||
return result_arrays
|
||||
|
||||
if isinstance(ctx.module_context.axis_context,
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext)):
|
||||
if isinstance(
|
||||
ctx.module_context.axis_context,
|
||||
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
|
||||
):
|
||||
# Apply maximal sharding so pjit only executes the callback on device device_index.
|
||||
sharding = xla_client.OpSharding()
|
||||
sharding.type = xla_client.OpSharding.Type.MAXIMAL
|
||||
@ -1715,11 +1717,17 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
|
||||
eqn.params,
|
||||
jaxpr=_rewrite_closed_jaxpr(jaxpr, True, True),
|
||||
donated_invars=eqn.params["donated_invars"] + (False, False),
|
||||
in_shardings=(eqn.params["in_shardings"] +
|
||||
(pjit._UNSPECIFIED, pjit._UNSPECIFIED)),
|
||||
out_shardings=(eqn.params["out_shardings"] +
|
||||
(pjit._UNSPECIFIED, pjit._UNSPECIFIED)),
|
||||
)))
|
||||
in_shardings=(
|
||||
eqn.params["in_shardings"]
|
||||
+ (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED)
|
||||
),
|
||||
out_shardings=(
|
||||
eqn.params["out_shardings"]
|
||||
+ (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED)
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
elif eqn.primitive is ad_checkpoint.remat_p:
|
||||
jaxpr_ = cast(core.Jaxpr, eqn.params["jaxpr"])
|
||||
eqns.append(
|
||||
|
@ -49,6 +49,7 @@ from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import op_shardings
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import pjit
|
||||
from jax._src import prng
|
||||
from jax._src import random as random_internal
|
||||
@ -3096,7 +3097,7 @@ def _shard_value(val: TfVal,
|
||||
sd: sharding.XLACompatibleSharding, *,
|
||||
skip_replicated_sharding: bool) -> TfVal:
|
||||
"""Apply sharding to a TfVal."""
|
||||
if pxla._is_unspecified(sd):
|
||||
if sharding_impls.is_unspecified(sd):
|
||||
return val
|
||||
|
||||
sharding_proto: xla_client.OpSharding = cast(
|
||||
|
@ -28,6 +28,7 @@ from jax import sharding
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import pjit
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
@ -270,7 +271,7 @@ def _add_dim_arg_computation(module: mlir.ir.Module,
|
||||
with ir.InsertionPoint(entry_block):
|
||||
orig_main_args: List[mlir.ir.Value] = []
|
||||
module_context = mlir.ModuleContext(
|
||||
"cpu", "cpu", mlir.ShardingContext([]),
|
||||
"cpu", "cpu", sharding_impls.ShardingContext([]),
|
||||
source_info_util.new_name_stack(),
|
||||
[], itertools.count(1), [], module=new_module, context=context)
|
||||
ctx = mlir.LoweringRuleContext(module_context=module_context,
|
||||
@ -467,7 +468,7 @@ def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported:
|
||||
map(lambda a: a.at_least_vspace(), primal.out_avals)))
|
||||
|
||||
# Expand in_shardings to all in_avals even not kept ones.
|
||||
all_in_shardings = [pxla._UNSPECIFIED] * len(primal.in_avals)
|
||||
all_in_shardings = [sharding_impls.UNSPECIFIED] * len(primal.in_avals)
|
||||
for idx, in_s in zip(sorted(primal.module_kept_var_idx),
|
||||
primal.in_shardings):
|
||||
all_in_shardings[idx] = in_s # type: ignore
|
||||
@ -475,13 +476,13 @@ def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported:
|
||||
# Cannot mix unspecified and specified shardings. Make the unspecified
|
||||
# ones replicated.
|
||||
specified_shardings = [
|
||||
s for s in all_shardings if not pxla._is_unspecified(s)]
|
||||
s for s in all_shardings if not sharding_impls.is_unspecified(s)]
|
||||
|
||||
vjp_in_shardings: Any # The primal inputs followed by output cotangents
|
||||
vjp_out_shardings: Any # The primal output cotangents
|
||||
if 0 == len(specified_shardings):
|
||||
vjp_in_shardings = pxla._UNSPECIFIED
|
||||
vjp_out_shardings = pxla._UNSPECIFIED
|
||||
vjp_in_shardings = sharding_impls.UNSPECIFIED
|
||||
vjp_out_shardings = sharding_impls.UNSPECIFIED
|
||||
else:
|
||||
if len(specified_shardings) < len(all_shardings):
|
||||
# There are some specified, but not all; pjit front-end does not liwk
|
||||
@ -489,13 +490,13 @@ def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported:
|
||||
assert isinstance(in_s, sharding.XLACompatibleSharding)
|
||||
replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment)
|
||||
all_shardings = [
|
||||
s if not pxla._is_unspecified(s) else replicated_s
|
||||
s if not sharding_impls.is_unspecified(s) else replicated_s
|
||||
for s in all_shardings]
|
||||
|
||||
vjp_in_shardings = tuple(all_shardings)
|
||||
vjp_out_shardings = tuple(all_shardings[:len(primal.in_avals)])
|
||||
if all(pxla._is_unspecified(s) for s in vjp_out_shardings):
|
||||
vjp_out_shardings = pxla._UNSPECIFIED
|
||||
if all(sharding_impls.is_unspecified(s) for s in vjp_out_shardings):
|
||||
vjp_out_shardings = sharding_impls.UNSPECIFIED
|
||||
|
||||
fun_vjp_jax = pjit.pjit(fun_vjp_jax,
|
||||
in_shardings=vjp_in_shardings,
|
||||
|
@ -68,6 +68,7 @@ from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import sharding_impls
|
||||
from jax._src.api_util import shaped_abstractify
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax as lax_internal
|
||||
@ -734,8 +735,12 @@ def _pjit_jet_rule(primals_in, series_in, **params):
|
||||
**params,
|
||||
'jaxpr': jaxpr_jet,
|
||||
'in_shardings': (
|
||||
params['in_shardings'] + (pjit._UNSPECIFIED,) * num_series_in),
|
||||
'out_shardings': params['out_shardings'] + (pjit._UNSPECIFIED,) * num_series_out,
|
||||
params['in_shardings'] + (sharding_impls.UNSPECIFIED,) * num_series_in
|
||||
),
|
||||
'out_shardings': (
|
||||
params['out_shardings']
|
||||
+ (sharding_impls.UNSPECIFIED,) * num_series_out
|
||||
),
|
||||
'donated_invars': params['donated_invars'] + (False,) * num_series_in,
|
||||
}
|
||||
result = pjit.pjit_p.bind(*primals_and_series, **new_params)
|
||||
|
@ -15,34 +15,36 @@
|
||||
# flake8: noqa
|
||||
|
||||
from jax._src.pjit import (
|
||||
AUTO as AUTO,
|
||||
ParsedPartitionSpec as ParsedPartitionSpec,
|
||||
get_array_mapping as get_array_mapping,
|
||||
hashable_pytree as hashable_pytree,
|
||||
parse_flatten_op_sharding as parse_flatten_op_sharding,
|
||||
pjit as pjit,
|
||||
pjit_p as pjit_p,
|
||||
with_sharding_constraint as with_sharding_constraint,
|
||||
)
|
||||
from jax._src.sharding_impls import (
|
||||
AUTO as AUTO,
|
||||
UNSPECIFIED as _UNSPECIFIED,
|
||||
ParsedPartitionSpec as ParsedPartitionSpec,
|
||||
get_array_mapping as get_array_mapping,
|
||||
prepare_axis_resources as _prepare_axis_resources
|
||||
)
|
||||
|
||||
from jax._src.pjit import (_UNSPECIFIED, _prepare_axis_resources,
|
||||
_get_op_sharding_from_executable,
|
||||
from jax._src.pjit import (_get_op_sharding_from_executable,
|
||||
_get_pspec_from_executable, _pjit_lower_cached,
|
||||
_pjit_lower, _pjit_jaxpr,
|
||||
_process_in_axis_resources)
|
||||
|
||||
|
||||
from jax._src.pjit import (
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding as _deprecated_NamedSharding,
|
||||
)
|
||||
from jax._src.partition_spec import (
|
||||
PartitionSpec as _deprecated_PartitionSpec,
|
||||
)
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
from jax._src.pjit import (
|
||||
NamedSharding as NamedSharding,
|
||||
PartitionSpec as PartitionSpec,
|
||||
)
|
||||
from jax._src.sharding_impls import NamedSharding as NamedSharding
|
||||
from jax._src.partition_spec import PartitionSpec as PartitionSpec
|
||||
del typing
|
||||
|
||||
_deprecations = {
|
||||
|
@ -34,6 +34,7 @@ from jax._src import linear_util as lu
|
||||
from jax._src import ops
|
||||
from jax._src import pjit
|
||||
from jax._src import prng
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
@ -47,7 +48,6 @@ from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import ad
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
@ -468,7 +468,9 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
|
||||
sharded_avals = [v.aval for v in jaxpr.invars]
|
||||
in_nodes_ = map(partial(_xla_shard, mesh), in_names, ctx.avals_in,
|
||||
sharded_avals, in_nodes)
|
||||
new_axis_context = mlir.SPMDAxisContext(mesh, frozenset(mesh.axis_names))
|
||||
new_axis_context = sharding_impls.SPMDAxisContext(
|
||||
mesh, frozenset(mesh.axis_names)
|
||||
)
|
||||
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
|
||||
with core.extend_axis_env_nd(tuple(mesh.shape.items())):
|
||||
out_nodes_, _ = mlir.jaxpr_subcomp(sub_ctx, jaxpr, mlir.TokenSet(),
|
||||
@ -483,8 +485,9 @@ def _xla_shard(mesh, names, aval_in, aval_out, x):
|
||||
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names), mesh)
|
||||
result_type, = mlir.aval_to_ir_types(aval_out)
|
||||
axes = {name: i for i, ns in names.items() for name in ns}
|
||||
shard_proto = NamedSharding(mesh, pxla.array_mapping_to_axis_resources(axes) # type: ignore
|
||||
)._to_xla_op_sharding(aval_in.ndim)
|
||||
shard_proto = NamedSharding(
|
||||
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
|
||||
)._to_xla_op_sharding(aval_in.ndim)
|
||||
if core.is_opaque_dtype(aval_in.dtype):
|
||||
shard_proto = aval_in.dtype._rules.physical_op_sharding(aval_in, shard_proto)
|
||||
sx = mlir.wrap_with_sharding_op(x, shard_proto, unspecified_dims=set())
|
||||
@ -496,8 +499,9 @@ def _xla_unshard(mesh, names, aval_in, aval_out, xs):
|
||||
result_type, = mlir.aval_to_ir_types(aval_out)
|
||||
sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=set())
|
||||
axes = {name: i for i, ns in names.items() for name in ns}
|
||||
shard_proto = NamedSharding(mesh, pxla.array_mapping_to_axis_resources(axes) # type: ignore
|
||||
)._to_xla_op_sharding(aval_out.ndim)
|
||||
shard_proto = NamedSharding(
|
||||
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
|
||||
)._to_xla_op_sharding(aval_out.ndim)
|
||||
if core.is_opaque_dtype(aval_out.dtype):
|
||||
shard_proto = aval_out.dtype._rules.physical_op_sharding(aval_out, shard_proto)
|
||||
return mlir.wrap_with_shard_to_full_op(result_type, sx, shard_proto, set())
|
||||
|
@ -53,18 +53,16 @@ from typing import (
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import pjit
|
||||
from jax._src import sharding_impls
|
||||
from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_sparse
|
||||
import jax.numpy as jnp
|
||||
from jax._src.api_util import flatten_fun_nokwargs
|
||||
from jax._src.lib import pytree
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
|
||||
from jax.util import safe_map, safe_zip, split_list
|
||||
from jax._src.config import config
|
||||
@ -777,9 +775,13 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
|
||||
# TODO(yashkatariya, vanderplas): Flatten twice and set the correct sharding
|
||||
# for data and indices.
|
||||
in_shardings = in_shardings + tuple(
|
||||
pxla._UNSPECIFIED for _ in range(len(args_flat) - len(in_shardings)))
|
||||
sharding_impls.UNSPECIFIED
|
||||
for _ in range(len(args_flat) - len(in_shardings))
|
||||
)
|
||||
out_shardings = out_shardings + tuple(
|
||||
pxla._UNSPECIFIED for _ in range(len(sp_call_jaxpr.out_avals) - len(out_shardings)))
|
||||
sharding_impls.UNSPECIFIED
|
||||
for _ in range(len(sp_call_jaxpr.out_avals) - len(out_shardings))
|
||||
)
|
||||
|
||||
out_flat = pjit.pjit_p.bind(
|
||||
*args_flat,
|
||||
|
@ -19,14 +19,9 @@ from jax._src.interpreters.mlir import (
|
||||
LoweringResult as LoweringResult,
|
||||
LoweringRule as LoweringRule,
|
||||
LoweringRuleContext as LoweringRuleContext,
|
||||
Mesh as Mesh,
|
||||
MeshAxisName as MeshAxisName,
|
||||
ModuleContext as ModuleContext,
|
||||
RECV_FROM_HOST_TYPE as RECV_FROM_HOST_TYPE,
|
||||
ReplicaAxisContext as ReplicaAxisContext,
|
||||
SEND_TO_HOST_TYPE as SEND_TO_HOST_TYPE,
|
||||
SPMDAxisContext as SPMDAxisContext,
|
||||
ShardingContext as ShardingContext,
|
||||
Token as Token,
|
||||
TokenSet as TokenSet,
|
||||
Value as Value,
|
||||
@ -64,3 +59,11 @@ from jax._src.interpreters.mlir import (
|
||||
token_type as token_type,
|
||||
xla_computation_to_mlir_module as xla_computation_to_mlir_module,
|
||||
)
|
||||
|
||||
from jax._src.mesh import Mesh as Mesh
|
||||
from jax._src.sharding_impls import (
|
||||
MeshAxisName as MeshAxisName,
|
||||
ReplicaAxisContext as ReplicaAxisContext,
|
||||
SPMDAxisContext as SPMDAxisContext,
|
||||
ShardingContext as ShardingContext,
|
||||
)
|
||||
|
@ -13,9 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.interpreters.pxla import (
|
||||
AUTO as AUTO,
|
||||
ArrayMapping as ArrayMapping,
|
||||
ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified,
|
||||
AvalDimSharding as AvalDimSharding,
|
||||
EmapInfo as EmapInfo,
|
||||
ExecuteReplicated as ExecuteReplicated,
|
||||
@ -26,7 +23,6 @@ from jax._src.interpreters.pxla import (
|
||||
MeshComputation as MeshComputation,
|
||||
MeshDimAssignment as MeshDimAssignment,
|
||||
MeshExecutable as MeshExecutable,
|
||||
OrderedDictType as OrderedDictType,
|
||||
ParallelCallableInfo as ParallelCallableInfo,
|
||||
PartitionInfo as PartitionInfo,
|
||||
PartitionsOrReplicated as PartitionsOrReplicated,
|
||||
@ -43,12 +39,9 @@ from jax._src.interpreters.pxla import (
|
||||
UnloadedMeshExecutable as UnloadedMeshExecutable,
|
||||
UnloadedPmapExecutable as UnloadedPmapExecutable,
|
||||
WeakRefList as WeakRefList,
|
||||
_UNSPECIFIED as _UNSPECIFIED,
|
||||
_create_pmap_sharding_spec as _create_pmap_sharding_spec,
|
||||
_get_and_check_device_assignment as _get_and_check_device_assignment,
|
||||
_is_unspecified as _is_unspecified,
|
||||
_pmap_sharding_spec as _pmap_sharding_spec,
|
||||
array_mapping_to_axis_resources as array_mapping_to_axis_resources,
|
||||
array_types as array_types,
|
||||
custom_resource_typing_rules as custom_resource_typing_rules,
|
||||
device_put as _deprecated_device_put,
|
||||
@ -103,6 +96,15 @@ from jax._src.op_shardings import (
|
||||
op_sharding_to_indices as op_sharding_to_indices,
|
||||
)
|
||||
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping as ArrayMapping,
|
||||
ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified,
|
||||
AUTO as AUTO,
|
||||
UNSPECIFIED as _UNSPECIFIED,
|
||||
array_mapping_to_axis_resources as array_mapping_to_axis_resources,
|
||||
is_unspecified as _is_unspecified,
|
||||
)
|
||||
|
||||
from jax._src.sharding_specs import (
|
||||
Chunked as Chunked,
|
||||
NoSharding as NoSharding,
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.interpreters.xla import (
|
||||
AxisEnv as AxisEnv,
|
||||
DeviceArray as _deprecated_DeviceArray,
|
||||
TranslationContext as TranslationContext,
|
||||
TranslationRule as TranslationRule,
|
||||
@ -46,6 +45,10 @@ from jax._src.dispatch import (
|
||||
backend_compile as backend_compile,
|
||||
)
|
||||
|
||||
from jax._src.sharding_impls import (
|
||||
AxisEnv as AxisEnv,
|
||||
)
|
||||
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc # type: ignore
|
||||
|
||||
|
@ -44,10 +44,12 @@ from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax._src import array
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src import op_shardings
|
||||
from jax._src.sharding_impls import (NamedSharding, GSPMDSharding,
|
||||
PositionalSharding, SingleDeviceSharding)
|
||||
from jax._src import sharding_impls
|
||||
from jax._src.sharding_impls import (
|
||||
AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding,
|
||||
SingleDeviceSharding)
|
||||
import jax._src.pjit as pjit_lib
|
||||
from jax._src.pjit import (pjit, pjit_p, AUTO)
|
||||
from jax._src.pjit import pjit, pjit_p
|
||||
from jax._src import mesh
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import mlir
|
||||
@ -3519,7 +3521,9 @@ class UtilTest(jtu.JaxTestCase):
|
||||
("multi_skip", {'x': 0, 'y': 1, 'z': 3}, P(('x',), ('y',), None, ('z',))),
|
||||
)
|
||||
def test_array_mapping_to_axis_resources(self, inp, expected_out):
|
||||
self.assertEqual(pxla.array_mapping_to_axis_resources(inp), expected_out)
|
||||
self.assertEqual(
|
||||
sharding_impls.array_mapping_to_axis_resources(inp), expected_out
|
||||
)
|
||||
|
||||
def test_get_input_indices_fully_replicated(self):
|
||||
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
@ -3552,22 +3556,22 @@ class UtilTest(jtu.JaxTestCase):
|
||||
pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval, array_mapping)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("all_unspecified", (pjit_lib._UNSPECIFIED, pjit_lib._UNSPECIFIED), AssertionError),
|
||||
("only_unspecified", pjit_lib._UNSPECIFIED),
|
||||
("all_unspecified", (UNSPECIFIED, UNSPECIFIED), AssertionError),
|
||||
("only_unspecified", UNSPECIFIED),
|
||||
("all_specified", (P('x'), P('y'))),
|
||||
("only_specified", P('x')),
|
||||
("mix_1", (P('x'), pjit_lib._UNSPECIFIED), ValueError),
|
||||
("mix_2", (P('x'), pjit_lib._UNSPECIFIED, P('y')), ValueError),
|
||||
("mix_3", (pjit_lib._UNSPECIFIED, P('x'), P('y')), ValueError),
|
||||
("mix_4", (pjit_lib._UNSPECIFIED, P('x'), pjit_lib._UNSPECIFIED), ValueError),
|
||||
("mix_1", (P('x'), UNSPECIFIED), ValueError),
|
||||
("mix_2", (P('x'), UNSPECIFIED, P('y')), ValueError),
|
||||
("mix_3", (UNSPECIFIED, P('x'), P('y')), ValueError),
|
||||
("mix_4", (UNSPECIFIED, P('x'), UNSPECIFIED), ValueError),
|
||||
)
|
||||
def test_all_or_non_unspecified(self, axis_resources, error=None):
|
||||
entries, _ = jax.tree_util.tree_flatten(axis_resources, is_leaf=lambda x: x is None)
|
||||
if error is not None:
|
||||
with self.assertRaises(error):
|
||||
pjit_lib._check_all_or_none_unspecified(entries, 'test axis resources')
|
||||
sharding_impls.check_all_or_none_unspecified(entries, 'test axis resources')
|
||||
else:
|
||||
pjit_lib._check_all_or_none_unspecified(entries, 'test axis resources')
|
||||
sharding_impls.check_all_or_none_unspecified(entries, 'test axis resources')
|
||||
|
||||
def test_op_sharding_equality_and_hash_equality(self):
|
||||
op1 = xc.OpSharding()
|
||||
@ -3673,7 +3677,6 @@ class UtilTest(jtu.JaxTestCase):
|
||||
cache_info4 = GSPMDSharding.devices_indices_map.cache_info()
|
||||
self.assertEqual(cache_info4.hits, cache_info3.hits + 1)
|
||||
|
||||
|
||||
def test_op_sharding_semantically_replicated(self):
|
||||
op1 = xc.OpSharding()
|
||||
op1.type = xc.OpSharding.Type.OTHER
|
||||
@ -3769,8 +3772,8 @@ class UtilTest(jtu.JaxTestCase):
|
||||
self.assertEqual(recovered_parsed_pspec[0].get_partition_spec(),
|
||||
P(('x',), ('y',)))
|
||||
|
||||
out_of_sync_parsed_pspec = pjit_lib.ParsedPartitionSpec(
|
||||
P('x', 'y'), ('x', 'y'), pjit_lib.SpecSync.OUT_OF_SYNC)
|
||||
out_of_sync_parsed_pspec = sharding_impls.ParsedPartitionSpec(
|
||||
P('x', 'y'), ('x', 'y'), sharding_impls.SpecSync.OUT_OF_SYNC)
|
||||
self.assertEqual(out_of_sync_parsed_pspec.get_partition_spec(),
|
||||
P(('x',), ('y',)))
|
||||
|
||||
|
@ -43,6 +43,7 @@ from jax._src import core
|
||||
from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
|
||||
linearize, device_put)
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import sharding_specs
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_extension
|
||||
@ -1091,7 +1092,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected)
|
||||
|
||||
def testAxisGroups(self):
|
||||
axis_env = xla.AxisEnv(8, ('i', 'j'), (4, 2))
|
||||
axis_env = sharding_impls.AxisEnv(8, ('i', 'j'), (4, 2))
|
||||
groups = xla.axis_groups(axis_env, 'i')
|
||||
self.assertEqual(groups, ((0, 2, 4, 6), (1, 3, 5, 7)))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user