mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 05:46:06 +00:00
sharding cleanup: use inline checks for unimplemented and auto
This commit is contained in:
parent
bb5fbec64b
commit
8948e6de58
jax
_src
experimental/jax2tf
interpreters
@ -801,9 +801,9 @@ def _export_lowered(
|
||||
nr_devices = len(lowering.compile_args["device_assignment"])
|
||||
def export_sharding(s: LoweringSharding,
|
||||
aval: core.ShapedArray) -> HloSharding | None:
|
||||
if sharding_impls.is_unspecified(s):
|
||||
if isinstance(s, sharding_impls.UnspecifiedValue):
|
||||
return None
|
||||
return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr]
|
||||
return s._to_xla_hlo_sharding(aval.ndim)
|
||||
|
||||
all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"],
|
||||
module_kept_var_idx,
|
||||
|
@ -68,8 +68,8 @@ from jax._src.partition_spec import PartitionSpec
|
||||
from jax._src.sharding import Sharding as JSharding
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
|
||||
UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
|
||||
is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources,
|
||||
UnspecifiedValue, get_array_mapping as _get_array_mapping,
|
||||
array_mapping_to_axis_resources,
|
||||
SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding)
|
||||
from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name,
|
||||
tuple_update, tuple_delete, distributed_debug_log,
|
||||
@ -149,7 +149,7 @@ shard_arg_handlers: dict[
|
||||
|
||||
@lru_cache(maxsize=2048)
|
||||
def is_default_layout(curr_layout, sharding, aval):
|
||||
if curr_layout is None or sharding is None or is_unspecified(sharding):
|
||||
if curr_layout is None or sharding is None or isinstance(sharding, UnspecifiedValue):
|
||||
return True
|
||||
if (aval is core.abstract_token or aval.dtype == dtypes.float0 or
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
@ -1643,7 +1643,7 @@ def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapp
|
||||
def check_if_any_auto(
|
||||
shardings: Iterable[(JSharding | AUTO | UnspecifiedValue)]) -> bool:
|
||||
for s in shardings:
|
||||
if is_auto(s):
|
||||
if isinstance(s, AUTO):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -1727,14 +1727,14 @@ def _get_and_check_device_assignment(
|
||||
devices = tuple(devices)
|
||||
|
||||
for i, s_type, source_info in shardings:
|
||||
if is_unspecified(i):
|
||||
if isinstance(i, UnspecifiedValue):
|
||||
continue
|
||||
|
||||
if first_sharding_info is None:
|
||||
first_sharding_info = (
|
||||
(i.mesh._flat_devices_tuple, s_type, source_info) if is_auto(i) # type: ignore
|
||||
else (i._device_assignment, s_type, source_info)) # type: ignore
|
||||
arr_device_assignment = i.mesh._flat_devices_tuple if is_auto(i) else i._device_assignment # type: ignore
|
||||
(i.mesh._flat_devices_tuple, s_type, source_info) if isinstance(i, AUTO)
|
||||
else (i._device_assignment, s_type, source_info))
|
||||
arr_device_assignment = i.mesh._flat_devices_tuple if isinstance(i, AUTO) else i._device_assignment
|
||||
if not devices:
|
||||
if first_sharding_info[0] != arr_device_assignment:
|
||||
raise DeviceAssignmentMismatchError([
|
||||
@ -1836,7 +1836,8 @@ class SemanticallyEqualShardings:
|
||||
def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...],
|
||||
avals: tuple[core.AbstractValue]):
|
||||
gspmd_shardings = [
|
||||
s if is_unspecified_or_auto(s) else to_gspmd_sharding(s, a.ndim) # type: ignore
|
||||
s if isinstance(s, (UnspecifiedValue, AUTO))
|
||||
else to_gspmd_sharding(s, a.ndim) # pytype: disable=attribute-error
|
||||
for s, a in zip(shardings, avals)]
|
||||
self._gspmd_shardings = gspmd_shardings
|
||||
self.shardings = shardings
|
||||
@ -2004,7 +2005,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
|
||||
except:
|
||||
return True
|
||||
for i in shardings:
|
||||
if is_unspecified_or_auto(i):
|
||||
if isinstance(i, (UnspecifiedValue, AUTO)):
|
||||
continue
|
||||
if i.memory_kind is None: # pytype: disable=attribute-error
|
||||
continue
|
||||
@ -2034,7 +2035,7 @@ def get_out_memory_kinds_via_propagation(closed_jaxpr: core.ClosedJaxpr,
|
||||
if in_shardings is None:
|
||||
invar_mem_kind = [None] * len(jaxpr.invars)
|
||||
else:
|
||||
invar_mem_kind = [None if is_unspecified_or_auto(s) else s.memory_kind
|
||||
invar_mem_kind = [None if isinstance(s, (UnspecifiedValue, AUTO)) else s.memory_kind
|
||||
for s in in_shardings]
|
||||
safe_map(write, jaxpr.invars, invar_mem_kind)
|
||||
safe_map(write, jaxpr.constvars, [None] * len(jaxpr.constvars))
|
||||
@ -2129,7 +2130,7 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment):
|
||||
|
||||
out = []
|
||||
for s, a in zip(shardings, avals):
|
||||
if is_unspecified(s) and a.sharding is not None:
|
||||
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
|
||||
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
|
||||
a.sharding.spec))
|
||||
else:
|
||||
@ -2216,9 +2217,9 @@ def lower_sharding_computation(
|
||||
committed = bool(
|
||||
devices_from_context or
|
||||
len(device_assignment) > 1 or
|
||||
any(not is_unspecified(i) for i in unique_in_shardings) or
|
||||
any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or
|
||||
any(not is_unspecified(o) for o in unique_out_shardings))
|
||||
any(not isinstance(i, UnspecifiedValue) for i in unique_in_shardings) or
|
||||
any(not isinstance(js, UnspecifiedValue) for js, _ in unique_intermediate_shardings) or
|
||||
any(not isinstance(o, UnspecifiedValue) for o in unique_out_shardings))
|
||||
|
||||
da_object = _create_da_object(tuple(device_assignment))
|
||||
|
||||
@ -2690,7 +2691,7 @@ def _maybe_get_and_check_in_shardings(
|
||||
new_in_shardings = []
|
||||
for xla_s, orig, aval in safe_zip(in_shardings_xla, in_shardings,
|
||||
global_in_avals):
|
||||
if is_unspecified(orig):
|
||||
if isinstance(orig, UnspecifiedValue):
|
||||
if (aval is not core.abstract_token and
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
xla_s = sharding_impls.logical_sharding(aval, xla_s)
|
||||
@ -2726,7 +2727,7 @@ def _maybe_get_and_check_out_shardings(
|
||||
new_out_shardings = []
|
||||
for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings,
|
||||
global_out_avals):
|
||||
if is_unspecified(orig):
|
||||
if isinstance(orig, UnspecifiedValue):
|
||||
if (aval is not core.abstract_token and
|
||||
dtypes.issubdtype(aval.dtype, dtypes.extended)):
|
||||
xla_s = sharding_impls.logical_sharding(aval, xla_s)
|
||||
@ -2839,16 +2840,16 @@ class UnloadedMeshExecutable:
|
||||
da = _create_da_object(tuple(device_assignment))
|
||||
del device_assignment
|
||||
|
||||
allow_prop_to_inputs = tuple(is_unspecified(i) or is_auto(i)
|
||||
allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO))
|
||||
for i in in_shardings)
|
||||
allow_prop_to_outputs = tuple(is_unspecified(o) or is_auto(o)
|
||||
allow_prop_to_outputs = tuple(isinstance(o, (UnspecifiedValue, AUTO))
|
||||
for o in out_shardings)
|
||||
|
||||
mesh = None
|
||||
if auto_spmd_lowering:
|
||||
for i in it.chain.from_iterable([in_shardings, out_shardings]):
|
||||
if is_auto(i):
|
||||
mesh = i.mesh # type: ignore
|
||||
if isinstance(i, AUTO):
|
||||
mesh = i.mesh
|
||||
break
|
||||
|
||||
xla_executable = _cached_compilation(
|
||||
@ -2861,9 +2862,9 @@ class UnloadedMeshExecutable:
|
||||
assert mesh is not None
|
||||
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
|
||||
xla_executable, mesh)
|
||||
in_shardings = [x if is_auto(i) else i
|
||||
in_shardings = [x if isinstance(i, AUTO) else i
|
||||
for x, i in safe_zip(in_shardings_xla, in_shardings)]
|
||||
out_shardings = [x if is_auto(o) else o
|
||||
out_shardings = [x if isinstance(o, AUTO) else o
|
||||
for x, o in safe_zip(out_shardings_xla, out_shardings)]
|
||||
else:
|
||||
if pmap_nreps == 1:
|
||||
@ -2954,8 +2955,8 @@ class JitGlobalCppCacheKeys:
|
||||
self.donate_argnames is not None or
|
||||
self.device is not None or
|
||||
self.backend is not None or
|
||||
any(not is_unspecified(i) for i in self.in_shardings_leaves) or
|
||||
any(not is_unspecified(o) for o in self.out_shardings_leaves) or
|
||||
any(not isinstance(i, UnspecifiedValue) for i in self.in_shardings_leaves) or
|
||||
any(not isinstance(o, UnspecifiedValue) for o in self.out_shardings_leaves) or
|
||||
any(i is not None for i in self.in_layouts_leaves) or
|
||||
any(o is not None for o in self.out_layouts_leaves))
|
||||
|
||||
@ -3130,7 +3131,7 @@ create_mesh_pspec_sharding = sharding_impls.create_mesh_pspec_sharding
|
||||
|
||||
def check_device_backend_on_shardings(shardings) -> bool:
|
||||
for i in shardings:
|
||||
if is_unspecified(i) or is_auto(i):
|
||||
if isinstance(i, (UnspecifiedValue, AUTO)):
|
||||
continue
|
||||
if getattr(i, '_device_backend', False):
|
||||
return True
|
||||
@ -3156,7 +3157,7 @@ def check_array_xla_sharding_layout_match(
|
||||
args_after_dce, in_xla_shardings, in_xla_layouts, arg_names):
|
||||
if not isinstance(arg, ArrayImpl):
|
||||
continue
|
||||
if is_unspecified_or_auto(xs):
|
||||
if isinstance(xs, (UnspecifiedValue, AUTO)):
|
||||
continue
|
||||
|
||||
db_xs = check_device_backend_on_shardings([xs])
|
||||
|
@ -19,7 +19,7 @@ from typing import Union
|
||||
import numpy as np
|
||||
from jax._src.dtypes import iinfo, issubdtype
|
||||
from jax._src.sharding import Sharding
|
||||
from jax._src.sharding_impls import AUTO as AutoSharding, is_auto
|
||||
from jax._src.sharding_impls import AUTO as AutoSharding
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
Shape = tuple[int, ...]
|
||||
@ -101,7 +101,7 @@ class Layout:
|
||||
sharding: ShardingOptions = None):
|
||||
# If layout is concrete and sharding is not, error.
|
||||
if (isinstance(device_local_layout, DeviceLocalLayout) and
|
||||
(sharding is None or is_auto(sharding))):
|
||||
(sharding is None or isinstance(sharding, AutoSharding))):
|
||||
raise ValueError(
|
||||
'Sharding has to be concrete when layout is of type'
|
||||
f' {type(device_local_layout)}. Please pass a'
|
||||
|
@ -67,8 +67,7 @@ from jax._src.mesh import AbstractMesh
|
||||
from jax._src.sharding_impls import (
|
||||
NamedSharding, GSPMDSharding,
|
||||
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
|
||||
ParsedPartitionSpec, get_single_pspec, is_unspecified,
|
||||
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
|
||||
ParsedPartitionSpec, get_single_pspec, prepare_axis_resources, parse_flatten_op_sharding)
|
||||
from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout
|
||||
from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef
|
||||
from jax._src.traceback_util import api_boundary
|
||||
@ -418,10 +417,10 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
if device is not None and backend is not None:
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
f"got {device=} and {backend=}")
|
||||
if in_shardings is not None and not is_unspecified(in_shardings):
|
||||
if in_shardings is not None and not isinstance(in_shardings, UnspecifiedValue):
|
||||
raise ValueError('If backend or device is specified on jit, then '
|
||||
'in_shardings should not be specified.')
|
||||
if out_shardings is not None and not is_unspecified(out_shardings):
|
||||
if out_shardings is not None and not isinstance(out_shardings, UnspecifiedValue):
|
||||
raise ValueError('If backend or device is specified on jit, then '
|
||||
'out_shardings should not be specified.')
|
||||
|
||||
@ -440,7 +439,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
out_shardings = prepare_axis_resources(out_shardings, 'out_shardings')
|
||||
|
||||
user_specified_in_shardings = (in_shardings is not None and
|
||||
not is_unspecified(in_shardings))
|
||||
not isinstance(in_shardings, UnspecifiedValue))
|
||||
|
||||
in_shardings_leaves, in_shardings_treedef = none_lr.flatten(in_shardings)
|
||||
out_shardings_leaves, out_shardings_treedef = none_lr.flatten(out_shardings)
|
||||
@ -483,7 +482,7 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
|
||||
@api_boundary
|
||||
def eval_shape(*args, **kwargs):
|
||||
p, _ = _infer_params(fun, jit_info, args, kwargs)
|
||||
out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']]
|
||||
out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']]
|
||||
# TODO(yashkatariya): Add `Layout` to SDS.
|
||||
out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s,
|
||||
weak_type=x.weak_type)
|
||||
@ -1001,7 +1000,7 @@ def hashable_pytree(pytree):
|
||||
def _create_sharding_for_array(mesh, x, name, api_name):
|
||||
if x is None and (mesh is None or mesh.empty):
|
||||
return UNSPECIFIED
|
||||
if isinstance(x, sharding.Sharding) or is_unspecified_or_auto(x):
|
||||
if isinstance(x, (AUTO, UnspecifiedValue, sharding.Sharding)):
|
||||
return x
|
||||
if mesh is None:
|
||||
msg = ('jax.jit only supports `Sharding`s being passed to'
|
||||
@ -1110,7 +1109,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
|
||||
orig_in_shardings = tree_unflatten(in_shardings_treedef, in_shardings_leaves)
|
||||
# Only do this if original in_shardings are unspecified. If it is AUTO, go
|
||||
# via flatten_axis_resources.
|
||||
if is_unspecified(orig_in_shardings):
|
||||
if isinstance(orig_in_shardings, UnspecifiedValue):
|
||||
in_shardings_flat = (orig_in_shardings,) * len(in_avals)
|
||||
else:
|
||||
in_shardings_flat = flatten_axis_resources(
|
||||
@ -1312,8 +1311,7 @@ def _check_and_canonicalize_out_shardings(
|
||||
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
|
||||
out_layouts_leaves, out_tree, out_avals, debug_info, device_or_backend_set):
|
||||
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
|
||||
if (is_unspecified(orig_out_shardings) or
|
||||
isinstance(orig_out_shardings, sharding.Sharding)):
|
||||
if isinstance(orig_out_shardings, (UnspecifiedValue, sharding.Sharding)):
|
||||
out_shardings_flat = (orig_out_shardings,) * len(out_avals)
|
||||
else:
|
||||
out_shardings_flat = flatten_axis_resources(
|
||||
@ -1391,7 +1389,7 @@ def pjit_check_aval_sharding(
|
||||
what_aval: str, allow_uneven_sharding: bool):
|
||||
new_names = [''] * len(shardings) if names is None else names
|
||||
for aval, s, name in zip(flat_avals, shardings, new_names):
|
||||
if is_unspecified_or_auto(s):
|
||||
if isinstance(s, (UnspecifiedValue, AUTO)):
|
||||
continue
|
||||
name_str = f' with pytree key path {name}' if name else ''
|
||||
shape = aval.shape
|
||||
@ -1466,7 +1464,7 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
|
||||
else:
|
||||
arg_layout, dispatch_arg_layout = None, None
|
||||
# Sharding can be unspecified when array is committed if it's a PmapSharding.
|
||||
is_pmap_sharding = (is_unspecified(rs) or
|
||||
is_pmap_sharding = (isinstance(rs, UnspecifiedValue) or
|
||||
isinstance(getattr(arg, 'sharding', None), PmapSharding))
|
||||
if jit_in_l is None:
|
||||
if committed:
|
||||
@ -1527,15 +1525,15 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
|
||||
if getattr(a, '_committed', True):
|
||||
committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None))
|
||||
|
||||
resolved_in_shardings = []
|
||||
resolved_in_shardings: list[PjitSharding] = []
|
||||
for arg, pjit_in_s in zip(args, pjit_in_shardings):
|
||||
# arg sharding can be None in case of ShapeDtypeStruct. jax.Array does
|
||||
# not allow None as the sharding.
|
||||
arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True))
|
||||
if hasattr(arg, 'sharding') and arg.sharding is not None
|
||||
else (UNSPECIFIED, False))
|
||||
if is_unspecified(pjit_in_s):
|
||||
if is_unspecified(arg_s):
|
||||
if isinstance(pjit_in_s, UnspecifiedValue):
|
||||
if isinstance(arg_s, UnspecifiedValue):
|
||||
resolved_in_shardings.append(arg_s)
|
||||
else:
|
||||
if committed:
|
||||
@ -1553,7 +1551,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
|
||||
'multiple devices is not supported.')
|
||||
else:
|
||||
if (isinstance(arg, np.ndarray) and
|
||||
not pjit_in_s.is_fully_replicated and # type: ignore
|
||||
not pjit_in_s.is_fully_replicated and # type: ignore[union-attr]
|
||||
xb.process_count() > 1):
|
||||
raise ValueError(
|
||||
'Passing non-trivial shardings for numpy '
|
||||
@ -1572,16 +1570,16 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
|
||||
# jax.jit does not allow resharding across different memory kinds even
|
||||
# if the argument is uncommitted. Use jax.device_put for those cases,
|
||||
# either outside or inside jax.jit.
|
||||
if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore
|
||||
if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore[union-attr]
|
||||
raise ValueError(
|
||||
'Memory kinds passed to jax.jit does not match memory kind on the'
|
||||
f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore
|
||||
f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore[union-attr]
|
||||
f'arg memory kind: {arg_s.memory_kind} for '
|
||||
f'arg shape: {shaped_abstractify(arg).str_short()}')
|
||||
if (committed and
|
||||
not isinstance(arg_s, PmapSharding) and
|
||||
not op_shardings.are_op_shardings_equal(
|
||||
pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore
|
||||
pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore[union-attr]
|
||||
arg_s._to_xla_hlo_sharding(arg.ndim))):
|
||||
raise ValueError('Sharding passed to pjit does not match the sharding '
|
||||
'on the respective arg. '
|
||||
@ -1780,8 +1778,8 @@ def pjit_staging_rule(trace, *args, **params):
|
||||
params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings,
|
||||
out_layouts=out_layouts)
|
||||
if (params["inline"] and
|
||||
all(is_unspecified(i) for i in params["in_shardings"]) and
|
||||
all(is_unspecified(o) for o in params["out_shardings"]) and
|
||||
all(isinstance(i, UnspecifiedValue) for i in params["in_shardings"]) and
|
||||
all(isinstance(o, UnspecifiedValue) for o in params["out_shardings"]) and
|
||||
all(i is None for i in params["in_layouts"]) and
|
||||
all(o is None for o in params["out_layouts"])):
|
||||
if config.dynamic_shapes.value:
|
||||
@ -1830,7 +1828,7 @@ pe.custom_staging_rules[pjit_p] = pjit_staging_rule
|
||||
|
||||
def _pjit_forwarding(jaxpr, out_shardings, out_layouts):
|
||||
in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr)
|
||||
in_fwd = [fwd if is_unspecified(os) and ol is None else None for fwd, os, ol
|
||||
in_fwd = [fwd if isinstance(os, UnspecifiedValue) and ol is None else None for fwd, os, ol
|
||||
in zip(in_fwd, out_shardings, out_layouts)]
|
||||
keep = [f is None for f in in_fwd]
|
||||
jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep)
|
||||
@ -1896,8 +1894,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
|
||||
|
||||
func = mod_ctx.cached_primitive_lowerings.get(key, None)
|
||||
if func is None:
|
||||
arg_shardings = [None if is_unspecified(i) else i for i in in_shardings]
|
||||
result_shardings = [None if is_unspecified(o) else o for o in out_shardings]
|
||||
arg_shardings = [None if isinstance(i, UnspecifiedValue) else i for i in in_shardings]
|
||||
result_shardings = [None if isinstance(o, UnspecifiedValue) else o for o in out_shardings]
|
||||
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
|
||||
# inputs or outputs because they are lost during MLIR->HLO conversion.
|
||||
# using_sharding_annotation=False means we add an identity operation instead.
|
||||
@ -1990,9 +1988,9 @@ batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, None)
|
||||
def _pjit_batcher_for_sharding(
|
||||
s: sharding.Sharding | UnspecifiedValue,
|
||||
dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int):
|
||||
if is_unspecified(s):
|
||||
if isinstance(s, UnspecifiedValue):
|
||||
return s
|
||||
hlo_s = s._to_xla_hlo_sharding(ndim) # type: ignore
|
||||
hlo_s = s._to_xla_hlo_sharding(ndim)
|
||||
if spmd_axis_name is None:
|
||||
if sharding_impls.is_op_sharding_replicated(hlo_s):
|
||||
return s
|
||||
@ -2004,7 +2002,7 @@ def _pjit_batcher_for_sharding(
|
||||
tad.insert(dim, 1)
|
||||
new_op.tile_assignment_dimensions = tad
|
||||
new_gs = GSPMDSharding(
|
||||
s._device_assignment, new_op, # type: ignore
|
||||
s._device_assignment, new_op,
|
||||
_device_list=getattr(s, '_internal_device_list', None))
|
||||
return pxla._get_out_sharding_from_orig_sharding([new_gs], [None], s, None)[0]
|
||||
else:
|
||||
@ -2107,7 +2105,7 @@ def _pjit_partial_eval(trace, *in_tracers,
|
||||
# Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
|
||||
in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
|
||||
in_fwd = [
|
||||
fwd if is_unspecified(os) and ol is None else None
|
||||
fwd if isinstance(os, UnspecifiedValue) and ol is None else None
|
||||
for os, ol, fwd in zip(
|
||||
keep_where(out_shardings, known_outs),
|
||||
keep_where(out_layouts, known_outs), in_fwd_primal)
|
||||
@ -2358,9 +2356,9 @@ def _pjit_pp_rule(eqn, context, settings):
|
||||
del params['inline']
|
||||
if not any(params['donated_invars']):
|
||||
del params['donated_invars']
|
||||
if all(is_unspecified(s) for s in params['in_shardings']):
|
||||
if all(isinstance(s, UnspecifiedValue) for s in params['in_shardings']):
|
||||
del params['in_shardings']
|
||||
if all(is_unspecified(s) for s in params['out_shardings']):
|
||||
if all(isinstance(s, UnspecifiedValue) for s in params['out_shardings']):
|
||||
del params['out_shardings']
|
||||
if all(l is None for l in params['in_layouts']):
|
||||
del params['in_layouts']
|
||||
@ -2382,8 +2380,7 @@ core.pp_eqn_rules[pjit_p] = _pjit_pp_rule
|
||||
def _pjit_state_discharge_rule(
|
||||
in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, **params):
|
||||
if not (all(map(is_unspecified, in_shardings)) and
|
||||
all(map(is_unspecified, out_shardings))):
|
||||
if not all(isinstance(s, UnspecifiedValue) for s in (*in_shardings, *out_shardings)):
|
||||
raise NotImplementedError
|
||||
|
||||
if not (all(l is None for l in in_layouts) and
|
||||
|
@ -957,21 +957,11 @@ class AUTO:
|
||||
return SdyArraySharding(self.mesh.shape_tuple, dim_shardings)
|
||||
|
||||
|
||||
def is_auto(x):
|
||||
return isinstance(x, AUTO)
|
||||
|
||||
|
||||
class UnspecifiedValue:
|
||||
def __repr__(self):
|
||||
return "UnspecifiedValue"
|
||||
UNSPECIFIED = UnspecifiedValue()
|
||||
|
||||
def is_unspecified(x):
|
||||
return isinstance(x, UnspecifiedValue)
|
||||
|
||||
def is_unspecified_or_auto(x):
|
||||
return is_auto(x) or is_unspecified(x)
|
||||
|
||||
|
||||
MeshAxisName = Any
|
||||
|
||||
@ -1014,8 +1004,6 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
|
||||
def get_array_mapping(
|
||||
axis_resources: ParsedPartitionSpec | AUTO | UnspecifiedValue
|
||||
) -> ArrayMappingOrAutoOrUnspecified:
|
||||
# TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported.
|
||||
# Don't use `is_auto` here to satisfy pytype and mypy.
|
||||
if isinstance(axis_resources, (AUTO, UnspecifiedValue)):
|
||||
return axis_resources
|
||||
return OrderedDict((axis, i)
|
||||
@ -1113,7 +1101,7 @@ def prepare_axis_resources(axis_resources, arg_name,
|
||||
|
||||
new_entries = []
|
||||
for entry in entries:
|
||||
if is_unspecified_or_auto(entry) or entry is None:
|
||||
if isinstance(entry, (UnspecifiedValue, AUTO)) or entry is None:
|
||||
new_entries.append(entry)
|
||||
elif isinstance(entry, sharding.Sharding):
|
||||
if isinstance(entry, PmapSharding):
|
||||
@ -1131,8 +1119,7 @@ def prepare_axis_resources(axis_resources, arg_name,
|
||||
def _check_unique_resources(axis_resources, arg_name):
|
||||
for arg_axis_resources in axis_resources:
|
||||
if not arg_axis_resources: continue
|
||||
if (is_unspecified_or_auto(arg_axis_resources) or
|
||||
isinstance(arg_axis_resources, sharding.Sharding)):
|
||||
if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, sharding.Sharding)):
|
||||
continue
|
||||
constrained_dims = [d for d in arg_axis_resources if d is not None]
|
||||
resource_counts = collections.Counter(
|
||||
|
@ -43,7 +43,7 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src.sharding_impls import is_unspecified_or_auto
|
||||
from jax._src.sharding_impls import UnspecifiedValue, AUTO
|
||||
from jax._src.layout import Layout
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
@ -649,7 +649,7 @@ class Lowered(Stage):
|
||||
out_avals = self._lowering.compile_args["global_out_avals"]
|
||||
out_shardings = self._lowering.compile_args["out_shardings"]
|
||||
return self.out_tree.unflatten(
|
||||
[OutInfo(o.shape, o.dtype, None if is_unspecified_or_auto(s) else s)
|
||||
[OutInfo(o.shape, o.dtype, None if isinstance(s, (UnspecifiedValue, AUTO)) else s)
|
||||
for o, s in zip(out_avals, out_shardings)])
|
||||
|
||||
def compile(
|
||||
|
@ -3537,7 +3537,7 @@ def split_to_logical_devices(tensor: TfVal,
|
||||
def _xla_compatible_sharding_to_hlo_sharding(
|
||||
s: sharding.Sharding,
|
||||
aval: core.ShapedArray) -> xla_client.HloSharding | None:
|
||||
if sharding_impls.is_unspecified(s):
|
||||
if isinstance(s, sharding_impls.UnspecifiedValue):
|
||||
return None
|
||||
return s._to_xla_hlo_sharding(aval.ndim)
|
||||
|
||||
|
@ -40,7 +40,6 @@ from jax._src.sharding_impls import (
|
||||
ArrayMapping as ArrayMapping,
|
||||
UNSPECIFIED as _UNSPECIFIED, # noqa: F401
|
||||
array_mapping_to_axis_resources as array_mapping_to_axis_resources,
|
||||
is_unspecified as _is_unspecified, # noqa: F401
|
||||
)
|
||||
|
||||
from jax._src.sharding_specs import (
|
||||
|
Loading…
x
Reference in New Issue
Block a user