Split sharding_impls into its own Bazel target.

* Move dependencies of sharding_impls into sharding_impls to avoid creating cyclic dependencies.
* Fix a handful of new pytype errors.

PiperOrigin-RevId: 523146076
This commit is contained in:
Peter Hawkins 2023-04-10 10:15:08 -07:00 committed by jax authors
parent a1797170af
commit be1cf46a49
28 changed files with 673 additions and 583 deletions

View File

@ -110,7 +110,6 @@ py_library_providing_imports_info(
"_src/prng.py",
"_src/public_test_util.py",
"_src/random.py",
"_src/sharding_impls.py",
"_src/stages.py",
] + glob(
[
@ -183,6 +182,7 @@ py_library_providing_imports_info(
":pretty_printer",
":profiler",
":sharding",
":sharding_impls",
":sharding_specs",
":source_info_util",
":traceback_util",
@ -345,6 +345,7 @@ pytype_strict_library(
":effects",
":op_shardings",
":partial_eval",
":sharding_impls",
":source_info_util",
":util",
":xla",
@ -418,6 +419,22 @@ pytype_strict_library(
],
)
pytype_strict_library(
name = "sharding_impls",
srcs = ["_src/sharding_impls.py"],
deps = [
":mesh",
":op_shardings",
":partition_spec",
":sharding",
":sharding_specs",
":tree_util",
":util",
":xla_bridge",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "sharding_specs",
srcs = ["_src/sharding_specs.py"],
@ -504,6 +521,7 @@ pytype_strict_library(
":abstract_arrays",
":config",
":core",
":sharding_impls",
":source_info_util",
":typing",
":util",

View File

@ -46,6 +46,7 @@ from jax._src import dispatch
from jax._src import effects
from jax._src import array
from jax._src import dtypes
from jax._src import sharding_impls
from jax._src import sharding_specs
from jax._src import source_info_util
from jax._src import traceback_util
@ -146,8 +147,8 @@ float0 = dtypes.float0
def jit(
fun: Callable,
in_shardings=pxla._UNSPECIFIED,
out_shardings=pxla._UNSPECIFIED,
in_shardings=sharding_impls.UNSPECIFIED,
out_shardings=sharding_impls.UNSPECIFIED,
static_argnums: Union[int, Sequence[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
donate_argnums: Union[int, Sequence[int]] = (),
@ -503,11 +504,11 @@ def xla_computation(fun: Callable,
def make_axis_env(nreps):
if axis_env is None:
return xla.AxisEnv(nreps, (), ())
return sharding_impls.AxisEnv(nreps, (), ())
else:
nreps = nreps * math.prod(size for name, size in axis_env)
names, sizes = unzip2(axis_env)
return xla.AxisEnv(nreps, names, sizes)
return sharding_impls.AxisEnv(nreps, names, sizes)
@wraps(fun)
@api_boundary
@ -553,7 +554,7 @@ def xla_computation(fun: Callable,
ordered_effects=ordered_effects,
backend_or_name=backend,
platform=platform,
axis_context=mlir.ReplicaAxisContext(axis_env_),
axis_context=sharding_impls.ReplicaAxisContext(axis_env_),
name_stack=source_info_util.new_name_stack(
wrap_name(fun_name, "xla_computation")),
donated_args=donated_invars,

View File

@ -23,6 +23,7 @@ from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import sharding_impls
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import ad
@ -107,13 +108,13 @@ def pure_callback_lowering(ctx, *args, callback, **params):
sharding = None
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, mlir.ShardingContext):
if isinstance(axis_context, sharding_impls.ShardingContext):
if len(axis_context.device_assignment) > 1:
raise NotImplementedError(
"pure_callback is only supported in spmd computations when all mesh"
" axes are partitioned manually (no partial automatic sharding)."
)
if isinstance(axis_context, mlir.SPMDAxisContext):
if isinstance(axis_context, sharding_impls.SPMDAxisContext):
if axis_context.manual_axes != frozenset(axis_context.mesh.axis_names):
raise NotImplementedError(
"pure_callback is only supported in spmd computations when all mesh"
@ -272,7 +273,7 @@ def io_callback_lowering(ctx, *args, callback, ordered, **params):
# can only safely maximally shard. Should we allow device_index to be passed
# in like host_callback?
if isinstance(ctx.module_context.axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext)):
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext)):
# Apply maximal sharding so pjit only executes the callback on device 0.
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MAXIMAL

View File

@ -31,6 +31,7 @@ from jax._src import custom_derivatives
from jax._src import effects
from jax._src import pjit
from jax._src import prng
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util as jtu
@ -869,7 +870,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
# Update pjit params to account for extra error values.
num_error_vals = len(err_vals)
num_out_error_vals = out_tree.num_leaves - len(out_shardings)
sharding = pjit._UNSPECIFIED
sharding = sharding_impls.UNSPECIFIED
new_in_shardings = (*[sharding] * num_error_vals, *in_shardings)
new_out_shardings = (*[sharding] * num_out_error_vals, *out_shardings)

View File

@ -29,6 +29,7 @@ from jax._src import effects
from jax._src import linear_util as lu
from jax._src import mesh as mesh_lib
from jax._src import pjit
from jax._src import sharding_impls
from jax._src import tree_util
from jax._src import util
from jax._src.interpreters import ad
@ -123,13 +124,16 @@ ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule
def debug_callback_lowering(ctx, *args, effect, callback, **params):
axis_context = ctx.module_context.axis_context
if (isinstance(axis_context, mlir.SPMDAxisContext) and
if (isinstance(axis_context, sharding_impls.SPMDAxisContext) and
set(axis_context.manual_axes) == set(axis_context.mesh.axis_names)):
# If we have fully manual sharding during lowering, that means the JAX
# program has per-device semantics, so we run the callback on each device.
sharding = xc.OpSharding()
sharding.type = xc.OpSharding.Type.MANUAL
elif isinstance(axis_context, (mlir.ShardingContext, mlir.SPMDAxisContext)):
elif isinstance(
axis_context,
(sharding_impls.ShardingContext, sharding_impls.SPMDAxisContext),
):
# If we have fully automatic sharding during lowering, that means the JAX
# program has bulk array semantics, so we run the callback with a MAXIMAL
# sharding and hence execute it only once on the full logical value).
@ -312,9 +316,9 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
mesh = mesh_lib.thread_resources.env.physical_mesh
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, mlir.ShardingContext):
if isinstance(axis_context, sharding_impls.ShardingContext):
devices = axis_context.device_assignment
elif isinstance(axis_context, mlir.SPMDAxisContext):
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
devices = list(axis_context.mesh.devices.flat)
else:
raise NotImplementedError(type(axis_context))

View File

@ -53,7 +53,8 @@ from jax._src.monitoring import record_event_duration_secs
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding)
PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding,
UNSPECIFIED)
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
@ -211,13 +212,13 @@ def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params):
def sharded_lowering(fun, name, donated_invars, keep_unused,
*arg_specs, lowering_platform: Optional[str]):
in_avals, in_shardings = util.unzip2(arg_specs)
in_shardings = [pxla._UNSPECIFIED if i is None else i for i in in_shardings] # type: ignore
in_shardings = [UNSPECIFIED if i is None else i for i in in_shardings] # type: ignore
# Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know
# Pass in a singleton `UNSPECIFIED` for out_shardings because we don't know
# the number of output avals at this stage. lower_sharding_computation will
# apply it to all out_avals.
return pxla.lower_sharding_computation(
fun, 'jit', name, in_shardings, pxla._UNSPECIFIED, donated_invars,
fun, 'jit', name, in_shardings, UNSPECIFIED, donated_invars,
tuple(in_avals), keep_unused=keep_unused, always_lower=False,
devices_from_context=None, lowering_platform=lowering_platform)

View File

@ -24,7 +24,7 @@ import itertools
import re
import typing
from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional,
Protocol, Sequence, Set, Tuple, Type, Union, FrozenSet)
Protocol, Sequence, Set, Tuple, Type, Union)
import warnings
import numpy as np
@ -35,6 +35,7 @@ from jax._src import dtypes
from jax._src import effects as effects_lib
from jax._src import linear_util as lu
from jax._src import op_shardings
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import util
from jax._src import xla_bridge as xb
@ -339,68 +340,11 @@ def make_ir_context() -> ir.Context:
return context
Mesh = Any
MeshAxisName = Any
@dataclasses.dataclass(frozen=True)
class SPMDAxisContext:
"""A hardware axis context for parallel computations that use the GSPMD partitioner.
This includes the mesh that will later by used to execute this computation,
as well as a set of mesh axes that are currently (e.g. because the current lowering
is invoked inside an xmap) lowered in the MANUAL sharding mode.
"""
mesh: Mesh
manual_axes: FrozenSet[MeshAxisName] = frozenset()
@property
def axis_env(self):
# All collectives that touch axis_env should remember to set use_global_device_ids
# when this context is enabled!
if self.manual_axes != frozenset(self.mesh.axis_names):
raise NotImplementedError(
"Collectives in manually partitioned computations are only supported "
"when all mesh axes are partitioned manually (no partial automatic sharding). "
"Make sure that you mention all mesh axes in axis_resources!")
return self.unsafe_axis_env
@property
def unsafe_axis_env(self):
return xla.AxisEnv(
nreps=self.mesh.size,
names=self.mesh.axis_names,
sizes=tuple(self.mesh.shape.values()))
def extend_manual(self, axes: FrozenSet[MeshAxisName]) -> SPMDAxisContext:
return SPMDAxisContext(self.mesh, self.manual_axes | axes)
@dataclasses.dataclass(frozen=True)
class ReplicaAxisContext:
"""A hardware axis context for parallel computations that are partitioned by JAX.
Unlike in the SPMDAxisContext, this means that JAX might need to emit calls to
explicit collectives.
"""
axis_env: xla.AxisEnv
@dataclasses.dataclass(frozen=True)
class ShardingContext:
"""A hardware axis context for parallel computations that use the sharding
interface.
This context also uses the GSPMD partitioner.
"""
device_assignment: Sequence[xc.Device]
# Similar to SPMDContext as ShardingContext also uses the GSPMD partitioner.
@property
def axis_env(self):
return xla.AxisEnv(nreps=1, names=(), sizes=())
AxisContext = Union[SPMDAxisContext, ReplicaAxisContext, ShardingContext]
AxisContext = Union[
sharding_impls.SPMDAxisContext,
sharding_impls.ReplicaAxisContext,
sharding_impls.ShardingContext,
]
@dataclasses.dataclass
class ModuleContext:
@ -428,7 +372,7 @@ class ModuleContext:
@property
def axis_env(self) -> xla.AxisEnv:
def axis_env(self) -> sharding_impls.AxisEnv:
return self.axis_context.axis_env
def __init__(
@ -1539,7 +1483,7 @@ def xla_fallback_lowering(prim: core.Primitive):
def fallback(ctx: LoweringRuleContext, *args, **params):
module_ctx = ctx.module_context
axis_ctx = module_ctx.axis_context
if isinstance(axis_ctx, SPMDAxisContext):
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
axis_env = axis_ctx.unsafe_axis_env
else:
axis_env = module_ctx.axis_env
@ -1645,7 +1589,7 @@ def _emit_tpu_python_callback(
ctx: LoweringRuleContext,
callback,
token: Optional[Any],
operands: List[ir.Value],
operands: Sequence[ir.Value],
operand_avals: List[core.ShapedArray],
operand_shapes: List[xc.Shape],
result_avals: List[core.ShapedArray],
@ -1716,7 +1660,7 @@ def _aval_to_default_layout(aval):
def emit_python_callback(
ctx: LoweringRuleContext, callback, token: Optional[Any],
operands: List[ir.Value], operand_avals: List[core.ShapedArray],
operands: Sequence[ir.Value], operand_avals: List[core.ShapedArray],
result_avals: List[core.ShapedArray],
has_side_effect: bool, *, sharding: Optional[xc.OpSharding] = None,
operand_layouts: Optional[Sequence[Optional[Sequence[int]]]] = None,

View File

@ -17,13 +17,12 @@ from __future__ import annotations
import enum
from contextlib import contextmanager
from collections import defaultdict, OrderedDict, namedtuple
from collections import defaultdict, namedtuple
import dataclasses
from functools import partial, lru_cache, cached_property
import itertools as it
import logging
import math
import sys
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Sequence, Set, Tuple, Type, Union, Iterable,
TYPE_CHECKING, cast)
@ -40,7 +39,7 @@ from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import mesh
from jax._src import mesh as mesh_lib
from jax._src import op_shardings
from jax._src import sharding_specs
from jax._src import profiler
@ -61,6 +60,11 @@ from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
AUTOAxisResource, UnspecifiedValue, UNSPECIFIED,
get_array_mapping as _get_array_mapping, is_auto, is_unspecified
)
from jax._src.util import (unzip3, safe_map, safe_zip, partition_list,
wrap_name, tuple_delete, distributed_debug_log,
unzip2, HashableFunction, weakref_lru_cache)
@ -71,11 +75,6 @@ class WeakRefList(list):
pass
if sys.version_info >= (3, 9):
OrderedDictType = OrderedDict
else:
OrderedDictType = Dict
xe = xc._xla
unsafe_map, map = map, safe_map # type: ignore
@ -92,8 +91,8 @@ ShardedAxis = sharding_specs.ShardedAxis
Replicated = sharding_specs.Replicated
AvalDimSharding = Union[Unstacked, Chunked, NoSharding]
Mesh = jax._src.mesh.Mesh
MeshAxisName = mesh.MeshAxisName
Mesh = mesh_lib.Mesh
MeshAxisName = sharding_impls.MeshAxisName
MeshDimAssignment = Union[ShardedAxis, Replicated]
ShardingSpec = sharding_specs.ShardingSpec
@ -233,56 +232,6 @@ def _shard_abstract_array(size, axis: int, x):
shard_aval_handlers[ShapedArray] = _shard_abstract_array
class AUTOAxisResource:
pass
AUTO = AUTOAxisResource()
def is_auto(x):
return isinstance(x, AUTOAxisResource)
class UnspecifiedValue:
def __repr__(self):
return "UnspecifiedValue"
_UNSPECIFIED = UnspecifiedValue()
def _is_unspecified(x):
return isinstance(x, UnspecifiedValue)
"""
ArrayMapping specifies how an ndarray should map to mesh axes.
Note that the ordering is crucial for the cases when this mapping is non-injective
(i.e. when multiple mesh axes map to the same positional axis). Then, the
order of entries of the mapping determines a major-to-minor order on mesh axes,
according to which chunks of the value along the repeated dimension will be assigned.
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
The second dimension of the value would get chunked into 6 pieces, and assigned to the
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
that would mean that a flat list of chunks would get assigned to a flattened list of
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
mesh devices ndarray would have to be transposed before flattening and assignment.
"""
ArrayMapping = OrderedDictType[mesh.MeshAxisName, int]
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTOAxisResource,
UnspecifiedValue]
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
if not array_mapping:
return PartitionSpec()
max_index = -1
reverse_map = defaultdict(list)
for axis, index in array_mapping.items():
reverse_map[index].append(axis)
if index > max_index:
max_index = index
partitions = tuple(tuple(reverse_map[i]) if reverse_map[i] else None
for i in range(max_index + 1))
return PartitionSpec(*partitions)
def local_aval_to_result_handler(
aval: core.AbstractValue,
sharding: sharding_impls.XLACompatibleSharding,
@ -370,7 +319,7 @@ def make_sharded_device_array(
if sharding_spec is None:
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
mesh = jax._src.mesh.thread_resources.env.physical_mesh
mesh = mesh_lib.thread_resources.env.physical_mesh
sharding: sharding_impls.XLACompatibleSharding
if mesh.empty:
sharding = sharding_impls.PmapSharding(
@ -930,7 +879,7 @@ def lower_parallel_callable(
shards.num_global_shards, avals, replicas.num_global_replicas,
parts.num_partitions)
axis_env = xla.AxisEnv(
axis_env = sharding_impls.AxisEnv(
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
name_stack = source_info_util.new_name_stack(wrap_name(name, 'pmap'))
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
@ -952,7 +901,7 @@ def lower_parallel_callable(
ordered_effects,
backend,
lowering_platform or backend.platform,
mlir.ReplicaAxisContext(axis_env),
sharding_impls.ReplicaAxisContext(axis_env),
name_stack,
donated_invars,
replicated_args=replicated_args,
@ -1161,6 +1110,7 @@ class UnloadedPmapExecutable:
input_indices = []
for aval, spec in safe_zip(self.local_input_avals, self.input_shardings):
assert isinstance(spec, sharding_impls.PmapSharding), spec
assert isinstance(aval, core.ShapedArray), aval
input_indices.append(
sharding_specs.spec_to_indices(aval.shape, spec.sharding_spec)
if spec.sharding_spec is not None else None)
@ -1723,7 +1673,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
with maybe_extend_axis_env(axis_name, global_axis_size, None): # type: ignore
sub_ctx = ctx.module_context.replace(
axis_context=mlir.ReplicaAxisContext(new_env),
axis_context=sharding_impls.ReplicaAxisContext(new_env),
name_stack=ctx.module_context.name_stack.extend(
util.wrap_name(name, 'pmap')))
sharded_outs, _ = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, mlir.TokenSet(), (),
@ -1805,7 +1755,9 @@ def _full_to_shard_abstract_eval(x, axes, mesh, **_):
# TODO: Assert x is a global aval! Or ideally check that it's global in dims from axes!
return tile_aval_nd(mesh.shape, axes, x)
def manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[mesh.MeshAxisName], mesh: Mesh):
def manual_proto(
aval: core.ShapedArray,
manual_axes_set: FrozenSet[sharding_impls.MeshAxisName], mesh: Mesh):
"""Create an OpSharding proto that declares all mesh axes from `axes` as manual
and all others as replicated.
"""
@ -1832,7 +1784,7 @@ def manual_proto(aval: core.ShapedArray, manual_axes_set: FrozenSet[mesh.MeshAxi
@partial(mlir.register_lowering, full_to_shard_p)
def _full_to_shard_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
manual_axes: FrozenSet[mesh.MeshAxisName]):
manual_axes: FrozenSet[sharding_impls.MeshAxisName]):
# TODO: Can we short-circuit for replicated values? Probably not.
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
@ -1852,7 +1804,7 @@ def _shard_to_full_abstract_eval(x, axes, mesh, **_):
@partial(mlir.register_lowering, shard_to_full_p)
def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
manual_axes: FrozenSet[mesh.MeshAxisName]):
manual_axes: FrozenSet[sharding_impls.MeshAxisName]):
aval_in, = ctx.avals_in
aval_out, = ctx.avals_out
proto = manual_proto(aval_in, manual_axes, mesh)
@ -1863,7 +1815,7 @@ def _shard_to_full_lowering(ctx, x, *, axes: ArrayMapping, mesh: Mesh,
return mlir.wrap_with_shard_to_full_op(result_type, sx, sharding_proto, unspecified_dims),
@lu.transformation
def vtile_manual(manual_axes: FrozenSet[mesh.MeshAxisName],
def vtile_manual(manual_axes: FrozenSet[sharding_impls.MeshAxisName],
mesh: Mesh,
in_axes: Sequence[ArrayMapping],
out_axes: Sequence[ArrayMapping],
@ -1882,7 +1834,7 @@ class TileVectorize:
@dataclasses.dataclass(frozen=True)
class TileManual:
manual_axes: FrozenSet[mesh.MeshAxisName]
manual_axes: FrozenSet[sharding_impls.MeshAxisName]
TilingMethod = Union[TileVectorize, TileManual]
@ -1971,7 +1923,7 @@ def _get_and_check_device_assignment(
devices = list(devices)
for i, s_type, source_info in shardings:
if is_auto(i) or _is_unspecified(i):
if is_auto(i) or is_unspecified(i):
continue
# Assign `first_sharding_info` after `AUTO` and `UNSPECIFIED` have been
# skipped.
@ -2096,14 +2048,14 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
in_op_shardings = map(_to_logical_op_sharding, global_in_avals, in_shardings)
out_op_shardings = map(_to_logical_op_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(global_in_avals)
axis_ctx = mlir.ShardingContext(device_assignment)
axis_ctx = sharding_impls.ShardingContext(device_assignment)
else:
# This path is triggered for `jit(pmap)` cases.
replicated_args = None
in_op_shardings = None
out_op_shardings = None
axis_env = xla.AxisEnv(nreps, (), ())
axis_ctx = mlir.ReplicaAxisContext(axis_env)
axis_env = sharding_impls.AxisEnv(nreps, (), ())
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
module_name = f"{api_name}_{fun_name}"
@ -2189,10 +2141,10 @@ def lower_sharding_computation(
) -> MeshComputation:
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
The caller of this code can pass in a singleton _UNSPECIFIED because the
The caller of this code can pass in a singleton UNSPECIFIED because the
number of out_avals might not be known at that time and
lower_sharding_computation calculates the number of out_avals so it can apply
the singleton _UNSPECIFIED to all out_avals.
the singleton UNSPECIFIED to all out_avals.
"""
# 1. Trace to jaxpr and preprocess/verify it
(closed_jaxpr, global_in_avals, global_out_avals, donated_invars,
@ -2202,8 +2154,8 @@ def lower_sharding_computation(
jaxpr = closed_jaxpr.jaxpr
in_shardings = tuple(s for i, s in enumerate(in_shardings) if i in kept_var_idx)
if _is_unspecified(out_shardings):
out_shardings = (_UNSPECIFIED,) * len(global_out_avals)
if is_unspecified(out_shardings):
out_shardings = (UNSPECIFIED,) * len(global_out_avals)
assert isinstance(out_shardings, tuple)
assert len(out_shardings) == len(global_out_avals), (
len(out_shardings), len(global_out_avals))
@ -2221,12 +2173,12 @@ def lower_sharding_computation(
committed = bool(
devices_from_context or
len(device_assignment) > 1 or
any(not _is_unspecified(i) for i in in_shardings) or
any(not _is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not _is_unspecified(o) for o in out_shardings))
any(not is_unspecified(i) for i in in_shardings) or
any(not is_unspecified(js) for js, _ in jaxpr_sharding) or
any(not is_unspecified(o) for o in out_shardings))
in_shardings = tuple(sharding_impls.GSPMDSharding.get_replicated(device_assignment)
if _is_unspecified(i) else i for i in in_shardings)
if is_unspecified(i) else i for i in in_shardings)
da_object = _create_da_object(tuple(device_assignment))
@ -2260,7 +2212,7 @@ def lower_sharding_computation(
# and don't need to evaluate their arguments.
if (not always_lower and not (jaxpr.effects or has_outfeed) and
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and
all(_is_unspecified(o) for o in out_shardings)):
all(is_unspecified(o) for o in out_shardings)):
return MeshComputation(
str(name_stack), None, True, donated_invars, jaxpr=jaxpr,
consts=closed_jaxpr.consts, global_in_avals=global_in_avals,
@ -2309,7 +2261,7 @@ def lower_sharding_computation(
def _to_logical_op_sharding(
aval: core.AbstractValue, sharding: Union[MaybeSharding, AUTOAxisResource]
) -> Optional[xc.OpSharding]:
if _is_unspecified(sharding) or is_auto(sharding):
if is_unspecified(sharding) or is_auto(sharding):
return None
elif isinstance(aval, ShapedArray):
assert isinstance(sharding, sharding_impls.XLACompatibleSharding)
@ -2419,15 +2371,16 @@ def lower_mesh_computation(
in_partitions = map(_to_logical_op_sharding, global_in_avals, in_shardings)
out_partitions = map(_to_logical_op_sharding, global_out_avals, out_shardings)
replicated_args = [False] * len(in_jaxpr_avals)
axis_ctx = mlir.SPMDAxisContext(mesh, manual_axes)
axis_ctx = sharding_impls.SPMDAxisContext(mesh, manual_axes)
else:
replicated_args = [not get_array_mapping(i.spec) for i in in_shardings] # type: ignore
in_partitions = None
out_partitions = None
axis_env = xla.AxisEnv(nreps=mesh.size,
names=tuple(global_axis_sizes.keys()),
sizes=tuple(global_axis_sizes.values()))
axis_ctx = mlir.ReplicaAxisContext(axis_env)
axis_env = sharding_impls.AxisEnv(
nreps=mesh.size,
names=tuple(global_axis_sizes.keys()),
sizes=tuple(global_axis_sizes.values()))
axis_ctx = sharding_impls.ReplicaAxisContext(axis_env)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
module: Union[str, xc.XlaComputation]
module_name = f"{api_name}_{fun_name}"
@ -2834,7 +2787,7 @@ class UnloadedMeshExecutable:
for x, o in safe_zip(out_shardings_xla, out_shardings)
]
out_shardings, are_out_shardings_from_xla = unzip2(out_shardings_tuple)
elif (out_shardings and any(_is_unspecified(o) for o in out_shardings)
elif (out_shardings and any(is_unspecified(o) for o in out_shardings)
and pmap_nreps == 1):
assert mesh is None
_, out_shardings_xla = get_gspmd_shardings_from_executable( # type: ignore
@ -2844,7 +2797,7 @@ class UnloadedMeshExecutable:
out_shardings, are_out_shardings_from_xla = [], [] # type: ignore
for xla_s, orig, aval in safe_zip(out_shardings_xla, orig_out_shardings,
global_out_avals):
if _is_unspecified(orig):
if is_unspecified(orig):
out_shardings.append(xla_s)
are_out_shardings_from_xla.append(True)
else:
@ -3133,7 +3086,7 @@ def create_mesh_pspec_sharding(
def check_device_backend_on_shardings(shardings) -> bool:
for i in shardings:
if _is_unspecified(i) or is_auto(i):
if is_unspecified(i) or is_auto(i):
continue
if hasattr(i, '_original_sharding') and getattr(
i._original_sharding, '_device_backend', False):
@ -3163,10 +3116,8 @@ def check_gda_or_array_xla_sharding_match(
def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
# Import here to avoid cyclic import error when importing gda in pjit.py.
from jax.experimental.pjit import get_array_mapping as _get_array_mapping, _prepare_axis_resources
parsed_pspec, _, _ = _prepare_axis_resources(pspec, "pspec to array_mapping")
parsed_pspec, _, _ = sharding_impls.prepare_axis_resources(
pspec, "pspec to array_mapping")
return _get_array_mapping(parsed_pspec)

View File

@ -22,7 +22,7 @@ import itertools as it
import math
import operator
import re
from typing import (Any, Callable, Dict, NamedTuple, Optional, Protocol,
from typing import (Any, Callable, Dict, Optional, Protocol,
Sequence, Set, Type, Tuple, Union, TYPE_CHECKING)
import numpy as np
@ -34,6 +34,7 @@ from jax._src import dtypes
from jax._src import source_info_util
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.sharding_impls import AxisEnv
from jax._src.util import safe_zip, safe_map
from jax._src.typing import Shape
@ -254,13 +255,6 @@ def primitive_subcomputation(platform: str, axis_env: 'AxisEnv',
### compiling jaxprs
class AxisEnv(NamedTuple):
"""Represents a pmap mesh (only along the replica axes)."""
nreps: int
names: Tuple[Any, ...]
sizes: Tuple[int, ...]
@dataclasses.dataclass
class TranslationContext:
builder: xc.XlaBuilder
@ -272,8 +266,6 @@ class TranslationContext:
def replace(self, **kw): return dataclasses.replace(self, **kw)
def xla_destructure(c, ans):
num_elements = len(c.get_shape(ans).tuple_shapes())
return [xops.GetTupleElement(ans, i) for i in range(num_elements)]

View File

@ -28,6 +28,7 @@ from jax import tree_util
from jax._src import core
from jax._src import dtypes
from jax._src import sharding_impls
from jax._src import util
from jax._src.core import ShapedArray, AxisName, raise_to_shaped
from jax._src.interpreters import ad
@ -750,8 +751,10 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
_replica_groups(ctx.module_context.axis_env, named_axes,
axis_index_groups))
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext))
is_spmd = isinstance(
axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
def all_reduce(aval, x):
if is_spmd:
@ -880,7 +883,10 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm):
full_perm = full_perm.reshape((-1, 2))
axis_context = ctx.module_context.axis_context
is_manual = isinstance(axis_context, mlir.SPMDAxisContext) and axis_context.manual_axes
is_manual = (
isinstance(axis_context, sharding_impls.SPMDAxisContext)
and axis_context.manual_axes
)
if is_manual:
channel = ctx.module_context.new_channel()
other_args = dict(
@ -978,8 +984,10 @@ def _all_to_all_lowering(ctx, x, *,
split_count = len(replica_groups[0])
if not all(split_count == len(g) for g in replica_groups):
raise ValueError('Replica groups must be equally sized')
is_spmd = isinstance(ctx.module_context.axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext))
is_spmd = isinstance(
ctx.module_context.axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
# channel ID, as otherwise it interprets the devices as replicas instead
@ -1209,8 +1217,10 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
x_aval, = ctx.avals_in
out_aval, = ctx.avals_out
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext))
is_spmd = isinstance(
axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
if (ctx.module_context.platform == 'tpu' or
ctx.module_context.platform in ('cuda', 'rocm')
and all_gather_dimension == 0):
@ -1354,8 +1364,10 @@ def _reduce_scatter_lowering(prim, reducer, ctx, x,
scatter_out_shape = list(x_aval.shape)
scatter_out_shape[scatter_dimension] //= axis_size
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext))
is_spmd = isinstance(
axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
if is_spmd:
# We want to emit the all-gather with global device IDs and a unique
# channel ID, as otherwise it interprets the devices as replicas instead
@ -1572,8 +1584,10 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
)
mod = mlir.ir_constant(np.array(axis_env.sizes[axis_pos], dtype=np.uint32))
axis_context = ctx.module_context.axis_context
is_spmd = isinstance(axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext))
is_spmd = isinstance(
axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
if is_spmd:
device_id = hlo.PartitionIdOp()
else:

View File

@ -662,6 +662,7 @@ def _select_and_gather_add_lowering(
max_bits=64):
_, operand_aval, = ctx.avals_in
out_aval, = ctx.avals_out
assert isinstance(operand_aval, core.ShapedArray), operand_aval
dtype = operand_aval.dtype
etype = mlir.dtype_to_ir_type(dtype)
nbits = dtypes.finfo(dtype).bits

View File

@ -26,9 +26,10 @@ from jax import numpy as jnp
from jax._src import core
from jax._src import dispatch
from jax._src import effects
from jax._src import mesh
from jax._src import mesh as mesh_lib
from jax._src import linear_util as lu
from jax._src import op_shardings
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import stages
from jax._src import traceback_util
@ -47,10 +48,11 @@ from jax._src.interpreters.partial_eval import (
convert_constvars_jaxpr, new_jaxpr_eqn)
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.pjit import (
sharding_constraint_p, ParsedPartitionSpec, get_unconstrained_dims,
GSPMDSharding)
from jax._src.sharding_impls import NamedSharding
from jax._src.pjit import (sharding_constraint_p, get_unconstrained_dims,
GSPMDSharding)
from jax._src.sharding_impls import (
ArrayMapping, NamedSharding, ParsedPartitionSpec,
array_mapping_to_axis_resources)
from jax._src.tree_util import (tree_flatten, tree_unflatten, all_leaves,
tree_map, treedef_tuple)
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3,
@ -90,12 +92,12 @@ class FrozenDict(abc.Mapping):
# Multi-dimensional generalized map
AxisName = core.AxisName
ResourceAxisName = mesh.ResourceAxisName # Different name just for documentation purposes
Mesh = mesh.Mesh
MeshAxisName = mesh.MeshAxisName
ResourceEnv = mesh.ResourceEnv
EMPTY_ENV = mesh.EMPTY_ENV
thread_resources = mesh.thread_resources
ResourceAxisName = mesh_lib.ResourceAxisName # Different name just for documentation purposes
Mesh = mesh_lib.Mesh
MeshAxisName = mesh_lib.MeshAxisName
ResourceEnv = mesh_lib.ResourceEnv
EMPTY_ENV = mesh_lib.EMPTY_ENV
thread_resources = mesh_lib.thread_resources
class SerialLoop:
@ -165,7 +167,7 @@ def serial_loop(name: ResourceAxisName, length: int):
axis_resources={'i': 'l'})(x)
"""
old_env: ResourceEnv = getattr(thread_resources, "env", EMPTY_ENV)
thread_resources.env = old_env.with_extra_loop(mesh.Loop(name, length))
thread_resources.env = old_env.with_extra_loop(mesh_lib.Loop(name, length))
try:
yield
finally:
@ -677,9 +679,9 @@ def make_xmap_callable(fun: lu.WrappedFun,
tiling_method = pxla.TileManual(manual_mesh_axes)
else:
tiling_method = pxla.TileVectorize()
in_shardings = [NamedSharding(mesh, pxla.array_mapping_to_axis_resources(i))
in_shardings = [NamedSharding(mesh, array_mapping_to_axis_resources(i))
for i in mesh_in_axes]
out_shardings = [NamedSharding(mesh, pxla.array_mapping_to_axis_resources(o))
out_shardings = [NamedSharding(mesh, array_mapping_to_axis_resources(o))
for o in mesh_out_axes]
return pxla.lower_mesh_computation(
f, 'xmap', name, mesh,
@ -937,7 +939,7 @@ def _resource_typing_xmap(avals,
raise JAXTypeError(
f"Detected disallowed xmap axis name shadowing at "
f"{source_info_util.summarize(source_info)} "
f"(shadowed axes: {mesh.show_axes(overlap)})")
f"(shadowed axes: {mesh_lib.show_axes(overlap)})")
if resource_env.physical_mesh != params['resource_env'].physical_mesh:
raise RuntimeError("Changing the physical mesh is not allowed inside xmap.")
@ -965,9 +967,9 @@ def _resource_typing_xmap(avals,
raise JAXTypeError(
f"One of xmapped function ({params['name']}) outputs is broadcast "
f"along axis `{baxis}` which is assigned to resources "
f"{mesh.show_axes(baxis_resources)}, but the output is already "
f"partitioned along {mesh.show_axes(overlap)}, because its "
f"named shape contains {mesh.show_axes(partitioning_axes)}")
f"{mesh_lib.show_axes(baxis_resources)}, but the output is already "
f"partitioned along {mesh_lib.show_axes(overlap)}, because its "
f"named shape contains {mesh_lib.show_axes(partitioning_axes)}")
pxla.custom_resource_typing_rules[xmap_p] = _resource_typing_xmap
@ -1269,7 +1271,7 @@ batching.BatchTrace.post_process_xmap = _batch_trace_post_process_xmap
# -------- nested xmap handling --------
def _xmap_lowering_rule(ctx, *args, **kwargs):
if isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext):
if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext):
if config.experimental_xmap_spmd_lowering_manual:
return _xmap_lowering_rule_spmd_manual(ctx, *args, **kwargs)
else:
@ -1277,9 +1279,9 @@ def _xmap_lowering_rule(ctx, *args, **kwargs):
# Here ShardingContext is used in place of ReplicaAxisContext because when
# axis_resources and mesh is not used with xmap, `make_xmap_callable` will
# go via `dispatch.sharded_lowering` path which sets the context to
# ShardingContext. mlir.ShardingContext is not used for SPMD.
# ShardingContext. sharding_impls.ShardingContext is not used for SPMD.
elif isinstance(ctx.module_context.axis_context,
(mlir.ReplicaAxisContext, mlir.ShardingContext)):
(sharding_impls.ReplicaAxisContext, sharding_impls.ShardingContext)):
return _xmap_lowering_rule_replica(ctx, *args, **kwargs)
else:
raise AssertionError("Unrecognized axis context type!")
@ -1382,7 +1384,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
# XXX: We modify mesh_in_axes and mesh_out_axes here
def add_spmd_axes(
flat_mesh_axes: Sequence[pxla.ArrayMapping],
flat_mesh_axes: Sequence[ArrayMapping],
flat_extra_axes: Optional[Sequence[Sequence[Sequence[MeshAxisName]]]]):
if flat_extra_axes is None:
return
@ -1456,7 +1458,8 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
assert isinstance(ctx.module_context.axis_context, mlir.SPMDAxisContext)
assert isinstance(ctx.module_context.axis_context,
sharding_impls.SPMDAxisContext)
sub_ctx = ctx.module_context.replace(
name_stack=ctx.module_context.name_stack.extend(wrap_name(name, 'xmap')),
axis_context=ctx.module_context.axis_context.extend_manual(manual_mesh_axes))
@ -1755,7 +1758,7 @@ def _check_gda_or_array_xmap_partitioning(axis_resources, resource_env,
s = arg.sharding
xmap_sharding = pxla.create_mesh_pspec_sharding(
mesh, pxla.array_mapping_to_axis_resources(xmap_array_mapping))
mesh, array_mapping_to_axis_resources(xmap_array_mapping))
# This check is cached because comparing OpSharding is expensive during
# dispatch and if the shardings are the same, then there is no need to
# compare twice.

View File

@ -13,10 +13,8 @@
# limitations under the License.
import dataclasses
from enum import IntEnum
import inspect
import numpy as np
from collections import OrderedDict, Counter
from typing import (Callable, Sequence, Tuple, Union, cast, List, Optional,
Iterable, NamedTuple, Any)
import itertools as it
@ -30,6 +28,7 @@ from jax._src import dispatch
from jax._src import mesh as mesh_lib
from jax._src import linear_util as lu
from jax._src import op_shardings
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
@ -51,10 +50,13 @@ from jax._src.interpreters import pxla
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_client as xc
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
NamedSharding, XLACompatibleSharding, GSPMDSharding,
XLADeviceAssignment, SingleDeviceSharding, PmapSharding)
XLADeviceAssignment, SingleDeviceSharding, PmapSharding,
AUTOAxisResource, UNSPECIFIED, UnspecifiedValue,
CanonicalizedParsedPartitionSpec, ParsedPartitionSpec,
SpecSync, get_single_pspec, is_auto, is_unspecified, is_unspecified_or_auto,
prepare_axis_resources)
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import (
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
@ -62,7 +64,7 @@ from jax._src.tree_util import (
prefix_errors, generate_key_paths)
from jax._src.util import (
HashableFunction, safe_map, safe_zip, wraps,
distributed_debug_log, split_list, tuple_insert, weakref_lru_cache,
distributed_debug_log, split_list, weakref_lru_cache,
merge_lists)
map, unsafe_map = safe_map, map
@ -70,41 +72,10 @@ zip, unsafe_zip = safe_zip, zip
traceback_util.register_exclusion(__file__)
_AUTOAxisResource = pxla.AUTOAxisResource
AUTO = pxla.AUTO # type: ignore
is_auto = pxla.is_auto
_UnspecifiedValue = pxla.UnspecifiedValue
_UNSPECIFIED = pxla._UNSPECIFIED # type: ignore
_is_unspecified = pxla._is_unspecified
def _is_unspecified_or_auto(x):
return is_auto(x) or _is_unspecified(x)
PjitSharding = Union[GSPMDSharding, _UnspecifiedValue, _AUTOAxisResource]
PjitShardingMinusUnspecified = Union[GSPMDSharding, _AUTOAxisResource]
MeshSharding = Union[NamedSharding, _UnspecifiedValue, _AUTOAxisResource]
MeshShardingMinusUnspecified = Union[NamedSharding, _AUTOAxisResource]
def _check_all_or_none_unspecified(axis_resources, name):
if not axis_resources:
return False
unspecified_count = 0
unspecified = _is_unspecified(axis_resources[0])
for resource in axis_resources:
current_is_unspecified = _is_unspecified(resource)
if current_is_unspecified:
unspecified_count += 1
assert unspecified_count == 1
if current_is_unspecified != unspecified:
raise ValueError(f'`pjit._UNSPECIFIED` exists in {name}. '
f'Make sure that every entry in {name} is '
'`pjit._UNSPECIFIED`.')
return unspecified
PjitSharding = Union[GSPMDSharding, UnspecifiedValue, AUTOAxisResource]
PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTOAxisResource]
MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTOAxisResource]
MeshShardingMinusUnspecified = Union[NamedSharding, AUTOAxisResource]
def _try_infer_args(f, tree):
dummy_args = tree_unflatten(tree, [False] * tree.num_leaves)
@ -281,27 +252,27 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
def _resolve_axis_resources_and_shardings_arg(
in_shardings, out_shardings, in_axis_resources, out_axis_resources):
if not _is_unspecified(in_shardings) and not _is_unspecified(in_axis_resources):
if not is_unspecified(in_shardings) and not is_unspecified(in_axis_resources):
raise ValueError(
'Setting both in_shardings and in_axis_resources is not '
'allowed. in_axis_resources is deprecated. Please use in_shardings.')
if not _is_unspecified(out_shardings) and not _is_unspecified(out_axis_resources):
if not is_unspecified(out_shardings) and not is_unspecified(out_axis_resources):
raise ValueError(
'Setting both out_shardings and out_axis_resources is not '
'allowed. out_axis_resources is deprecated. Please use out_shardings.')
if (not _is_unspecified(in_axis_resources) or
not _is_unspecified(out_axis_resources)):
if (not is_unspecified(in_axis_resources) or
not is_unspecified(out_axis_resources)):
warnings.warn(
'in_axis_resources and out_axis_resources are deprecated. Please use '
'in_shardings and out_shardings as their replacement.',
DeprecationWarning)
if not _is_unspecified(in_axis_resources):
if not is_unspecified(in_axis_resources):
final_in_shardings = in_axis_resources
else:
final_in_shardings = in_shardings
if not _is_unspecified(out_axis_resources):
if not is_unspecified(out_axis_resources):
final_out_shardings = out_axis_resources
else:
final_out_shardings = out_shardings
@ -326,10 +297,10 @@ def pre_infer_params(fun, in_shardings, out_shardings,
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
f"got {device=} and {backend=}")
if not _is_unspecified(in_shardings):
if not is_unspecified(in_shardings):
raise ValueError('If backend or device is specified on jit, then '
'in_shardings should not be specified.')
if not _is_unspecified(out_shardings):
if not is_unspecified(out_shardings):
raise ValueError('If backend or device is specified on jit, then '
'out_shardings should not be specified.')
@ -341,8 +312,8 @@ def pre_infer_params(fun, in_shardings, out_shardings,
# rather than raising an error. https://github.com/google/jax/issues/2367
in_shardings = tuple(in_shardings)
in_shardings, _, _ = _prepare_axis_resources(in_shardings, 'in_shardings')
out_shardings, _, _ = _prepare_axis_resources(out_shardings, 'out_shardings')
in_shardings, _, _ = prepare_axis_resources(in_shardings, 'in_shardings')
out_shardings, _, _ = prepare_axis_resources(out_shardings, 'out_shardings')
donate_argnums, static_argnums, static_argnames = resolve_argnums(
fun, donate_argnums, static_argnums, static_argnames)
@ -394,8 +365,8 @@ def _pjit_explicit_sharding(in_shardings, out_shardings, device,
out_shardings_flat, _ = tree_flatten(out_shardings)
return (device is not None or
backend is not None or
any(not _is_unspecified(i) for i in in_shardings_flat) or
any(not _is_unspecified(i) for i in out_shardings_flat))
any(not is_unspecified(i) for i in in_shardings_flat) or
any(not is_unspecified(i) for i in out_shardings_flat))
class PjitInfo(NamedTuple):
@ -418,7 +389,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
donate_argnums, device, backend, keep_unused, inline,
resource_env, abstracted_axes) = pjit_info_args
if kwargs and not _is_unspecified(user_in_shardings):
if kwargs and not is_unspecified(user_in_shardings):
raise ValueError(
"pjit does not support kwargs when in_shardings is specified.")
@ -511,7 +482,7 @@ def common_infer_params(pjit_info_args, *args, **kwargs):
num_extra_args = len(implicit_args) + len(consts)
canonicalized_in_shardings_flat = \
(_UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat
(UNSPECIFIED,) * num_extra_args + canonicalized_in_shardings_flat
donated_invars = (False,) * num_extra_args + donated_invars
assert (len(canonicalized_in_shardings_flat) == len(donated_invars) ==
len(consts) + len(args_flat))
@ -574,10 +545,10 @@ def _flat_axes_specs(abstracted_axes, *args, **kwargs
# because `None` means that the input is fully replicated.
def pjit(
fun: Callable,
in_shardings=_UNSPECIFIED,
out_shardings=_UNSPECIFIED,
in_axis_resources=_UNSPECIFIED,
out_axis_resources=_UNSPECIFIED,
in_shardings=UNSPECIFIED,
out_shardings=UNSPECIFIED,
in_axis_resources=UNSPECIFIED,
out_axis_resources=UNSPECIFIED,
static_argnums: Union[int, Sequence[int], None] = None,
static_argnames: Union[str, Iterable[str], None] = None,
donate_argnums: Union[int, Sequence[int]] = (),
@ -775,13 +746,13 @@ def hashable_pytree(pytree):
@lru_cache(maxsize=4096)
def _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x):
if _is_unspecified_or_auto(x):
if is_unspecified_or_auto(x):
return x
return pxla.create_mesh_pspec_sharding(mesh, x.user_spec, x)
def _create_sharding_for_array(mesh, x, name):
if isinstance(x, XLACompatibleSharding) or _is_unspecified_or_auto(x):
if isinstance(x, XLACompatibleSharding) or is_unspecified_or_auto(x):
return x
if mesh is None:
msg = ('jax.jit only supports `XLACompatibleSharding`s being passed to'
@ -803,7 +774,7 @@ def _create_sharding_for_array(mesh, x, name):
' call site? Alternatively, provide `XLACompatibleSharding`s to'
' `in_shardings` and `out_shardings` and then the mesh context manager'
' is not required.')
# A nice user error is raised in _prepare_axis_resources.
# A nice user error is raised in prepare_axis_resources.
assert isinstance(x, ParsedPartitionSpec), x
return _create_mesh_pspec_sharding_from_parsed_pspec(mesh, x)
@ -884,7 +855,7 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
orig_in_shardings = in_shardings_thunk()
# Only do this if original in_shardings are unspecified. If it is AUTO, go
# via flatten_axis_resources.
if _is_unspecified(orig_in_shardings):
if is_unspecified(orig_in_shardings):
in_shardings_flat = (orig_in_shardings,) * len(in_avals)
else:
in_shardings_flat = flatten_axis_resources(
@ -895,7 +866,7 @@ def _process_in_axis_resources(in_shardings_thunk, in_avals, in_tree,
pjit_check_aval_sharding(in_shardings_flat, in_avals,
"pjit arguments", allow_uneven_sharding=False)
canonicalized_shardings = tuple(
i if _is_unspecified_or_auto(i) else to_gspmd_sharding(i, aval.ndim)
i if is_unspecified_or_auto(i) else to_gspmd_sharding(i, aval.ndim)
for i, aval in zip(in_shardings_flat, in_avals))
return canonicalized_shardings
@ -933,7 +904,7 @@ def _check_and_canonicalize_out_shardings(
# instead. This condition exists because flatten_axis_resources passes in an
# `object()` while unflattening which breaks assertion is user defined
# pytrees (which shouldn't exist but they do).
if (_is_unspecified(orig_out_shardings) or
if (is_unspecified(orig_out_shardings) or
isinstance(orig_out_shardings, XLACompatibleSharding)):
out_shardings_flat = (orig_out_shardings,) * len(out_type)
else:
@ -946,7 +917,7 @@ def _check_and_canonicalize_out_shardings(
allow_uneven_sharding=False)
canonicalized_out_shardings_flat = tuple(
o if _is_unspecified(o) or is_auto(o) else to_gspmd_sharding(o, aval.ndim)
o if is_unspecified(o) or is_auto(o) else to_gspmd_sharding(o, aval.ndim)
for o, aval in zip(out_shardings_flat, out_type)
)
return canonicalized_out_shardings_flat
@ -965,7 +936,7 @@ def _pjit_jaxpr(fun, out_shardings_thunk, in_type, debug_info, out_tree,
def pjit_check_aval_sharding(
shardings, flat_avals, what_aval: str, allow_uneven_sharding: bool):
for aval, s in zip(flat_avals, shardings):
if _is_unspecified_or_auto(s):
if is_unspecified_or_auto(s):
continue
shape = aval.shape
try:
@ -994,170 +965,6 @@ def pjit_check_aval_sharding(
f"(full shape: {shape}) ")
class SpecSync(IntEnum):
"""Encodes how much out of sync the real value of partitions is compared to the user specified one.
We use this to make sure we don't show garbage modified values while claiming
that the users have specified them like that.
"""
OUT_OF_SYNC = 0 # Arbitrary changes, including new axes inserted
DIM_PERMUTE = 1 # Dimensions permuted, but no new sharding axes
IN_SYNC = 2 # Entirely in sync
class ParsedPartitionSpec:
__slots__ = ('unsafe_user_spec', 'partitions', 'sync')
def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC):
self.unsafe_user_spec = user_spec
# None in partitions represents unconstrained dim.
# TODO(yashkatariya): May use a sentinel value.
self.partitions = tuple(partitions)
self.sync = sync
@property
def user_spec(self):
return self.unsynced_user_spec(SpecSync.IN_SYNC)
def get_partition_spec(self) -> PartitionSpec:
if self.sync < SpecSync.IN_SYNC:
return _get_single_pspec(self)
else:
if isinstance(self.unsafe_user_spec, PartitionSpec):
return self.unsafe_user_spec
else:
return _get_single_pspec(self)
def unsynced_user_spec(self, min_sync):
if self.sync < min_sync:
raise AssertionError(f"Please open a bug report! ({self.sync} >= {min_sync})")
return self.unsafe_user_spec
def insert_axis_partitions(self, dim, val):
parts = self.partitions
too_short = dim - len(parts)
if too_short > 0:
parts += ((),) * too_short
new_partitions = tuple_insert(parts, dim, val)
new_sync = SpecSync.DIM_PERMUTE if (val == () or val is None) else SpecSync.OUT_OF_SYNC
return ParsedPartitionSpec(self.unsafe_user_spec, new_partitions, sync=new_sync)
@classmethod
def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False):
if entry is None:
return cls(entry, ())
if not isinstance(entry, PartitionSpec):
raise TypeError(f"{arg_name} are expected to be "
f"PartitionSpec instances or None, but got {entry}")
axis_specs = []
for axis_spec in entry:
if axis_spec is None:
axis_spec = ()
elif isinstance(axis_spec, (list, tuple)):
axis_spec = tuple(axis_spec)
elif axis_spec == PartitionSpec.UNCONSTRAINED:
if not allow_unconstrained_dims:
raise ValueError(f"Unconstrained dims are not allowed: {entry}")
axis_spec = None
else:
axis_spec = (axis_spec,)
axis_specs.append(axis_spec)
return cls(entry, axis_specs)
def __hash__(self):
return hash((self.partitions, self.sync))
def __eq__(self, other):
return (self.partitions == other.partitions and
self.sync == other.sync)
def __len__(self):
return len(self.partitions)
def __getitem__(self, i):
return self.partitions[i]
def __iter__(self):
return iter(self.partitions)
def __repr__(self):
return (f"ParsedPartitionSpec(partitions={self.partitions}, "
f"unsafe_user_spec={self.unsafe_user_spec}, "
f"sync={self.sync})")
class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
"""ParsedPartitionSpecs that are canonicalized.
ParsedPartitionSpecs may contain trailing empty tuples, that make them
semantically different in general, and yet in some situations we prefer
to regard them as equivalent. For example, partitions of () and ((),)
cannot be always considered equivalent, since the first one is a valid
spec for a scalar value, while the second is not! However, when either of
those are applied to a 2D array, they both mean that the array is fully
replicated.
So CanonicalizedParsedPartitionSpecs removes the trailing empty tuples from
partitions.
"""
def __init__(self, parsed_pspec: ParsedPartitionSpec):
partitions = list(parsed_pspec.partitions)
while partitions and partitions[-1] == ():
partitions.pop()
super().__init__(parsed_pspec.unsafe_user_spec, partitions,
parsed_pspec.sync)
def __repr__(self):
return (f"CanonicalizedParsedPartitionSpec(partitions={self.partitions}, "
f"unsafe_user_spec={self.unsafe_user_spec}, "
f"sync={self.sync})")
def _prepare_axis_resources(axis_resources,
arg_name,
allow_unconstrained_dims=False):
# PyTrees don't treat None values as leaves, so we use an is_leaf function.
entries, treedef = tree_flatten(axis_resources, is_leaf=lambda x: x is None)
what = f"{arg_name} leaf specifications"
# All entries should be specified or if unspecified then there should only
# be 1 entry for that since _UNSPECIFIED is a private API.
_check_all_or_none_unspecified(entries, arg_name)
new_entries = []
for entry in entries:
if _is_unspecified_or_auto(entry):
new_entries.append(entry)
elif isinstance(entry, Sharding):
if isinstance(entry, PmapSharding):
raise ValueError(f'One of {what} got sharding {entry} which is not '
'allowed.')
if not isinstance(entry, XLACompatibleSharding):
raise ValueError(f'One of {what} got sharding {entry} which is not a '
'subclass of XLACompatibleSharding.')
new_entries.append(entry)
else:
new_entries.append(ParsedPartitionSpec.from_user_input(
entry, what, allow_unconstrained_dims=allow_unconstrained_dims))
_check_unique_resources(new_entries, arg_name)
return tree_unflatten(treedef, new_entries), new_entries, treedef
def _check_unique_resources(axis_resources, arg_name):
for arg_axis_resources in axis_resources:
if not arg_axis_resources: continue
if (_is_unspecified_or_auto(arg_axis_resources) or
isinstance(arg_axis_resources, XLACompatibleSharding)):
continue
constrained_dims = [d for d in arg_axis_resources if d is not None]
resource_counts = Counter(it.chain.from_iterable(constrained_dims))
if not resource_counts: continue
if resource_counts.most_common(1)[0][1] > 1:
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
if multiple_uses:
raise ValueError(f"A single {arg_name} specification can map every mesh axis "
f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
f"has duplicate entries for {mesh_lib.show_axes(multiple_uses)}")
# -------------------- pjit rules --------------------
@ -1203,21 +1010,21 @@ def _resolve_in_shardings(
resolved_in_shardings = []
for arg, pjit_in_s in zip(args, pjit_in_shardings):
arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True))
if hasattr(arg, 'sharding') else (_UNSPECIFIED, False))
if _is_unspecified(pjit_in_s):
if _is_unspecified(arg_s):
if hasattr(arg, 'sharding') else (UNSPECIFIED, False))
if is_unspecified(pjit_in_s):
if is_unspecified(arg_s):
resolved_in_shardings.append(arg_s)
else:
if committed:
# If the arg has a PmapSharding, then reshard it unconditionally.
if isinstance(arg_s, PmapSharding):
resolved_in_shardings.append(_UNSPECIFIED)
resolved_in_shardings.append(UNSPECIFIED)
else:
resolved_in_shardings.append(to_gspmd_sharding(
cast(XLACompatibleSharding, arg_s), arg.ndim))
else:
if dispatch.is_single_device_sharding(arg_s):
resolved_in_shardings.append(_UNSPECIFIED)
resolved_in_shardings.append(UNSPECIFIED)
else:
raise NotImplementedError('Having uncommitted Array sharded on '
'multiple devices is not supported.')
@ -1239,7 +1046,7 @@ def _resolve_in_shardings(
'Please see the jax.Array migration guide for more information '
'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. '
f'Got arg shape: {arg.shape}, arg value: {arg}')
if not _is_unspecified(arg_s):
if not is_unspecified(arg_s):
if (committed and
not isinstance(arg_s, PmapSharding) and
not op_shardings.are_op_shardings_equal(
@ -1265,7 +1072,7 @@ def _pjit_call_impl(*args, jaxpr,
args, in_shardings, out_shardings,
resource_env.physical_mesh if resource_env is not None else None)
_allow_propagation_to_outputs = [_is_unspecified(o) for o in out_shardings]
_allow_propagation_to_outputs = [is_unspecified(o) for o in out_shardings]
compiled = _pjit_lower(
jaxpr, in_shardings, out_shardings, resource_env,
donated_invars, name, keep_unused,
@ -1428,8 +1235,8 @@ def _pjit_lower_cached(
def pjit_staging_rule(trace, *args, **params):
if (params["inline"] and
all(_is_unspecified(i) for i in params["in_shardings"]) and
all(_is_unspecified(o) for o in params["out_shardings"])):
all(is_unspecified(i) for i in params["in_shardings"]) and
all(is_unspecified(o) for o in params["out_shardings"])):
jaxpr = params['jaxpr']
return core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
elif config.jax_dynamic_shapes:
@ -1488,9 +1295,9 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
output_types = [mlir.token_type()] * len(effects) + output_types
flat_output_types = util.flatten(output_types)
arg_shardings = [None if _is_unspecified(i) else i._to_xla_op_sharding(aval.ndim)
arg_shardings = [None if is_unspecified(i) else i._to_xla_op_sharding(aval.ndim)
for aval, i in zip(ctx.avals_in, in_shardings)]
result_shardings = [None if _is_unspecified(o) else o._to_xla_op_sharding(aval.ndim)
result_shardings = [None if is_unspecified(o) else o._to_xla_op_sharding(aval.ndim)
for aval, o in zip(ctx.avals_out, out_shardings)]
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
@ -1557,9 +1364,9 @@ batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, False, None)
pxla.spmd_primitive_batchers[pjit_p] = partial(_pjit_batcher, True, None)
def _pjit_batcher_for_sharding(
s: Union[GSPMDSharding, _UnspecifiedValue],
s: Union[GSPMDSharding, UnspecifiedValue],
dim: int, val: Tuple[str, ...], mesh, ndim: int):
if _is_unspecified(s):
if is_unspecified(s):
return s
if not val:
new_op = s._op_sharding.clone() # type: ignore
@ -1631,7 +1438,7 @@ def _pjit_partial_eval(trace, *in_tracers,
def keep_where(l, should_keep):
return tuple(x for x, keep in unsafe_zip(l, should_keep) if keep)
residual_shardings = (_UNSPECIFIED,) * num_residuals
residual_shardings = (UNSPECIFIED,) * num_residuals
# Compute the known outputs
known_params = dict(
jaxpr=known_jaxpr,
@ -1649,7 +1456,7 @@ def _pjit_partial_eval(trace, *in_tracers,
# Only forward the outvars where the out_sharding is UNSPECIFIED.
known_user_out_shardings = keep_where(known_params['out_shardings'], known_outs)
fwds_known_user = [
fwd if _is_unspecified(os) else None
fwd if is_unspecified(os) else None
for os, fwd in zip(known_user_out_shardings,
fwds_known[:len(known_user_out_shardings)])]
fwds_known = fwds_known_user + fwds_known[len(known_user_out_shardings):]
@ -1728,7 +1535,7 @@ def _pjit_partial_eval_custom_params_updater(
if num_res == 0:
residual_shardings = []
else:
residual_shardings = [_UNSPECIFIED] * num_res
residual_shardings = [UNSPECIFIED] * num_res
_, out_shardings_known = pe.partition_list(kept_outs_known, params_known['out_shardings'])
new_params_known = dict(params_known,
in_shardings=tuple(in_shardings_known),
@ -1862,7 +1669,7 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
raise RuntimeError("Changing the physical mesh is not allowed inside pjit.")
for aval, s in zip(jaxpr.in_avals, params['in_shardings']):
if _is_unspecified(s) or is_auto(s):
if is_unspecified(s) or is_auto(s):
continue
elif hasattr(s, '_original_sharding') and hasattr(
s._original_sharding, '_parsed_pspec'):
@ -1884,7 +1691,7 @@ def _resource_typing_pjit(avals, params, source_info, resource_env, named_axis_r
what = "pjit output"
for aval, s in zip(jaxpr.out_avals, params['out_shardings']):
if _is_unspecified(s) or is_auto(s):
if is_unspecified(s) or is_auto(s):
continue
elif hasattr(s, '_original_sharding') and hasattr(
s._original_sharding, '_parsed_pspec'):
@ -1907,9 +1714,9 @@ def _pjit_pp_rule(eqn, context, settings):
del params['inline']
if not any(params['donated_invars']):
del params['donated_invars']
if all(pxla._is_unspecified(s) for s in params['in_shardings']):
if all(is_unspecified(s) for s in params['in_shardings']):
del params['in_shardings']
if all(pxla._is_unspecified(s) for s in params['out_shardings']):
if all(is_unspecified(s) for s in params['out_shardings']):
del params['out_shardings']
if not params['keep_unused']:
del params['keep_unused']
@ -1923,17 +1730,17 @@ core.pp_eqn_rules[pjit_p] = _pjit_pp_rule
# -------------------- with_sharding_constraint --------------------
def _resolve_wsc_args(axis_resources, shardings):
if not _is_unspecified(axis_resources) and not _is_unspecified(shardings):
if not is_unspecified(axis_resources) and not is_unspecified(shardings):
raise ValueError(
'Setting both axis_resources and shardings is not '
'allowed. axis_resources is deprecated. Please use shardings.')
if _is_unspecified(axis_resources) and _is_unspecified(shardings):
if is_unspecified(axis_resources) and is_unspecified(shardings):
raise ValueError(
'Not specifying shardings to `with_sharding_constraint` is not allowed. '
'Please specify the shardings argument with a concrete sharding. Note '
'that axis_resources is deprecated, so use the shardings argument.')
if not _is_unspecified(axis_resources):
if not is_unspecified(axis_resources):
warnings.warn(
'axis_resources is deprecated. Please use shardings argument instead.',
DeprecationWarning)
@ -1946,11 +1753,11 @@ def _resolve_wsc_args(axis_resources, shardings):
# TODO(yashkatariya): Remove the axis_resources argument and make the signature
# `with_sharding_constraint(x, shardings)` with no defaults after deprecation
# period is finished. The deprecation period expires 3 months from Feb 13, 2023.
def with_sharding_constraint(x, shardings=_UNSPECIFIED,
axis_resources=_UNSPECIFIED):
def with_sharding_constraint(x, shardings=UNSPECIFIED,
axis_resources=UNSPECIFIED):
final_shardings = _resolve_wsc_args(axis_resources, shardings)
x_flat, tree = tree_flatten(x)
user_shardings, _, _ = _prepare_axis_resources(
user_shardings, _, _ = prepare_axis_resources(
final_shardings, "shardings", allow_unconstrained_dims=True)
del final_shardings
@ -1996,7 +1803,7 @@ def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
# axis_ctx and manual_axes is *only used with xmap* and xmap only works with
# NamedSharding. So convert the GSPMDSharding to NamedSharding
# and then convert it back with the added special axes.
if isinstance(axis_ctx, mlir.SPMDAxisContext):
if isinstance(axis_ctx, sharding_impls.SPMDAxisContext):
mesh = resource_env.physical_mesh
parsed_pspec = parse_flatten_op_sharding(sharding._op_sharding, mesh)[0]
mps = NamedSharding._from_parsed_pspec(mesh, parsed_pspec)
@ -2054,18 +1861,6 @@ pxla.custom_resource_typing_rules[sharding_constraint_p] = \
# -------------------- helpers --------------------
def get_array_mapping(
axis_resources: Union[ParsedPartitionSpec, _AUTOAxisResource, _UnspecifiedValue]
) -> pxla.ArrayMappingOrAutoOrUnspecified:
# TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported.
# Don't use `is_auto` here to satisfy pytype and mypy.
if isinstance(axis_resources, (_AUTOAxisResource, _UnspecifiedValue)):
return axis_resources
return OrderedDict((axis, i)
for i, axes in enumerate(axis_resources)
if axes is not None for axis in axes)
def to_gspmd_sharding(s: XLACompatibleSharding, ndim: int) -> GSPMDSharding:
if isinstance(s, GSPMDSharding):
return s
@ -2085,7 +1880,7 @@ def _fast_path_get_device_assignment(
shardings: Iterable[PjitSharding]) -> Optional[XLADeviceAssignment]:
da = None
for i in shardings:
if is_auto(i) or _is_unspecified(i):
if is_auto(i) or is_unspecified(i):
continue
da = i._device_assignment # type: ignore
break
@ -2227,11 +2022,8 @@ def parse_flatten_op_sharding(op_sharding: xc.OpSharding,
raise AssertionError("Unhandled OpSharding type. Please open a bug report!")
_get_single_pspec = lambda p: pxla.array_mapping_to_axis_resources(
cast(pxla.ArrayMapping, get_array_mapping(p)))
def _get_partition_spec(ppspec: Sequence[ParsedPartitionSpec]) -> Sequence[PartitionSpec]:
return [_get_single_pspec(p) for p in ppspec]
return [get_single_pspec(p) for p in ppspec]
def _get_op_sharding_from_executable(

View File

@ -14,24 +14,35 @@
from __future__ import annotations
import collections
import dataclasses
import enum
import functools
from collections import Counter
import itertools
import operator as op
from typing import (Any, Sequence, List, Tuple, Optional, Mapping, Dict, Set,
FrozenSet, Union, cast)
import sys
from typing import (Any, Dict, FrozenSet, List, Mapping, Optional, OrderedDict,
NamedTuple, Sequence, Set, Tuple, Union, cast)
from jax._src import mesh as mesh_lib
from jax._src import op_shardings
from jax._src import sharding
from jax._src import sharding_specs
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge
from jax._src.util import safe_map, safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src.interpreters import mlir
from jax._src.partition_spec import PartitionSpec
import numpy as np
if sys.version_info >= (3, 9):
OrderedDictType = OrderedDict
else:
OrderedDictType = Dict
Shape = Tuple[int, ...]
Device = xc.Device
Index = Tuple[slice, ...]
@ -134,7 +145,7 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
'create a device to index mapping for your sharding from which replica '
'ids will be calculated.') from None
index_to_replica: Dict[int, int] = Counter()
index_to_replica: Dict[int, int] = collections.Counter()
out = {}
for device, index in device_indices_map_fn(global_shape).items():
h_index = hashed_index(index)
@ -204,8 +215,7 @@ class NamedSharding(XLACompatibleSharding):
# TODO(yaskatariya): Remove this and replace this with a normalized
# representation of Parsed Pspec
if self._parsed_pspec is None:
from jax._src import pjit
self._parsed_pspec, _, _ = pjit._prepare_axis_resources(
self._parsed_pspec, _, _ = prepare_axis_resources(
self.spec, "NamedSharding spec")
_check_mesh_resource_axis(self.mesh, self._parsed_pspec)
@ -263,9 +273,8 @@ class NamedSharding(XLACompatibleSharding):
def _to_xla_op_sharding(
self,
num_dimensions: int,
axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None
axis_ctx: Optional[Union[SPMDAxisContext, ShardingContext]] = None
) -> xc.OpSharding:
from jax._src.pjit import get_array_mapping
assert self._parsed_pspec is not None
array_mapping = get_array_mapping(self._parsed_pspec)
# TODO(yashkatariya): Move away from sharding spec in NamedSharding
@ -275,7 +284,7 @@ class NamedSharding(XLACompatibleSharding):
# Used in `with_sharding_constraint`.
special_axes = {}
# Manual axes is only used with xmap.
if axis_ctx is not None and isinstance(axis_ctx, mlir.SPMDAxisContext):
if axis_ctx is not None and isinstance(axis_ctx, SPMDAxisContext):
axis_names = self.mesh.axis_names
# Ignore type because mypy doesn't recognize the `hasattr` check above.
for manual_axis in axis_ctx.manual_axes: # type: ignore
@ -634,3 +643,324 @@ class GSPMDSharding(XLACompatibleSharding):
def get_replicated(cls, device_assignment):
proto = get_replicated_op_sharding()
return cls(device_assignment, proto)
class AUTOAxisResource:
pass
AUTO = AUTOAxisResource()
def is_auto(x):
return isinstance(x, AUTOAxisResource)
class UnspecifiedValue:
def __repr__(self):
return "UnspecifiedValue"
UNSPECIFIED = UnspecifiedValue()
def is_unspecified(x):
return isinstance(x, UnspecifiedValue)
def is_unspecified_or_auto(x):
return is_auto(x) or is_unspecified(x)
MeshAxisName = Any
"""
ArrayMapping specifies how an ndarray should map to mesh axes.
Note that the ordering is crucial for the cases when this mapping is non-injective
(i.e. when multiple mesh axes map to the same positional axis). Then, the
order of entries of the mapping determines a major-to-minor order on mesh axes,
according to which chunks of the value along the repeated dimension will be assigned.
For example, consider a mapping {'x': 1, 'y': 1} and a mesh with shape {'x': 2, 'y': 3}.
The second dimension of the value would get chunked into 6 pieces, and assigned to the
mesh in a way that treats 'y' as the fastest changing (minor) dimension. In this case,
that would mean that a flat list of chunks would get assigned to a flattened list of
mesh devices without any modifications. If the mapping was {'y': 1, 'x': 1}, then the
mesh devices ndarray would have to be transposed before flattening and assignment.
"""
ArrayMapping = OrderedDictType[MeshAxisName, int]
ArrayMappingOrAutoOrUnspecified = Union[ArrayMapping, AUTOAxisResource,
UnspecifiedValue]
def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
if not array_mapping:
return PartitionSpec()
max_index = -1
reverse_map = collections.defaultdict(list)
for axis, index in array_mapping.items():
reverse_map[index].append(axis)
if index > max_index:
max_index = index
partitions = tuple(tuple(reverse_map[i]) if reverse_map[i] else None
for i in range(max_index + 1))
return PartitionSpec(*partitions)
def get_array_mapping(
axis_resources: Union[ParsedPartitionSpec, AUTOAxisResource, UnspecifiedValue]
) -> ArrayMappingOrAutoOrUnspecified:
# TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported.
# Don't use `is_auto` here to satisfy pytype and mypy.
if isinstance(axis_resources, (AUTOAxisResource, UnspecifiedValue)):
return axis_resources
return OrderedDict((axis, i)
for i, axes in enumerate(axis_resources)
if axes is not None for axis in axes)
get_single_pspec = lambda p: array_mapping_to_axis_resources(
cast(ArrayMapping, get_array_mapping(p)))
class SpecSync(enum.IntEnum):
"""Encodes how much out of sync the real value of partitions is compared to the user specified one.
We use this to make sure we don't show garbage modified values while claiming
that the users have specified them like that.
"""
OUT_OF_SYNC = 0 # Arbitrary changes, including new axes inserted
DIM_PERMUTE = 1 # Dimensions permuted, but no new sharding axes
IN_SYNC = 2 # Entirely in sync
class ParsedPartitionSpec:
__slots__ = ('unsafe_user_spec', 'partitions', 'sync')
def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC):
self.unsafe_user_spec = user_spec
# None in partitions represents unconstrained dim.
# TODO(yashkatariya): May use a sentinel value.
self.partitions = tuple(partitions)
self.sync = sync
@property
def user_spec(self):
return self.unsynced_user_spec(SpecSync.IN_SYNC)
def get_partition_spec(self) -> PartitionSpec:
if self.sync < SpecSync.IN_SYNC:
return get_single_pspec(self)
else:
if isinstance(self.unsafe_user_spec, PartitionSpec):
return self.unsafe_user_spec
else:
return get_single_pspec(self)
def unsynced_user_spec(self, min_sync):
if self.sync < min_sync:
raise AssertionError(f"Please open a bug report! ({self.sync} >= {min_sync})")
return self.unsafe_user_spec
def insert_axis_partitions(self, dim, val):
parts = self.partitions
too_short = dim - len(parts)
if too_short > 0:
parts += ((),) * too_short
new_partitions = util.tuple_insert(parts, dim, val)
new_sync = SpecSync.DIM_PERMUTE if (val == () or val is None) else SpecSync.OUT_OF_SYNC
return ParsedPartitionSpec(self.unsafe_user_spec, new_partitions, sync=new_sync)
@classmethod
def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False):
if entry is None:
return cls(entry, ())
if not isinstance(entry, PartitionSpec):
raise TypeError(f"{arg_name} are expected to be "
f"PartitionSpec instances or None, but got {entry}")
axis_specs = []
for axis_spec in entry:
if axis_spec is None:
axis_spec = ()
elif isinstance(axis_spec, (list, tuple)):
axis_spec = tuple(axis_spec)
elif axis_spec == PartitionSpec.UNCONSTRAINED:
if not allow_unconstrained_dims:
raise ValueError(f"Unconstrained dims are not allowed: {entry}")
axis_spec = None
else:
axis_spec = (axis_spec,)
axis_specs.append(axis_spec)
return cls(entry, axis_specs)
def __hash__(self):
return hash((self.partitions, self.sync))
def __eq__(self, other):
return (self.partitions == other.partitions and
self.sync == other.sync)
def __len__(self):
return len(self.partitions)
def __getitem__(self, i):
return self.partitions[i]
def __iter__(self):
return iter(self.partitions)
def __repr__(self):
return (f"ParsedPartitionSpec(partitions={self.partitions}, "
f"unsafe_user_spec={self.unsafe_user_spec}, "
f"sync={self.sync})")
class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
"""ParsedPartitionSpecs that are canonicalized.
ParsedPartitionSpecs may contain trailing empty tuples, that make them
semantically different in general, and yet in some situations we prefer
to regard them as equivalent. For example, partitions of () and ((),)
cannot be always considered equivalent, since the first one is a valid
spec for a scalar value, while the second is not! However, when either of
those are applied to a 2D array, they both mean that the array is fully
replicated.
So CanonicalizedParsedPartitionSpecs removes the trailing empty tuples from
partitions.
"""
def __init__(self, parsed_pspec: ParsedPartitionSpec):
partitions = list(parsed_pspec.partitions)
while partitions and partitions[-1] == ():
partitions.pop()
super().__init__(parsed_pspec.unsafe_user_spec, partitions,
parsed_pspec.sync)
def __repr__(self):
return (f"CanonicalizedParsedPartitionSpec(partitions={self.partitions}, "
f"unsafe_user_spec={self.unsafe_user_spec}, "
f"sync={self.sync})")
def check_all_or_none_unspecified(axis_resources, name):
if not axis_resources:
return False
unspecified_count = 0
unspecified = is_unspecified(axis_resources[0])
for resource in axis_resources:
current_is_unspecified = is_unspecified(resource)
if current_is_unspecified:
unspecified_count += 1
assert unspecified_count == 1
if current_is_unspecified != unspecified:
raise ValueError(f'`pjit.UNSPECIFIED` exists in {name}. '
f'Make sure that every entry in {name} is '
'`pjit.UNSPECIFIED`.')
return unspecified
def prepare_axis_resources(axis_resources,
arg_name,
allow_unconstrained_dims=False):
# PyTrees don't treat None values as leaves, so we use an is_leaf function.
entries, treedef = tree_util.tree_flatten(
axis_resources, is_leaf=lambda x: x is None)
what = f"{arg_name} leaf specifications"
# All entries should be specified or if unspecified then there should only
# be 1 entry for that since UNSPECIFIED is a private API.
check_all_or_none_unspecified(entries, arg_name)
new_entries = []
for entry in entries:
if is_unspecified_or_auto(entry):
new_entries.append(entry)
elif isinstance(entry, sharding.Sharding):
if isinstance(entry, PmapSharding):
raise ValueError(f'One of {what} got sharding {entry} which is not '
'allowed.')
if not isinstance(entry, XLACompatibleSharding):
raise ValueError(f'One of {what} got sharding {entry} which is not a '
'subclass of XLACompatibleSharding.')
new_entries.append(entry)
else:
new_entries.append(ParsedPartitionSpec.from_user_input(
entry, what, allow_unconstrained_dims=allow_unconstrained_dims))
_check_unique_resources(new_entries, arg_name)
return tree_util.tree_unflatten(treedef, new_entries), new_entries, treedef
def _check_unique_resources(axis_resources, arg_name):
for arg_axis_resources in axis_resources:
if not arg_axis_resources: continue
if (is_unspecified_or_auto(arg_axis_resources) or
isinstance(arg_axis_resources, XLACompatibleSharding)):
continue
constrained_dims = [d for d in arg_axis_resources if d is not None]
resource_counts = collections.Counter(
itertools.chain.from_iterable(constrained_dims))
if not resource_counts: continue
if resource_counts.most_common(1)[0][1] > 1:
multiple_uses = [r for r, c in resource_counts.items() if c > 1]
if multiple_uses:
raise ValueError(f"A single {arg_name} specification can map every mesh axis "
f"to at most one positional dimension, but {arg_axis_resources.user_spec} "
f"has duplicate entries for {mesh_lib.show_axes(multiple_uses)}")
# Axis environments
class AxisEnv(NamedTuple):
"""Represents a pmap mesh (only along the replica axes)."""
nreps: int
names: Tuple[Any, ...]
sizes: Tuple[int, ...]
@dataclasses.dataclass(frozen=True)
class SPMDAxisContext:
"""A hardware axis context for parallel computations that use the GSPMD partitioner.
This includes the mesh that will later by used to execute this computation,
as well as a set of mesh axes that are currently (e.g. because the current lowering
is invoked inside an xmap) lowered in the MANUAL sharding mode.
"""
mesh: mesh_lib.Mesh
manual_axes: FrozenSet[MeshAxisName] = frozenset()
@property
def axis_env(self):
# All collectives that touch axis_env should remember to set use_global_device_ids
# when this context is enabled!
if self.manual_axes != frozenset(self.mesh.axis_names):
raise NotImplementedError(
"Collectives in manually partitioned computations are only supported "
"when all mesh axes are partitioned manually (no partial automatic sharding). "
"Make sure that you mention all mesh axes in axis_resources!")
return self.unsafe_axis_env
@property
def unsafe_axis_env(self):
return AxisEnv(
nreps=self.mesh.size,
names=self.mesh.axis_names,
sizes=tuple(self.mesh.shape.values()))
def extend_manual(self, axes: FrozenSet[MeshAxisName]) -> SPMDAxisContext:
return SPMDAxisContext(self.mesh, self.manual_axes | axes)
@dataclasses.dataclass(frozen=True)
class ReplicaAxisContext:
"""A hardware axis context for parallel computations that are partitioned by JAX.
Unlike in the SPMDAxisContext, this means that JAX might need to emit calls to
explicit collectives.
"""
axis_env: AxisEnv
@dataclasses.dataclass(frozen=True)
class ShardingContext:
"""A hardware axis context for parallel computations that use the sharding
interface.
This context also uses the GSPMD partitioner.
"""
device_assignment: Sequence[xc.Device]
# Similar to SPMDContext as ShardingContext also uses the GSPMD partitioner.
@property
def axis_env(self):
return AxisEnv(nreps=1, names=(), sizes=())

View File

@ -38,6 +38,7 @@ from typing import Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tu
import jax
from jax._src import core
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util
@ -596,7 +597,7 @@ class Lowered(Stage):
if isinstance(self._lowering, pxla.MeshComputation):
kw.update(
_allow_propagation_to_outputs=[
pxla._is_unspecified(o)
sharding_impls.is_unspecified(o)
for o in self._lowering.compile_args["out_shardings"]
]
)

View File

@ -17,6 +17,7 @@ import inspect
from jax._src import core
from jax import tree_util
from jax._src import linear_util as lu
from jax._src import sharding_impls
from jax.experimental import pjit
from jax.errors import UnexpectedTracerError
from jax._src import mesh as mesh_lib
@ -97,7 +98,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
]
closed_jaxpr = jax.make_jaxpr(
lower_fn, axis_env=list(info.mesh.shape.items()))(*tiled_args)
axis_context = mlir.SPMDAxisContext(info.mesh)
axis_context = sharding_impls.SPMDAxisContext(info.mesh)
built = mlir.build_xla_computation_helper(
closed_jaxpr,
name="tmp_xla_computation",
@ -373,9 +374,9 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
mesh = mesh_lib.thread_resources.env.physical_mesh
axis_context = ctx.module_context.axis_context
if isinstance(axis_context, mlir.ShardingContext):
if isinstance(axis_context, sharding_impls.ShardingContext):
devices = axis_context.device_assignment
elif isinstance(axis_context, mlir.SPMDAxisContext):
elif isinstance(axis_context, sharding_impls.SPMDAxisContext):
devices = list(axis_context.mesh.devices.flat)
else:
devices = None

View File

@ -518,9 +518,9 @@ from jax._src.interpreters import xla
from jax._src import ad_checkpoint
from jax._src import dispatch
from jax._src import pretty_printer as pp
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import util
from jax._src import lib as jaxlib
from jax._src.lib import pytree
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client
@ -1245,8 +1245,10 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext,
result_arrays = ()
return result_arrays
if isinstance(ctx.module_context.axis_context,
(mlir.SPMDAxisContext, mlir.ShardingContext)):
if isinstance(
ctx.module_context.axis_context,
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
):
# Apply maximal sharding so pjit only executes the callback on device device_index.
sharding = xla_client.OpSharding()
sharding.type = xla_client.OpSharding.Type.MAXIMAL
@ -1715,11 +1717,17 @@ def _rewrite_eqn(eqn: core.JaxprEqn, eqns: List[core.JaxprEqn],
eqn.params,
jaxpr=_rewrite_closed_jaxpr(jaxpr, True, True),
donated_invars=eqn.params["donated_invars"] + (False, False),
in_shardings=(eqn.params["in_shardings"] +
(pjit._UNSPECIFIED, pjit._UNSPECIFIED)),
out_shardings=(eqn.params["out_shardings"] +
(pjit._UNSPECIFIED, pjit._UNSPECIFIED)),
)))
in_shardings=(
eqn.params["in_shardings"]
+ (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED)
),
out_shardings=(
eqn.params["out_shardings"]
+ (sharding_impls.UNSPECIFIED, sharding_impls.UNSPECIFIED)
),
),
)
)
elif eqn.primitive is ad_checkpoint.remat_p:
jaxpr_ = cast(core.Jaxpr, eqn.params["jaxpr"])
eqns.append(

View File

@ -49,6 +49,7 @@ from jax._src import dispatch
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import op_shardings
from jax._src import sharding_impls
from jax._src import pjit
from jax._src import prng
from jax._src import random as random_internal
@ -3096,7 +3097,7 @@ def _shard_value(val: TfVal,
sd: sharding.XLACompatibleSharding, *,
skip_replicated_sharding: bool) -> TfVal:
"""Apply sharding to a TfVal."""
if pxla._is_unspecified(sd):
if sharding_impls.is_unspecified(sd):
return val
sharding_proto: xla_client.OpSharding = cast(

View File

@ -28,6 +28,7 @@ from jax import sharding
from jax._src import core
from jax._src import pjit
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import util
from jax._src import xla_bridge as xb
@ -270,7 +271,7 @@ def _add_dim_arg_computation(module: mlir.ir.Module,
with ir.InsertionPoint(entry_block):
orig_main_args: List[mlir.ir.Value] = []
module_context = mlir.ModuleContext(
"cpu", "cpu", mlir.ShardingContext([]),
"cpu", "cpu", sharding_impls.ShardingContext([]),
source_info_util.new_name_stack(),
[], itertools.count(1), [], module=new_module, context=context)
ctx = mlir.LoweringRuleContext(module_context=module_context,
@ -467,7 +468,7 @@ def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported:
map(lambda a: a.at_least_vspace(), primal.out_avals)))
# Expand in_shardings to all in_avals even not kept ones.
all_in_shardings = [pxla._UNSPECIFIED] * len(primal.in_avals)
all_in_shardings = [sharding_impls.UNSPECIFIED] * len(primal.in_avals)
for idx, in_s in zip(sorted(primal.module_kept_var_idx),
primal.in_shardings):
all_in_shardings[idx] = in_s # type: ignore
@ -475,13 +476,13 @@ def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported:
# Cannot mix unspecified and specified shardings. Make the unspecified
# ones replicated.
specified_shardings = [
s for s in all_shardings if not pxla._is_unspecified(s)]
s for s in all_shardings if not sharding_impls.is_unspecified(s)]
vjp_in_shardings: Any # The primal inputs followed by output cotangents
vjp_out_shardings: Any # The primal output cotangents
if 0 == len(specified_shardings):
vjp_in_shardings = pxla._UNSPECIFIED
vjp_out_shardings = pxla._UNSPECIFIED
vjp_in_shardings = sharding_impls.UNSPECIFIED
vjp_out_shardings = sharding_impls.UNSPECIFIED
else:
if len(specified_shardings) < len(all_shardings):
# There are some specified, but not all; pjit front-end does not liwk
@ -489,13 +490,13 @@ def _export_native_vjp(primal_fun_jax, primal: Exported) -> Exported:
assert isinstance(in_s, sharding.XLACompatibleSharding)
replicated_s = sharding.GSPMDSharding.get_replicated(in_s._device_assignment)
all_shardings = [
s if not pxla._is_unspecified(s) else replicated_s
s if not sharding_impls.is_unspecified(s) else replicated_s
for s in all_shardings]
vjp_in_shardings = tuple(all_shardings)
vjp_out_shardings = tuple(all_shardings[:len(primal.in_avals)])
if all(pxla._is_unspecified(s) for s in vjp_out_shardings):
vjp_out_shardings = pxla._UNSPECIFIED
if all(sharding_impls.is_unspecified(s) for s in vjp_out_shardings):
vjp_out_shardings = sharding_impls.UNSPECIFIED
fun_vjp_jax = pjit.pjit(fun_vjp_jax,
in_shardings=vjp_in_shardings,

View File

@ -68,6 +68,7 @@ from jax._src import ad_util
from jax._src import core
from jax._src import dispatch
from jax._src import linear_util as lu
from jax._src import sharding_impls
from jax._src.api_util import shaped_abstractify
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
@ -734,8 +735,12 @@ def _pjit_jet_rule(primals_in, series_in, **params):
**params,
'jaxpr': jaxpr_jet,
'in_shardings': (
params['in_shardings'] + (pjit._UNSPECIFIED,) * num_series_in),
'out_shardings': params['out_shardings'] + (pjit._UNSPECIFIED,) * num_series_out,
params['in_shardings'] + (sharding_impls.UNSPECIFIED,) * num_series_in
),
'out_shardings': (
params['out_shardings']
+ (sharding_impls.UNSPECIFIED,) * num_series_out
),
'donated_invars': params['donated_invars'] + (False,) * num_series_in,
}
result = pjit.pjit_p.bind(*primals_and_series, **new_params)

View File

@ -15,34 +15,36 @@
# flake8: noqa
from jax._src.pjit import (
AUTO as AUTO,
ParsedPartitionSpec as ParsedPartitionSpec,
get_array_mapping as get_array_mapping,
hashable_pytree as hashable_pytree,
parse_flatten_op_sharding as parse_flatten_op_sharding,
pjit as pjit,
pjit_p as pjit_p,
with_sharding_constraint as with_sharding_constraint,
)
from jax._src.sharding_impls import (
AUTO as AUTO,
UNSPECIFIED as _UNSPECIFIED,
ParsedPartitionSpec as ParsedPartitionSpec,
get_array_mapping as get_array_mapping,
prepare_axis_resources as _prepare_axis_resources
)
from jax._src.pjit import (_UNSPECIFIED, _prepare_axis_resources,
_get_op_sharding_from_executable,
from jax._src.pjit import (_get_op_sharding_from_executable,
_get_pspec_from_executable, _pjit_lower_cached,
_pjit_lower, _pjit_jaxpr,
_process_in_axis_resources)
from jax._src.pjit import (
from jax._src.sharding_impls import (
NamedSharding as _deprecated_NamedSharding,
)
from jax._src.partition_spec import (
PartitionSpec as _deprecated_PartitionSpec,
)
import typing
if typing.TYPE_CHECKING:
from jax._src.pjit import (
NamedSharding as NamedSharding,
PartitionSpec as PartitionSpec,
)
from jax._src.sharding_impls import NamedSharding as NamedSharding
from jax._src.partition_spec import PartitionSpec as PartitionSpec
del typing
_deprecations = {

View File

@ -34,6 +34,7 @@ from jax._src import linear_util as lu
from jax._src import ops
from jax._src import pjit
from jax._src import prng
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
@ -47,7 +48,6 @@ from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax.interpreters import ad
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
@ -468,7 +468,9 @@ def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names,
sharded_avals = [v.aval for v in jaxpr.invars]
in_nodes_ = map(partial(_xla_shard, mesh), in_names, ctx.avals_in,
sharded_avals, in_nodes)
new_axis_context = mlir.SPMDAxisContext(mesh, frozenset(mesh.axis_names))
new_axis_context = sharding_impls.SPMDAxisContext(
mesh, frozenset(mesh.axis_names)
)
sub_ctx = ctx.module_context.replace(axis_context=new_axis_context)
with core.extend_axis_env_nd(tuple(mesh.shape.items())):
out_nodes_, _ = mlir.jaxpr_subcomp(sub_ctx, jaxpr, mlir.TokenSet(),
@ -483,8 +485,9 @@ def _xla_shard(mesh, names, aval_in, aval_out, x):
manual_proto = pxla.manual_proto(aval_in, frozenset(mesh.axis_names), mesh)
result_type, = mlir.aval_to_ir_types(aval_out)
axes = {name: i for i, ns in names.items() for name in ns}
shard_proto = NamedSharding(mesh, pxla.array_mapping_to_axis_resources(axes) # type: ignore
)._to_xla_op_sharding(aval_in.ndim)
shard_proto = NamedSharding(
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
)._to_xla_op_sharding(aval_in.ndim)
if core.is_opaque_dtype(aval_in.dtype):
shard_proto = aval_in.dtype._rules.physical_op_sharding(aval_in, shard_proto)
sx = mlir.wrap_with_sharding_op(x, shard_proto, unspecified_dims=set())
@ -496,8 +499,9 @@ def _xla_unshard(mesh, names, aval_in, aval_out, xs):
result_type, = mlir.aval_to_ir_types(aval_out)
sx = mlir.wrap_with_sharding_op(x, manual_proto, unspecified_dims=set())
axes = {name: i for i, ns in names.items() for name in ns}
shard_proto = NamedSharding(mesh, pxla.array_mapping_to_axis_resources(axes) # type: ignore
)._to_xla_op_sharding(aval_out.ndim)
shard_proto = NamedSharding(
mesh, sharding_impls.array_mapping_to_axis_resources(axes) # type: ignore
)._to_xla_op_sharding(aval_out.ndim)
if core.is_opaque_dtype(aval_out.dtype):
shard_proto = aval_out.dtype._rules.physical_op_sharding(aval_out, shard_proto)
return mlir.wrap_with_shard_to_full_op(result_type, sx, shard_proto, set())

View File

@ -53,18 +53,16 @@ from typing import (
import numpy as np
import jax
from jax import lax
from jax._src import core
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import sharding_impls
from jax.experimental.sparse.bcoo import bcoo_multiply_dense, bcoo_multiply_sparse
import jax.numpy as jnp
from jax._src.api_util import flatten_fun_nokwargs
from jax._src.lib import pytree
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax.util import safe_map, safe_zip, split_list
from jax._src.config import config
@ -777,9 +775,13 @@ def _pjit_sparse(spenv, *spvalues, jaxpr, in_shardings, out_shardings,
# TODO(yashkatariya, vanderplas): Flatten twice and set the correct sharding
# for data and indices.
in_shardings = in_shardings + tuple(
pxla._UNSPECIFIED for _ in range(len(args_flat) - len(in_shardings)))
sharding_impls.UNSPECIFIED
for _ in range(len(args_flat) - len(in_shardings))
)
out_shardings = out_shardings + tuple(
pxla._UNSPECIFIED for _ in range(len(sp_call_jaxpr.out_avals) - len(out_shardings)))
sharding_impls.UNSPECIFIED
for _ in range(len(sp_call_jaxpr.out_avals) - len(out_shardings))
)
out_flat = pjit.pjit_p.bind(
*args_flat,

View File

@ -19,14 +19,9 @@ from jax._src.interpreters.mlir import (
LoweringResult as LoweringResult,
LoweringRule as LoweringRule,
LoweringRuleContext as LoweringRuleContext,
Mesh as Mesh,
MeshAxisName as MeshAxisName,
ModuleContext as ModuleContext,
RECV_FROM_HOST_TYPE as RECV_FROM_HOST_TYPE,
ReplicaAxisContext as ReplicaAxisContext,
SEND_TO_HOST_TYPE as SEND_TO_HOST_TYPE,
SPMDAxisContext as SPMDAxisContext,
ShardingContext as ShardingContext,
Token as Token,
TokenSet as TokenSet,
Value as Value,
@ -64,3 +59,11 @@ from jax._src.interpreters.mlir import (
token_type as token_type,
xla_computation_to_mlir_module as xla_computation_to_mlir_module,
)
from jax._src.mesh import Mesh as Mesh
from jax._src.sharding_impls import (
MeshAxisName as MeshAxisName,
ReplicaAxisContext as ReplicaAxisContext,
SPMDAxisContext as SPMDAxisContext,
ShardingContext as ShardingContext,
)

View File

@ -13,9 +13,6 @@
# limitations under the License.
from jax._src.interpreters.pxla import (
AUTO as AUTO,
ArrayMapping as ArrayMapping,
ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified,
AvalDimSharding as AvalDimSharding,
EmapInfo as EmapInfo,
ExecuteReplicated as ExecuteReplicated,
@ -26,7 +23,6 @@ from jax._src.interpreters.pxla import (
MeshComputation as MeshComputation,
MeshDimAssignment as MeshDimAssignment,
MeshExecutable as MeshExecutable,
OrderedDictType as OrderedDictType,
ParallelCallableInfo as ParallelCallableInfo,
PartitionInfo as PartitionInfo,
PartitionsOrReplicated as PartitionsOrReplicated,
@ -43,12 +39,9 @@ from jax._src.interpreters.pxla import (
UnloadedMeshExecutable as UnloadedMeshExecutable,
UnloadedPmapExecutable as UnloadedPmapExecutable,
WeakRefList as WeakRefList,
_UNSPECIFIED as _UNSPECIFIED,
_create_pmap_sharding_spec as _create_pmap_sharding_spec,
_get_and_check_device_assignment as _get_and_check_device_assignment,
_is_unspecified as _is_unspecified,
_pmap_sharding_spec as _pmap_sharding_spec,
array_mapping_to_axis_resources as array_mapping_to_axis_resources,
array_types as array_types,
custom_resource_typing_rules as custom_resource_typing_rules,
device_put as _deprecated_device_put,
@ -103,6 +96,15 @@ from jax._src.op_shardings import (
op_sharding_to_indices as op_sharding_to_indices,
)
from jax._src.sharding_impls import (
ArrayMapping as ArrayMapping,
ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified,
AUTO as AUTO,
UNSPECIFIED as _UNSPECIFIED,
array_mapping_to_axis_resources as array_mapping_to_axis_resources,
is_unspecified as _is_unspecified,
)
from jax._src.sharding_specs import (
Chunked as Chunked,
NoSharding as NoSharding,

View File

@ -13,7 +13,6 @@
# limitations under the License.
from jax._src.interpreters.xla import (
AxisEnv as AxisEnv,
DeviceArray as _deprecated_DeviceArray,
TranslationContext as TranslationContext,
TranslationRule as TranslationRule,
@ -46,6 +45,10 @@ from jax._src.dispatch import (
backend_compile as backend_compile,
)
from jax._src.sharding_impls import (
AxisEnv as AxisEnv,
)
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc # type: ignore

View File

@ -44,10 +44,12 @@ from jax.experimental.custom_partitioning import custom_partitioning
from jax._src import array
from jax._src.sharding import Sharding
from jax._src import op_shardings
from jax._src.sharding_impls import (NamedSharding, GSPMDSharding,
PositionalSharding, SingleDeviceSharding)
from jax._src import sharding_impls
from jax._src.sharding_impls import (
AUTO, UNSPECIFIED, NamedSharding, GSPMDSharding, PositionalSharding,
SingleDeviceSharding)
import jax._src.pjit as pjit_lib
from jax._src.pjit import (pjit, pjit_p, AUTO)
from jax._src.pjit import pjit, pjit_p
from jax._src import mesh
from jax._src.interpreters import pxla
from jax.interpreters import mlir
@ -3519,7 +3521,9 @@ class UtilTest(jtu.JaxTestCase):
("multi_skip", {'x': 0, 'y': 1, 'z': 3}, P(('x',), ('y',), None, ('z',))),
)
def test_array_mapping_to_axis_resources(self, inp, expected_out):
self.assertEqual(pxla.array_mapping_to_axis_resources(inp), expected_out)
self.assertEqual(
sharding_impls.array_mapping_to_axis_resources(inp), expected_out
)
def test_get_input_indices_fully_replicated(self):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
@ -3552,22 +3556,22 @@ class UtilTest(jtu.JaxTestCase):
pxla.mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval, array_mapping)
@parameterized.named_parameters(
("all_unspecified", (pjit_lib._UNSPECIFIED, pjit_lib._UNSPECIFIED), AssertionError),
("only_unspecified", pjit_lib._UNSPECIFIED),
("all_unspecified", (UNSPECIFIED, UNSPECIFIED), AssertionError),
("only_unspecified", UNSPECIFIED),
("all_specified", (P('x'), P('y'))),
("only_specified", P('x')),
("mix_1", (P('x'), pjit_lib._UNSPECIFIED), ValueError),
("mix_2", (P('x'), pjit_lib._UNSPECIFIED, P('y')), ValueError),
("mix_3", (pjit_lib._UNSPECIFIED, P('x'), P('y')), ValueError),
("mix_4", (pjit_lib._UNSPECIFIED, P('x'), pjit_lib._UNSPECIFIED), ValueError),
("mix_1", (P('x'), UNSPECIFIED), ValueError),
("mix_2", (P('x'), UNSPECIFIED, P('y')), ValueError),
("mix_3", (UNSPECIFIED, P('x'), P('y')), ValueError),
("mix_4", (UNSPECIFIED, P('x'), UNSPECIFIED), ValueError),
)
def test_all_or_non_unspecified(self, axis_resources, error=None):
entries, _ = jax.tree_util.tree_flatten(axis_resources, is_leaf=lambda x: x is None)
if error is not None:
with self.assertRaises(error):
pjit_lib._check_all_or_none_unspecified(entries, 'test axis resources')
sharding_impls.check_all_or_none_unspecified(entries, 'test axis resources')
else:
pjit_lib._check_all_or_none_unspecified(entries, 'test axis resources')
sharding_impls.check_all_or_none_unspecified(entries, 'test axis resources')
def test_op_sharding_equality_and_hash_equality(self):
op1 = xc.OpSharding()
@ -3673,7 +3677,6 @@ class UtilTest(jtu.JaxTestCase):
cache_info4 = GSPMDSharding.devices_indices_map.cache_info()
self.assertEqual(cache_info4.hits, cache_info3.hits + 1)
def test_op_sharding_semantically_replicated(self):
op1 = xc.OpSharding()
op1.type = xc.OpSharding.Type.OTHER
@ -3769,8 +3772,8 @@ class UtilTest(jtu.JaxTestCase):
self.assertEqual(recovered_parsed_pspec[0].get_partition_spec(),
P(('x',), ('y',)))
out_of_sync_parsed_pspec = pjit_lib.ParsedPartitionSpec(
P('x', 'y'), ('x', 'y'), pjit_lib.SpecSync.OUT_OF_SYNC)
out_of_sync_parsed_pspec = sharding_impls.ParsedPartitionSpec(
P('x', 'y'), ('x', 'y'), sharding_impls.SpecSync.OUT_OF_SYNC)
self.assertEqual(out_of_sync_parsed_pspec.get_partition_spec(),
P(('x',), ('y',)))

View File

@ -43,6 +43,7 @@ from jax._src import core
from jax import (pmap, jit, vmap, jvp, grad, make_jaxpr,
linearize, device_put)
from jax._src import config as jax_config
from jax._src import sharding_impls
from jax._src import sharding_specs
from jax._src import xla_bridge
from jax._src.lib import xla_extension
@ -1091,7 +1092,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected)
def testAxisGroups(self):
axis_env = xla.AxisEnv(8, ('i', 'j'), (4, 2))
axis_env = sharding_impls.AxisEnv(8, ('i', 'j'), (4, 2))
groups = xla.axis_groups(axis_env, 'i')
self.assertEqual(groups, ((0, 2, 4, 6), (1, 3, 5, 7)))