Now that sharding_in_types config flag is True, remove the config and all the conditionals

PiperOrigin-RevId: 728653433
This commit is contained in:
Yash Katariya 2025-02-19 06:52:52 -08:00 committed by jax authors
parent d5e5b42de8
commit a3edfb43ef
22 changed files with 170 additions and 340 deletions

View File

@ -1029,9 +1029,6 @@ def vmap(fun: F,
return cast(F, vmap_f)
def _mapped_axis_spec(args_flat, in_axes_flat):
if not config.sharding_in_types.value:
return None
def _get_spec(arg, i):
try:
# Duck type arrays like BCOO arrays can be passed to vmap.

View File

@ -234,7 +234,6 @@ def trace_context():
default_device.value, random_seed_offset.value,
threefry_partitionable.value,
threefry_gpu_kernel_lowering.value,
sharding_in_types.value,
use_direct_linearize.value,
softmax_custom_jvp.value,
disable_jit.value,
@ -1067,13 +1066,6 @@ threefry_gpu_kernel_lowering = bool_state(
'cost.'),
include_in_jit_key=True)
sharding_in_types = bool_state(
name='jax_sharding_in_types',
default=True,
help=('When True, enables forward only sharding propagation in JAX and '
'avals have sharding on them.'),
include_in_jit_key=True)
use_direct_linearize = bool_state(
name='jax_use_direct_linearize',
default=False,

View File

@ -498,8 +498,6 @@ class Primitive:
return f'{self.name}'
def bind(self, *args, **params):
if not config.sharding_in_types.value:
return self._true_bind(*args, **params)
args = args if self.skip_canonicalization else map(canonicalize_value, args)
return self._true_bind(*args, **params)
@ -598,7 +596,6 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True) -> list[
return map(read, jaxpr.outvars)
def check_avals_context_mesh(avals, prim_name):
if config.sharding_in_types.value:
cur_mesh = mesh_lib.get_abstract_mesh()
for a in avals:
# TODO(yashkatariya): Should be cur_mesh.unset
@ -1498,7 +1495,7 @@ def check_valid_jaxtype(x):
f"Value {x!r} of type {type(x)} is not a valid JAX type")
def update_aval_with_sharding(aval, sharding):
if config.sharding_in_types.value and isinstance(sharding, NamedSharding):
if isinstance(sharding, NamedSharding):
aval = aval.update(sharding=NamedSharding(
sharding.mesh.abstract_mesh,
sharding.spec._normalized_spec_for_aval(aval.ndim)))
@ -1761,9 +1758,6 @@ def _make_lengths_same(sharding, ndim):
# TODO(dougalm): Cast scalar, numpy arrays, etc to jax arrays so that values
# passed to primitives are always have avals, etc i.e. they are canonical.
def canonicalize_value(val):
if not config.sharding_in_types.value:
return val
try:
aval = get_aval(val)
except TypeError:
@ -1783,8 +1777,6 @@ def canonicalize_value(val):
def get_cur_mesh_sharding(spec=None):
if not config.sharding_in_types.value:
return None
spec = P() if spec is None else spec
return NamedSharding(mesh_lib.get_abstract_mesh(), spec)
@ -1845,7 +1837,6 @@ class ShapedArray(UnshapedArray):
self.shape = canonicalize_shape(shape)
self.dtype = _dtype_object(dtype)
self.weak_type = weak_type
if config.sharding_in_types.value:
self.sharding = get_sharding(sharding, len(self.shape))
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
@ -1856,7 +1847,7 @@ class ShapedArray(UnshapedArray):
if weak_type is None:
weak_type = self.weak_type
if 'sharding' not in kwargs:
kwargs['sharding'] = getattr(self, 'sharding', None)
kwargs['sharding'] = self.sharding
return ShapedArray(shape, dtype, weak_type, **kwargs)
ndim = property(lambda self: len(self.shape))
@ -1873,28 +1864,23 @@ class ShapedArray(UnshapedArray):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape
and self.weak_type == other.weak_type
and getattr(self, 'sharding', None) == getattr(other, 'sharding', None))
and self.sharding == other.sharding)
def __hash__(self):
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
# the unique character code via hash(self.dtype.char)
return hash((self.shape, self.dtype, self.weak_type,
getattr(self, 'sharding', None)))
return hash((self.shape, self.dtype, self.weak_type, self.sharding))
def to_tangent_aval(self):
if config.sharding_in_types.value:
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type, sharding=self.sharding)
else:
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)
def str_short(self, short_dtypes=False):
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
self.dtype.name)
dt_str = dt_str.replace('void', 'float0')
if hasattr(self, 'sharding') and self.sharding is not None:
if self.sharding is not None:
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
return f'{dt_str}[{shapestr}]'
else:
@ -2484,8 +2470,7 @@ def _map_shaped_array(
assert axis is None or aval.shape[axis] == size
# TODO: Extend the named shape
if axis is None: return aval
sharding = (aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis))
if config.sharding_in_types.value else None)
sharding = aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis))
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
weak_type=aval.weak_type, sharding=sharding)
@ -2494,9 +2479,8 @@ def _unmap_shaped_array(
) -> ShapedArray:
if axis is None: return aval
elif type(axis) is int:
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis,
explicit_mesh_axis))
if config.sharding_in_types.value else None)
sharding = aval.sharding.with_spec(tuple_insert(
aval.sharding.spec, axis, explicit_mesh_axis))
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
weak_type=aval.weak_type, sharding=sharding)
else: raise TypeError(axis)

