diff --git a/jax/_src/config.py b/jax/_src/config.py index 5519e7d79..c457acb0e 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -817,8 +817,8 @@ already_configured_with_absl = False trace_state = config_ext.Config(None, include_in_jit_key=True) axis_env_state = config_ext.Config((), include_in_jit_key=True) mesh_context_manager = config_ext.Config((), include_in_jit_key=True) -abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True) -device_context = config_ext.Config((), include_in_jit_key=True) +abstract_mesh_context_manager = config_ext.Config(None, include_in_jit_key=True) +device_context = config_ext.Config(None, include_in_jit_key=True) compute_on_context_manager = config_ext.Config(None, include_in_jit_key=True) xla_metadata_context_manager = config_ext.Config(None, include_in_jit_key=True) diff --git a/jax/_src/core.py b/jax/_src/core.py index 5a01def7f..a6d530c89 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -591,8 +591,10 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True) -> list[ def check_avals_context_mesh(avals, prim_name): if config.sharding_in_types.value: + cur_mesh = mesh_lib.get_abstract_mesh() for a in avals: - cur_mesh = mesh_lib.get_abstract_mesh() + if a.sharding.mesh.empty or cur_mesh.empty: + continue if a.sharding.mesh != cur_mesh: raise ValueError( f"For primitive {prim_name}, context mesh {cur_mesh} should match" @@ -1778,11 +1780,11 @@ def get_sharding(sharding, ndim): "Length of sharding.spec must be equal to aval's ndim. Got" f" sharding.spec {out_s.spec} and aval.ndim {ndim}") else: - context_mesh = mesh_lib.get_abstract_mesh() - if not context_mesh: + cur_mesh = mesh_lib.get_abstract_mesh() + if cur_mesh.empty: raise RuntimeError("Please set the mesh via `jax.set_mesh` API.") assert sharding is None - out_s = NamedSharding(context_mesh, P(*[None] * ndim)) + out_s = NamedSharding(cur_mesh, P(*[None] * ndim)) if not isinstance(out_s.mesh, mesh_lib.AbstractMesh): raise ValueError("Mesh of an aval must be an AbstractMesh. " f"Got {out_s.mesh} of type {type(out_s.mesh)}") @@ -1924,6 +1926,11 @@ class DShapedArray(UnshapedArray): weak_type = self.weak_type return DShapedArray(shape, dtype, weak_type) + @property + def sharding(self): + from jax._src.sharding_impls import NamedSharding # type: ignore + return NamedSharding(mesh_lib.AbstractMesh(()), P()) + def _len(self, tracer): return self.shape[0] diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 3999541ba..66c9fc7a6 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1663,8 +1663,10 @@ def lower_jaxpr_to_fun( flat_args = [ replicate_trailing_dims(entry_lowering_ctx, o, a) if (a is not core.abstract_token and - dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # pytype: disable=attribute-error - for o, s, a in zip(flat_args, ir_arg_shardings, input_avals) + dtypes.issubdtype(a.dtype, dtypes.extended) and + (s is None or all_unconstrained(rs, a))) else o # pytype: disable=attribute-error + for o, s, a, rs in zip(flat_args, ir_arg_shardings, input_avals, + arg_shardings) # type: ignore ] _, token_args, unflattened_args = util.split_list( @@ -1717,8 +1719,10 @@ def lower_jaxpr_to_fun( flat_outputs = [ replicate_trailing_dims(entry_lowering_ctx, o, a) if (a is not core.abstract_token and - dtypes.issubdtype(a.dtype, dtypes.extended) and s is None) else o # pytype: disable=attribute-error - for o, s, a in zip(flat_outputs, ir_result_shardings, output_avals) + dtypes.issubdtype(a.dtype, dtypes.extended) and + (s is None or all_unconstrained(rs, a))) else o # pytype: disable=attribute-error + for o, s, a, rs in zip(flat_outputs, ir_result_shardings, output_avals, + result_shardings) # type: ignore ] func_dialect.return_(flat_outputs) @@ -1917,7 +1921,8 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, source_info = eqn.source_info.replace( name_stack=name_stack + eqn.source_info.name_stack) loc = _source_info_to_location(ctx, eqn.primitive, source_info) - with source_info_util.user_context(eqn.source_info.traceback), loc: + with (source_info_util.user_context(eqn.source_info.traceback), loc, + eqn.ctx.manager): override_rule = get_override_lowering_rule(eqn.primitive) platform_rules: dict[str, LoweringRule] = {} default_rule: LoweringRule | None = None diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index fd9c05d0c..7ca6f31d0 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2102,7 +2102,8 @@ def _get_num_devices( for s in shardings: if isinstance(s, UnspecifiedValue): continue - elif isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): + elif (isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh) and + not s.mesh.empty): if abstract_mesh is not None and abstract_mesh != s.mesh: raise ValueError("AbstractMesh should be the same across all " f"shardings. Got {abstract_mesh} and {s.mesh}") @@ -2158,6 +2159,9 @@ def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts, donated_invars, out_shardings, out_layouts) def _concretize_abstract_out_shardings(shardings, avals, device_assignment): + if len(device_assignment) == 1: + return shardings + np_dev = np.vectorize(lambda i: device_assignment[i], otypes=[object])(np.arange(len(device_assignment))) @@ -2170,11 +2174,14 @@ def _concretize_abstract_out_shardings(shardings, avals, device_assignment): out = [] for s, a in zip(shardings, avals): if isinstance(s, UnspecifiedValue) and a.sharding is not None: - spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp - for sp in a.sharding.spec]) - if a.sharding.mesh._any_axis_auto else a.sharding.spec) - out.append(NamedSharding( - _abstract_to_concrete_mesh(a.sharding.mesh), spec)) + if a.sharding.mesh.empty: + out.append(s) + else: + spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp + for sp in a.sharding.spec]) + if a.sharding.mesh._any_axis_auto else a.sharding.spec) + out.append(NamedSharding( + _abstract_to_concrete_mesh(a.sharding.mesh), spec)) else: out.append(s) return tuple(out) @@ -2792,7 +2799,10 @@ def _maybe_get_and_check_out_shardings( if (aval is not core.abstract_token and dtypes.issubdtype(aval.dtype, dtypes.extended)): xla_s = sharding_impls.logical_sharding(aval, xla_s) - new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) # type: ignore + try: + new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) # type: ignore + except: + new_out_shardings.append(xla_s) else: xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # pytype: disable=attribute-error diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 6c3faa872..7a56dc099 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -229,10 +229,12 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]], msg.format(', '.join(str(x) for x in xs_flat if not hasattr(x, 'shape')))) from err + xs_avals = [core.get_aval(x) for x in xs_flat] + if (config.sharding_in_types.value and - not all(x.aval.sharding.spec[0] is None for x in xs_flat)): + not all(a.sharding.spec[0] is None for a in xs_avals)): raise ValueError('0th dimension of all xs should be replicated. Got ' - f'{", ".join(str(x.aval.sharding.spec) for x in xs_flat)}') + f'{", ".join(str(a.sharding.spec) for a in xs_avals)}') if length is not None: try: @@ -270,7 +272,6 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]], stacked_y = tree_map(stack, *maybe_reversed(ys)) return carry, stacked_y - xs_avals = [core.get_aval(x) for x in xs_flat] x_avals = [core.mapped_aval(length, 0, aval) for aval in xs_avals] dbg_body = api_util.tracing_debug_info("scan", f, (init, xs), {}) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 8dadea012..e85db5b7c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -64,6 +64,7 @@ from jax._src.lax.utils import ( from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo +from jax._src.mesh import get_abstract_mesh from jax._src.sharding_impls import (PmapSharding, NamedSharding, PartitionSpec as P, canonicalize_sharding) from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape @@ -179,7 +180,7 @@ def _broadcast_shapes_uncached(*shapes): # Raise ValueError here for backward compatibility. raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}") from err -def broadcast_shardings(*avals) -> NamedSharding: +def broadcast_shardings(*avals): fst, *rst = avals if not rst: return fst.sharding @@ -585,11 +586,8 @@ def _convert_element_type( new_dtype = np.dtype(new_dtype) new_dtype = dtypes.dtype(new_dtype, canonicalize=True) - if (config.sharding_in_types.value and sharding is None and - isinstance(operand, Array)): - sharding = operand.aval.sharding - - sharding = canonicalize_sharding(sharding, check_mesh_consistency=False) # type: ignore + if sharding is not None and not isinstance(sharding, Sharding): + raise ValueError(f'{sharding=} must be an instance of jax.sharding.Sharding') if (warn_on_complex_to_real_cast and dtypes.issubdtype(old_dtype, np.complexfloating) and @@ -1373,6 +1371,7 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN raise NotImplementedError( '`out_sharding` argument of `dot_general` only supports NamedSharding ' 'instances. Please file a bug if this is not enough for your use case.') + out_sharding = canonicalize_sharding(out_sharding) (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers cdims = (api_util._ensure_index_tuple(lhs_contract), api_util._ensure_index_tuple(rhs_contract)) @@ -1933,13 +1932,13 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value)) fill_value = _convert_element_type(fill_value, dtype, weak_type) if (sharding is not None and not isinstance(sharding, PmapSharding) and - isinstance(fill_value, array.ArrayImpl) and - not config.sharding_in_types.value): + isinstance(fill_value, array.ArrayImpl) and sharding.is_concrete): broadcast_shape = sharding.shard_shape(shape) shard = broadcast(fill_value, broadcast_shape) return array.make_array_from_callback(shape, sharding, lambda _: shard) - if config.sharding_in_types.value and sharding is not None: + if (config.sharding_in_types.value and sharding is not None and + not sharding.is_concrete): return broadcast(fill_value, shape, sharding=sharding) else: return broadcast(fill_value, shape) @@ -2150,7 +2149,7 @@ def full_like(x: ArrayLike | DuckTypedArray, return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr] if (config.sharding_in_types.value and sharding is None and shape is None and - isinstance(x, Array)): + isinstance(x, core.Tracer)): sharding = x.aval.sharding else: # If `x` has a sharding but no `_committed` attribute @@ -3184,6 +3183,13 @@ def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type, def _convert_element_type_sharding_rule(operand, *, new_dtype, weak_type, sharding): + if sharding is None: + return operand.sharding + if sharding.is_concrete: + if isinstance(sharding, NamedSharding): + return NamedSharding(sharding.mesh.abstract_mesh, sharding.spec) + else: + return None return sharding def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type, @@ -3268,7 +3274,7 @@ convert_element_type_p = Primitive('convert_element_type') def _convert_element_type_bind_with_trace(trace, args, params): sharding = params['sharding'] operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params) - if sharding is not None and not config.sharding_in_types.value: + if sharding is not None and sharding.is_concrete: with core.set_current_trace(trace): operand = pjit.with_sharding_constraint(operand, sharding) return operand @@ -3303,8 +3309,6 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type, aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype)) out = mlir.convert_hlo(ctx, operand, aval_in, aval_out) if config.sharding_in_types.value: - if sharding is not None: - assert aval_out.sharding == sharding return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] return [out] @@ -3734,8 +3738,12 @@ def _dot_batch_rule( else: rhs_shape = np.shape(rhs) if out_sharding is not None: - raise NotImplementedError("vmap with out_sharding is not supported. " - "Please open an issue.") + cur_mesh = get_abstract_mesh() + if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual: + out_sharding = None + else: + raise NotImplementedError("vmap with out_sharding is not supported. " + "Please open an issue.") batched_out = invoke_prim( lhs, rhs, @@ -4433,8 +4441,13 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape, dyn_limits.append(bound) new_shape = (stacked_size,) + _merge_dyn_shape(shape, dyn_limits) if sharding is not None: - raise NotImplementedError('Implement broadcast_in_dim_batch_rule') - result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions) + if sharding.mesh._are_all_axes_auto or sharding.mesh._are_all_axes_manual: + sharding = None + else: + raise NotImplementedError('Implement sharding support for ' + 'broadcast_in_dim_batch_rule') + result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions, + sharding=sharding) out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None] out_bdim = batching.make_batch_axis( result.ndim, 0, zip(out_ragged_axes, out_ragged_sizes)) @@ -5108,8 +5121,9 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding): return [reshape(t, operand.aval.shape)] else: if config.sharding_in_types.value: - t_s = operand.sharding.with_spec( - tuple(map(str, np.take(operand.aval.sharding.spec, dimensions)))) + t_s = operand.aval.sharding.with_spec( + tuple(map(lambda s: s if s is None else str(s), + np.take(operand.aval.sharding.spec, dimensions)))) else: t_s = None return [transpose(reshape(t, np.take(operand.aval.shape, dimensions), @@ -5119,13 +5133,18 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding): def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions, sharding): if sharding is not None: - raise NotImplementedError + if sharding.mesh._are_all_axes_manual or sharding.mesh._are_all_axes_auto: + sharding = None + else: + raise NotImplementedError('reshape batch sharding support') operand, = batched_args bdim, = batch_dims operand = batching.moveaxis(operand, bdim, 0) if dimensions is not None: dimensions = (0,) + tuple(np.add(1, dimensions)) - return reshape(operand, operand.shape[:1] + new_sizes, dimensions), 0 + out = reshape(operand, operand.shape[:1] + new_sizes, dimensions, + sharding=sharding) + return out, 0 def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding): @@ -6689,16 +6708,18 @@ _zeros: Callable = partial(full_like, fill_value=0) def _zero(x): if config.sharding_in_types.value: + x_aval = core.get_aval(x) return full_like(x, shape=(), fill_value=0, - sharding=x.aval.sharding.with_spec(P())) # type: ignore + sharding=x_aval.sharding.with_spec(P())) # type: ignore return full_like(x, shape=(), fill_value=0) _ones: Callable = partial(full_like, fill_value=1) def _one(x): if config.sharding_in_types.value: + x_aval = core.get_aval(x) return full_like(x, shape=(), fill_value=1, - sharding=x.aval.sharding.with_spec(P())) + sharding=x_aval.sharding.with_spec(P())) return full_like(x, shape=(), fill_value=1) _twos: Callable = partial(full_like, fill_value=2) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 61080b2a9..bb287e08a 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -49,16 +49,15 @@ def _get_array_abstraction_level(a): return a.array_abstraction_level def call_sharding_rule(prim, rule, num_out, *avals, **kwargs): if config.sharding_in_types.value: + cur_mesh = mesh_lib.get_abstract_mesh() + if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual: + return None if num_out is None else [None] * num_out if rule is None: - cur_mesh = mesh_lib.get_abstract_mesh() - if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual: # type: ignore - return None if num_out is None else [None] * num_out - else: - raise ValueError( - f'sharding rule for {prim.name} is not implemented. Please file a' - ' bug at https://github.com/jax-ml/jax/issues. You can work around' - ' this error by dropping that operation into full hidden sharding' - ' mode via: `jax.experimental.shard.hidden_axes(fun, out_shardings=...)`') + raise ValueError( + f'sharding rule for {prim.name} is not implemented. Please file a' + ' bug at https://github.com/jax-ml/jax/issues. You can work around' + ' this error by dropping that operation into full hidden sharding' + ' mode via: `jax.experimental.shard.hidden_axes(fun, out_shardings=...)`') return rule(*avals, **kwargs) return None if num_out is None else [None] * num_out diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 4f7c58a03..350bfe138 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -120,6 +120,16 @@ def axis_types_to_names(name_to_type: dict[str, AxisTypes]): d[t].append(n) return {t: ns[0] if len(ns) == 1 else tuple(ns) for t, ns in d.items()} +def to_axis_types_tuple(axis_types): + out = [] + for t, names in axis_types.items(): + if isinstance(names, tuple): + new_names = names[0] if len(names) == 1 else names + else: + new_names = names + out.append((t, new_names)) + return tuple(out) + _mesh_object_dict = {} # type: ignore @@ -200,7 +210,7 @@ class Mesh(contextlib.ContextDecorator): axis_types = ({AxisTypes.Auto: axis_names} if axis_types is None else axis_types) - axis_types_tuple = tuple(axis_types.items()) + axis_types_tuple = to_axis_types_tuple(axis_types) if len(axis_names_to_types(axis_types).keys()) != len(axis_names): raise ValueError( "Number of axis names in axis_types should match the number of" @@ -405,7 +415,7 @@ class AbstractMesh: self._axis_names, self._axis_sizes = (), () self.axis_types = ({AxisTypes.Auto: self._axis_names} if axis_types is None else axis_types) - self._axis_types_tuple = tuple(self.axis_types.items()) + self._axis_types_tuple = to_axis_types_tuple(self.axis_types) if len(self._name_to_type.keys()) != len(self._axis_names): raise ValueError( "Number of axis names in axis_types should match the number of" @@ -424,7 +434,8 @@ class AbstractMesh: self._axis_types_tuple == other._axis_types_tuple) def __repr__(self): - mesh_repr = ", ".join(f"'{n}': {v}" for n, v in self.shape_tuple) + mesh_repr = (", ".join(f"'{n}': {v}" for n, v in self.shape_tuple) + if self.shape_tuple else "()") atr = f", axis_types={self.axis_types}" return f"AbstractMesh({mesh_repr}{atr})" @@ -537,8 +548,11 @@ def set_abstract_mesh(mesh: AbstractMesh): finally: jax_config.abstract_mesh_context_manager.set_local(prev_val) +empty_abstract_mesh = AbstractMesh(()) + def get_abstract_mesh(): - return jax_config.abstract_mesh_context_manager.value + val = jax_config.abstract_mesh_context_manager.value + return empty_abstract_mesh if val is None else val @contextlib.contextmanager diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index be5475476..4203547c9 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5554,8 +5554,9 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, # array([1, 2, 3]) weak_type = dtype is None and dtypes.is_weakly_typed(object) if (config.sharding_in_types.value and device is None and - isinstance(object, Array)): + isinstance(object, core.Tracer)): sharding = object.aval.sharding + sharding = None if sharding.mesh.empty else sharding else: sharding = canonicalize_device_to_sharding(device) # type: ignore diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 4cf1d7b6e..a6edadf92 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -82,12 +82,7 @@ def promote_dtypes(*args: ArrayLike) -> list[Array]: else: to_dtype, weak_type = dtypes._lattice_result_type(*args) to_dtype = dtypes.canonicalize_dtype(to_dtype, allow_extended_dtype=True) # type: ignore[assignment] - if config.sharding_in_types.value: - return [lax._convert_element_type(x, to_dtype, weak_type, - getattr(x, "sharding", None)) - for x in args] - else: - return [lax._convert_element_type(x, to_dtype, weak_type) for x in args] + return [lax._convert_element_type(x, to_dtype, weak_type) for x in args] def promote_dtypes_inexact(*args: ArrayLike) -> list[Array]: diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 09f3fb892..6e9dc99b5 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -622,8 +622,8 @@ def _infer_params_impl( attr_token = _attr_token(flat_fun, in_type) abstract_mesh = ( - get_abstract_mesh_from_avals(in_type) - if not mesh_lib.get_abstract_mesh() else mesh_lib.get_abstract_mesh()) + get_abstract_mesh_from_avals(in_avals) + if mesh_lib.get_abstract_mesh().empty else mesh_lib.get_abstract_mesh()) with mesh_lib.set_abstract_mesh(abstract_mesh): jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr( flat_fun, in_type, attr_token, dbg, @@ -677,13 +677,16 @@ def get_abstract_mesh_from_avals(in_avals): return None m = None for a in in_avals: + if a is core.abstract_token: + continue + if a.sharding.mesh.empty: # type: ignore + continue if m is not None and m != a.sharding.mesh: raise ValueError( f'Mesh for all inputs should be equal. Got one mesh: {m} and' f' another mesh: {a.sharding.mesh}') m = a.sharding.mesh # type: ignore - assert isinstance(m, AbstractMesh) - return m + return AbstractMesh(()) if m is None else m class InferParamsCacheEntry: @@ -722,6 +725,10 @@ def _infer_params( else: resource_env = None pjit_mesh = None + if resource_env is not None and mesh_lib.get_concrete_mesh() is not None: + raise ValueError( + 'Using `with mesh:` context manager and `jax.sharding.use_mesh`' + ' together is not allowed.') dbg = tracing_debug_info( 'jit', fun, args, kwargs, static_argnums=ji.static_argnums, @@ -1561,6 +1568,8 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] arg_s, committed = ((arg.sharding, getattr(arg, '_committed', True)) if hasattr(arg, 'sharding') and arg.sharding is not None else (UNSPECIFIED, False)) + if isinstance(arg_s, NamedSharding) and arg_s.mesh.empty: + arg_s, committed = UNSPECIFIED, False if isinstance(pjit_in_s, UnspecifiedValue): if isinstance(arg_s, UnspecifiedValue): resolved_in_shardings.append(arg_s) @@ -1595,7 +1604,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] 'Please see the jax.Array migration guide for more information ' 'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. ' f'Got arg shape: {arg.shape}, arg value: {arg}') - if not isinstance(arg_s, UnspecifiedValue): + if not isinstance(arg_s, UnspecifiedValue) and arg_s.is_concrete: # jax.jit does not allow resharding across different memory kinds even # if the argument is uncommitted. Use jax.device_put for those cases, # either outside or inside jax.jit. @@ -1789,7 +1798,10 @@ def _pjit_lower( pgle_profiler: profiler.PGLEProfiler | None): util.test_event("pjit_lower") if config.sharding_in_types.value: - mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit' + if resource_env is not None: + mesh, api_name = resource_env.physical_mesh, 'pjit' + else: + mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit' else: mesh, api_name = ((resource_env.physical_mesh, 'pjit') if resource_env is not None else (None, 'jit')) @@ -2683,18 +2695,17 @@ batching.skippable_batchers[sharding_constraint_p] = lambda _: () # TODO(yashkatariya): Make shardings optional. def mesh_cast(xs, out_shardings): x_flat, treedef = tree_flatten(xs) - x_avals_flat = [core.shaped_abstractify(x) for x in x_flat] shardings_flat = flatten_axes("mesh_cast shardings", treedef, out_shardings) out_flat = [ mesh_cast_p.bind( - x, src_sharding=x_aval.sharding, - dst_sharding=canonicalize_sharding(s, check_mesh_consistency=False)) - for x, x_aval, s in safe_zip(x_flat, x_avals_flat, shardings_flat) + x, dst_sharding=canonicalize_sharding(s, check_mesh_consistency=False)) + for x, s in safe_zip(x_flat, shardings_flat) ] return tree_unflatten(treedef, out_flat) mesh_cast_p = core.Primitive('mesh_cast') -def _mesh_cast_abstract_eval(aval, src_sharding, dst_sharding): +def _mesh_cast_abstract_eval(aval, dst_sharding): + src_sharding = aval.sharding if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple: raise ValueError( f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not' @@ -2731,17 +2742,15 @@ def _mesh_cast_abstract_eval(aval, src_sharding, dst_sharding): return aval.update(sharding=dst_sharding) mesh_cast_p.def_abstract_eval(_mesh_cast_abstract_eval) -def _mesh_cast_impl(x, src_sharding, dst_sharding): - return dispatch.apply_primitive(mesh_cast_p, x, src_sharding=src_sharding, - dst_sharding=dst_sharding) +def _mesh_cast_impl(x, dst_sharding): + return dispatch.apply_primitive(mesh_cast_p, x, dst_sharding=dst_sharding) mesh_cast_p.def_impl(_mesh_cast_impl) -def _mesh_cast_transpose_rule(ct, _, src_sharding, dst_sharding): - return [mesh_cast_p.bind(ct, src_sharding=dst_sharding, - dst_sharding=src_sharding)] +def _mesh_cast_transpose_rule(ct, x, dst_sharding): + return [mesh_cast_p.bind(ct, dst_sharding=x.aval.sharding)] ad.deflinear2(mesh_cast_p, _mesh_cast_transpose_rule) -def _mesh_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding): +def _mesh_cast_hlo_lowering(ctx, x_node, *, dst_sharding): aval, = ctx.avals_in aval_out, = ctx.avals_out proto = (dst_sharding._to_sdy_sharding(aval.ndim) @@ -2750,27 +2759,18 @@ def _mesh_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding): return [mlir.lower_sharding_under_shit(ctx, x_node, aval_out, proto)] mlir.register_lowering(mesh_cast_p, _mesh_cast_hlo_lowering) -# TODO(yashkatariya): Comment this in after vmap ShiT tests are added. -# def _mesh_cast_batcher(axis_data, vals_in, dims_in, src_sharding, -# dst_sharding): -# if axis_data.spmd_name is not None: -# used = {n for ns in dst_sharding.spec -# for n in (ns if isinstance(ns, tuple) else (ns,))} -# if set(axis_data.spmd_name) & used: -# raise ValueError( -# f'vmap spmd_axis_name {axis_data.spmd_name} cannot ' -# f'appear in mesh_cast spec, but got spec {dst_sharding.spec}') -# x, = vals_in -# d, = dims_in +def _mesh_cast_batcher(axis_data, vals_in, dims_in, dst_sharding): + assert axis_data.spmd_name is None + x, = vals_in + d, = dims_in -# val = None if axis_data.spmd_name is None else axis_data.spmd_name -# new_spec = PartitionSpec(*util.tuple_insert(dst_sharding.spec, d, val)) -# vmapped_dst_sharding = NamedSharding(dst_sharding.mesh, new_spec) -# y = mesh_cast_p.bind(x, src_sharding=src_sharding, -# dst_sharding=vmapped_dst_sharding) -# return y, d -# batching.fancy_primitive_batchers[mesh_cast_p] = _mesh_cast_batcher -# batching.skippable_batchers[mesh_cast_p] = lambda _: () + val = None + new_spec = PartitionSpec(*util.tuple_insert(dst_sharding.spec, d, val)) + vmapped_dst_sharding = NamedSharding(dst_sharding.mesh, new_spec) + y = mesh_cast_p.bind(x, dst_sharding=vmapped_dst_sharding) + return y, d +batching.fancy_primitive_batchers[mesh_cast_p] = _mesh_cast_batcher +batching.skippable_batchers[mesh_cast_p] = lambda _: () # -------------------- reshard ------------------------------------ @@ -2782,13 +2782,13 @@ def reshard(xs, out_shardings): for x, x_aval, s in safe_zip(x_flat, x_avals_flat, shardings_flat): ds = canonicalize_sharding(s) ds = ds.with_spec(ds.spec._normalized_spec(x_aval.ndim)) # type: ignore - out_flat.append(reshard_p.bind(x, src_sharding=x_aval.sharding, - dst_sharding=ds)) + out_flat.append(reshard_p.bind(x, dst_sharding=ds)) return tree_unflatten(treedef, out_flat) reshard_p = core.Primitive('reshard') -def _reshard_abstract_eval(aval, src_sharding, dst_sharding): +def _reshard_abstract_eval(aval, dst_sharding): + src_sharding = aval.sharding if src_sharding.mesh.abstract_mesh != dst_sharding.mesh.abstract_mesh: raise ValueError( f'Mesh of the input {src_sharding.mesh.abstract_mesh} does not' @@ -2797,17 +2797,15 @@ def _reshard_abstract_eval(aval, src_sharding, dst_sharding): return aval.update(sharding=dst_sharding) reshard_p.def_abstract_eval(_reshard_abstract_eval) -def _reshard_impl(x, src_sharding, dst_sharding): - return dispatch.apply_primitive(reshard_p, x, src_sharding=src_sharding, - dst_sharding=dst_sharding) +def _reshard_impl(x, dst_sharding): + return dispatch.apply_primitive(reshard_p, x, dst_sharding=dst_sharding) reshard_p.def_impl(_reshard_impl) -def _reshard_transpose_rule(ct, _, src_sharding, dst_sharding): - return [reshard_p.bind(ct, src_sharding=dst_sharding, - dst_sharding=src_sharding)] +def _reshard_transpose_rule(ct, x, dst_sharding): + return [reshard_p.bind(ct, dst_sharding=x.aval.sharding)] ad.deflinear2(reshard_p, _reshard_transpose_rule) -def _reshard_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding): +def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding): aval, = ctx.avals_in aval_out, = ctx.avals_out proto = (dst_sharding._to_sdy_sharding(aval.ndim) diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index 23f0ef13c..ce5e52c29 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -138,6 +138,10 @@ class Sharding: ############################################################################# # Default implementations below that all subclasses will inherit. + @property + def is_concrete(self) -> bool: + return True + @functools.cached_property def addressable_devices(self) -> set[Device]: """The set of devices in the :class:`Sharding` that are addressable by the diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 511e3e1d7..2d18710eb 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -391,6 +391,12 @@ class NamedSharding(jsharding.Sharding): # mesh across multiple NamedSharding objects will be the same. return not self.mesh.is_multi_process + @property + def is_concrete(self) -> bool: + if isinstance(self.mesh, mesh_lib.AbstractMesh): + return False + return True + @property def addressable_devices(self) -> set[Device]: if isinstance(self.mesh, mesh_lib.AbstractMesh): @@ -1761,14 +1767,18 @@ def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None, return sharding # type: ignore if sharding is None: return sharding + # TODO(yashkatariya): Remove this after vmap + shit works. + if isinstance(sharding, NamedSharding) and sharding.mesh.empty: + return None + cur_mesh = mesh_lib.get_abstract_mesh() if isinstance(sharding, PartitionSpec): - sharding = NamedSharding(mesh_lib.get_abstract_mesh(), sharding) # type: ignore + sharding = NamedSharding(cur_mesh, sharding) # type: ignore else: - if (check_mesh_consistency and - sharding.mesh.abstract_mesh != mesh_lib.get_abstract_mesh()): + if (check_mesh_consistency and not cur_mesh.empty and + sharding.mesh.abstract_mesh != cur_mesh): raise ValueError( - f'Context mesh {mesh_lib.get_abstract_mesh()} should match the mesh' + f'Context mesh {cur_mesh} should match the mesh' f' of sharding {sharding.mesh.abstract_mesh}. This error occurs at' f' source: {source_info_util.summarize(source_info_util.current())}') if isinstance(sharding.mesh, mesh_lib.Mesh): diff --git a/jax/_src/stages.py b/jax/_src/stages.py index ed6febd76..47cad591f 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -737,7 +737,7 @@ class Traced(Stage): "_args_flat", "_arg_names", "_num_consts"] def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree, - lower_callable, abstract_mesh=None, + lower_callable, abstract_mesh=mesh_lib.AbstractMesh(()), args_flat=None, arg_names=None, num_consts: int = 0): self.jaxpr = jaxpr self.args_info = args_info diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index 1120c0bf4..fa169b9e7 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -320,6 +320,15 @@ class AbstractRef(core.AbstractValue): f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`." ) from None + @property + def sharding(self): + try: + return self.inner_aval.sharding # pytype: disable=attribute-error + except AttributeError: + raise AttributeError( + f"`Ref{{{self.inner_aval.str_short()}}} has no `sharding`." + ) from None + @core.aval_property def at(self): return RefIndexer(self) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index d74795471..9cc067343 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -70,7 +70,6 @@ from jax._src.tree_util import (broadcast_prefix, prefix_errors, PyTreeDef, generate_key_paths, KeyPath) from jax.experimental.multihost_utils import (host_local_array_to_global_array, global_array_to_host_local_array) -from jax._src.pjit import sharding_constraint_p P = PartitionSpec @@ -473,6 +472,12 @@ shard_map_p = ShardMapPrimitive('shard_map') # Staging +def _as_manual_mesh(mesh): + if config.sharding_in_types.value: + return AbstractMesh( + mesh.shape_tuple, axis_types={AxisTypes.Manual: mesh.axis_names}) + return None + def _shard_map_staging( trace: pe.DynamicJaxprTrace, prim: core.Primitive, f: lu.WrappedFun, in_tracers: Sequence[Any], *, mesh: Mesh, @@ -485,8 +490,9 @@ def _shard_map_staging( in_tracers = map(trace.to_jaxpr_tracer, in_tracers) in_avals = [t.aval for t in in_tracers] in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals) + manual_mesh = _as_manual_mesh(mesh) with (core.extend_axis_env_nd(list(mesh.shape.items())), - set_abstract_mesh(pjit.get_abstract_mesh_from_avals(in_avals_))): + set_abstract_mesh(manual_mesh)): jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_) _check_names(out_names_thunk(), out_avals_) if check_rep: @@ -502,7 +508,8 @@ def _shard_map_staging( constvars = map(trace.getvar, map(trace.to_jaxpr_tracer, consts)) outvars = map(trace.makevar, out_tracers) in_names_staged = ({},) * len(consts) + tuple(in_names) # type: ignore - with core.extend_axis_env_nd(list(mesh.shape.items())): + with (core.extend_axis_env_nd(list(mesh.shape.items())), + set_abstract_mesh(manual_mesh)): jaxpr = pe.convert_constvars_jaxpr(jaxpr) params = dict(mesh=mesh, in_names=in_names_staged, out_names=tuple(out_names_thunk()), jaxpr=jaxpr, @@ -552,7 +559,9 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames, for i, sz in enumerate(aval.shape)) if config.sharding_in_types.value: spec = _names_to_pspec(names)._normalized_spec(aval.ndim) - new_sharding = NamedSharding(get_abstract_mesh(), spec) + new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else + get_abstract_mesh()) + new_sharding = NamedSharding(new_mesh, spec) else: new_sharding = None return aval.update(shape=new_shape, sharding=new_sharding) @@ -787,7 +796,8 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, del prim, auto if isinstance(mesh, AbstractMesh): mesh = get_mesh_from_args(args, mesh) - args = map(partial(_unmatch_spec, mesh), in_names, args) + args = map(partial(_unmatch_spec, mesh, context_mesh=get_abstract_mesh()), + in_names, args) in_rep = map(partial(_in_names_to_rep, mesh), in_names) outs, out_rep = _run_shmap(fun, mesh, args, in_rep, check_rep) out_avals = [core.mapped_aval(x.shape[0], 0, core.get_aval(x)) for x in outs] @@ -798,13 +808,16 @@ def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, return map(partial(_match_spec, mesh, check_rep), pspecs, outs) core.EvalTrace.process_shard_map = _shard_map_impl -def _run_shmap(f, mesh, args, reps, check_rep): - trace = ShardMapTrace(mesh, check_rep) +def _run_shmap(f, mesh, args, reps, check_rep, context_mesh=None): + context_mesh = get_abstract_mesh() if context_mesh is None else context_mesh + trace = ShardMapTrace(mesh, check_rep, context_mesh) in_tracers = map(partial(ShardMapTracer, trace), reps, args) - with core.set_current_trace(trace): - with core.extend_axis_env_nd(mesh.shape.items()): - ans = f.call_wrapped(*in_tracers) - outs, out_rep = unzip2(map(trace.to_val_rep_pair, ans)) + manual_mesh = _as_manual_mesh(mesh) + with (core.set_current_trace(trace), + core.extend_axis_env_nd(mesh.shape.items()), + set_abstract_mesh(manual_mesh)): + ans = f.call_wrapped(*in_tracers) + outs, out_rep = unzip2(map(trace.to_val_rep_pair, ans)) return outs, out_rep def _names_to_pspec(names: AxisNames) -> PartitionSpec: @@ -812,8 +825,9 @@ def _names_to_pspec(names: AxisNames) -> PartitionSpec: unpack = lambda t: t[0] if t is not None and len(t) == 1 else t return PartitionSpec(*(unpack(names.get(i)) for i in range(ndmin))) -def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType) -> JaxType: - with core.eval_context(), jax.disable_jit(False): +def _unmatch_spec(mesh: Mesh, src: AxisNames, x: JaxType, context_mesh) -> JaxType: + with (core.eval_context(), jax.disable_jit(False), + set_abstract_mesh(context_mesh)): return jax.jit(HashablePartial(_unmatch, mesh, tuple(src.items())))(x) def _unmatch(mesh, src_tup, x): @@ -850,18 +864,20 @@ def _match(mesh, check_rep, pspec, x): # TODO put back (?) needed for rep checking in eager? for now test rewrite return shard_map(_rem_singleton, mesh, (src,), pspec, check_rep=False)(x) -def _rem_singleton(x): return x.reshape(x.shape[1:]) -def _add_singleton(x): return x.reshape(1, *x.shape) +def _rem_singleton(x): return jnp.squeeze(x, axis=0) +def _add_singleton(x): return jnp.expand_dims(x, axis=0) class ShardMapTrace(core.Trace): - __slots__ = ("mesh", "check") + __slots__ = ("mesh", "check", "context_mesh") mesh: Mesh check: bool + context_mesh: AbstractMesh - def __init__(self, mesh, check): + def __init__(self, mesh, check, context_mesh): self.mesh = mesh self.check = check + self.context_mesh = context_mesh def to_val_rep_pair(self, val): if isinstance(val, ShardMapTracer): @@ -869,7 +885,7 @@ class ShardMapTrace(core.Trace): elif isinstance(val, Tracer): raise Exception("Shouldn't have any non-shard_map tracers") else: - val_ = _unmatch_spec(self.mesh, {}, val) + val_ = _unmatch_spec(self.mesh, {}, val, self.context_mesh) return val_, None def process_primitive(self, prim, tracers, params): @@ -879,7 +895,8 @@ class ShardMapTrace(core.Trace): out_vals = eager_rule(self.mesh, *in_vals, **params) else: f = HashablePartial(_prim_applier, prim, tuple(params.items()), self.mesh) - with core.eval_context(), jax.disable_jit(False): + with (core.eval_context(), jax.disable_jit(False), + set_abstract_mesh(self.context_mesh)): out_vals = jax.jit(f)(*in_vals) rep_rule = _check_rules.get(prim, partial(_rule_missing, prim)) out_rep = rep_rule(self.mesh, *in_rep, **params) if self.check else set() @@ -910,7 +927,8 @@ class ShardMapTrace(core.Trace): raise NotImplementedError(msg) del prim, jvp, symbolic_zeros in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) - out_vals, out_rep = _run_shmap(fun, self.mesh, in_vals, in_rep, self.check) + out_vals, out_rep = _run_shmap(fun, self.mesh, in_vals, in_rep, self.check, + self.context_mesh) return map(partial(ShardMapTracer, self), out_rep, out_vals) def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, @@ -922,7 +940,8 @@ class ShardMapTrace(core.Trace): raise NotImplementedError(msg) del prim, fwd, bwd, out_trees, symbolic_zeros in_vals, in_rep = unzip2(map(self.to_val_rep_pair, tracers)) - out_vals, out_rep = _run_shmap(fun, self.mesh, in_vals, in_rep, self.check) + out_vals, out_rep = _run_shmap(fun, self.mesh, in_vals, in_rep, self.check, + self.context_mesh) return map(partial(ShardMapTracer, self), out_rep, out_vals) @@ -938,17 +957,23 @@ class ShardMapTracer(core.Tracer): @property def aval(self): aval = core.get_aval(self.val) - return core.mapped_aval(self._trace.mesh.size, 0, aval) + out = core.mapped_aval(self._trace.mesh.size, 0, aval) + if config.sharding_in_types.value: + new_sharding = NamedSharding(_as_manual_mesh(self._trace.mesh), + out.sharding.spec) # type: ignore + else: + new_sharding = None + return out.update(sharding=new_sharding) def to_concrete_value(self): if self.rep == set(self._trace.mesh.axis_names): - with core.eval_context(): + with core.eval_context(), set_abstract_mesh(self._trace.context_mesh): return core.to_concrete_value(self.val[0]) else: return None def __str__(self) -> str: - with core.eval_context(): + with core.eval_context(), set_abstract_mesh(self._trace.context_mesh): blocks = list(self.val) mesh = self._trace.mesh axis_names = f"({', '.join(map(str, mesh.axis_names))},)" @@ -1135,7 +1160,7 @@ for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(), for p in [control_flow.loops.cumsum_p, control_flow.loops.cumlogsumexp_p, control_flow.loops.cumprod_p, control_flow.loops.cummax_p, - control_flow.loops.cummin_p, sharding_constraint_p]: + control_flow.loops.cummin_p, pjit.sharding_constraint_p]: register_standard_check(p) register_standard_rewrite(p) @@ -1706,11 +1731,11 @@ def _partial_eval_jaxpr_custom_rule( for var, w in zip(jaxpr_staged.invars[:num_res], which) if w] eqn_known = pe.new_jaxpr_eqn(ins_known, [*out_binders_known, *residuals], eqn.primitive, params_known, jaxpr_known.effects, - eqn.source_info) + eqn.source_info, eqn.ctx) full_res = subs_list2(in_fwd, out_fwd, ins_known, out_binders_known, residuals) eqn_staged = pe.new_jaxpr_eqn([*full_res, *ins_staged], out_binders_staged, eqn.primitive, params_staged, - jaxpr_staged.effects, eqn.source_info) + jaxpr_staged.effects, eqn.source_info, eqn.ctx) assert len(eqn_staged.invars) == len(jaxpr_staged.invars) new_inst = [x for x, inst in zip(eqn.invars, inst_in) if type(x) is core.Var and not inst] @@ -1795,7 +1820,7 @@ def _shard_map_dce(used_outputs: list[bool], eqn: core.JaxprEqn new_eqn = pe.new_jaxpr_eqn( [v for v, used in zip(eqn.invars, used_inputs) if used], [x for x, used in zip(eqn.outvars, used_outputs) if used], - eqn.primitive, new_params, effs, eqn.source_info) + eqn.primitive, new_params, effs, eqn.source_info, eqn.ctx) return used_inputs, new_eqn pe.dce_rules[shard_map_p] = _shard_map_dce diff --git a/tests/array_test.py b/tests/array_test.py index 045ef7b99..7aabcc831 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -1326,6 +1326,12 @@ class ShardingTest(jtu.JaxTestCase): axis_types={jax.sharding.AxisTypes.Auto: 'x'}) def test_make_mesh_axis_types(self): + mesh1 = jax.sharding.AbstractMesh( + (('x', 2),), axis_types={jax.sharding.AxisTypes.Auto: 'x'}) + mesh2 = jax.sharding.AbstractMesh( + (('x', 2),), axis_types={jax.sharding.AxisTypes.Auto: ('x',)}) + self.assertEqual(mesh1, mesh2) + mesh = jax.make_mesh((1, 1), ('x', 'y')) self.assertDictEqual(mesh.axis_types, {AxisTypes.Auto: ('x', 'y')}) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 221724cf4..d0fda8425 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3523,20 +3523,17 @@ class ArrayPjitTest(jtu.JaxTestCase): f = pjit(mul, in_shardings=ns, out_shardings=ns) - with jtu.count_pjit_cpp_cache_miss() as count: + with (jtu.count_pjit_cpp_cache_miss() as count, + jtu.count_jit_and_pmap_lowerings() as lowering_count): out = f(arr) - cache_info1 = pxla._cached_lowering_to_hlo.cache_info() self.assertIsInstance(out.sharding, NamedSharding) out2 = f(np_arr) - cache_info2 = pxla._cached_lowering_to_hlo.cache_info() self.assertIsInstance(out2.sharding, NamedSharding) # Drops out of C++ cache i.e. cache miss self.assertEqual(count(), 2) - # Still gets a hit on pjit_lower cache. - self.assertEqual(cache_info2.hits, cache_info1.hits + 1) - self.assertEqual(cache_info2.misses, cache_info1.misses) + self.assertEqual(lowering_count(), 1) def test_list_in_pspec(self): mesh = jtu.create_mesh((2,), ('x',)) @@ -4913,8 +4910,9 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out[0].sharding, arr1.sharding) self.assertEqual(out[1].sharding, arr2.sharding) + @parameterized.parameters([True, False]) @jtu.with_user_mesh((4,), ('x',)) - def test_dot_general_out_sharding(self, mesh): + def test_dot_general_out_sharding(self, jit, mesh): np_inp1 = np.arange(16.).reshape(8, 2) arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None))) arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x'))) @@ -4924,34 +4922,38 @@ class ShardingInTypesTest(jtu.JaxTestCase): self.assertEqual(out.aval.sharding.spec, P('x', None)) return jnp.sum(out) - out = f(arr1, arr2) - self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp1.T)) - self.assertEqual(out.sharding, NamedSharding(mesh, P())) - - f = jax.jit(f) + if jit: + f = jax.jit(f) out = f(arr1, arr2) self.assertArraysEqual(out, np.sum(np_inp1 @ np_inp1.T)) self.assertEqual(out.sharding, NamedSharding(mesh, P())) + with self.assertRaisesRegex( + ValueError, + 'PartitionSpec cannot contain axis names that are of type Auto or' + ' Manual'): + auto_axes(f, out_shardings=P())(arr1, arr2) + out = jax.grad(f, argnums=(0, 1))(arr1, arr2) self.assertEqual(out[0].sharding, arr1.sharding) self.assertEqual(out[1].sharding, arr2.sharding) - jitted_grad = jax.jit(jax.grad(f, argnums=(0, 1))) - out = jitted_grad(arr1, arr2) - self.assertEqual(out[0].sharding, arr1.sharding) - self.assertEqual(out[1].sharding, arr2.sharding) + if jit: + jitted_grad = jax.jit(jax.grad(f, argnums=(0, 1))) + out = jitted_grad(arr1, arr2) + self.assertEqual(out[0].sharding, arr1.sharding) + self.assertEqual(out[1].sharding, arr2.sharding) - jaxpr = jitted_grad.trace(arr1, arr2).jaxpr - bwd_jaxpr = jaxpr.eqns[1] - expected_spec = [('broadcast_in_dim', P('x', None)), - ('dot_general', P('x', None)), - ('transpose', P(None, 'x')), - ('dot_general', P('x', None))] - for eqn, spec in zip(bwd_jaxpr.params['jaxpr'].eqns, expected_spec): - self.assertEqual(eqn.primitive.name, spec[0]) - self.assertEqual(eqn.outvars[0].aval.sharding.spec, spec[1]) + jaxpr = jitted_grad.trace(arr1, arr2).jaxpr + bwd_jaxpr = jaxpr.eqns[1] + expected_spec = [('broadcast_in_dim', P('x', None)), + ('dot_general', P('x', None)), + ('transpose', P(None, 'x')), + ('dot_general', P('x', None))] + for eqn, spec in zip(bwd_jaxpr.params['jaxpr'].eqns, expected_spec): + self.assertEqual(eqn.primitive.name, spec[0]) + self.assertEqual(eqn.outvars[0].aval.sharding.spec, spec[1]) @parameterized.named_parameters( ('fail1', P('x', None), P(None, 'x'), @@ -5977,7 +5979,7 @@ class ShardingInTypesTest(jtu.JaxTestCase): {AxisTypes.Explicit: ('x',)}) out = f(arr) self.assertArraysEqual(out, np_inp) - self.assertDictEqual(out.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'}) + self.assertDictEqual(out.sharding.mesh.axis_types, {AxisTypes.Auto: ('x',)}) @jtu.with_user_mesh((2,), 'x') def test_inputs_different_context(self, mesh): @@ -6379,6 +6381,19 @@ class ShardingInTypesTest(jtu.JaxTestCase): f(arr) # doesn't crash + def test_shaped_array_input_to_jit_no_sharding(self): + # export_test.py has similar test but it's more complicated. This is a + # simplified version of a part of that test. + aval = core.ShapedArray((8,), jnp.int32) + aval2 = core.ShapedArray((8,), jnp.int32) + + @jax.jit + def f(x, y): + return x * y + + lowered_text = f.lower(aval, aval2).as_text() + self.assertNotIn("mhlo.sharding", lowered_text) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 25d46c5ad..9c6cf0861 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -58,7 +58,7 @@ class ShardAlikeTest(jtu.JaxTestCase): self.assertArraysEqual(out, np_inp * np_inp * 4) def test_output_sharded_alike_input(self): - mesh = jtu.create_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -73,7 +73,7 @@ class ShardAlikeTest(jtu.JaxTestCase): self.assertArraysEqual(out, np_inp * 2) def test_arange_shard_alike_jit(self): - mesh = jtu.create_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s)