1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 05:46:06 +00:00

sharding cleanup: use inline checks for unimplemented and auto

This commit is contained in:
Jake VanderPlas 2024-10-23 15:52:20 -07:00
parent bb5fbec64b
commit 8948e6de58
8 changed files with 66 additions and 82 deletions
jax
_src
experimental/jax2tf
interpreters

@ -801,9 +801,9 @@ def _export_lowered(
nr_devices = len(lowering.compile_args["device_assignment"])
def export_sharding(s: LoweringSharding,
aval: core.ShapedArray) -> HloSharding | None:
if sharding_impls.is_unspecified(s):
if isinstance(s, sharding_impls.UnspecifiedValue):
return None
return s._to_xla_hlo_sharding(aval.ndim) # type: ignore[union-attr]
return s._to_xla_hlo_sharding(aval.ndim)
all_in_shardings = expand_in_shardings(lowering.compile_args["in_shardings"],
module_kept_var_idx,

@ -68,8 +68,8 @@ from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources,
UnspecifiedValue, get_array_mapping as _get_array_mapping,
array_mapping_to_axis_resources,
SingleDeviceSharding, GSPMDSharding, NamedSharding, PositionalSharding)
from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name,
tuple_update, tuple_delete, distributed_debug_log,
@ -149,7 +149,7 @@ shard_arg_handlers: dict[
@lru_cache(maxsize=2048)
def is_default_layout(curr_layout, sharding, aval):
if curr_layout is None or sharding is None or is_unspecified(sharding):
if curr_layout is None or sharding is None or isinstance(sharding, UnspecifiedValue):
return True
if (aval is core.abstract_token or aval.dtype == dtypes.float0 or
dtypes.issubdtype(aval.dtype, dtypes.extended)):
@ -1643,7 +1643,7 @@ def _shard_to_full_lowering(ctx: mlir.LoweringRuleContext, x, *, axes: ArrayMapp
def check_if_any_auto(
shardings: Iterable[(JSharding | AUTO | UnspecifiedValue)]) -> bool:
for s in shardings:
if is_auto(s):
if isinstance(s, AUTO):
return True
return False
@ -1727,14 +1727,14 @@ def _get_and_check_device_assignment(
devices = tuple(devices)
for i, s_type, source_info in shardings:
if is_unspecified(i):
if isinstance(i, UnspecifiedValue):
continue
if first_sharding_info is None:
first_sharding_info = (
(i.mesh._flat_devices_tuple, s_type, source_info) if is_auto(i) # type: ignore
else (i._device_assignment, s_type, source_info)) # type: ignore
arr_device_assignment = i.mesh._flat_devices_tuple if is_auto(i) else i._device_assignment # type: ignore
(i.mesh._flat_devices_tuple, s_type, source_info) if isinstance(i, AUTO)
else (i._device_assignment, s_type, source_info))
arr_device_assignment = i.mesh._flat_devices_tuple if isinstance(i, AUTO) else i._device_assignment
if not devices:
if first_sharding_info[0] != arr_device_assignment:
raise DeviceAssignmentMismatchError([
@ -1836,7 +1836,8 @@ class SemanticallyEqualShardings:
def __init__(self, shardings: tuple[GSPMDSharding | UnspecifiedValue, ...],
avals: tuple[core.AbstractValue]):
gspmd_shardings = [
s if is_unspecified_or_auto(s) else to_gspmd_sharding(s, a.ndim) # type: ignore
s if isinstance(s, (UnspecifiedValue, AUTO))
else to_gspmd_sharding(s, a.ndim) # pytype: disable=attribute-error
for s, a in zip(shardings, avals)]
self._gspmd_shardings = gspmd_shardings
self.shardings = shardings
@ -2004,7 +2005,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings):
except:
return True
for i in shardings:
if is_unspecified_or_auto(i):
if isinstance(i, (UnspecifiedValue, AUTO)):
continue
if i.memory_kind is None: # pytype: disable=attribute-error
continue
@ -2034,7 +2035,7 @@ def get_out_memory_kinds_via_propagation(closed_jaxpr: core.ClosedJaxpr,
if in_shardings is None:
invar_mem_kind = [None] * len(jaxpr.invars)
else:
invar_mem_kind = [None if is_unspecified_or_auto(s) else s.memory_kind
invar_mem_kind = [None if isinstance(s, (UnspecifiedValue, AUTO)) else s.memory_kind
for s in in_shardings]
safe_map(write, jaxpr.invars, invar_mem_kind)
safe_map(write, jaxpr.constvars, [None] * len(jaxpr.constvars))
@ -2129,7 +2130,7 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment):
out = []
for s, a in zip(shardings, avals):
if is_unspecified(s) and a.sharding is not None:
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
a.sharding.spec))
else:
@ -2216,9 +2217,9 @@ def lower_sharding_computation(
committed = bool(
devices_from_context or
len(device_assignment) > 1 or
any(not is_unspecified(i) for i in unique_in_shardings) or
any(not is_unspecified(js) for js, _ in unique_intermediate_shardings) or
any(not is_unspecified(o) for o in unique_out_shardings))
any(not isinstance(i, UnspecifiedValue) for i in unique_in_shardings) or
any(not isinstance(js, UnspecifiedValue) for js, _ in unique_intermediate_shardings) or
any(not isinstance(o, UnspecifiedValue) for o in unique_out_shardings))
da_object = _create_da_object(tuple(device_assignment))
@ -2690,7 +2691,7 @@ def _maybe_get_and_check_in_shardings(
new_in_shardings = []
for xla_s, orig, aval in safe_zip(in_shardings_xla, in_shardings,
global_in_avals):
if is_unspecified(orig):
if isinstance(orig, UnspecifiedValue):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = sharding_impls.logical_sharding(aval, xla_s)
@ -2726,7 +2727,7 @@ def _maybe_get_and_check_out_shardings(
new_out_shardings = []
for xla_s, orig, aval in safe_zip(out_shardings_xla, out_shardings,
global_out_avals):
if is_unspecified(orig):
if isinstance(orig, UnspecifiedValue):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = sharding_impls.logical_sharding(aval, xla_s)
@ -2839,16 +2840,16 @@ class UnloadedMeshExecutable:
da = _create_da_object(tuple(device_assignment))
del device_assignment
allow_prop_to_inputs = tuple(is_unspecified(i) or is_auto(i)
allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO))
for i in in_shardings)
allow_prop_to_outputs = tuple(is_unspecified(o) or is_auto(o)
allow_prop_to_outputs = tuple(isinstance(o, (UnspecifiedValue, AUTO))
for o in out_shardings)
mesh = None
if auto_spmd_lowering:
for i in it.chain.from_iterable([in_shardings, out_shardings]):
if is_auto(i):
mesh = i.mesh # type: ignore
if isinstance(i, AUTO):
mesh = i.mesh
break
xla_executable = _cached_compilation(
@ -2861,9 +2862,9 @@ class UnloadedMeshExecutable:
assert mesh is not None
in_shardings_xla, out_shardings_xla = _get_mesh_pspec_shardings_from_executable(
xla_executable, mesh)
in_shardings = [x if is_auto(i) else i
in_shardings = [x if isinstance(i, AUTO) else i
for x, i in safe_zip(in_shardings_xla, in_shardings)]
out_shardings = [x if is_auto(o) else o
out_shardings = [x if isinstance(o, AUTO) else o
for x, o in safe_zip(out_shardings_xla, out_shardings)]
else:
if pmap_nreps == 1:
@ -2954,8 +2955,8 @@ class JitGlobalCppCacheKeys:
self.donate_argnames is not None or
self.device is not None or
self.backend is not None or
any(not is_unspecified(i) for i in self.in_shardings_leaves) or
any(not is_unspecified(o) for o in self.out_shardings_leaves) or
any(not isinstance(i, UnspecifiedValue) for i in self.in_shardings_leaves) or
any(not isinstance(o, UnspecifiedValue) for o in self.out_shardings_leaves) or
any(i is not None for i in self.in_layouts_leaves) or
any(o is not None for o in self.out_layouts_leaves))
@ -3130,7 +3131,7 @@ create_mesh_pspec_sharding = sharding_impls.create_mesh_pspec_sharding
def check_device_backend_on_shardings(shardings) -> bool:
for i in shardings:
if is_unspecified(i) or is_auto(i):
if isinstance(i, (UnspecifiedValue, AUTO)):
continue
if getattr(i, '_device_backend', False):
return True
@ -3156,7 +3157,7 @@ def check_array_xla_sharding_layout_match(
args_after_dce, in_xla_shardings, in_xla_layouts, arg_names):
if not isinstance(arg, ArrayImpl):
continue
if is_unspecified_or_auto(xs):
if isinstance(xs, (UnspecifiedValue, AUTO)):
continue
db_xs = check_device_backend_on_shardings([xs])

@ -19,7 +19,7 @@ from typing import Union
import numpy as np
from jax._src.dtypes import iinfo, issubdtype
from jax._src.sharding import Sharding
from jax._src.sharding_impls import AUTO as AutoSharding, is_auto
from jax._src.sharding_impls import AUTO as AutoSharding
from jax._src.lib import xla_client as xc
Shape = tuple[int, ...]
@ -101,7 +101,7 @@ class Layout:
sharding: ShardingOptions = None):
# If layout is concrete and sharding is not, error.
if (isinstance(device_local_layout, DeviceLocalLayout) and
(sharding is None or is_auto(sharding))):
(sharding is None or isinstance(sharding, AutoSharding))):
raise ValueError(
'Sharding has to be concrete when layout is of type'
f' {type(device_local_layout)}. Please pass a'

@ -67,8 +67,7 @@ from jax._src.mesh import AbstractMesh
from jax._src.sharding_impls import (
NamedSharding, GSPMDSharding,
SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue,
ParsedPartitionSpec, get_single_pspec, is_unspecified,
is_unspecified_or_auto, prepare_axis_resources, parse_flatten_op_sharding)
ParsedPartitionSpec, get_single_pspec, prepare_axis_resources, parse_flatten_op_sharding)
from jax._src.layout import Layout, DeviceLocalLayout, AutoLayout
from jax._src.state import discharge as state_discharge, RefEffect, AbstractRef
from jax._src.traceback_util import api_boundary
@ -418,10 +417,10 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
f"got {device=} and {backend=}")
if in_shardings is not None and not is_unspecified(in_shardings):
if in_shardings is not None and not isinstance(in_shardings, UnspecifiedValue):
raise ValueError('If backend or device is specified on jit, then '
'in_shardings should not be specified.')
if out_shardings is not None and not is_unspecified(out_shardings):
if out_shardings is not None and not isinstance(out_shardings, UnspecifiedValue):
raise ValueError('If backend or device is specified on jit, then '
'out_shardings should not be specified.')
@ -440,7 +439,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
out_shardings = prepare_axis_resources(out_shardings, 'out_shardings')
user_specified_in_shardings = (in_shardings is not None and
not is_unspecified(in_shardings))
not isinstance(in_shardings, UnspecifiedValue))
in_shardings_leaves, in_shardings_treedef = none_lr.flatten(in_shardings)
out_shardings_leaves, out_shardings_treedef = none_lr.flatten(out_shardings)
@ -483,7 +482,7 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo):
@api_boundary
def eval_shape(*args, **kwargs):
p, _ = _infer_params(fun, jit_info, args, kwargs)
out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']]
out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']]
# TODO(yashkatariya): Add `Layout` to SDS.
out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s,
weak_type=x.weak_type)
@ -1001,7 +1000,7 @@ def hashable_pytree(pytree):
def _create_sharding_for_array(mesh, x, name, api_name):
if x is None and (mesh is None or mesh.empty):
return UNSPECIFIED
if isinstance(x, sharding.Sharding) or is_unspecified_or_auto(x):
if isinstance(x, (AUTO, UnspecifiedValue, sharding.Sharding)):
return x
if mesh is None:
msg = ('jax.jit only supports `Sharding`s being passed to'
@ -1110,7 +1109,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
orig_in_shardings = tree_unflatten(in_shardings_treedef, in_shardings_leaves)
# Only do this if original in_shardings are unspecified. If it is AUTO, go
# via flatten_axis_resources.
if is_unspecified(orig_in_shardings):
if isinstance(orig_in_shardings, UnspecifiedValue):
in_shardings_flat = (orig_in_shardings,) * len(in_avals)
else:
in_shardings_flat = flatten_axis_resources(
@ -1312,8 +1311,7 @@ def _check_and_canonicalize_out_shardings(
out_shardings_treedef, out_shardings_leaves, out_layouts_treedef,
out_layouts_leaves, out_tree, out_avals, debug_info, device_or_backend_set):
orig_out_shardings = tree_unflatten(out_shardings_treedef, out_shardings_leaves)
if (is_unspecified(orig_out_shardings) or
isinstance(orig_out_shardings, sharding.Sharding)):
if isinstance(orig_out_shardings, (UnspecifiedValue, sharding.Sharding)):
out_shardings_flat = (orig_out_shardings,) * len(out_avals)
else:
out_shardings_flat = flatten_axis_resources(
@ -1391,7 +1389,7 @@ def pjit_check_aval_sharding(
what_aval: str, allow_uneven_sharding: bool):
new_names = [''] * len(shardings) if names is None else names
for aval, s, name in zip(flat_avals, shardings, new_names):
if is_unspecified_or_auto(s):
if isinstance(s, (UnspecifiedValue, AUTO)):
continue
name_str = f' with pytree key path {name}' if name else ''
shape = aval.shape
@ -1466,7 +1464,7 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
else:
arg_layout, dispatch_arg_layout = None, None
# Sharding can be unspecified when array is committed if it's a PmapSharding.
is_pmap_sharding = (is_unspecified(rs) or
is_pmap_sharding = (isinstance(rs, UnspecifiedValue) or
isinstance(getattr(arg, 'sharding', None), PmapSharding))
if jit_in_l is None:
if committed:
@ -1527,15 +1525,15 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
if getattr(a, '_committed', True):
committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None))
resolved_in_shardings = []
resolved_in_shardings: list[PjitSharding] = []
for arg, pjit_in_s in zip(args, pjit_in_shardings):
# arg sharding can be None in case of ShapeDtypeStruct. jax.Array does
# not allow None as the sharding.
arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True))
if hasattr(arg, 'sharding') and arg.sharding is not None
else (UNSPECIFIED, False))
if is_unspecified(pjit_in_s):
if is_unspecified(arg_s):
if isinstance(pjit_in_s, UnspecifiedValue):
if isinstance(arg_s, UnspecifiedValue):
resolved_in_shardings.append(arg_s)
else:
if committed:
@ -1553,7 +1551,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
'multiple devices is not supported.')
else:
if (isinstance(arg, np.ndarray) and
not pjit_in_s.is_fully_replicated and # type: ignore
not pjit_in_s.is_fully_replicated and # type: ignore[union-attr]
xb.process_count() > 1):
raise ValueError(
'Passing non-trivial shardings for numpy '
@ -1572,16 +1570,16 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
# jax.jit does not allow resharding across different memory kinds even
# if the argument is uncommitted. Use jax.device_put for those cases,
# either outside or inside jax.jit.
if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore
if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore[union-attr]
raise ValueError(
'Memory kinds passed to jax.jit does not match memory kind on the'
f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore
f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore[union-attr]
f'arg memory kind: {arg_s.memory_kind} for '
f'arg shape: {shaped_abstractify(arg).str_short()}')
if (committed and
not isinstance(arg_s, PmapSharding) and
not op_shardings.are_op_shardings_equal(
pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore
pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore[union-attr]
arg_s._to_xla_hlo_sharding(arg.ndim))):
raise ValueError('Sharding passed to pjit does not match the sharding '
'on the respective arg. '
@ -1780,8 +1778,8 @@ def pjit_staging_rule(trace, *args, **params):
params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings,
out_layouts=out_layouts)
if (params["inline"] and
all(is_unspecified(i) for i in params["in_shardings"]) and
all(is_unspecified(o) for o in params["out_shardings"]) and
all(isinstance(i, UnspecifiedValue) for i in params["in_shardings"]) and
all(isinstance(o, UnspecifiedValue) for o in params["out_shardings"]) and
all(i is None for i in params["in_layouts"]) and
all(o is None for o in params["out_layouts"])):
if config.dynamic_shapes.value:
@ -1830,7 +1828,7 @@ pe.custom_staging_rules[pjit_p] = pjit_staging_rule
def _pjit_forwarding(jaxpr, out_shardings, out_layouts):
in_fwd: list[int | None] = pe._jaxpr_forwarding(jaxpr.jaxpr)
in_fwd = [fwd if is_unspecified(os) and ol is None else None for fwd, os, ol
in_fwd = [fwd if isinstance(os, UnspecifiedValue) and ol is None else None for fwd, os, ol
in zip(in_fwd, out_shardings, out_layouts)]
keep = [f is None for f in in_fwd]
jaxpr = pe.prune_closed_jaxpr_outputs(jaxpr, keep)
@ -1896,8 +1894,8 @@ def _pjit_cached_lower_jaxpr_to_fun(ctx, name, jaxpr, effects, in_shardings,
func = mod_ctx.cached_primitive_lowerings.get(key, None)
if func is None:
arg_shardings = [None if is_unspecified(i) else i for i in in_shardings]
result_shardings = [None if is_unspecified(o) else o for o in out_shardings]
arg_shardings = [None if isinstance(i, UnspecifiedValue) else i for i in in_shardings]
result_shardings = [None if isinstance(o, UnspecifiedValue) else o for o in out_shardings]
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
# inputs or outputs because they are lost during MLIR->HLO conversion.
# using_sharding_annotation=False means we add an identity operation instead.
@ -1990,9 +1988,9 @@ batching.axis_primitive_batchers[pjit_p] = partial(_pjit_batcher, None)
def _pjit_batcher_for_sharding(
s: sharding.Sharding | UnspecifiedValue,
dim: int, spmd_axis_name: tuple[str, ...] | None, mesh, ndim: int):
if is_unspecified(s):
if isinstance(s, UnspecifiedValue):
return s
hlo_s = s._to_xla_hlo_sharding(ndim) # type: ignore
hlo_s = s._to_xla_hlo_sharding(ndim)
if spmd_axis_name is None:
if sharding_impls.is_op_sharding_replicated(hlo_s):
return s
@ -2004,7 +2002,7 @@ def _pjit_batcher_for_sharding(
tad.insert(dim, 1)
new_op.tile_assignment_dimensions = tad
new_gs = GSPMDSharding(
s._device_assignment, new_op, # type: ignore
s._device_assignment, new_op,
_device_list=getattr(s, '_internal_device_list', None))
return pxla._get_out_sharding_from_orig_sharding([new_gs], [None], s, None)[0]
else:
@ -2107,7 +2105,7 @@ def _pjit_partial_eval(trace, *in_tracers,
# Only forward primal outputs when corresponding out_sharding is UNSPECIFIED.
in_fwd_primal, in_fwd_res = split_list(in_fwd, [num_out_primals])
in_fwd = [
fwd if is_unspecified(os) and ol is None else None
fwd if isinstance(os, UnspecifiedValue) and ol is None else None
for os, ol, fwd in zip(
keep_where(out_shardings, known_outs),
keep_where(out_layouts, known_outs), in_fwd_primal)
@ -2358,9 +2356,9 @@ def _pjit_pp_rule(eqn, context, settings):
del params['inline']
if not any(params['donated_invars']):
del params['donated_invars']
if all(is_unspecified(s) for s in params['in_shardings']):
if all(isinstance(s, UnspecifiedValue) for s in params['in_shardings']):
del params['in_shardings']
if all(is_unspecified(s) for s in params['out_shardings']):
if all(isinstance(s, UnspecifiedValue) for s in params['out_shardings']):
del params['out_shardings']
if all(l is None for l in params['in_layouts']):
del params['in_layouts']
@ -2382,8 +2380,7 @@ core.pp_eqn_rules[pjit_p] = _pjit_pp_rule
def _pjit_state_discharge_rule(
in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings,
in_layouts, out_layouts, **params):
if not (all(map(is_unspecified, in_shardings)) and
all(map(is_unspecified, out_shardings))):
if not all(isinstance(s, UnspecifiedValue) for s in (*in_shardings, *out_shardings)):
raise NotImplementedError
if not (all(l is None for l in in_layouts) and

@ -957,21 +957,11 @@ class AUTO:
return SdyArraySharding(self.mesh.shape_tuple, dim_shardings)
def is_auto(x):
return isinstance(x, AUTO)
class UnspecifiedValue:
def __repr__(self):
return "UnspecifiedValue"
UNSPECIFIED = UnspecifiedValue()
def is_unspecified(x):
return isinstance(x, UnspecifiedValue)
def is_unspecified_or_auto(x):
return is_auto(x) or is_unspecified(x)
MeshAxisName = Any
@ -1014,8 +1004,6 @@ def array_mapping_to_axis_resources(array_mapping: ArrayMapping):
def get_array_mapping(
axis_resources: ParsedPartitionSpec | AUTO | UnspecifiedValue
) -> ArrayMappingOrAutoOrUnspecified:
# TODO(yashkatariya): Use `TypeGuard` on `is_auto` when it is supported.
# Don't use `is_auto` here to satisfy pytype and mypy.
if isinstance(axis_resources, (AUTO, UnspecifiedValue)):
return axis_resources
return OrderedDict((axis, i)
@ -1113,7 +1101,7 @@ def prepare_axis_resources(axis_resources, arg_name,
new_entries = []
for entry in entries:
if is_unspecified_or_auto(entry) or entry is None:
if isinstance(entry, (UnspecifiedValue, AUTO)) or entry is None:
new_entries.append(entry)
elif isinstance(entry, sharding.Sharding):
if isinstance(entry, PmapSharding):
@ -1131,8 +1119,7 @@ def prepare_axis_resources(axis_resources, arg_name,
def _check_unique_resources(axis_resources, arg_name):
for arg_axis_resources in axis_resources:
if not arg_axis_resources: continue
if (is_unspecified_or_auto(arg_axis_resources) or
isinstance(arg_axis_resources, sharding.Sharding)):
if isinstance(arg_axis_resources, (UnspecifiedValue, AUTO, sharding.Sharding)):
continue
constrained_dims = [d for d in arg_axis_resources if d is not None]
resource_counts = collections.Counter(

@ -43,7 +43,7 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util
from jax._src import util
from jax._src.sharding_impls import is_unspecified_or_auto
from jax._src.sharding_impls import UnspecifiedValue, AUTO
from jax._src.layout import Layout
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
@ -649,7 +649,7 @@ class Lowered(Stage):
out_avals = self._lowering.compile_args["global_out_avals"]
out_shardings = self._lowering.compile_args["out_shardings"]
return self.out_tree.unflatten(
[OutInfo(o.shape, o.dtype, None if is_unspecified_or_auto(s) else s)
[OutInfo(o.shape, o.dtype, None if isinstance(s, (UnspecifiedValue, AUTO)) else s)
for o, s in zip(out_avals, out_shardings)])
def compile(

@ -3537,7 +3537,7 @@ def split_to_logical_devices(tensor: TfVal,
def _xla_compatible_sharding_to_hlo_sharding(
s: sharding.Sharding,
aval: core.ShapedArray) -> xla_client.HloSharding | None:
if sharding_impls.is_unspecified(s):
if isinstance(s, sharding_impls.UnspecifiedValue):
return None
return s._to_xla_hlo_sharding(aval.ndim)

@ -40,7 +40,6 @@ from jax._src.sharding_impls import (
ArrayMapping as ArrayMapping,
UNSPECIFIED as _UNSPECIFIED, # noqa: F401
array_mapping_to_axis_resources as array_mapping_to_axis_resources,
is_unspecified as _is_unspecified, # noqa: F401
)
from jax._src.sharding_specs import (