View File

@ -1091,12 +1091,9 @@ def broadcast(x, sz, axis, mesh_axis=None):
shape = list(np.shape(x))
shape.insert(axis, sz)
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
if config.sharding_in_types.value:
x_aval = core.get_aval(x)
new_spec = P(*tuple_insert(x_aval.sharding.spec, axis, mesh_axis))
sharding = x_aval.sharding.with_spec(new_spec)
else:
sharding = None
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding)
def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False):

View File

@ -1719,7 +1719,7 @@ def lower_jaxpr_to_fun(
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
output_avals, unconstrained_shardings): # type: ignore
if us[0] and not us[1]:
if config.use_shardy_partitioner.value and config.sharding_in_types.value:
if config.use_shardy_partitioner.value:
s = modify_sdy_sharding_wrt_axis_types(s, o_aval.sharding.mesh)
temp_flat_outputs.append(wrap_with_sharding_op(
entry_lowering_ctx, o, o_aval, s, unspecified_dims=us[2]))

View File

@ -974,10 +974,6 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr,
core.check_jaxpr(jaxpr_unknown)
def check(first, second):
if not config.sharding_in_types.value:
assert first == second
return
for f, s in zip(first, second):
if (not isinstance(f, core.ShapedArray) and
not isinstance(s, core.ShapedArray)):

View File

@ -2311,7 +2311,6 @@ def lower_sharding_computation(
propagated_out_mem_kinds = get_out_memory_kinds_via_propagation(
closed_jaxpr, in_shardings)
if config.sharding_in_types.value:
out_shardings = _concretize_abstract_out_shardings(
out_shardings, global_out_avals, device_assignment,
propagated_out_mem_kinds)

View File

@ -231,8 +231,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
xs_avals = [core.get_aval(x) for x in xs_flat]
if (config.sharding_in_types.value and
not all(a.sharding.spec[0] is None for a in xs_avals)):
if not all(a.sharding.spec[0] is None for a in xs_avals):
raise ValueError('0th dimension of all xs should be replicated. Got '
f'{", ".join(str(a.sharding.spec) for a in xs_avals)}')
@ -504,8 +503,7 @@ def _split_leading(sz, x):
def _concat(a, b): return lax.concatenate([a, b], 0)
def _empty_array(prefix, length_spec, aval):
sharding = (aval.sharding.with_spec((length_spec, *aval.sharding.spec))
if config.sharding_in_types.value else None)
sharding = aval.sharding.with_spec((length_spec, *aval.sharding.spec))
return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape),
out_sharding=sharding)

View File

