mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Now that sharding_in_types config flag is True, remove the config and all the conditionals
PiperOrigin-RevId: 728653433
This commit is contained in:
parent
d5e5b42de8
commit
a3edfb43ef
@ -1029,9 +1029,6 @@ def vmap(fun: F,
|
||||
return cast(F, vmap_f)
|
||||
|
||||
def _mapped_axis_spec(args_flat, in_axes_flat):
|
||||
if not config.sharding_in_types.value:
|
||||
return None
|
||||
|
||||
def _get_spec(arg, i):
|
||||
try:
|
||||
# Duck type arrays like BCOO arrays can be passed to vmap.
|
||||
|
@ -234,7 +234,6 @@ def trace_context():
|
||||
default_device.value, random_seed_offset.value,
|
||||
threefry_partitionable.value,
|
||||
threefry_gpu_kernel_lowering.value,
|
||||
sharding_in_types.value,
|
||||
use_direct_linearize.value,
|
||||
softmax_custom_jvp.value,
|
||||
disable_jit.value,
|
||||
@ -1067,13 +1066,6 @@ threefry_gpu_kernel_lowering = bool_state(
|
||||
'cost.'),
|
||||
include_in_jit_key=True)
|
||||
|
||||
sharding_in_types = bool_state(
|
||||
name='jax_sharding_in_types',
|
||||
default=True,
|
||||
help=('When True, enables forward only sharding propagation in JAX and '
|
||||
'avals have sharding on them.'),
|
||||
include_in_jit_key=True)
|
||||
|
||||
use_direct_linearize = bool_state(
|
||||
name='jax_use_direct_linearize',
|
||||
default=False,
|
||||
|
@ -498,8 +498,6 @@ class Primitive:
|
||||
return f'{self.name}'
|
||||
|
||||
def bind(self, *args, **params):
|
||||
if not config.sharding_in_types.value:
|
||||
return self._true_bind(*args, **params)
|
||||
args = args if self.skip_canonicalization else map(canonicalize_value, args)
|
||||
return self._true_bind(*args, **params)
|
||||
|
||||
@ -598,7 +596,6 @@ def eval_jaxpr(jaxpr: Jaxpr, consts, *args, propagate_source_info=True) -> list[
|
||||
return map(read, jaxpr.outvars)
|
||||
|
||||
def check_avals_context_mesh(avals, prim_name):
|
||||
if config.sharding_in_types.value:
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
for a in avals:
|
||||
# TODO(yashkatariya): Should be cur_mesh.unset
|
||||
@ -1498,7 +1495,7 @@ def check_valid_jaxtype(x):
|
||||
f"Value {x!r} of type {type(x)} is not a valid JAX type")
|
||||
|
||||
def update_aval_with_sharding(aval, sharding):
|
||||
if config.sharding_in_types.value and isinstance(sharding, NamedSharding):
|
||||
if isinstance(sharding, NamedSharding):
|
||||
aval = aval.update(sharding=NamedSharding(
|
||||
sharding.mesh.abstract_mesh,
|
||||
sharding.spec._normalized_spec_for_aval(aval.ndim)))
|
||||
@ -1761,9 +1758,6 @@ def _make_lengths_same(sharding, ndim):
|
||||
# TODO(dougalm): Cast scalar, numpy arrays, etc to jax arrays so that values
|
||||
# passed to primitives are always have avals, etc i.e. they are canonical.
|
||||
def canonicalize_value(val):
|
||||
if not config.sharding_in_types.value:
|
||||
return val
|
||||
|
||||
try:
|
||||
aval = get_aval(val)
|
||||
except TypeError:
|
||||
@ -1783,8 +1777,6 @@ def canonicalize_value(val):
|
||||
|
||||
|
||||
def get_cur_mesh_sharding(spec=None):
|
||||
if not config.sharding_in_types.value:
|
||||
return None
|
||||
spec = P() if spec is None else spec
|
||||
return NamedSharding(mesh_lib.get_abstract_mesh(), spec)
|
||||
|
||||
@ -1845,7 +1837,6 @@ class ShapedArray(UnshapedArray):
|
||||
self.shape = canonicalize_shape(shape)
|
||||
self.dtype = _dtype_object(dtype)
|
||||
self.weak_type = weak_type
|
||||
if config.sharding_in_types.value:
|
||||
self.sharding = get_sharding(sharding, len(self.shape))
|
||||
|
||||
def update(self, shape=None, dtype=None, weak_type=None, **kwargs):
|
||||
@ -1856,7 +1847,7 @@ class ShapedArray(UnshapedArray):
|
||||
if weak_type is None:
|
||||
weak_type = self.weak_type
|
||||
if 'sharding' not in kwargs:
|
||||
kwargs['sharding'] = getattr(self, 'sharding', None)
|
||||
kwargs['sharding'] = self.sharding
|
||||
return ShapedArray(shape, dtype, weak_type, **kwargs)
|
||||
|
||||
ndim = property(lambda self: len(self.shape))
|
||||
@ -1873,28 +1864,23 @@ class ShapedArray(UnshapedArray):
|
||||
return (type(self) is type(other)
|
||||
and self.dtype == other.dtype and self.shape == other.shape
|
||||
and self.weak_type == other.weak_type
|
||||
and getattr(self, 'sharding', None) == getattr(other, 'sharding', None))
|
||||
and self.sharding == other.sharding)
|
||||
|
||||
def __hash__(self):
|
||||
# can use hash(self.dtype) and rely on the fact that numpy reuses base dtype
|
||||
# objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use
|
||||
# the unique character code via hash(self.dtype.char)
|
||||
return hash((self.shape, self.dtype, self.weak_type,
|
||||
getattr(self, 'sharding', None)))
|
||||
return hash((self.shape, self.dtype, self.weak_type, self.sharding))
|
||||
|
||||
def to_tangent_aval(self):
|
||||
if config.sharding_in_types.value:
|
||||
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type, sharding=self.sharding)
|
||||
else:
|
||||
return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
|
||||
self.weak_type)
|
||||
|
||||
def str_short(self, short_dtypes=False):
|
||||
dt_str = (dtypes.short_dtype_name(self.dtype) if short_dtypes else
|
||||
self.dtype.name)
|
||||
dt_str = dt_str.replace('void', 'float0')
|
||||
if hasattr(self, 'sharding') and self.sharding is not None:
|
||||
if self.sharding is not None:
|
||||
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
|
||||
return f'{dt_str}[{shapestr}]'
|
||||
else:
|
||||
@ -2484,8 +2470,7 @@ def _map_shaped_array(
|
||||
assert axis is None or aval.shape[axis] == size
|
||||
# TODO: Extend the named shape
|
||||
if axis is None: return aval
|
||||
sharding = (aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis))
|
||||
if config.sharding_in_types.value else None)
|
||||
sharding = aval.sharding.with_spec(tuple_delete(aval.sharding.spec, axis))
|
||||
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
|
||||
weak_type=aval.weak_type, sharding=sharding)
|
||||
|
||||
@ -2494,9 +2479,8 @@ def _unmap_shaped_array(
|
||||
) -> ShapedArray:
|
||||
if axis is None: return aval
|
||||
elif type(axis) is int:
|
||||
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis,
|
||||
explicit_mesh_axis))
|
||||
if config.sharding_in_types.value else None)
|
||||
sharding = aval.sharding.with_spec(tuple_insert(
|
||||
aval.sharding.spec, axis, explicit_mesh_axis))
|
||||
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
|
||||
weak_type=aval.weak_type, sharding=sharding)
|
||||
else: raise TypeError(axis)
|
||||
|
@ -1091,12 +1091,9 @@ def broadcast(x, sz, axis, mesh_axis=None):
|
||||
shape = list(np.shape(x))
|
||||
shape.insert(axis, sz)
|
||||
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
|
||||
if config.sharding_in_types.value:
|
||||
x_aval = core.get_aval(x)
|
||||
new_spec = P(*tuple_insert(x_aval.sharding.spec, axis, mesh_axis))
|
||||
sharding = x_aval.sharding.with_spec(new_spec)
|
||||
else:
|
||||
sharding = None
|
||||
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, out_sharding=sharding)
|
||||
|
||||
def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False):
|
||||
|
@ -1719,7 +1719,7 @@ def lower_jaxpr_to_fun(
|
||||
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
|
||||
output_avals, unconstrained_shardings): # type: ignore
|
||||
if us[0] and not us[1]:
|
||||
if config.use_shardy_partitioner.value and config.sharding_in_types.value:
|
||||
if config.use_shardy_partitioner.value:
|
||||
s = modify_sdy_sharding_wrt_axis_types(s, o_aval.sharding.mesh)
|
||||
temp_flat_outputs.append(wrap_with_sharding_op(
|
||||
entry_lowering_ctx, o, o_aval, s, unspecified_dims=us[2]))
|
||||
|
@ -974,10 +974,6 @@ def _partial_eval_jaxpr_nounits(jaxpr: ClosedJaxpr,
|
||||
core.check_jaxpr(jaxpr_unknown)
|
||||
|
||||
def check(first, second):
|
||||
if not config.sharding_in_types.value:
|
||||
assert first == second
|
||||
return
|
||||
|
||||
for f, s in zip(first, second):
|
||||
if (not isinstance(f, core.ShapedArray) and
|
||||
not isinstance(s, core.ShapedArray)):
|
||||
|
@ -2311,7 +2311,6 @@ def lower_sharding_computation(
|
||||
propagated_out_mem_kinds = get_out_memory_kinds_via_propagation(
|
||||
closed_jaxpr, in_shardings)
|
||||
|
||||
if config.sharding_in_types.value:
|
||||
out_shardings = _concretize_abstract_out_shardings(
|
||||
out_shardings, global_out_avals, device_assignment,
|
||||
propagated_out_mem_kinds)
|
||||
|
@ -231,8 +231,7 @@ def scan(f: Callable[[Carry, X], tuple[Carry, Y]],
|
||||
|
||||
xs_avals = [core.get_aval(x) for x in xs_flat]
|
||||
|
||||
if (config.sharding_in_types.value and
|
||||
not all(a.sharding.spec[0] is None for a in xs_avals)):
|
||||
if not all(a.sharding.spec[0] is None for a in xs_avals):
|
||||
raise ValueError('0th dimension of all xs should be replicated. Got '
|
||||
f'{", ".join(str(a.sharding.spec) for a in xs_avals)}')
|
||||
|
||||
@ -504,8 +503,7 @@ def _split_leading(sz, x):
|
||||
def _concat(a, b): return lax.concatenate([a, b], 0)
|
||||
|
||||
def _empty_array(prefix, length_spec, aval):
|
||||
sharding = (aval.sharding.with_spec((length_spec, *aval.sharding.spec))
|
||||
if config.sharding_in_types.value else None)
|
||||
sharding = aval.sharding.with_spec((length_spec, *aval.sharding.spec))
|
||||
return lax.broadcast(lax.empty(aval.dtype), (*prefix, *aval.shape),
|
||||
out_sharding=sharding)
|
||||
|
||||
|
@ -2025,9 +2025,6 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN
|
||||
by the ``lhs`` non-contracting/non-batch dimensions, and finally the ``rhs``
|
||||
non-contracting/non-batch dimensions.
|
||||
"""
|
||||
if out_sharding is not None and not config.sharding_in_types.value:
|
||||
raise NotImplementedError("out_sharding only works when sharding_in_types "
|
||||
"config is True.")
|
||||
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
|
||||
raise NotImplementedError(
|
||||
'`out_sharding` argument of `dot_general` only supports NamedSharding '
|
||||
@ -2116,9 +2113,6 @@ def broadcast_in_dim(operand: ArrayLike, shape: Shape,
|
||||
See Also:
|
||||
jax.lax.broadcast : simpler interface to add new leading dimensions.
|
||||
"""
|
||||
if not config.sharding_in_types.value and out_sharding is not None:
|
||||
raise NotImplementedError("out_sharding argument to broadcast_in_dim is only "
|
||||
"allowed when sharding_in_types config is on.")
|
||||
out_sharding = canonicalize_sharding(out_sharding)
|
||||
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions) and
|
||||
isinstance(operand, Array) and out_sharding is None):
|
||||
@ -2748,8 +2742,7 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
|
||||
shard = broadcast(fill_value, broadcast_shape)
|
||||
return array.make_array_from_callback(shape, sharding, lambda _: shard)
|
||||
|
||||
if (config.sharding_in_types.value and sharding is not None and
|
||||
not sharding._is_concrete):
|
||||
if sharding is not None and not sharding._is_concrete:
|
||||
return broadcast(fill_value, shape, out_sharding=sharding)
|
||||
else:
|
||||
return broadcast(fill_value, shape)
|
||||
@ -2762,9 +2755,7 @@ def zeros_like_shaped_array(aval: ShapedArray) -> Array:
|
||||
scalar_zero = np.zeros((), dtype=aval.dtype)
|
||||
else:
|
||||
scalar_zero = _convert_element_type(0, aval.dtype, aval.weak_type)
|
||||
if config.sharding_in_types.value:
|
||||
return broadcast(scalar_zero, aval.shape, out_sharding=aval.sharding)
|
||||
return broadcast(scalar_zero, aval.shape)
|
||||
|
||||
ad_util.aval_zeros_likers[ShapedArray] = zeros_like_shaped_array
|
||||
|
||||
@ -2793,9 +2784,6 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int,
|
||||
static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
|
||||
dimension = core.concrete_or_error(
|
||||
int, dimension, "dimension argument of lax.broadcasted_iota")
|
||||
if not config.sharding_in_types.value and out_sharding is not None:
|
||||
raise NotImplementedError('sharding support for broadcasted_iota is not '
|
||||
'implemented outside of sharding_in_types mode.')
|
||||
out_sharding = canonicalize_sharding(out_sharding)
|
||||
return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
|
||||
dimension=dimension, sharding=out_sharding)
|
||||
@ -2959,8 +2947,7 @@ def full_like(x: ArrayLike | DuckTypedArray,
|
||||
if dtypes.issubdtype(dtype, dtypes.extended):
|
||||
return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr]
|
||||
|
||||
if (config.sharding_in_types.value and sharding is None and shape is None and
|
||||
isinstance(x, core.Tracer)):
|
||||
if sharding is None and shape is None and isinstance(x, core.Tracer):
|
||||
sharding = x.aval.sharding
|
||||
else:
|
||||
# If `x` has a sharding but no `_committed` attribute
|
||||
@ -3461,14 +3448,10 @@ def _nary_lower_hlo(op: Callable, ctx,
|
||||
del params
|
||||
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
|
||||
args = mlir.multi_broadcast_in_dim(ctx, args, avals_in, aval_out.shape)
|
||||
if config.sharding_in_types.value:
|
||||
args = multi_sharding_in_dim(ctx, args, avals_in, aval_out)
|
||||
|
||||
out = op(*args)
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
else:
|
||||
return [out]
|
||||
|
||||
|
||||
_float = {np.floating}
|
||||
@ -3876,10 +3859,8 @@ def _integer_pow_lowering(ctx, x, *, y):
|
||||
if builtins.abs(y) >= 3:
|
||||
lowering = mlir.cache_lowering(lowering)
|
||||
out, = lowering(ctx, x, y=y)
|
||||
if config.sharding_in_types.value:
|
||||
aval_out, = ctx.avals_out
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
|
||||
mlir.register_lowering(integer_pow_p, _integer_pow_lowering)
|
||||
|
||||
@ -4267,9 +4248,7 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type,
|
||||
operand = hlo.real(operand)
|
||||
aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype))
|
||||
out = mlir.convert_hlo(ctx, operand, aval_in, aval_out)
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
|
||||
mlir.register_lowering(convert_element_type_p, _convert_element_type_lower)
|
||||
|
||||
@ -4630,12 +4609,9 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision,
|
||||
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
|
||||
unsorted_axes = list(x_batch) + x_kept + x_contract_sorted_by_y
|
||||
out_axes = np.argsort(unsorted_axes)
|
||||
if config.sharding_in_types.value:
|
||||
xs = x.aval.sharding
|
||||
inverse_spec = tuple(xs.spec[o] for o in unsorted_axes)
|
||||
ds = xs.with_spec(inverse_spec)
|
||||
else:
|
||||
ds = None
|
||||
dot_general_out = dot_general(g, y, dims, precision=precision,
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_sharding=ds)
|
||||
@ -5020,7 +4996,7 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
precision_config=precision_attr(precision),
|
||||
**algorithm_kwarg,
|
||||
)
|
||||
if config.sharding_in_types.value:
|
||||
|
||||
result = mlir.lower_sharding_under_shit(ctx, result, aval_out)
|
||||
if accumulation_aval.dtype != aval_out.dtype:
|
||||
result = mlir.convert_hlo(ctx, result, accumulation_aval, aval_out)
|
||||
@ -5487,9 +5463,7 @@ def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions,
|
||||
aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
|
||||
out = mlir.broadcast_in_dim(ctx, x, aval_out,
|
||||
broadcast_dimensions=broadcast_dimensions)
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
|
||||
def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions,
|
||||
sharding):
|
||||
@ -5498,12 +5472,9 @@ def _broadcast_in_dim_abstract_eval(x, *dyn_shape, shape, broadcast_dimensions,
|
||||
type(core.get_aval(d).dtype) is core.bint for d in shape)):
|
||||
shape = _broadcast_in_dim_shape_rule( # error checking
|
||||
x, shape=shape, broadcast_dimensions=broadcast_dimensions, sharding=None)
|
||||
if config.sharding_in_types.value:
|
||||
new_sharding = _broadcast_in_dim_sharding_rule(
|
||||
x, shape=shape, broadcast_dimensions=broadcast_dimensions,
|
||||
sharding=sharding)
|
||||
else:
|
||||
new_sharding = None
|
||||
return core.ShapedArray(shape, x.dtype, x.weak_type, sharding=new_sharding)
|
||||
# If any BInts in shape, or Tracers in dyn_shape, produce a DShapedArray
|
||||
# (even if x is a ShapedArray)
|
||||
@ -5680,9 +5651,7 @@ pe.padding_rules[concatenate_p] = _concatenate_pad_rule
|
||||
def _concatenate_lower(ctx, *xs, dimension):
|
||||
aval_out, = ctx.avals_out
|
||||
out = hlo.concatenate(xs, mlir.i64_attr(dimension))
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
mlir.register_lowering(concatenate_p, _concatenate_lower)
|
||||
|
||||
|
||||
@ -5734,8 +5703,7 @@ def _split_lower(ctx, x, *, sizes, axis):
|
||||
limit_indices[axis] = start_indices[axis] + aval_out.shape[axis]
|
||||
out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices,
|
||||
limit_indices=limit_indices, strides=strides)
|
||||
outs.append(mlir.lower_sharding_under_shit(ctx, out, aval_out)
|
||||
if config.sharding_in_types.value else out)
|
||||
outs.append(mlir.lower_sharding_under_shit(ctx, out, aval_out))
|
||||
start_indices[axis] = limit_indices[axis]
|
||||
return outs
|
||||
|
||||
@ -5840,9 +5808,7 @@ def _pad_lower(ctx, x, padding_value, *, padding_config):
|
||||
aval_out, = ctx.avals_out
|
||||
low, high, interior = util.unzip3(padding_config)
|
||||
out = mlir.pad(ctx, aval_out, x, padding_value, low, high, interior)
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
|
||||
mlir.register_lowering(pad_p, _pad_lower)
|
||||
|
||||
@ -5908,9 +5874,7 @@ def _squeeze_lower(ctx, operand, *, dimensions):
|
||||
del dimensions # Implied by the output aval.
|
||||
aval_out, = ctx.avals_out
|
||||
out = mlir.reshape(ctx, operand, aval_out)
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
|
||||
mlir.register_lowering(squeeze_p, _squeeze_lower)
|
||||
|
||||
@ -6073,16 +6037,11 @@ def _reshape_dtype_rule(operand, *, new_sizes, dimensions, sharding):
|
||||
def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding):
|
||||
assert ad.is_undefined_primal(operand)
|
||||
if dimensions is None:
|
||||
if config.sharding_in_types.value:
|
||||
return [reshape(t, operand.aval.shape, out_sharding=operand.aval.sharding)]
|
||||
return [reshape(t, operand.aval.shape)]
|
||||
else:
|
||||
if config.sharding_in_types.value:
|
||||
t_s = operand.aval.sharding.with_spec(
|
||||
tuple(map(lambda s: s if s is None else str(s),
|
||||
np.take(operand.aval.sharding.spec, dimensions))))
|
||||
else:
|
||||
t_s = None
|
||||
return [transpose(reshape(t, np.take(operand.aval.shape, dimensions),
|
||||
out_sharding=t_s),
|
||||
np.argsort(dimensions))]
|
||||
@ -6110,9 +6069,7 @@ def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions, sharding):
|
||||
if dyn_shape:
|
||||
aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
|
||||
out = mlir.reshape(ctx, x, aval_out)
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
|
||||
def _reshape_staging_rule(
|
||||
trace, x, *dyn, new_sizes, dimensions, sharding):
|
||||
@ -6162,9 +6119,7 @@ batching.primitive_batchers[rev_p] = _rev_batch_rule
|
||||
def _rev_lower(ctx, x, *, dimensions):
|
||||
aval_out, = ctx.avals_out
|
||||
out = hlo.reverse(x, mlir.dense_int_array(dimensions))
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
mlir.register_lowering(rev_p, _rev_lower)
|
||||
|
||||
|
||||
@ -6201,9 +6156,7 @@ def _transpose_lower(ctx, x, *, permutation):
|
||||
trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))]
|
||||
permutation = [*permutation, *trailing_dims]
|
||||
out = hlo.transpose(x, mlir.dense_int_array(permutation))
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
|
||||
transpose_p = standard_primitive(
|
||||
_transpose_shape_rule, _input_dtype, 'transpose',
|
||||
@ -6345,10 +6298,6 @@ def _select_hlo_lowering_opaque(ctx, which, *cases):
|
||||
avals_in=[aval_which_bcast, *physical_avals_cases],
|
||||
avals_out=[physical_aval_out])[0]
|
||||
|
||||
def _add_shit_to_select(ctx, op, aval_out):
|
||||
if config.sharding_in_types.value:
|
||||
return mlir.lower_sharding_under_shit(ctx, op, aval_out)
|
||||
return op
|
||||
|
||||
def _select_hlo_lowering(ctx, which, *cases):
|
||||
which_aval = ctx.avals_in[0]
|
||||
@ -6356,13 +6305,13 @@ def _select_hlo_lowering(ctx, which, *cases):
|
||||
|
||||
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
|
||||
op = _select_hlo_lowering_opaque(ctx, which, *cases)
|
||||
return [_add_shit_to_select(ctx, op, aval_out)]
|
||||
return [mlir.lower_sharding_under_shit(ctx, op, aval_out)]
|
||||
|
||||
if which_aval.dtype == np.dtype(np.bool_):
|
||||
assert len(cases) <= 2
|
||||
if len(cases) == 1: return cases
|
||||
op = hlo.select(which, cases[1], cases[0])
|
||||
return [_add_shit_to_select(ctx, op, aval_out)]
|
||||
return [mlir.lower_sharding_under_shit(ctx, op, aval_out)]
|
||||
|
||||
if dtypes.issubdtype(which_aval.dtype, np.signedinteger):
|
||||
compare_type = 'SIGNED'
|
||||
@ -6382,7 +6331,7 @@ def _select_hlo_lowering(ctx, which, *cases):
|
||||
_select(offset + mid, cases[mid:]))
|
||||
|
||||
op = _select(0, cases)
|
||||
return [_add_shit_to_select(ctx, op, aval_out)]
|
||||
return [mlir.lower_sharding_under_shit(ctx, op, aval_out)]
|
||||
|
||||
select_n_p = standard_primitive(
|
||||
_select_shape_rule, _select_dtype_rule, 'select_n',
|
||||
@ -6514,10 +6463,8 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions):
|
||||
*reducer.arguments,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
hlo.return_(mlir.flatten_ir_values(out_nodes))
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, r, aval)
|
||||
for r, aval in safe_zip(op.results, ctx.avals_out)]
|
||||
return op.results
|
||||
|
||||
mlir.register_lowering(reduce_p, _reduce_lower)
|
||||
|
||||
@ -6532,11 +6479,8 @@ def _reduce_sum_transpose_rule(cotangent, operand, *, axes):
|
||||
assert ad.is_undefined_primal(operand)
|
||||
input_shape = operand.aval.shape
|
||||
broadcast_dimensions = tuple(np.delete(np.arange(len(input_shape)), axes))
|
||||
if config.sharding_in_types.value:
|
||||
result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions,
|
||||
out_sharding=operand.aval.sharding)
|
||||
else:
|
||||
result = broadcast_in_dim(cotangent, input_shape, broadcast_dimensions)
|
||||
assert result.shape == input_shape
|
||||
return [result]
|
||||
|
||||
@ -6674,7 +6618,7 @@ def _compute_argminmax(value_comparator, get_identity,
|
||||
axis, = axes
|
||||
indices = broadcasted_iota(
|
||||
index_dtype, np.shape(operand), axis,
|
||||
out_sharding=operand.aval.sharding if config.sharding_in_types.value else None)
|
||||
out_sharding=operand.aval.sharding)
|
||||
res = reduce([operand, indices],
|
||||
[get_identity(operand.dtype), np.array(0, index_dtype)],
|
||||
_ArgMinMaxReducer(value_comparator),
|
||||
@ -6739,9 +6683,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
|
||||
reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer_region):
|
||||
hlo.return_([reducer(*reducer_region.arguments)])
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, op.result, aval_out)]
|
||||
return op.results
|
||||
|
||||
mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp,
|
||||
_get_sum_identity))
|
||||
@ -6782,9 +6724,7 @@ def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
|
||||
aval_out, = ctx.avals_out
|
||||
out = hlo.reduce_precision(operand, mlir.i32_attr(exponent_bits),
|
||||
mlir.i32_attr(mantissa_bits))
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
|
||||
mlir.register_lowering(reduce_precision_p, _reduce_precision_lower)
|
||||
|
||||
@ -6924,8 +6864,7 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
|
||||
mlir.flatten_ir_values(operands),
|
||||
dimension=mlir.i64_attr(dimension),
|
||||
is_stable=ir.BoolAttr.get(is_stable))
|
||||
scalar_s = (lambda a: a.sharding.with_spec(P())
|
||||
if config.sharding_in_types.value else lambda _: None)
|
||||
scalar_s = lambda a: a.sharding.with_spec(P())
|
||||
scalar_avals = [aval.update(shape=(), sharding=scalar_s(aval))
|
||||
for aval in ctx.avals_in]
|
||||
scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals)
|
||||
@ -7458,9 +7397,7 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension, sharding):
|
||||
if dyn_shape:
|
||||
aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
|
||||
out = mlir.iota(ctx, aval_out, dimension=dimension)
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
mlir.register_lowering(iota_p, _iota_lower)
|
||||
|
||||
def _iota_batching_rule(in_vals, in_dims, *, dtype, shape, dimension,
|
||||
@ -7668,20 +7605,16 @@ def _const(example, val):
|
||||
_zeros: Callable = partial(full_like, fill_value=0)
|
||||
|
||||
def _zero(x):
|
||||
if config.sharding_in_types.value:
|
||||
x_aval = core.get_aval(x)
|
||||
return full_like(x, shape=(), fill_value=0,
|
||||
sharding=x_aval.sharding.with_spec(P()))
|
||||
return full_like(x, shape=(), fill_value=0)
|
||||
|
||||
_ones: Callable = partial(full_like, fill_value=1)
|
||||
|
||||
def _one(x):
|
||||
if config.sharding_in_types.value:
|
||||
x_aval = core.get_aval(x)
|
||||
return full_like(x, shape=(), fill_value=1,
|
||||
sharding=x_aval.sharding.with_spec(P()))
|
||||
return full_like(x, shape=(), fill_value=1)
|
||||
|
||||
_twos: Callable = partial(full_like, fill_value=2)
|
||||
_two: Callable = partial(full_like, shape=(), fill_value=2)
|
||||
|
@ -1041,7 +1041,6 @@ def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
|
||||
|
||||
batch_dims = operand.shape[:-2]
|
||||
n = operand.shape[-1]
|
||||
if config.sharding_in_types.value:
|
||||
batch_s = operand.sharding.spec[:-2]
|
||||
ns = operand.sharding.spec[-1]
|
||||
if ns is not None:
|
||||
@ -1049,8 +1048,6 @@ def _eigh_jacobi_abstract_eval(operand, *, lower, sort_eigenvalues):
|
||||
' specs. Try marking their specs as None.')
|
||||
w_s = operand.sharding.with_spec(P(*batch_s + (ns,)))
|
||||
v_s = operand.sharding.with_spec(P(*batch_s + (ns, ns)))
|
||||
else:
|
||||
w_s, v_s = None, None
|
||||
w = operand.update(shape=batch_dims + (n,),
|
||||
dtype=lax_internal._complex_basetype(operand.dtype),
|
||||
sharding=w_s)
|
||||
@ -1123,7 +1120,6 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index):
|
||||
n = operand.shape[-1]
|
||||
d = (n if subset_by_index is None else
|
||||
subset_by_index[1] - subset_by_index[0])
|
||||
if config.sharding_in_types.value:
|
||||
batch_s = operand.sharding.spec[:-2]
|
||||
ns, ds = operand.sharding.spec[-1], None
|
||||
if ns is not None:
|
||||
@ -1131,8 +1127,6 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index):
|
||||
'marking their specs as None.')
|
||||
v_s = operand.sharding.with_spec(P(*batch_s + (ns, ds)))
|
||||
w_s = operand.sharding.with_spec(P(*batch_s + (ds,)))
|
||||
else:
|
||||
v_s, w_s = None, None
|
||||
v = operand.update(shape=batch_dims + (n, d), sharding=v_s)
|
||||
w = operand.update(
|
||||
shape=batch_dims + (d,),
|
||||
@ -1450,9 +1444,7 @@ def _triangular_solve_lowering(
|
||||
ir.BoolAttr.get(lower),
|
||||
ir.BoolAttr.get(unit_diagonal),
|
||||
hlo.TransposeAttr.get(transpose))
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, out_aval)]
|
||||
return [out]
|
||||
|
||||
|
||||
def _triangular_solve_cpu_lower(
|
||||
@ -1901,15 +1893,12 @@ def _geqrf_abstract_eval(operand):
|
||||
if operand.ndim < 2:
|
||||
raise ValueError("Argument to QR decomposition must have ndims >= 2")
|
||||
*batch_dims, m, n = operand.shape
|
||||
if config.sharding_in_types.value:
|
||||
spec = operand.sharding.spec
|
||||
batch_s, ms, ns = spec[:-2], spec[-2], spec[-1]
|
||||
if ms is not None or ns is not None:
|
||||
raise ValueError(f'm and n should be unsharded. Got m: {ms} and n: {ns}'
|
||||
' specs. Try marking their specs as None.')
|
||||
taus_s = operand.sharding.with_spec(P(*(*batch_s, None)))
|
||||
else:
|
||||
taus_s = None
|
||||
taus = operand.update(shape=(*batch_dims, core.min_dim(m, n)),
|
||||
sharding=taus_s)
|
||||
return operand, taus
|
||||
@ -2117,7 +2106,7 @@ def _qr_abstract_eval(operand, *, pivoting, full_matrices, use_magma):
|
||||
raise ValueError("Argument to QR decomposition must have ndims >= 2")
|
||||
*batch_dims, m, n = operand.shape
|
||||
k = m if full_matrices else core.min_dim(m, n)
|
||||
if config.sharding_in_types.value:
|
||||
|
||||
*batch_s, ms, ns = operand.sharding.spec
|
||||
ks = None
|
||||
if ms is not None or ns is not None:
|
||||
@ -2126,8 +2115,7 @@ def _qr_abstract_eval(operand, *, pivoting, full_matrices, use_magma):
|
||||
q_s = operand.sharding.with_spec(P(*(*batch_s, ms, ks)))
|
||||
r_s = operand.sharding.with_spec(P(*(*batch_s, ks, ns)))
|
||||
p_s = operand.sharding.with_spec(P(*(*batch_s, ns)))
|
||||
else:
|
||||
q_s, r_s, p_s = None, None, None
|
||||
|
||||
q = operand.update(shape=(*batch_dims, m, k), sharding=q_s)
|
||||
r = operand.update(shape=(*batch_dims, k, n), sharding=r_s)
|
||||
p = operand.update(shape=(*batch_dims, n), dtype=np.dtype(np.int32),
|
||||
@ -2241,7 +2229,6 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index,
|
||||
raise ValueError("full_matrices and subset_by_index cannot both be set")
|
||||
rank = min(rank, subset_by_index[1] - subset_by_index[0])
|
||||
|
||||
if config.sharding_in_types.value:
|
||||
batch_s = operand.sharding.spec[:-2]
|
||||
ms = operand.sharding.spec[-2]
|
||||
ns = operand.sharding.spec[-1]
|
||||
@ -2254,8 +2241,6 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index,
|
||||
P(*batch_s + (ms, ms if full_matrices else rank_s)))
|
||||
vt_sharding = operand.sharding.with_spec(
|
||||
P(*batch_s + (ns if full_matrices else rank_s, ns)))
|
||||
else:
|
||||
s_sharding, u_sharding, vt_sharding = None, None, None
|
||||
|
||||
s = operand.update(
|
||||
shape=batch_dims + (rank,),
|
||||
|
@ -24,7 +24,6 @@ import math
|
||||
|
||||
from jax import tree_util
|
||||
from jax._src import core
|
||||
from jax._src import config
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src.sharding_impls import (SPMDAxisContext, ShardingContext,
|
||||
@ -732,16 +731,12 @@ def _allreduce_effectful_abstract_eval(*args, axes, axis_index_groups):
|
||||
if len(pos_axes) != 0:
|
||||
raise ValueError(f"axis_index_groups can only be used with reductions over "
|
||||
f"named axes, but got: {axes}")
|
||||
if config.sharding_in_types.value:
|
||||
core.check_avals_context_mesh(args, 'all_reduce')
|
||||
out_avals = [
|
||||
ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype,
|
||||
sharding=lax._reduce_op_sharding_rule(arg, axes=pos_axes))
|
||||
for arg in args
|
||||
]
|
||||
else:
|
||||
out_avals = [ShapedArray(lax._reduce_op_shape_rule(arg, axes=pos_axes), arg.dtype)
|
||||
for arg in args]
|
||||
return out_avals, {core.NamedAxisEffect(axis) for axis in named_axes}
|
||||
|
||||
def _check_axis_names(axes):
|
||||
@ -795,11 +790,8 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
else:
|
||||
op = hlo.AllReduceOp(
|
||||
[x.type], [x], replica_groups=replica_groups, **other_args)
|
||||
if config.sharding_in_types.value:
|
||||
scalar_aval = core.ShapedArray(
|
||||
(), aval.dtype, sharding=NamedSharding(aval.sharding.mesh, P()))
|
||||
else:
|
||||
scalar_aval = core.ShapedArray((), aval.dtype)
|
||||
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
||||
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer_block):
|
||||
|
@ -1370,9 +1370,7 @@ def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
|
||||
aval_out, = ctx.avals_out
|
||||
out = mlir.slice_op(ctx, x, aval_out, start_indices=start_indices,
|
||||
limit_indices=limit_indices, strides=strides)
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
|
||||
mlir.register_lowering(slice_p, _slice_lower)
|
||||
|
||||
@ -1525,9 +1523,7 @@ def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes):
|
||||
if dyn:
|
||||
aval_out = aval_out.update(shape=lax._merge_dyn_shape(slice_sizes, dyn))
|
||||
out = mlir.dynamic_slice(ctx, aval_out, x, start_indices=start_indices)
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
|
||||
mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower)
|
||||
|
||||
@ -1642,9 +1638,7 @@ def _dynamic_update_slice_lower(ctx, x, update, *start_indices):
|
||||
aval_out, = ctx.avals_out
|
||||
out = mlir.dynamic_update_slice(ctx, aval_out, x, update,
|
||||
start_indices=start_indices)
|
||||
if config.sharding_in_types.value:
|
||||
return [mlir.lower_sharding_under_shit(ctx, out, aval_out)]
|
||||
return [out]
|
||||
|
||||
mlir.register_lowering(dynamic_update_slice_p, _dynamic_update_slice_lower)
|
||||
|
||||
|
@ -20,7 +20,6 @@ from functools import partial
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import config
|
||||
from jax._src import dtypes
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.util import safe_zip
|
||||
@ -50,8 +49,6 @@ def standard_primitive(shape_rule, dtype_rule, name,
|
||||
def _get_array_abstraction_level(a): return a.array_abstraction_level
|
||||
|
||||
def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh:
|
||||
if not config.sharding_in_types.value:
|
||||
return None # type: ignore
|
||||
m = None
|
||||
for a in in_avals:
|
||||
if a is core.abstract_token:
|
||||
@ -69,7 +66,6 @@ def _get_abstract_mesh_from_avals(in_avals) -> mesh_lib.AbstractMesh:
|
||||
|
||||
|
||||
def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
|
||||
if config.sharding_in_types.value:
|
||||
cur_mesh = mesh_lib.get_abstract_mesh()
|
||||
aval_mesh = _get_abstract_mesh_from_avals(avals)
|
||||
if ((cur_mesh.empty or cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual) and
|
||||
@ -84,7 +80,6 @@ def call_sharding_rule(prim, rule, num_out, *avals, **kwargs):
|
||||
' this error by dropping that operation into full auto sharding'
|
||||
' mode via: `jax.experimental.shard.auto_axes(fun, out_shardings=...)`')
|
||||
return rule(*avals, **kwargs)
|
||||
return None if num_out is None else [None] * num_out
|
||||
|
||||
def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
|
||||
sharding_rule, *avals, **kwargs):
|
||||
|
@ -563,6 +563,5 @@ def get_concrete_mesh():
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use_mesh(mesh: Mesh):
|
||||
with (set_abstract_mesh(mesh.abstract_mesh),
|
||||
jax_config.sharding_in_types(True), set_concrete_mesh(mesh)):
|
||||
with set_abstract_mesh(mesh.abstract_mesh), set_concrete_mesh(mesh):
|
||||
yield
|
||||
|
@ -665,11 +665,8 @@ def _one_hot(x: Array, num_classes: int, *,
|
||||
lhs = lax.expand_dims(x, (axis,))
|
||||
rhs_shape = [1] * x.ndim
|
||||
rhs_shape.insert(output_pos_axis, num_classes)
|
||||
if config.sharding_in_types.value:
|
||||
# TODO(yashkatariya): Maybe expose `out_sharding` on `one_hot` too?
|
||||
rhs_sharding = NamedSharding(x.aval.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error
|
||||
else:
|
||||
rhs_sharding = None
|
||||
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis,
|
||||
out_sharding=rhs_sharding)
|
||||
return (lhs == rhs).astype(dtype)
|
||||
|
@ -411,9 +411,6 @@ def _einsum(
|
||||
_dot_general=lax.dot_general,
|
||||
out_sharding=None,
|
||||
):
|
||||
if out_sharding is not None and not config.sharding_in_types.value:
|
||||
raise NotImplementedError("out_sharding only works when sharding_in_types "
|
||||
"config is True.")
|
||||
out_sharding = canonicalize_sharding(out_sharding)
|
||||
if out_sharding is not None and not isinstance(out_sharding, NamedSharding):
|
||||
raise NotImplementedError(
|
||||
@ -546,8 +543,7 @@ def _einsum(
|
||||
**k_out_sharding)
|
||||
else:
|
||||
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
|
||||
if (config.sharding_in_types.value and out_sharding is not None and
|
||||
names != result_names):
|
||||
if out_sharding is not None and names != result_names:
|
||||
spec = out_sharding.spec
|
||||
inverse_spec = tuple(spec[result_names.index(name)] for name in names)
|
||||
dot_general_out_sharding = NamedSharding(out_sharding.mesh,
|
||||
|
@ -5463,8 +5463,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True,
|
||||
# whenever x is weak, but avoids introducing weak types with something like
|
||||
# array([1, 2, 3])
|
||||
weak_type = dtype is None and dtypes.is_weakly_typed(object)
|
||||
if (config.sharding_in_types.value and device is None and
|
||||
isinstance(object, core.Tracer)):
|
||||
if device is None and isinstance(object, core.Tracer):
|
||||
sharding = object.aval.sharding
|
||||
sharding = None if sharding.mesh.empty else sharding
|
||||
else:
|
||||
|
@ -250,8 +250,7 @@ def _broadcast_arrays(*args: ArrayLike) -> list[Array]:
|
||||
if not shapes or all(core.definitely_equal_shape(shapes[0], s) for s in shapes):
|
||||
return [lax.asarray(arg) for arg in args]
|
||||
result_shape = lax.broadcast_shapes(*shapes)
|
||||
result_sharding = (lax.broadcast_shardings(*avals)
|
||||
if config.sharding_in_types.value else None)
|
||||
result_sharding = lax.broadcast_shardings(*avals)
|
||||
return [_broadcast_to(arg, result_shape, result_sharding) for arg in args]
|
||||
|
||||
|
||||
|
@ -891,8 +891,7 @@ def _pallas_call_batching_rule(
|
||||
|
||||
batched_out_avals = []
|
||||
for aval in out_avals:
|
||||
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, 0, None))
|
||||
if config.sharding_in_types.value else None)
|
||||
sharding = aval.sharding.with_spec(tuple_insert(aval.sharding.spec, 0, None))
|
||||
shape = tuple_insert(aval.shape, 0, axis_size)
|
||||
batched_out_avals.append(aval.update(shape=shape, sharding=sharding))
|
||||
batched_out_avals = tuple(batched_out_avals)
|
||||
|
@ -1746,14 +1746,10 @@ def _pjit_lower(
|
||||
lowering_parameters: mlir.LoweringParameters,
|
||||
pgle_profiler: profiler.PGLEProfiler | None):
|
||||
util.test_event("pjit_lower")
|
||||
if config.sharding_in_types.value:
|
||||
if resource_env is not None:
|
||||
mesh, api_name = resource_env.physical_mesh, 'pjit'
|
||||
else:
|
||||
mesh, api_name = mesh_lib.get_concrete_mesh(), 'jit'
|
||||
else:
|
||||
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
|
||||
if resource_env is not None else (None, 'jit'))
|
||||
return pxla.lower_sharding_computation(
|
||||
jaxpr, api_name, name, in_shardings, out_shardings,
|
||||
in_layouts, out_layouts, tuple(donated_invars),
|
||||
@ -2494,9 +2490,6 @@ state_discharge.register_discharge_rule(pjit_p)(_pjit_state_discharge_rule)
|
||||
# -------------------- with_sharding_constraint --------------------
|
||||
|
||||
def check_shardings_are_auto(shardings_flat):
|
||||
if not config.sharding_in_types.value:
|
||||
return
|
||||
|
||||
for s in shardings_flat:
|
||||
if not isinstance(s, NamedSharding):
|
||||
continue
|
||||
|
@ -22,7 +22,6 @@ import math
|
||||
from typing import Any, NamedTuple, cast
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import config
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src import sharding as jsharding
|
||||
from jax._src import sharding_specs
|
||||
@ -1254,8 +1253,6 @@ def flatten_spec(spec):
|
||||
def canonicalize_sharding(sharding: NamedSharding | PartitionSpec | None,
|
||||
check_mesh_consistency: bool = True
|
||||
) -> NamedSharding | None:
|
||||
if not config.sharding_in_types.value:
|
||||
return sharding # type: ignore
|
||||
if sharding is None:
|
||||
return sharding
|
||||
if isinstance(sharding, NamedSharding) and sharding.mesh.empty:
|
||||
|
@ -480,10 +480,8 @@ shard_map_p = ShardMapPrimitive('shard_map')
|
||||
# Staging
|
||||
|
||||
def _as_manual_mesh(mesh):
|
||||
if config.sharding_in_types.value:
|
||||
return AbstractMesh(
|
||||
mesh.shape_tuple, axis_types={AxisTypes.Manual: mesh.axis_names})
|
||||
return None
|
||||
|
||||
def _extend_axis_env(mesh, auto):
|
||||
return core.extend_axis_env_nd([(k, v) for k, v in mesh.shape.items()
|
||||
@ -554,12 +552,9 @@ def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
|
||||
assert isinstance(aval, core.ShapedArray)
|
||||
new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
for i, sz in enumerate(aval.shape))
|
||||
if config.sharding_in_types.value:
|
||||
new_mesh = AbstractMesh(
|
||||
mesh.shape_tuple, axis_types={AxisTypes.Manual: mesh.axis_names})
|
||||
new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim))
|
||||
else:
|
||||
new_sharding = None
|
||||
return aval.update(shape=new_shape, sharding=new_sharding)
|
||||
core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array
|
||||
|
||||
@ -568,13 +563,10 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames,
|
||||
assert isinstance(aval, core.ShapedArray)
|
||||
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
|
||||
for i, sz in enumerate(aval.shape))
|
||||
if config.sharding_in_types.value:
|
||||
spec = _names_to_pspec(names)._normalized_spec_for_aval(aval.ndim)
|
||||
new_mesh = (mesh.abstract_mesh if get_abstract_mesh().empty else
|
||||
get_abstract_mesh())
|
||||
new_sharding = NamedSharding(new_mesh, spec)
|
||||
else:
|
||||
new_sharding = None
|
||||
return aval.update(shape=new_shape, sharding=new_sharding)
|
||||
core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array
|
||||
|
||||
@ -979,11 +971,8 @@ class ShardMapTracer(core.Tracer):
|
||||
def aval(self):
|
||||
aval = core.get_aval(self.val)
|
||||
out = core.mapped_aval(self._trace.mesh.size, 0, aval)
|
||||
if config.sharding_in_types.value:
|
||||
new_sharding = NamedSharding(_as_manual_mesh(self._trace.mesh),
|
||||
out.sharding.spec) # pytype: disable=attribute-error
|
||||
else:
|
||||
new_sharding = None
|
||||
return out.update(sharding=new_sharding)
|
||||
|
||||
def to_concrete_value(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user