From a3edfb43efe5761adb127d6b51343ee5ef67057a Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 19 Feb 2025 06:52:52 -0800 Subject: [PATCH] Now that sharding_in_types config flag is True, remove the config and all the conditionals PiperOrigin-RevId: 728653433 --- jax/_src/api.py | 3 - jax/_src/config.py | 8 -- jax/_src/core.py | 68 +++++------ jax/_src/interpreters/batching.py | 9 +- jax/_src/interpreters/mlir.py | 2 +- jax/_src/interpreters/partial_eval.py | 4 - jax/_src/interpreters/pxla.py | 7 +- jax/_src/lax/control_flow/loops.py | 6 +- jax/_src/lax/lax.py | 159 ++++++++------------------ jax/_src/lax/linalg.py | 101 +++++++--------- jax/_src/lax/parallel.py | 24 ++-- jax/_src/lax/slicing.py | 12 +- jax/_src/lax/utils.py | 33 +++--- jax/_src/mesh.py | 3 +- jax/_src/nn/functions.py | 7 +- jax/_src/numpy/einsum.py | 6 +- jax/_src/numpy/lax_numpy.py | 3 +- jax/_src/numpy/util.py | 3 +- jax/_src/pallas/pallas_call.py | 3 +- jax/_src/pjit.py | 13 +-- jax/_src/sharding_impls.py | 3 - jax/experimental/shard_map.py | 33 ++---- 22 files changed, 170 insertions(+), 340 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 8d47a0e96..b284d8c28 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -1029,9 +1029,6 @@ def vmap(fun: F, return cast(F, vmap_f) def _mapped_axis_spec(args_flat, in_axes_flat): - if not config.sharding_in_types.value: - return None - def _get_spec(arg, i): try: # Duck type arrays like BCOO arrays can be passed to vmap. diff --git a/jax/_src/config.py b/jax/_src/config.py index 474ab0c88..6f1ec3586 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -234,7 +234,6 @@ def trace_context(): default_device.value, random_seed_offset.value, threefry_partitionable.value, threefry_gpu_kernel_lowering.value, - sharding_in_types.value, use_direct_linearize.value, softmax_custom_jvp.value, disable_jit.value, @@ -1067,13 +1066,6 @@ threefry_gpu_kernel_lowering = bool_state( 'cost.'), include_in_jit_key=True) -sharding_in_types = bool_state( - name='jax_sharding_in_types', - default=True, - help=('When True, enables forward only sharding propagation in JAX and ' - 'avals have sharding on them.'), - include_in_jit_key=True) - use_direct_linearize = bool_state( name='jax_use_direct_linearize', default=False, diff --git a/jax/_src/core.py b/jax/_src/core.py index 6b5bd88e9..9aa5e42d5 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -498,8 +498,6 @@ class Primitive: return f'{self.name}' def bind(self, *args, **params): - if not config.sharding_in_types.value: - return self._true_bind(*args, **params) args = args if self.skip_canonicalization else map(canonicalize_value, args) return self._true_bind(*args, **params) @@ -598,22 +596,21 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True) -> list[ return map(read, jaxpr.outvars) def check_avals_context_mesh(avals, prim_name): - if config.sharding_in_types.value: - cur_mesh = mesh_lib.get_abstract_mesh() - for a in avals: - # TODO(yashkatariya): Should be cur_mesh.unset - if cur_mesh.empty or a.sharding.mesh.empty: - continue - # avals can have meshes with different axis_names so allow that in - # full auto mode. - if a.sharding.mesh._are_all_axes_auto and cur_mesh._are_all_axes_auto: - continue - if a.sharding.mesh != cur_mesh: - raise ValueError( - f"For primitive {prim_name}, context mesh {cur_mesh} should match" - f" the aval mesh {a.sharding.mesh} for shape {a.str_short()}. This" - " error occurs at source: " - f" {source_info_util.summarize(source_info_util.current())}") + cur_mesh = mesh_lib.get_abstract_mesh() + for a in avals: + # TODO(yashkatariya): Should be cur_mesh.unset + if cur_mesh.empty or a.sharding.mesh.empty: + continue + # avals can have meshes with different axis_names so allow that in + # full auto mode. + if a.sharding.mesh._are_all_axes_auto and cur_mesh._are_all_axes_auto: + continue + if a.sharding.mesh != cur_mesh: + raise ValueError( + f"For primitive {prim_name}, context mesh {cur_mesh} should match" + f" the aval mesh {a.sharding.mesh} for shape {a.str_short()}. This" + " error occurs at source: " + f" {source_info_util.summarize(source_info_util.current())}") # -------------------- tracing -------------------- @@ -1498,7 +1495,7 @@ def check_valid_jaxtype(x): f"Value {x!r} of type {type(x)} is not a valid JAX type") def update_aval_with_sharding(aval, sharding): - if config.sharding_in_types.value and isinstance(sharding, NamedSharding): + if isinstance(sharding, NamedSharding): aval = aval.update(sharding=NamedSharding( sharding.mesh.abstract_mesh, sharding.spec._normalized_spec_for_aval(aval.ndim))) @@ -1761,9 +1758,6 @@ def _make_lengths_same(sharding, ndim): # TODO(dougalm): Cast scalar, numpy arrays, etc to jax arrays so that values # passed to primitives are always have avals, etc i.e. they are canonical. def canonicalize_value(val): - if not config.sharding_in_types.value: - return val - try: aval = get_aval(val) except TypeError: @@ -1783,8 +1777,6 @@ def canonicalize_value(val): def get_cur_mesh_sharding(spec=None): - if not config.sharding_in_types.value: - return None spec = P() if spec is None else spec return NamedSharding(mesh_lib.get_abstract_mesh(), spec) @@ -1845,8 +1837,7 @@ class ShapedArray(UnshapedArray): self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type - if config.sharding_in_types.value: - self.sharding = get_sharding(sharding, len(self.shape)) + self.sharding = get_sharding(sharding, len(self.shape)) def update(self, shape=None, dtype=None, weak_type=None, **kwargs): if shape is None: @@ -1856,7 +1847,7 @@ class ShapedArray(UnshapedArray): if weak_type is None: weak_type = self.weak_type if 'sharding' not in kwargs: - kwargs['sharding'] = getattr(self, 'sharding', None) + kwargs['sharding'] = self.sharding return ShapedArray(shape, dtype, weak_type, **kwargs) ndim = property(lambda self: len(self.shape)) @@ -1873,28 +1864,23 @@ class ShapedArray(UnshapedArray): return (type(self) is type(other) and self.dtype == other.dtype and self.shape == other.shape and self.weak_type == other.weak_type - and getattr(self, 'sharding', None) == getattr(other, 'sharding', None)) + and self.sharding == other.sharding) def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) - return hash((self.shape, self.dtype, self.weak_type, - getattr(self, 'sharding', None))) + return hash((self.shape, self.dtype, self.weak_type, self.sharding)) def to_tangent_aval(self): - if config.sharding_in_types.value: - return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type, sharding=self.sharding) - else: - return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), - self.weak_type) + return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), + self.weak_type, sharding=self.sharding) def str_short(self, short_dtypes=False): dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name) dt_str = dt_str.replace('void', 'float0') - if hasattr(self, 'sharding') and self.sharding is not None: + if self.sharding is not None: shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) return f'{dt_str}[{shapestr}]' else: @@ -2484,8 +2470,7 @@ def _map_shaped_array( assert axis is None or aval.shape[axis] == size # TODO: Extend the named shape if axis is None: return aval - sharding = (aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis)) - if config.sharding_in_types.value else None) + sharding = aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis)) return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype, weak_type=aval.weak_type, sharding=sharding) @@ -2494,9 +2479,8 @@ def _unmap_shaped_array( ) -> ShapedArray: if axis is None: return aval elif type(axis) is int: - sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis, - explicit_mesh_axis)) - if config.sharding_in_types.value else None) + sharding = aval.sharding.with_spec(tuple_insert( + aval.sharding.spec, axis, explicit_mesh_axis)) return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype, weak_type=aval.weak_type, sharding=sharding) else: raise TypeError(axis) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index f308f507a..6e5619399 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -1091,12 +1091,9 @@ def broadcast(x, sz, axis, mesh_axis=None): shape = list(np.shape(x)) shape.insert(axis, sz) broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis)) - if config.sharding_in_types.value: - x_aval = core.get_aval(x) - new_spec = P(*tuple_insert(x_aval.sharding.spec, axis, mesh_axis)) - sharding = x_aval.sharding.with_spec(new_spec) - else: - sharding = None + x_aval = core.get_aval(x) + new_spec = P(*tuple_insert(x_aval.sharding.spec, axis, mesh_axis)) + sharding = x_aval.sharding.with_spec(new_spec) return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding) def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False): diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 629cda3bf..8507e3e91 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1719,7 +1719,7 @@ def lower_jaxpr_to_fun( for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings, output_avals, unconstrained_shardings): # type: ignore if us[0] and not us[1]: - if config.use_shardy_partitioner.value and config.sharding_in_types.value: + if config.use_shardy_partitioner.value: s = modify_sdy_sharding_wrt_axis_types(s, o_aval.sharding.mesh) temp_flat_outputs.append(wrap_with_sharding_op( entry_lowering_ctx, o, o_aval, s, unspecified_dims=us[2])) diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index b77817d99..f5a9fc6ba 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -974,10 +974,6 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr, core.check_jaxpr(jaxpr_unknown) def check(first, second): - if not config.sharding_in_types.value: - assert first == second - return - for f, s in zip(first, second): if (not isinstance(f, core.ShapedArray) and not isinstance(s, core.ShapedArray)): diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index eb6c6b940..07571c7ec 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2311,10 +2311,9 @@ def lower_sharding_computation( propagated_out_mem_kinds = get_out_memory_kinds_via_propagation( closed_jaxpr, in_shardings) - if config.sharding_in_types.value: - out_shardings = _concretize_abstract_out_shardings( - out_shardings, global_out_avals, device_assignment, - propagated_out_mem_kinds) + out_shardings = _concretize_abstract_out_shardings( + out_shardings, global_out_avals, device_assignment, + propagated_out_mem_kinds) # 2. Build up the HLO diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index f7930b31d..8418f4501 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -231,8 +231,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]], xs_avals = [core.get_aval(x) for x in xs_flat] - if (config.sharding_in_types.value and - not all(a.sharding.spec[0] is None for a in xs_avals)): + if not all(a.sharding.spec[0] is None for a in xs_avals): raise ValueError('0th dimension of all xs should be replicated. Got ' f'{", ".join(str(a.sharding.spec) for a in xs_avals)}') @@ -504,8 +503,7 @@ def _split_leading(sz, x): def _concat(a, b): return lax.concatenate([a, b], 0) def _empty_array(prefix, length_spec, aval): - sharding = (aval.sharding.with_spec((length_spec, *aval.sharding.spec)) - if config.sharding_in_types.value else None) + sharding = aval.sharding.with_spec((length_spec, *aval.sharding.spec)) return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape), out_sharding=sharding) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index cd99c404e..86f240d15 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2025,9 +2025,6 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs`` non-contracting/non-batch dimensions. """ - if out_sharding is not None and not config.sharding_in_types.value: - raise NotImplementedError("out_sharding only works when sharding_in_types " - "config is True.") if out_sharding is not None and not isinstance(out_sharding, NamedSharding): raise NotImplementedError( '`out_sharding` argument of `dot_general` only supports NamedSharding ' @@ -2116,9 +2113,6 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape, See Also: jax.lax.broadcast : simpler interface to add new leading dimensions. """ - if not config.sharding_in_types.value and out_sharding is not None: - raise NotImplementedError("out_sharding argument to broadcast_in_dim is only " - "allowed when sharding_in_types config is on.") out_sharding = canonicalize_sharding(out_sharding) if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and isinstance(operand, Array) and out_sharding is None): @@ -2748,8 +2742,7 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *, shard = broadcast(fill_value, broadcast_shape) return array.make_array_from_callback(shape, sharding, lambda _: shard) - if (config.sharding_in_types.value and sharding is not None and - not sharding._is_concrete): + if sharding is not None and not sharding._is_concrete: return broadcast(fill_value, shape, out_sharding=sharding) else: return broadcast(fill_value, shape) @@ -2762,9 +2755,7 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array: scalar_zero = np.zeros((), dtype=aval.dtype) else: scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type) - if config.sharding_in_types.value: - return broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding) - return broadcast(scalar_zero, aval.shape) + return broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding) ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array @@ -2793,9 +2784,6 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int, static_shape = [None if isinstance(d, core.Tracer) else d for d in shape] dimension = core.concrete_or_error( int, dimension, "dimension argument of lax.broadcasted_iota") - if not config.sharding_in_types.value and out_sharding is not None: - raise NotImplementedError('sharding support for broadcasted_iota is not ' - 'implemented outside of sharding_in_types mode.') out_sharding = canonicalize_sharding(out_sharding) return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape), dimension=dimension, sharding=out_sharding) @@ -2959,8 +2947,7 @@ def full_like(x: ArrayLike | DuckTypedArray, if dtypes.issubdtype(dtype, dtypes.extended): return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr] - if (config.sharding_in_types.value and sharding is None and shape is None and - isinstance(x, core.Tracer)): + if sharding is None and shape is None and isinstance(x, core.Tracer): sharding = x.aval.sharding else: # If `x` has a sharding but no `_committed` attribute @@ -3461,14 +3448,10 @@ def _nary_lower_hlo(op: Callable, ctx, del params avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out args = mlir.multi_broadcast_in_dim(ctx, args, avals_in, aval_out.shape) - if config.sharding_in_types.value: - args = multi_sharding_in_dim(ctx, args, avals_in, aval_out) + args = multi_sharding_in_dim(ctx, args, avals_in, aval_out) out = op(*args) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] - else: - return [out] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] _float = {np.floating} @@ -3876,10 +3859,8 @@ def _integer_pow_lowering(ctx, x, *, y): if builtins.abs(y) >= 3: lowering = mlir.cache_lowering(lowering) out, = lowering(ctx, x, y=y) - if config.sharding_in_types.value: - aval_out, = ctx.avals_out - return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] - return [out] + aval_out, = ctx.avals_out + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] mlir.register_lowering(integer_pow_p, _integer_pow_lowering) @@ -4267,9 +4248,7 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type, operand = hlo.real(operand) aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype)) out = mlir.convert_hlo(ctx, operand, aval_in, aval_out) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] - return [out] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] mlir.register_lowering(convert_element_type_p, _convert_element_type_lower) @@ -4630,12 +4609,9 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y out_axes = np.argsort(unsorted_axes) - if config.sharding_in_types.value: - xs = x.aval.sharding - inverse_spec = tuple(xs.spec[o] for o in unsorted_axes) - ds = xs.with_spec(inverse_spec) - else: - ds = None + xs = x.aval.sharding + inverse_spec = tuple(xs.spec[o] for o in unsorted_axes) + ds = xs.with_spec(inverse_spec) dot_general_out = dot_general(g, y, dims, precision=precision, preferred_element_type=preferred_element_type, out_sharding=ds) @@ -5020,8 +4996,8 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, precision_config=precision_attr(precision), **algorithm_kwarg, ) - if config.sharding_in_types.value: - result = mlir.lower_sharding_under_shit(ctx, result, aval_out) + + result = mlir.lower_sharding_under_shit(ctx, result, aval_out) if accumulation_aval.dtype != aval_out.dtype: result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out) return [result] @@ -5487,9 +5463,7 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions, aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape)) out = mlir.broadcast_in_dim(ctx, x, aval_out, broadcast_dimensions=broadcast_dimensions) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] - return [out] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, sharding): @@ -5498,12 +5472,9 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions, type(core.get_aval(d).dtype) is core.bint for d in shape)): shape = _broadcast_in_dim_shape_rule( # error checking x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=None) - if config.sharding_in_types.value: - new_sharding = _broadcast_in_dim_sharding_rule( - x, shape=shape, broadcast_dimensions=broadcast_dimensions, - sharding=sharding) - else: - new_sharding = None + new_sharding = _broadcast_in_dim_sharding_rule( + x, shape=shape, broadcast_dimensions=broadcast_dimensions, + sharding=sharding) return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding) # If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray # (even if x is a ShapedArray) @@ -5680,9 +5651,7 @@ pe.padding_rules[concatenate_p] = _concatenate_pad_rule def _concatenate_lower(ctx, *xs, dimension): aval_out, = ctx.avals_out out = hlo.concatenate(xs, mlir.i64_attr(dimension)) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] - return [out] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] mlir.register_lowering(concatenate_p, _concatenate_lower) @@ -5734,8 +5703,7 @@ def _split_lower(ctx, x, *, sizes, axis): limit_indices[axis] = start_indices[axis] + aval_out.shape[axis] out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices, limit_indices=limit_indices, strides=strides) - outs.append(mlir.lower_sharding_under_shit(ctx, out, aval_out) - if config.sharding_in_types.value else out) + outs.append(mlir.lower_sharding_under_shit(ctx, out, aval_out)) start_indices[axis] = limit_indices[axis] return outs @@ -5840,9 +5808,7 @@ def _pad_lower(ctx, x, padding_value, *, padding_config): aval_out, = ctx.avals_out low, high, interior = util.unzip3(padding_config) out = mlir.pad(ctx, aval_out, x, padding_value, low, high, interior) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] - return [out] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] mlir.register_lowering(pad_p, _pad_lower) @@ -5908,9 +5874,7 @@ def _squeeze_lower(ctx, operand, *, dimensions): del dimensions # Implied by the output aval. aval_out, = ctx.avals_out out = mlir.reshape(ctx, operand, aval_out) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] - return [out] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] mlir.register_lowering(squeeze_p, _squeeze_lower) @@ -6073,16 +6037,11 @@ def _reshape_dtype_rule(operand, *, new_sizes, dimensions, sharding): def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding): assert ad.is_undefined_primal(operand) if dimensions is None: - if config.sharding_in_types.value: - return [reshape(t, operand.aval.shape, out_sharding=operand.aval.sharding)] - return [reshape(t, operand.aval.shape)] + return [reshape(t, operand.aval.shape, out_sharding=operand.aval.sharding)] else: - if config.sharding_in_types.value: - t_s = operand.aval.sharding.with_spec( - tuple(map(lambda s: s if s is None else str(s), - np.take(operand.aval.sharding.spec, dimensions)))) - else: - t_s = None + t_s = operand.aval.sharding.with_spec( + tuple(map(lambda s: s if s is None else str(s), + np.take(operand.aval.sharding.spec, dimensions)))) return [transpose(reshape(t, np.take(operand.aval.shape, dimensions), out_sharding=t_s), np.argsort(dimensions))] @@ -6110,9 +6069,7 @@ def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding): if dyn_shape: aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape)) out = mlir.reshape(ctx, x, aval_out) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] - return [out] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] def _reshape_staging_rule( trace, x, *dyn, new_sizes, dimensions, sharding): @@ -6162,9 +6119,7 @@ batching.primitive_batchers[rev_p] = _rev_batch_rule def _rev_lower(ctx, x, *, dimensions): aval_out, = ctx.avals_out out = hlo.reverse(x, mlir.dense_int_array(dimensions)) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] - return [out] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] mlir.register_lowering(rev_p, _rev_lower) @@ -6201,9 +6156,7 @@ def _transpose_lower(ctx, x, *, permutation): trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))] permutation = [*permutation, *trailing_dims] out = hlo.transpose(x, mlir.dense_int_array(permutation)) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] - return [out] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] transpose_p = standard_primitive( _transpose_shape_rule, _input_dtype, 'transpose', @@ -6345,10 +6298,6 @@ def _select_hlo_lowering_opaque(ctx, which, *cases): avals_in=[aval_which_bcast, *physical_avals_cases], avals_out=[physical_aval_out])[0] -def _add_shit_to_select(ctx, op, aval_out): - if config.sharding_in_types.value: - return mlir.lower_sharding_under_shit(ctx, op, aval_out) - return op def _select_hlo_lowering(ctx, which, *cases): which_aval = ctx.avals_in[0] @@ -6356,13 +6305,13 @@ def _select_hlo_lowering(ctx, which, *cases): if dtypes.issubdtype(aval_out.dtype, dtypes.extended): op = _select_hlo_lowering_opaque(ctx, which, *cases) - return [_add_shit_to_select(ctx, op, aval_out)] + return [mlir.lower_sharding_under_shit(ctx, op, aval_out)] if which_aval.dtype == np.dtype(np.bool_): assert len(cases) <= 2 if len(cases) == 1: return cases op = hlo.select(which, cases[1], cases[0]) - return [_add_shit_to_select(ctx, op, aval_out)] + return [mlir.lower_sharding_under_shit(ctx, op, aval_out)] if dtypes.issubdtype(which_aval.dtype, np.signedinteger): compare_type = 'SIGNED' @@ -6382,7 +6331,7 @@ def _select_hlo_lowering(ctx, which, *cases): _select(offset + mid, cases[mid:])) op = _select(0, cases) - return [_add_shit_to_select(ctx, op, aval_out)] + return [mlir.lower_sharding_under_shit(ctx, op, aval_out)] select_n_p = standard_primitive( _select_shape_rule, _select_dtype_rule, 'select_n', @@ -6514,10 +6463,8 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions): *reducer.arguments, dim_var_values=ctx.dim_var_values) hlo.return_(mlir.flatten_ir_values(out_nodes)) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, r, aval) - for r, aval in safe_zip(op.results, ctx.avals_out)] - return op.results + return [mlir.lower_sharding_under_shit(ctx, r, aval) + for r, aval in safe_zip(op.results, ctx.avals_out)] mlir.register_lowering(reduce_p, _reduce_lower) @@ -6532,11 +6479,8 @@ def _reduce_sum_transpose_rule(cotangent, operand, *, axes): assert ad.is_undefined_primal(operand) input_shape = operand.aval.shape broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes)) - if config.sharding_in_types.value: - result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions, - out_sharding=operand.aval.sharding) - else: - result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions) + result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions, + out_sharding=operand.aval.sharding) assert result.shape == input_shape return [result] @@ -6674,7 +6618,7 @@ def _compute_argminmax(value_comparator, get_identity, axis, = axes indices = broadcasted_iota( index_dtype, np.shape(operand), axis, - out_sharding=operand.aval.sharding if config.sharding_in_types.value else None) + out_sharding=operand.aval.sharding) res = reduce([operand, indices], [get_identity(operand.dtype), np.array(0, index_dtype)], _ArgMinMaxReducer(value_comparator), @@ -6739,9 +6683,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes): reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer_region): hlo.return_([reducer(*reducer_region.arguments)]) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, op.result, aval_out)] - return op.results + return [mlir.lower_sharding_under_shit(ctx, op.result, aval_out)] mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp, _get_sum_identity)) @@ -6782,9 +6724,7 @@ def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits): aval_out, = ctx.avals_out out = hlo.reduce_precision(operand, mlir.i32_attr(exponent_bits), mlir.i32_attr(mantissa_bits)) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] - return [out] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] mlir.register_lowering(reduce_precision_p, _reduce_precision_lower) @@ -6924,8 +6864,7 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys): mlir.flatten_ir_values(operands), dimension=mlir.i64_attr(dimension), is_stable=ir.BoolAttr.get(is_stable)) - scalar_s = (lambda a: a.sharding.with_spec(P()) - if config.sharding_in_types.value else lambda _: None) + scalar_s = lambda a: a.sharding.with_spec(P()) scalar_avals = [aval.update(shape=(), sharding=scalar_s(aval)) for aval in ctx.avals_in] scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals) @@ -7458,9 +7397,7 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension, sharding): if dyn_shape: aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape)) out = mlir.iota(ctx, aval_out, dimension=dimension) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] - return [out] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] mlir.register_lowering(iota_p, _iota_lower) def _iota_batching_rule(in_vals, in_dims, *, dtype, shape, dimension, @@ -7668,20 +7605,16 @@ def _const(example, val): _zeros: Callable = partial(full_like, fill_value=0) def _zero(x): - if config.sharding_in_types.value: - x_aval = core.get_aval(x) - return full_like(x, shape=(), fill_value=0, - sharding=x_aval.sharding.with_spec(P())) - return full_like(x, shape=(), fill_value=0) + x_aval = core.get_aval(x) + return full_like(x, shape=(), fill_value=0, + sharding=x_aval.sharding.with_spec(P())) _ones: Callable = partial(full_like, fill_value=1) def _one(x): - if config.sharding_in_types.value: - x_aval = core.get_aval(x) - return full_like(x, shape=(), fill_value=1, - sharding=x_aval.sharding.with_spec(P())) - return full_like(x, shape=(), fill_value=1) + x_aval = core.get_aval(x) + return full_like(x, shape=(), fill_value=1, + sharding=x_aval.sharding.with_spec(P())) _twos: Callable = partial(full_like, fill_value=2) _two: Callable = partial(full_like, shape=(), fill_value=2) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 0702ae516..4f628568a 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1041,16 +1041,13 @@ def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues): batch_dims = operand.shape[:-2] n = operand.shape[-1] - if config.sharding_in_types.value: - batch_s = operand.sharding.spec[:-2] - ns = operand.sharding.spec[-1] - if ns is not None: - raise ValueError(f'n should be unsharded. Got n: {ns}' - ' specs. Try marking their specs as None.') - w_s = operand.sharding.with_spec(P(*batch_s + (ns,))) - v_s = operand.sharding.with_spec(P(*batch_s + (ns, ns))) - else: - w_s, v_s = None, None + batch_s = operand.sharding.spec[:-2] + ns = operand.sharding.spec[-1] + if ns is not None: + raise ValueError(f'n should be unsharded. Got n: {ns}' + ' specs. Try marking their specs as None.') + w_s = operand.sharding.with_spec(P(*batch_s + (ns,))) + v_s = operand.sharding.with_spec(P(*batch_s + (ns, ns))) w = operand.update(shape=batch_dims + (n,), dtype=lax_internal._complex_basetype(operand.dtype), sharding=w_s) @@ -1123,16 +1120,13 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index): n = operand.shape[-1] d = (n if subset_by_index is None else subset_by_index[1] - subset_by_index[0]) - if config.sharding_in_types.value: - batch_s = operand.sharding.spec[:-2] - ns, ds = operand.sharding.spec[-1], None - if ns is not None: - raise ValueError(f'n should be unsharded. Got n: {ns} specs. Try ' - 'marking their specs as None.') - v_s = operand.sharding.with_spec(P(*batch_s + (ns, ds))) - w_s = operand.sharding.with_spec(P(*batch_s + (ds,))) - else: - v_s, w_s = None, None + batch_s = operand.sharding.spec[:-2] + ns, ds = operand.sharding.spec[-1], None + if ns is not None: + raise ValueError(f'n should be unsharded. Got n: {ns} specs. Try ' + 'marking their specs as None.') + v_s = operand.sharding.with_spec(P(*batch_s + (ns, ds))) + w_s = operand.sharding.with_spec(P(*batch_s + (ds,))) v = operand.update(shape=batch_dims + (n, d), sharding=v_s) w = operand.update( shape=batch_dims + (d,), @@ -1450,9 +1444,7 @@ def _triangular_solve_lowering( ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal), hlo.TransposeAttr.get(transpose)) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, out, out_aval)] - return [out] + return [mlir.lower_sharding_under_shit(ctx, out, out_aval)] def _triangular_solve_cpu_lower( @@ -1901,15 +1893,12 @@ def _geqrf_abstract_eval(operand): if operand.ndim < 2: raise ValueError("Argument to QR decomposition must have ndims >= 2") *batch_dims, m, n = operand.shape - if config.sharding_in_types.value: - spec = operand.sharding.spec - batch_s, ms, ns = spec[:-2], spec[-2], spec[-1] - if ms is not None or ns is not None: - raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}' - ' specs. Try marking their specs as None.') - taus_s = operand.sharding.with_spec(P(*(*batch_s, None))) - else: - taus_s = None + spec = operand.sharding.spec + batch_s, ms, ns = spec[:-2], spec[-2], spec[-1] + if ms is not None or ns is not None: + raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}' + ' specs. Try marking their specs as None.') + taus_s = operand.sharding.with_spec(P(*(*batch_s, None))) taus = operand.update(shape=(*batch_dims, core.min_dim(m, n)), sharding=taus_s) return operand, taus @@ -2117,17 +2106,16 @@ def _qr_abstract_eval(operand, *, pivoting, full_matrices, use_magma): raise ValueError("Argument to QR decomposition must have ndims >= 2") *batch_dims, m, n = operand.shape k = m if full_matrices else core.min_dim(m, n) - if config.sharding_in_types.value: - *batch_s, ms, ns = operand.sharding.spec - ks = None - if ms is not None or ns is not None: - raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}' - ' specs. Try marking their specs as None.') - q_s = operand.sharding.with_spec(P(*(*batch_s, ms, ks))) - r_s = operand.sharding.with_spec(P(*(*batch_s, ks, ns))) - p_s = operand.sharding.with_spec(P(*(*batch_s, ns))) - else: - q_s, r_s, p_s = None, None, None + + *batch_s, ms, ns = operand.sharding.spec + ks = None + if ms is not None or ns is not None: + raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}' + ' specs. Try marking their specs as None.') + q_s = operand.sharding.with_spec(P(*(*batch_s, ms, ks))) + r_s = operand.sharding.with_spec(P(*(*batch_s, ks, ns))) + p_s = operand.sharding.with_spec(P(*(*batch_s, ns))) + q = operand.update(shape=(*batch_dims, m, k), sharding=q_s) r = operand.update(shape=(*batch_dims, k, n), sharding=r_s) p = operand.update(shape=(*batch_dims, n), dtype=np.dtype(np.int32), @@ -2241,21 +2229,18 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index, raise ValueError("full_matrices and subset_by_index cannot both be set") rank = min(rank, subset_by_index[1] - subset_by_index[0]) - if config.sharding_in_types.value: - batch_s = operand.sharding.spec[:-2] - ms = operand.sharding.spec[-2] - ns = operand.sharding.spec[-1] - if ms is not None or ns is not None: - raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}' - ' specs. Try marking their specs as None.') - rank_s = None - s_sharding = operand.sharding.with_spec(P(*batch_s + (rank_s,))) - u_sharding = operand.sharding.with_spec( - P(*batch_s + (ms, ms if full_matrices else rank_s))) - vt_sharding = operand.sharding.with_spec( - P(*batch_s + (ns if full_matrices else rank_s, ns))) - else: - s_sharding, u_sharding, vt_sharding = None, None, None + batch_s = operand.sharding.spec[:-2] + ms = operand.sharding.spec[-2] + ns = operand.sharding.spec[-1] + if ms is not None or ns is not None: + raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}' + ' specs. Try marking their specs as None.') + rank_s = None + s_sharding = operand.sharding.with_spec(P(*batch_s + (rank_s,))) + u_sharding = operand.sharding.with_spec( + P(*batch_s + (ms, ms if full_matrices else rank_s))) + vt_sharding = operand.sharding.with_spec( + P(*batch_s + (ns if full_matrices else rank_s, ns))) s = operand.update( shape=batch_dims + (rank,), diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 5abd55036..b556042fe 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -24,7 +24,6 @@ import math from jax import tree_util from jax._src import core -from jax._src import config from jax._src import dispatch from jax._src import dtypes from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext, @@ -732,16 +731,12 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups): if len(pos_axes) != 0: raise ValueError(f"axis_index_groups can only be used with reductions over " f"named axes, but got: {axes}") - if config.sharding_in_types.value: - core.check_avals_context_mesh(args, 'all_reduce') - out_avals = [ - ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, - sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes)) - for arg in args - ] - else: - out_avals = [ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype) - for arg in args] + core.check_avals_context_mesh(args, 'all_reduce') + out_avals = [ + ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype, + sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes)) + for arg in args + ] return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes} def _check_axis_names(axes): @@ -795,11 +790,8 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups): else: op = hlo.AllReduceOp( [x.type], [x], replica_groups=replica_groups, **other_args) - if config.sharding_in_types.value: - scalar_aval = core.ShapedArray( - (), aval.dtype, sharding=NamedSharding(aval.sharding.mesh, P())) - else: - scalar_aval = core.ShapedArray((), aval.dtype) + scalar_aval = core.ShapedArray( + (), aval.dtype, sharding=NamedSharding(aval.sharding.mesh, P())) scalar_type = mlir.aval_to_ir_type(scalar_aval) reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer_block): diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 5dfd4938b..2d6182ead 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1370,9 +1370,7 @@ def _slice_lower(ctx, x, *, start_indices, limit_indices, strides): aval_out, = ctx.avals_out out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices, limit_indices=limit_indices, strides=strides) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] - return [out] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] mlir.register_lowering(slice_p, _slice_lower) @@ -1525,9 +1523,7 @@ def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes): if dyn: aval_out = aval_out.update(shape=lax._merge_dyn_shape(slice_sizes, dyn)) out = mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] - return [out] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower) @@ -1642,9 +1638,7 @@ def _dynamic_update_slice_lower(ctx, x, update, *start_indices): aval_out, = ctx.avals_out out = mlir.dynamic_update_slice(ctx, aval_out, x, update, start_indices=start_indices) - if config.sharding_in_types.value: - return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] - return [out] + return [mlir.lower_sharding_under_shit(ctx, out, aval_out)] mlir.register_lowering(dynamic_update_slice_p, _dynamic_update_slice_lower) diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 44760cffd..2b329291b 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -20,7 +20,6 @@ from functools import partial from jax._src import core from jax._src import dispatch -from jax._src import config from jax._src import dtypes from jax._src import mesh as mesh_lib from jax._src.util import safe_zip @@ -50,8 +49,6 @@ def standard_primitive(shape_rule, dtype_rule, name, def _get_array_abstraction_level(a): return a.array_abstraction_level def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh: - if not config.sharding_in_types.value: - return None # type: ignore m = None for a in in_avals: if a is core.abstract_token: @@ -69,22 +66,20 @@ def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh: def call_sharding_rule(prim, rule, num_out, *avals, **kwargs): - if config.sharding_in_types.value: - cur_mesh = mesh_lib.get_abstract_mesh() - aval_mesh = _get_abstract_mesh_from_avals(avals) - if ((cur_mesh.empty or cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual) and - (aval_mesh.empty or aval_mesh._are_all_axes_auto or aval_mesh._are_all_axes_manual)): - aval_mesh = cur_mesh if aval_mesh.empty else aval_mesh - s = NamedSharding(aval_mesh, P()) - return s if num_out is None else [s] * num_out - if rule is None: - raise ValueError( - f'sharding rule for {prim.name} is not implemented. Please file a' - ' bug at https://github.com/jax-ml/jax/issues. You can work around' - ' this error by dropping that operation into full auto sharding' - ' mode via: `jax.experimental.shard.auto_axes(fun, out_shardings=...)`') - return rule(*avals, **kwargs) - return None if num_out is None else [None] * num_out + cur_mesh = mesh_lib.get_abstract_mesh() + aval_mesh = _get_abstract_mesh_from_avals(avals) + if ((cur_mesh.empty or cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual) and + (aval_mesh.empty or aval_mesh._are_all_axes_auto or aval_mesh._are_all_axes_manual)): + aval_mesh = cur_mesh if aval_mesh.empty else aval_mesh + s = NamedSharding(aval_mesh, P()) + return s if num_out is None else [s] * num_out + if rule is None: + raise ValueError( + f'sharding rule for {prim.name} is not implemented. Please file a' + ' bug at https://github.com/jax-ml/jax/issues. You can work around' + ' this error by dropping that operation into full auto sharding' + ' mode via: `jax.experimental.shard.auto_axes(fun, out_shardings=...)`') + return rule(*avals, **kwargs) def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule, *avals, **kwargs): diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 512d88d07..8f23b8ba0 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -563,6 +563,5 @@ def get_concrete_mesh(): @contextlib.contextmanager def use_mesh(mesh: Mesh): - with (set_abstract_mesh(mesh.abstract_mesh), - jax_config.sharding_in_types(True), set_concrete_mesh(mesh)): + with set_abstract_mesh(mesh.abstract_mesh), set_concrete_mesh(mesh): yield diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 6138503bc..92c176014 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -665,11 +665,8 @@ def _one_hot(x: Array, num_classes: int, *, lhs = lax.expand_dims(x, (axis,)) rhs_shape = [1] * x.ndim rhs_shape.insert(output_pos_axis, num_classes) - if config.sharding_in_types.value: - # TODO(yashkatariya): Maybe expose `out_sharding` on `one_hot` too? - rhs_sharding = NamedSharding(x.aval.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error - else: - rhs_sharding = None + # TODO(yashkatariya): Maybe expose `out_sharding` on `one_hot` too? + rhs_sharding = NamedSharding(x.aval.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis, out_sharding=rhs_sharding) return (lhs == rhs).astype(dtype) diff --git a/jax/_src/numpy/einsum.py b/jax/_src/numpy/einsum.py index 1882aeb72..dcf52fe64 100644 --- a/jax/_src/numpy/einsum.py +++ b/jax/_src/numpy/einsum.py @@ -411,9 +411,6 @@ def _einsum( _dot_general=lax.dot_general, out_sharding=None, ): - if out_sharding is not None and not config.sharding_in_types.value: - raise NotImplementedError("out_sharding only works when sharding_in_types " - "config is True.") out_sharding = canonicalize_sharding(out_sharding) if out_sharding is not None and not isinstance(out_sharding, NamedSharding): raise NotImplementedError( @@ -546,8 +543,7 @@ def _einsum( **k_out_sharding) else: names = batch_names_str + remaining_lhs_names + remaining_rhs_names - if (config.sharding_in_types.value and out_sharding is not None and - names != result_names): + if out_sharding is not None and names != result_names: spec = out_sharding.spec inverse_spec = tuple(spec[result_names.index(name)] for name in names) dot_general_out_sharding = NamedSharding(out_sharding.mesh, diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 6bbef9b2a..47791a216 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -5463,8 +5463,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, # whenever x is weak, but avoids introducing weak types with something like # array([1, 2, 3]) weak_type = dtype is None and dtypes.is_weakly_typed(object) - if (config.sharding_in_types.value and device is None and - isinstance(object, core.Tracer)): + if device is None and isinstance(object, core.Tracer): sharding = object.aval.sharding sharding = None if sharding.mesh.empty else sharding else: diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 82170ae1e..1db2e0bde 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -250,8 +250,7 @@ def _broadcast_arrays(*args: ArrayLike) -> list[Array]: if not shapes or all(core.definitely_equal_shape(shapes[0], s) for s in shapes): return [lax.asarray(arg) for arg in args] result_shape = lax.broadcast_shapes(*shapes) - result_sharding = (lax.broadcast_shardings(*avals) - if config.sharding_in_types.value else None) + result_sharding = lax.broadcast_shardings(*avals) return [_broadcast_to(arg, result_shape, result_sharding) for arg in args] diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 302378ea6..26eeadd7d 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -891,8 +891,7 @@ def _pallas_call_batching_rule( batched_out_avals = [] for aval in out_avals: - sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, 0, None)) - if config.sharding_in_types.value else None) + sharding = aval.sharding.with_spec(tuple_insert(aval.sharding.spec, 0, None)) shape = tuple_insert(aval.shape, 0, axis_size) batched_out_avals.append(aval.update(shape=shape, sharding=sharding)) batched_out_avals = tuple(batched_out_avals) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index a57405947..785eb4086 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1746,14 +1746,10 @@ def _pjit_lower( lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): util.test_event("pjit_lower") - if config.sharding_in_types.value: - if resource_env is not None: - mesh, api_name = resource_env.physical_mesh, 'pjit' - else: - mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit' + if resource_env is not None: + mesh, api_name = resource_env.physical_mesh, 'pjit' else: - mesh, api_name = ((resource_env.physical_mesh, 'pjit') - if resource_env is not None else (None, 'jit')) + mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit' return pxla.lower_sharding_computation( jaxpr, api_name, name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), @@ -2494,9 +2490,6 @@ state_discharge.register_discharge_rule(pjit_p)(_pjit_state_discharge_rule) # -------------------- with_sharding_constraint -------------------- def check_shardings_are_auto(shardings_flat): - if not config.sharding_in_types.value: - return - for s in shardings_flat: if not isinstance(s, NamedSharding): continue diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 764018a69..c86a4bcad 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -22,7 +22,6 @@ import math from typing import Any, NamedTuple, cast from jax._src import core -from jax._src import config from jax._src import mesh as mesh_lib from jax._src import sharding as jsharding from jax._src import sharding_specs @@ -1254,8 +1253,6 @@ def flatten_spec(spec): def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None, check_mesh_consistency: bool = True ) -> NamedSharding | None: - if not config.sharding_in_types.value: - return sharding # type: ignore if sharding is None: return sharding if isinstance(sharding, NamedSharding) and sharding.mesh.empty: diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 547ed7f0c..8359f1b10 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -480,10 +480,8 @@ shard_map_p = ShardMapPrimitive('shard_map') # Staging def _as_manual_mesh(mesh): - if config.sharding_in_types.value: - return AbstractMesh( - mesh.shape_tuple, axis_types={AxisTypes.Manual: mesh.axis_names}) - return None + return AbstractMesh( + mesh.shape_tuple, axis_types={AxisTypes.Manual: mesh.axis_names}) def _extend_axis_env(mesh, auto): return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items() @@ -554,12 +552,9 @@ def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue assert isinstance(aval, core.ShapedArray) new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape)) - if config.sharding_in_types.value: - new_mesh = AbstractMesh( - mesh.shape_tuple, axis_types={AxisTypes.Manual: mesh.axis_names}) - new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim)) - else: - new_sharding = None + new_mesh = AbstractMesh( + mesh.shape_tuple, axis_types={AxisTypes.Manual: mesh.axis_names}) + new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim)) return aval.update(shape=new_shape, sharding=new_sharding) core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array @@ -568,13 +563,10 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames, assert isinstance(aval, core.ShapedArray) new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape)) - if config.sharding_in_types.value: - spec = _names_to_pspec(names)._normalized_spec_for_aval(aval.ndim) - new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else - get_abstract_mesh()) - new_sharding = NamedSharding(new_mesh, spec) - else: - new_sharding = None + spec = _names_to_pspec(names)._normalized_spec_for_aval(aval.ndim) + new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else + get_abstract_mesh()) + new_sharding = NamedSharding(new_mesh, spec) return aval.update(shape=new_shape, sharding=new_sharding) core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array @@ -979,11 +971,8 @@ class ShardMapTracer(core.Tracer): def aval(self): aval = core.get_aval(self.val) out = core.mapped_aval(self._trace.mesh.size, 0, aval) - if config.sharding_in_types.value: - new_sharding = NamedSharding(_as_manual_mesh(self._trace.mesh), - out.sharding.spec) # pytype: disable=attribute-error - else: - new_sharding = None + new_sharding = NamedSharding(_as_manual_mesh(self._trace.mesh), + out.sharding.spec) # pytype: disable=attribute-error return out.update(sharding=new_sharding) def to_concrete_value(self):