@ -2025,9 +2025,6 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
non-contracting/non-batch dimensions.
"""
if out_sharding is not None and not config.sharding_in_types.value:
raise NotImplementedError("out_sharding only works when sharding_in_types "
"config is True.")
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
raise NotImplementedError(
'`out_sharding` argument of `dot_general` only supports NamedSharding '
@ -2116,9 +2113,6 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
See Also:
jax.lax.broadcast : simpler interface to add new leading dimensions.
"""
if not config.sharding_in_types.value and out_sharding is not None:
raise NotImplementedError("out_sharding argument to broadcast_in_dim is only "
"allowed when sharding_in_types config is on.")
out_sharding = canonicalize_sharding(out_sharding)
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and
isinstance(operand, Array) and out_sharding is None):
@ -2748,8 +2742,7 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
shard = broadcast(fill_value, broadcast_shape)
return array.make_array_from_callback(shape, sharding, lambda _: shard)
if (config.sharding_in_types.value and sharding is not None and
not sharding._is_concrete):
if sharding is not None and not sharding._is_concrete:
return broadcast(fill_value, shape, out_sharding=sharding)
else:
return broadcast(fill_value, shape)
@ -2762,9 +2755,7 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array:
scalar_zero = np.zeros((), dtype=aval.dtype)
else:
scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type)
if config.sharding_in_types.value:
return broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding)
return broadcast(scalar_zero, aval.shape)
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
@ -2793,9 +2784,6 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int,
static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
dimension = core.concrete_or_error(
int, dimension, "dimension argument of lax.broadcasted_iota")
if not config.sharding_in_types.value and out_sharding is not None:
raise NotImplementedError('sharding support for broadcasted_iota is not '
'implemented outside of sharding_in_types mode.')
out_sharding = canonicalize_sharding(out_sharding)
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
dimension=dimension, sharding=out_sharding)
@ -2959,8 +2947,7 @@ def full_like(x: ArrayLike | DuckTypedArray,
if dtypes.issubdtype(dtype, dtypes.extended):
return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr]
if (config.sharding_in_types.value and sharding is None and shape is None and
isinstance(x, core.Tracer)):
if sharding is None and shape is None and isinstance(x, core.Tracer):
sharding = x.aval.sharding
else:
# If `x` has a sharding but no `_committed` attribute
@ -3461,14 +3448,10 @@ def _nary_lower_hlo(op: Callable, ctx,
del params
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
args = mlir.multi_broadcast_in_dim(ctx, args, avals_in, aval_out.shape)
if config.sharding_in_types.value:
args = multi_sharding_in_dim(ctx, args, avals_in, aval_out)
out = op(*args)
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
else:
return [out]
_float = {np.floating}
@ -3876,10 +3859,8 @@ def _integer_pow_lowering(ctx, x, *, y):
if builtins.abs(y) >= 3:
lowering = mlir.cache_lowering(lowering)
out, = lowering(ctx, x, y=y)
if config.sharding_in_types.value:
aval_out, = ctx.avals_out
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(integer_pow_p, _integer_pow_lowering)
@ -4267,9 +4248,7 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
operand = hlo.real(operand)
aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype))
out = mlir.convert_hlo(ctx, operand, aval_in, aval_out)
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(convert_element_type_p, _convert_element_type_lower)
@ -4630,12 +4609,9 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y
out_axes = np.argsort(unsorted_axes)
if config.sharding_in_types.value:
xs = x.aval.sharding
inverse_spec = tuple(xs.spec[o] for o in unsorted_axes)
ds = xs.with_spec(inverse_spec)
else:
ds = None
dot_general_out = dot_general(g, y, dims, precision=precision,
preferred_element_type=preferred_element_type,
out_sharding=ds)
@ -5020,7 +4996,7 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
precision_config=precision_attr(precision),
**algorithm_kwarg,
)
if config.sharding_in_types.value:
result = mlir.lower_sharding_under_shit(ctx, result, aval_out)
if accumulation_aval.dtype != aval_out.dtype:
result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out)
@ -5487,9 +5463,7 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions,
aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
out = mlir.broadcast_in_dim(ctx, x, aval_out,
broadcast_dimensions=broadcast_dimensions)
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions,
sharding):
@ -5498,12 +5472,9 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions,
type(core.get_aval(d).dtype) is core.bint for d in shape)):
shape = _broadcast_in_dim_shape_rule( # error checking
x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=None)
if config.sharding_in_types.value:
new_sharding = _broadcast_in_dim_sharding_rule(
x, shape=shape, broadcast_dimensions=broadcast_dimensions,
sharding=sharding)
else:
new_sharding = None
return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding)
# If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray
# (even if x is a ShapedArray)
@ -5680,9 +5651,7 @@ pe.padding_rules[concatenate_p] = _concatenate_pad_rule
def _concatenate_lower(ctx, *xs, dimension):
aval_out, = ctx.avals_out
out = hlo.concatenate(xs, mlir.i64_attr(dimension))
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(concatenate_p, _concatenate_lower)
@ -5734,8 +5703,7 @@ def _split_lower(ctx, x, *, sizes, axis):
limit_indices[axis] = start_indices[axis] + aval_out.shape[axis]
out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices,
limit_indices=limit_indices, strides=strides)
outs.append(mlir.lower_sharding_under_shit(ctx, out, aval_out)
if config.sharding_in_types.value else out)
outs.append(mlir.lower_sharding_under_shit(ctx, out, aval_out))
start_indices[axis] = limit_indices[axis]
return outs
@ -5840,9 +5808,7 @@ def _pad_lower(ctx, x, padding_value, *, padding_config):
aval_out, = ctx.avals_out
low, high, interior = util.unzip3(padding_config)
out = mlir.pad(ctx, aval_out, x, padding_value, low, high, interior)
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(pad_p, _pad_lower)
@ -5908,9 +5874,7 @@ def _squeeze_lower(ctx, operand, *, dimensions):
del dimensions # Implied by the output aval.
aval_out, = ctx.avals_out
out = mlir.reshape(ctx, operand, aval_out)
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(squeeze_p, _squeeze_lower)
@ -6073,16 +6037,11 @@ def _reshape_dtype_rule(operand, *, new_sizes, dimensions, sharding):
def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding):
assert ad.is_undefined_primal(operand)
if dimensions is None:
if config.sharding_in_types.value:
return [reshape(t, operand.aval.shape, out_sharding=operand.aval.sharding)]
return [reshape(t, operand.aval.shape)]
else:
if config.sharding_in_types.value:
t_s = operand.aval.sharding.with_spec(
tuple(map(lambda s: s if s is None else str(s),
np.take(operand.aval.sharding.spec, dimensions))))
else:
t_s = None
return [transpose(reshape(t, np.take(operand.aval.shape, dimensions),
out_sharding=t_s),
np.argsort(dimensions))]
@ -6110,9 +6069,7 @@ def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding):
if dyn_shape:
aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
out = mlir.reshape(ctx, x, aval_out)
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
def _reshape_staging_rule(
trace, x, *dyn, new_sizes, dimensions, sharding):
@ -6162,9 +6119,7 @@ batching.primitive_batchers[rev_p] = _rev_batch_rule
def _rev_lower(ctx, x, *, dimensions):
aval_out, = ctx.avals_out
out = hlo.reverse(x, mlir.dense_int_array(dimensions))
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(rev_p, _rev_lower)
@ -6201,9 +6156,7 @@ def _transpose_lower(ctx, x, *, permutation):
trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))]
permutation = [*permutation, *trailing_dims]
out = hlo.transpose(x, mlir.dense_int_array(permutation))
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
transpose_p = standard_primitive(
_transpose_shape_rule, _input_dtype, 'transpose',
@ -6345,10 +6298,6 @@ def _select_hlo_lowering_opaque(ctx, which, *cases):
avals_in=[aval_which_bcast, *physical_avals_cases],
avals_out=[physical_aval_out])[0]
def _add_shit_to_select(ctx, op, aval_out):
if config.sharding_in_types.value:
return mlir.lower_sharding_under_shit(ctx, op, aval_out)
return op
def _select_hlo_lowering(ctx, which, *cases):
which_aval = ctx.avals_in[0]
@ -6356,13 +6305,13 @@ def _select_hlo_lowering(ctx, which, *cases):
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
op = _select_hlo_lowering_opaque(ctx, which, *cases)
return [_add_shit_to_select(ctx, op, aval_out)]
return [mlir.lower_sharding_under_shit(ctx, op, aval_out)]
if which_aval.dtype == np.dtype(np.bool_):
assert len(cases) <= 2
if len(cases) == 1: return cases
op = hlo.select(which, cases[1], cases[0])
return [_add_shit_to_select(ctx, op, aval_out)]
return [mlir.lower_sharding_under_shit(ctx, op, aval_out)]
if dtypes.issubdtype(which_aval.dtype, np.signedinteger):
compare_type = 'SIGNED'
@ -6382,7 +6331,7 @@ def _select_hlo_lowering(ctx, which, *cases):
_select(offset + mid, cases[mid:]))
op = _select(0, cases)
return [_add_shit_to_select(ctx, op, aval_out)]
return [mlir.lower_sharding_under_shit(ctx, op, aval_out)]
select_n_p = standard_primitive(
_select_shape_rule, _select_dtype_rule, 'select_n',
@ -6514,10 +6463,8 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions):
*reducer.arguments,
dim_var_values=ctx.dim_var_values)
hlo.return_(mlir.flatten_ir_values(out_nodes))
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, r, aval)
for r, aval in safe_zip(op.results, ctx.avals_out)]
return op.results
mlir.register_lowering(reduce_p, _reduce_lower)
@ -6532,11 +6479,8 @@ def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
assert ad.is_undefined_primal(operand)
input_shape = operand.aval.shape
broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes))
if config.sharding_in_types.value:
result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions,
out_sharding=operand.aval.sharding)
else:
result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions)
assert result.shape == input_shape
return [result]
@ -6674,7 +6618,7 @@ def _compute_argminmax(value_comparator, get_identity,
axis, = axes
indices = broadcasted_iota(
index_dtype, np.shape(operand), axis,
out_sharding=operand.aval.sharding if config.sharding_in_types.value else None)
out_sharding=operand.aval.sharding)
res = reduce([operand, indices],
[get_identity(operand.dtype), np.array(0, index_dtype)],
_ArgMinMaxReducer(value_comparator),
@ -6739,9 +6683,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_region):
hlo.return_([reducer(*reducer_region.arguments)])
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, op.result, aval_out)]
return op.results
mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp,
_get_sum_identity))
@ -6782,9 +6724,7 @@ def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
aval_out, = ctx.avals_out
out = hlo.reduce_precision(operand, mlir.i32_attr(exponent_bits),
mlir.i32_attr(mantissa_bits))
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(reduce_precision_p, _reduce_precision_lower)
@ -6924,8 +6864,7 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
mlir.flatten_ir_values(operands),
dimension=mlir.i64_attr(dimension),
is_stable=ir.BoolAttr.get(is_stable))
scalar_s = (lambda a: a.sharding.with_spec(P())
if config.sharding_in_types.value else lambda _: None)
scalar_s = lambda a: a.sharding.with_spec(P())
scalar_avals = [aval.update(shape=(), sharding=scalar_s(aval))
for aval in ctx.avals_in]
scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals)
@ -7458,9 +7397,7 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension, sharding):
if dyn_shape:
aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
out = mlir.iota(ctx, aval_out, dimension=dimension)
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(iota_p, _iota_lower)
def _iota_batching_rule(in_vals, in_dims, *, dtype, shape, dimension,
@ -7668,20 +7605,16 @@ def _const(example, val):
_zeros: Callable = partial(full_like, fill_value=0)
def _zero(x):
if config.sharding_in_types.value:
x_aval = core.get_aval(x)
return full_like(x, shape=(), fill_value=0,
sharding=x_aval.sharding.with_spec(P()))
return full_like(x, shape=(), fill_value=0)
_ones: Callable = partial(full_like, fill_value=1)
def _one(x):
if config.sharding_in_types.value:
x_aval = core.get_aval(x)
return full_like(x, shape=(), fill_value=1,
sharding=x_aval.sharding.with_spec(P()))
return full_like(x, shape=(), fill_value=1)
_twos: Callable = partial(full_like, fill_value=2)
_two: Callable = partial(full_like, shape=(), fill_value=2)

View File

@ -1041,7 +1041,6 @@ def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
batch_dims = operand.shape[:-2]
n = operand.shape[-1]
if config.sharding_in_types.value:
batch_s = operand.sharding.spec[:-2]
ns = operand.sharding.spec[-1]
if ns is not None:
@ -1049,8 +1048,6 @@ def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
' specs. Try marking their specs as None.')
w_s = operand.sharding.with_spec(P(*batch_s + (ns,)))
v_s = operand.sharding.with_spec(P(*batch_s + (ns, ns)))
else:
w_s, v_s = None, None
w = operand.update(shape=batch_dims + (n,),
dtype=lax_internal._complex_basetype(operand.dtype),
sharding=w_s)
@ -1123,7 +1120,6 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index):
n = operand.shape[-1]
d = (n if subset_by_index is None else
subset_by_index[1] - subset_by_index[0])
if config.sharding_in_types.value:
batch_s = operand.sharding.spec[:-2]
ns, ds = operand.sharding.spec[-1], None
if ns is not None:
@ -1131,8 +1127,6 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index):
'marking their specs as None.')
v_s = operand.sharding.with_spec(P(*batch_s + (ns, ds)))
w_s = operand.sharding.with_spec(P(*batch_s + (ds,)))
else:
v_s, w_s = None, None
v = operand.update(shape=batch_dims + (n, d), sharding=v_s)
w = operand.update(
shape=batch_dims + (d,),
@ -1450,9 +1444,7 @@ def _triangular_solve_lowering(
ir.BoolAttr.get(lower),
ir.BoolAttr.get(unit_diagonal),
hlo.TransposeAttr.get(transpose))
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, out_aval)]
return [out]
def _triangular_solve_cpu_lower(
@ -1901,15 +1893,12 @@ def _geqrf_abstract_eval(operand):
if operand.ndim < 2:
raise ValueError("Argument to QR decomposition must have ndims >= 2")
*batch_dims, m, n = operand.shape
if config.sharding_in_types.value:
spec = operand.sharding.spec
batch_s, ms, ns = spec[:-2], spec[-2], spec[-1]
if ms is not None or ns is not None:
raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}'
' specs. Try marking their specs as None.')
taus_s = operand.sharding.with_spec(P(*(*batch_s, None)))
else:
taus_s = None
taus = operand.update(shape=(*batch_dims, core.min_dim(m, n)),
sharding=taus_s)
return operand, taus
@ -2117,7 +2106,7 @@ def _qr_abstract_eval(operand, *, pivoting, full_matrices, use_magma):
raise ValueError("Argument to QR decomposition must have ndims >= 2")
*batch_dims, m, n = operand.shape
k = m if full_matrices else core.min_dim(m, n)
if config.sharding_in_types.value:
*batch_s, ms, ns = operand.sharding.spec
ks = None
if ms is not None or ns is not None:
@ -2126,8 +2115,7 @@ def _qr_abstract_eval(operand, *, pivoting, full_matrices, use_magma):
q_s = operand.sharding.with_spec(P(*(*batch_s, ms, ks)))
r_s = operand.sharding.with_spec(P(*(*batch_s, ks, ns)))
p_s = operand.sharding.with_spec(P(*(*batch_s, ns)))
else:
q_s, r_s, p_s = None, None, None
q = operand.update(shape=(*batch_dims, m, k), sharding=q_s)
r = operand.update(shape=(*batch_dims, k, n), sharding=r_s)
p = operand.update(shape=(*batch_dims, n), dtype=np.dtype(np.int32),
@ -2241,7 +2229,6 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index,
raise ValueError("full_matrices and subset_by_index cannot both be set")
rank = min(rank, subset_by_index[1] - subset_by_index[0])
if config.sharding_in_types.value:
batch_s = operand.sharding.spec[:-2]
ms = operand.sharding.spec[-2]
ns = operand.sharding.spec[-1]
@ -2254,8 +2241,6 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index,
P(*batch_s + (ms, ms if full_matrices else rank_s)))
vt_sharding = operand.sharding.with_spec(
P(*batch_s + (ns if full_matrices else rank_s, ns)))
else:
s_sharding, u_sharding, vt_sharding = None, None, None
s = operand.update(
shape=batch_dims + (rank,),

View File

@ -24,7 +24,6 @@ import math
from jax import tree_util
from jax._src import core
from jax._src import config
from jax._src import dispatch
from jax._src import dtypes
from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext,
@ -732,16 +731,12 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
if len(pos_axes) != 0:
raise ValueError(f"axis_index_groups can only be used with reductions over "
f"named axes, but got: {axes}")
if config.sharding_in_types.value:
core.check_avals_context_mesh(args, 'all_reduce')
out_avals = [
ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype,
sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes))
for arg in args
]
else:
out_avals = [ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype)
for arg in args]
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}
def _check_axis_names(axes):
@ -795,11 +790,8 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
else:
op = hlo.AllReduceOp(
[x.type], [x], replica_groups=replica_groups, **other_args)
if config.sharding_in_types.value:
scalar_aval = core.ShapedArray(
(), aval.dtype, sharding=NamedSharding(aval.sharding.mesh, P()))
else:
scalar_aval = core.ShapedArray((), aval.dtype)
scalar_type = mlir.aval_to_ir_type(scalar_aval)
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_block):

View File

@ -1370,9 +1370,7 @@ def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
aval_out, = ctx.avals_out
out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices,
limit_indices=limit_indices, strides=strides)
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(slice_p, _slice_lower)
@ -1525,9 +1523,7 @@ def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes):
if dyn:
aval_out = aval_out.update(shape=lax._merge_dyn_shape(slice_sizes, dyn))
out = mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices)
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower)
@ -1642,9 +1638,7 @@ def _dynamic_update_slice_lower(ctx, x, update, *start_indices):
aval_out, = ctx.avals_out
out = mlir.dynamic_update_slice(ctx, aval_out, x, update,
start_indices=start_indices)
if config.sharding_in_types.value:
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
return [out]
mlir.register_lowering(dynamic_update_slice_p, _dynamic_update_slice_lower)

View File

@ -20,7 +20,6 @@ from functools import partial
from jax._src import core
from jax._src import dispatch
from jax._src import config
from jax._src import dtypes
from jax._src import mesh as mesh_lib
from jax._src.util import safe_zip
@ -50,8 +49,6 @@ def standard_primitive(shape_rule, dtype_rule, name,
def _get_array_abstraction_level(a): return a.array_abstraction_level
def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh:
if not config.sharding_in_types.value:
return None # type: ignore
m = None
for a in in_avals:
if a is core.abstract_token:
@ -69,7 +66,6 @@ def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh:
def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
if config.sharding_in_types.value:
cur_mesh = mesh_lib.get_abstract_mesh()
aval_mesh = _get_abstract_mesh_from_avals(avals)
if ((cur_mesh.empty or cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual) and
@ -84,7 +80,6 @@ def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
' this error by dropping that operation into full auto sharding'
' mode via: `jax.experimental.shard.auto_axes(fun, out_shardings=...)`')
return rule(*avals, **kwargs)
return None if num_out is None else [None] * num_out
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
sharding_rule, *avals, **kwargs):

View File

@ -563,6 +563,5 @@ def get_concrete_mesh():
@contextlib.contextmanager
def use_mesh(mesh: Mesh):
with (set_abstract_mesh(mesh.abstract_mesh),
jax_config.sharding_in_types(True), set_concrete_mesh(mesh)):
with set_abstract_mesh(mesh.abstract_mesh), set_concrete_mesh(mesh):
yield

View File

@ -665,11 +665,8 @@ def _one_hot(x: Array, num_classes: int, *,
lhs = lax.expand_dims(x, (axis,))
rhs_shape = [1] * x.ndim
rhs_shape.insert(output_pos_axis, num_classes)
if config.sharding_in_types.value:
# TODO(yashkatariya): Maybe expose `out_sharding` on `one_hot` too?
rhs_sharding = NamedSharding(x.aval.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error
else:
rhs_sharding = None
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis,
out_sharding=rhs_sharding)
return (lhs == rhs).astype(dtype)

View File

@ -411,9 +411,6 @@ def _einsum(
_dot_general=lax.dot_general,
out_sharding=None,
):
if out_sharding is not None and not config.sharding_in_types.value:
raise NotImplementedError("out_sharding only works when sharding_in_types "
"config is True.")
out_sharding = canonicalize_sharding(out_sharding)
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
raise NotImplementedError(
@ -546,8 +543,7 @@ def _einsum(
**k_out_sharding)
else:
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
if (config.sharding_in_types.value and out_sharding is not None and
names != result_names):
if out_sharding is not None and names != result_names:
spec = out_sharding.spec
inverse_spec = tuple(spec[result_names.index(name)] for name in names)
dot_general_out_sharding = NamedSharding(out_sharding.mesh,

View File

@ -5463,8 +5463,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
# whenever x is weak, but avoids introducing weak types with something like
# array([1, 2, 3])
weak_type = dtype is None and dtypes.is_weakly_typed(object)
if (config.sharding_in_types.value and device is None and
isinstance(object, core.Tracer)):
if device is None and isinstance(object, core.Tracer):
sharding = object.aval.sharding
sharding = None if sharding.mesh.empty else sharding
else:

View File

@ -250,8 +250,7 @@ def _broadcast_arrays(*args: ArrayLike) -> list[Array]:
if not shapes or all(core.definitely_equal_shape(shapes[0], s) for s in shapes):
return [lax.asarray(arg) for arg in args]
result_shape = lax.broadcast_shapes(*shapes)
result_sharding = (lax.broadcast_shardings(*avals)
if config.sharding_in_types.value else None)
result_sharding = lax.broadcast_shardings(*avals)
return [_broadcast_to(arg, result_shape, result_sharding) for arg in args]

View File

@ -891,8 +891,7 @@ def _pallas_call_batching_rule(
batched_out_avals = []
for aval in out_avals:
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, 0, None))
if config.sharding_in_types.value else None)
sharding = aval.sharding.with_spec(tuple_insert(aval.sharding.spec, 0, None))
shape = tuple_insert(aval.shape, 0, axis_size)
batched_out_avals.append(aval.update(shape=shape, sharding=sharding))
batched_out_avals = tuple(batched_out_avals)

View File

@ -1746,14 +1746,10 @@ def _pjit_lower(
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None):
util.test_event("pjit_lower")
if config.sharding_in_types.value:
if resource_env is not None:
mesh, api_name = resource_env.physical_mesh, 'pjit'
else:
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
else:
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
if resource_env is not None else (None, 'jit'))
return pxla.lower_sharding_computation(
jaxpr, api_name, name, in_shardings, out_shardings,
in_layouts, out_layouts, tuple(donated_invars),
@ -2494,9 +2490,6 @@ state_discharge.register_discharge_rule(pjit_p)(_pjit_state_discharge_rule)
# -------------------- with_sharding_constraint --------------------
def check_shardings_are_auto(shardings_flat):
if not config.sharding_in_types.value:
return
for s in shardings_flat:
if not isinstance(s, NamedSharding):
continue

View File

@ -22,7 +22,6 @@ import math
from typing import Any, NamedTuple, cast
from jax._src import core
from jax._src import config
from jax._src import mesh as mesh_lib
from jax._src import sharding as jsharding
from jax._src import sharding_specs
@ -1254,8 +1253,6 @@ def flatten_spec(spec):
def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
check_mesh_consistency: bool = True
) -> NamedSharding | None:
if not config.sharding_in_types.value:
return sharding # type: ignore
if sharding is None:
return sharding
if isinstance(sharding, NamedSharding) and sharding.mesh.empty:

View File

@ -480,10 +480,8 @@ shard_map_p = ShardMapPrimitive('shard_map')
# Staging
def _as_manual_mesh(mesh):
if config.sharding_in_types.value:
return AbstractMesh(
mesh.shape_tuple, axis_types={AxisTypes.Manual: mesh.axis_names})
return None
def _extend_axis_env(mesh, auto):
return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items()
@ -554,12 +552,9 @@ def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
assert isinstance(aval, core.ShapedArray)
new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape))
if config.sharding_in_types.value:
new_mesh = AbstractMesh(
mesh.shape_tuple, axis_types={AxisTypes.Manual: mesh.axis_names})
new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim))
else:
new_sharding = None
return aval.update(shape=new_shape, sharding=new_sharding)
core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array
@ -568,13 +563,10 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames,
assert isinstance(aval, core.ShapedArray)
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape))
if config.sharding_in_types.value:
spec = _names_to_pspec(names)._normalized_spec_for_aval(aval.ndim)
new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else
get_abstract_mesh())
new_sharding = NamedSharding(new_mesh, spec)
else:
new_sharding = None
return aval.update(shape=new_shape, sharding=new_sharding)
core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array
@ -979,11 +971,8 @@ class ShardMapTracer(core.Tracer):
def aval(self):
aval = core.get_aval(self.val)
out = core.mapped_aval(self._trace.mesh.size, 0, aval)
if config.sharding_in_types.value:
new_sharding = NamedSharding(_as_manual_mesh(self._trace.mesh),
out.sharding.spec) # pytype: disable=attribute-error
else:
new_sharding = None
return out.update(sharding=new_sharding)
def to_concrete_value(self):