1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-20 05:46:06 +00:00

[sharding_in_types] Add vmap + explicit sharding support. The main changes are:

* Track `explicit_mesh_axis` on `AxisData`.
* Modify `unmapped_aval` to the the above explicit mesh axis and insert it into the right place in the sharding so out_shardings are correct.
* Make `matchaxis` also handle shardings correctly
* All mapped dimensions should be sharded the same way
* spmd_axis_name and explicit sharded arrays cannot be used together
* `out_shardings` parameter on `dot_general`, `broadcast_in_dim`, `reshape`, `reshard` and `mesh_cast` is handled correctly in presence of vmap.

This should eventually help us get rid of `spmd_axis_name` from `vmap`.

PiperOrigin-RevId: 721007659
This commit is contained in:
Yash Katariya 2025-01-29 09:33:44 -08:00 committed by jax authors
parent 20843643ab
commit dcb28f1218
14 changed files with 292 additions and 89 deletions

@ -815,8 +815,7 @@ def vmap(fun: F,
out_axes: Any = 0,
axis_name: AxisName | None = None,
axis_size: int | None = None,
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None
) -> F:
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None) -> F:
"""Vectorizing map. Creates a function which maps ``fun`` over argument axes.
Args:
@ -989,8 +988,15 @@ def vmap(fun: F,
in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
axis_size_ = (axis_size if axis_size is not None else
_mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap"))
explicit_mesh_axis = _mapped_axis_spec(args_flat, in_axes_flat)
if spmd_axis_name is not None and explicit_mesh_axis is not None:
raise ValueError(
"Only one of spmd_axis_name or arrays sharded on `Explicit` mesh"
f" axis type is allowed. Got {spmd_axis_name=} and"
f" arrays sharded on {explicit_mesh_axis=}")
try:
axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name)
axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name,
explicit_mesh_axis)
out_flat = batching.batch(
flat_fun, axis_data, in_axes_flat,
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
@ -1007,6 +1013,28 @@ def vmap(fun: F,
return cast(F, vmap_f)
def _mapped_axis_spec(args_flat, in_axes_flat):
if not config.sharding_in_types.value:
return None
def _get_spec(arg, i):
try:
# Duck type arrays like BCOO arrays can be passed to vmap.
return shaped_abstractify(arg).sharding.spec[i]
except TypeError:
return None
temp_spec = None
for arg, i in zip(args_flat, in_axes_flat):
if i is not None:
spec = _get_spec(arg, i)
if temp_spec is not None and temp_spec != spec:
raise ValueError(
"Mapped away dimension of inputs passed to vmap should be sharded"
f" the same. Got inconsistent axis specs: {temp_spec} vs {spec}")
temp_spec = spec
return temp_spec
def _mapped_axis_size(fn, tree, vals, dims, name):
if not vals:
args, kwargs = tree_unflatten(tree, vals)

@ -2417,10 +2417,10 @@ def mapped_aval(size: AxisSize, axis: int | None,
raise TypeError(f"no mapping handler for {aval} of type {type(aval)}")
def unmapped_aval(size: AxisSize, axis: int | None,
aval: AbstractValue) -> AbstractValue:
aval: AbstractValue, explicit_mesh_axis=None) -> AbstractValue:
_, handler = aval_mapping_handlers.get(type(aval), (None, None))
if handler is not None:
return handler(size, axis, aval)
return handler(size, axis, explicit_mesh_axis, aval)
else:
raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}")
@ -2436,10 +2436,12 @@ def _map_shaped_array(
weak_type=aval.weak_type, sharding=sharding)
def _unmap_shaped_array(
size: int, axis: int | None, aval: ShapedArray) -> ShapedArray:
size: int, axis: int | None, explicit_mesh_axis, aval: ShapedArray
) -> ShapedArray:
if axis is None: return aval
elif type(axis) is int:
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis, None))
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis,
explicit_mesh_axis))
if config.sharding_in_types.value else None)
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
weak_type=aval.weak_type, sharding=sharding)
@ -2452,7 +2454,7 @@ def _map_dshaped_array(
aval.weak_type)
def _unmap_dshaped_array(
size: AxisSize, axis: int | None, aval: DShapedArray
size: AxisSize, axis: int | None, explicit_mesh_axis, aval: DShapedArray
) -> DShapedArray:
if axis is None: return aval
elif type(axis) is int:
@ -2465,7 +2467,7 @@ AvalMapHandlerPair = tuple[Callable, Callable]
aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
AbstractToken: (lambda _, __, a: a, lambda _, __, a: a)
AbstractToken: (lambda _, __, a: a, lambda _, __, ____, a: a)
}
# When a mapped function is given no axis name, we generate a name object based

@ -217,7 +217,7 @@ def maybe_bdim_at_front(x, bdim):
# axes instead of accepting and matching a given spec of output axes. Assumes
# `f` is pytree-flattened
def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size):
axis_data = batching.AxisData(axis_name, axis_size, None)
axis_data = batching.AxisData(axis_name, axis_size, None, None)
tag = core.TraceTag()
f, out_axes = batching.batch_subtrace(f, tag, axis_data, in_axes)
outs = f.call_wrapped(*args)

@ -26,6 +26,8 @@ from jax._src import config
from jax._src import core
from jax._src import source_info_util
from jax._src import linear_util as lu
from jax._src.partition_spec import PartitionSpec as P
from jax._src.sharding_impls import NamedSharding
from jax._src.ad_util import (Zero, instantiate, SymbolicZero,
replace_rule_output_symbolic_zeros,
add_jaxvals, add_jaxvals_p)
@ -36,7 +38,7 @@ from jax._src.tree_util import (tree_unflatten, tree_flatten,
from jax._src.typing import Array
from jax._src.util import (unzip2, safe_map, safe_zip, split_list,
canonicalize_axis, moveaxis, as_hashable_function,
curry, memoize, weakref_lru_cache)
curry, memoize, weakref_lru_cache, tuple_insert)
map, unsafe_map = safe_map, map
@ -241,6 +243,7 @@ Vmappable = Any
Elt = Any
MapSpec = Any
AxisSize = Any
MeshAxis = Any
GetIdx = Callable[[], Tracer] # TODO(mattjj): revise this laziness
ToEltHandler = Callable[[Callable, GetIdx, Vmappable, MapSpec], Elt]
FromEltHandler = Callable[[Callable, AxisSize, Elt, MapSpec], Vmappable]
@ -277,12 +280,12 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
to_elt_handlers: dict[type, ToEltHandler] = {}
def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int,
x: Elt, spec: MapSpec) -> Vmappable:
def from_elt(trace: BatchTrace, axis_size: AxisSize, mesh_axis: MeshAxis,
i: int, x: Elt, spec: MapSpec) -> Vmappable:
handler = from_elt_handlers.get(type(x))
if handler:
def _cont(axis_size, elt, axis):
return from_elt(trace, axis_size, i, elt, axis)
return from_elt(trace, axis_size, mesh_axis, i, elt, axis)
return handler(_cont, axis_size, x, spec)
val, bdim = trace.to_batch_info(x)
if type(bdim) is RaggedAxis:
@ -292,7 +295,8 @@ def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int,
return _jumble_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val)
else:
try:
return matchaxis(trace.axis_data.name, axis_size, bdim, spec, val)
return matchaxis(trace.axis_data.name, axis_size, mesh_axis,
bdim, spec, val)
except SpecMatchError:
raise SpecMatchError(i, x.batch_dim, spec) from None
from_elt_handlers: dict[type, FromEltHandler] = {}
@ -441,7 +445,17 @@ class BatchTracer(Tracer):
class AxisData:
name : Any
size : Any
# Only one of spmd_axis_name and explicit_mesh_axis is set.
spmd_name : Any
explicit_mesh_axis: Any
def get_sharding_for_vmap(axis_data, orig_sharding, axis):
if orig_sharding.mesh.empty:
return None
val = axis_data.explicit_mesh_axis
new_spec = P(*tuple_insert(orig_sharding.spec, axis, val))
return NamedSharding(orig_sharding.mesh, new_spec)
class BatchTrace(Trace):
@ -472,7 +486,8 @@ class BatchTrace(Trace):
return p.bind_with_trace(self.parent_trace, vals_in, params)
else:
with core.set_current_trace(self.parent_trace):
val_out, dim_out = fancy_primitive_batchers[p](self.axis_data, vals_in, dims_in, **params)
val_out, dim_out = fancy_primitive_batchers[p](
self.axis_data, vals_in, dims_in, **params)
elif args_not_mapped:
# no-op shortcut
return p.bind_with_trace(self.parent_trace, vals_in, params)
@ -605,8 +620,9 @@ def _batch_inner(f: Callable, axis_data, out_dim_dests, tag, in_dims, *in_vals):
outs = f(*in_tracers)
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)),
outs, out_dim_dests)
out_vals = map(partial(from_elt, trace, axis_data.size,
axis_data.explicit_mesh_axis),
range(len(outs)), outs, out_dim_dests)
return out_vals, trace
@ -639,7 +655,7 @@ def vtile(f_flat: lu.WrappedFun,
outputs_flat = f(*map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat))
return map(untile_axis, outputs_flat, out_axes_flat)
axis_data = AxisData(axis_name, tile_size, None)
axis_data = AxisData(axis_name, tile_size, None, None)
return _map_to_tile(batch(f_flat, axis_data, in_axes_flat, out_axes_flat))
### API for batching functions with jaxpr type inputs and outputs
@ -751,7 +767,8 @@ def _batch_jaxpr2(
handle_ragged(closed_jaxpr.in_avals, dim, aval)
if isinstance(dim, RaggedAxis) else (dim, aval)
for dim, aval in zip(in_axes, closed_jaxpr.in_avals)])
avals_in2 = [core.unmapped_aval(axis_data.size, b, aval)
avals_in2 = [core.unmapped_aval(axis_data.size, b, aval,
axis_data.explicit_mesh_axis)
if b is not not_mapped else aval
for aval, b in unsafe_zip(avals_in, in_axes2)]
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2)
@ -788,7 +805,9 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest):
f, out_axes = _batch_jaxpr_inner(f, axis_data)
f, out_batched = _match_axes_jaxpr(f, axis_data, out_axes_dest, out_axes)
f = _batch_jaxpr_outer(f, axis_data, in_axes)
avals_in = [core.unmapped_aval(axis_data.size, b, aval) if b is not not_mapped
avals_in = [core.unmapped_aval(axis_data.size, b, aval,
axis_data.explicit_mesh_axis)
if b is not not_mapped
else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
@ -821,7 +840,8 @@ def _match_axes_jaxpr(f, store, axis_data, out_axes_dest, out_axes, trace, in_ax
if len(out_axes_dest) != len(out_axes):
out_axis_dest, = out_axes_dest
out_axes_dest = [out_axis_dest] * len(out_axes)
out_vals = map(partial(matchaxis, axis_data.name, axis_data.size),
out_vals = map(partial(matchaxis, axis_data.name, axis_data.size,
axis_data.explicit_mesh_axis),
out_axes, out_axes_dest, out_vals)
out_batched = [dst is not None for dst in out_axes_dest]
store.store(out_batched)
@ -853,6 +873,7 @@ zero_if_mapped = ZeroIfMapped()
@lu.transformation_with_aux2
def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
size = axis_data.size
mesh_axis = axis_data.explicit_mesh_axis
with core.take_current_trace() as parent_trace:
trace = BatchTrace(parent_trace, tag, axis_data)
in_tracers = [val if dim is None else
@ -869,9 +890,9 @@ def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2])
out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2])
out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds)
out_primals = map(partial(matchaxis, trace.axis_data.name, size),
out_primals = map(partial(matchaxis, trace.axis_data.name, size, mesh_axis),
out_primal_bds, out_dims, out_primals)
out_tangents = map(partial(matchaxis, trace.axis_data.name, size),
out_tangents = map(partial(matchaxis, trace.axis_data.name, size, mesh_axis),
out_tangent_bds, out_dims, out_tangents)
store.store(out_dims * 2)
return out_primals + out_tangents
@ -879,6 +900,7 @@ def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests):
axis_size = axis_data.size
axis_name = axis_data.name
mesh_axis = axis_data.explicit_mesh_axis
def new_bwd(*args):
in_dims_ = in_dims() if callable(in_dims) else in_dims
args = [SymbolicZero(core.mapped_aval(axis_size, dim, x.aval))
@ -887,19 +909,22 @@ def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests):
in_dims_ = [None if type(x) is SymbolicZero else d
for x, d in zip(args, in_dims_)]
bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd), tag, axis_data, in_dims_)
bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk,
out_dim_dests)
bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, mesh_axis,
out_dims_thunk, out_dim_dests)
return bwd_.call_wrapped(*args)
return new_bwd
@lu.transformation2
def _match_axes_and_sum(f, axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals):
def _match_axes_and_sum(f, axis_size, axis_name, mesh_axis, out_dims_thunk,
out_dim_dests, *in_vals):
# this is like _match_axes, but we do reduce-sums as needed
out_vals = f(*in_vals)
return map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name,
sum_match=True), out_dims_thunk(), out_dim_dests, out_vals)
return map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, mesh_axis,
axis_name, sum_match=True),
out_dims_thunk(), out_dim_dests, out_vals)
def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False):
def _matchaxis_symbolic_zeros(axis_name, sz, mesh_axis, name, src, dst, x,
sum_match=False):
# Just like `matchaxis`, but handles symbolic zeros using ad_util.py
# TODO(mattjj): dedup with matchaxis
if isinstance(x, (Zero, SymbolicZero)):
@ -907,15 +932,15 @@ def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False)
return x
elif type(src) == type(dst) == int:
aval = core.mapped_aval(sz, src, x.aval)
return Zero(core.unmapped_aval(sz, dst, aval))
return Zero(core.unmapped_aval(sz, dst, aval, mesh_axis))
elif src is not_mapped and dst is not not_mapped:
return Zero(core.unmapped_aval(sz, dst, x.aval))
return Zero(core.unmapped_aval(sz, dst, x.aval, mesh_axis))
elif dst is not_mapped and sum_match:
return Zero(core.mapped_aval(sz, src, x.aval))
else:
raise ValueError((axis_name, x, src, dst))
else:
return matchaxis(axis_name, sz, src, dst, x, sum_match=sum_match)
return matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=sum_match)
### utilities for defining primitives' batching rules
@ -1057,13 +1082,19 @@ def move_stacked_axis(operand, bdim, dst):
### general utilities for manipulating axes on jaxpr types (not vmappables)
def broadcast(x, sz, axis):
def broadcast(x, sz, axis, mesh_axis=None):
shape = list(np.shape(x))
shape.insert(axis, sz)
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)
if config.sharding_in_types.value:
x_aval = core.get_aval(x)
new_spec = P(*tuple_insert(x_aval.sharding.spec, axis, mesh_axis))
sharding = x_aval.sharding.with_spec(new_spec)
else:
sharding = None
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, sharding=sharding)
def matchaxis(axis_name, sz, src, dst, x, sum_match=False):
def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False):
if dst == jumble_axis:
x = bdim_at_front(x, src, sz)
elt_ty = x.aval.update(shape=x.shape[1:])
@ -1080,7 +1111,7 @@ def matchaxis(axis_name, sz, src, dst, x, sum_match=False):
elif type(src) == type(dst) == int:
return moveaxis(x, src, dst)
elif src is not_mapped and dst is not not_mapped:
return broadcast(x, sz, canonicalize_axis(dst, np.ndim(x) + 1))
return broadcast(x, sz, canonicalize_axis(dst, np.ndim(x) + 1), mesh_axis)
elif dst is not_mapped and sum_match:
return x.sum(src)
else:

@ -291,7 +291,8 @@ def _for_vmap(axis_data, args, dims, *,
batched = map(operator.or_, batched, out_batched)
else:
raise Exception("Invalid fixpoint")
args = [batching.broadcast(x, axis_data.size, 0) if now_bat and not was_bat
args = [batching.broadcast(x, axis_data.size, 0, axis_data.explicit_mesh_axis)
if now_bat and not was_bat
else batching.moveaxis(x, d, 0) if now_bat else x
for x, d, was_bat, now_bat in zip(args, dims, init_batched, batched)]
batched_jaxpr_, _ = batching.batch_jaxpr(

@ -951,7 +951,8 @@ def _scan_batching_rule(axis_data, args,
consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry])
new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
else x for x, d in zip(consts, consts_bdims)]
new_init = [batching.broadcast(x, axis_data.size, 0) if now_batched and not was_batched
new_init = [batching.broadcast(x, axis_data.size, 0, axis_data.explicit_mesh_axis)
if now_batched and not was_batched
else batching.moveaxis(x, d, 0) if now_batched else x
for x, d, was_batched, now_batched in
zip(init, init_bdims, init_batched, carry_batched)]
@ -1492,7 +1493,8 @@ def _while_loop_batching_rule(axis_data, args, dims, cond_nconsts, cond_jaxpr,
new_init = []
for x, old_axis, new_axis in zip(init, init_dims, carry_dims):
if old_axis is batching.not_mapped and new_axis is not batching.not_mapped:
new_init.append(batching.broadcast(x, axis_data.size, new_axis))
new_init.append(batching.broadcast(x, axis_data.size, new_axis,
axis_data.explicit_mesh_axis))
elif old_axis is batching.not_mapped and new_axis is batching.not_mapped:
new_init.append(x)
else:

@ -465,7 +465,8 @@ def _linear_solve_batching_rule(axis_data, args, dims, const_lengths, jaxprs):
]
# Broadcast out b if necessary
new_b = [
batching.broadcast(x, axis_data.size, 0) if now_bat and not was_bat else
batching.broadcast(x, axis_data.size, 0, axis_data.explicit_mesh_axis)
if now_bat and not was_bat else
batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
]

@ -64,7 +64,6 @@ from jax._src.lax.utils import (
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.mesh import get_abstract_mesh
from jax._src.sharding_impls import (PmapSharding, NamedSharding,
PartitionSpec as P, canonicalize_sharding)
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
@ -1932,13 +1931,13 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
fill_value = _convert_element_type(fill_value, dtype, weak_type)
if (sharding is not None and not isinstance(sharding, PmapSharding) and
isinstance(fill_value, array.ArrayImpl) and sharding.is_concrete):
isinstance(fill_value, array.ArrayImpl) and sharding._is_concrete):
broadcast_shape = sharding.shard_shape(shape)
shard = broadcast(fill_value, broadcast_shape)
return array.make_array_from_callback(shape, sharding, lambda _: shard)
if (config.sharding_in_types.value and sharding is not None and
not sharding.is_concrete):
not sharding._is_concrete):
return broadcast(fill_value, shape, sharding=sharding)
else:
return broadcast(fill_value, shape)
@ -3185,7 +3184,7 @@ def _convert_element_type_sharding_rule(operand, *, new_dtype, weak_type,
sharding):
if sharding is None:
return operand.sharding
if sharding.is_concrete:
if sharding._is_concrete:
if isinstance(sharding, NamedSharding):
return NamedSharding(sharding.mesh.abstract_mesh, sharding.spec)
else:
@ -3274,7 +3273,7 @@ convert_element_type_p = Primitive('convert_element_type')
def _convert_element_type_bind_with_trace(trace, args, params):
sharding = params['sharding']
operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
if sharding is not None and sharding.is_concrete:
if sharding is not None and sharding._is_concrete:
with core.set_current_trace(trace):
operand = pjit.with_sharding_constraint(operand, sharding)
return operand
@ -3703,6 +3702,7 @@ def _dot_batch_rule(
unpack_args,
unpack_dims,
invoke_prim,
axis_data,
batched_args,
batch_dims,
*,
@ -3737,13 +3737,15 @@ def _dot_batch_rule(
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
else:
rhs_shape = np.shape(rhs)
result_batch_dim = batching.shape_as_bdim(
result_stack_dim,
_dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers))
if out_sharding is not None:
cur_mesh = get_abstract_mesh()
if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual:
out_sharding = None
else:
raise NotImplementedError("vmap with out_sharding is not supported. "
"Please open an issue.")
out_sharding = batching.get_sharding_for_vmap(
axis_data, out_sharding, result_batch_dim)
batched_out = invoke_prim(
lhs,
rhs,
@ -3752,9 +3754,6 @@ def _dot_batch_rule(
preferred_element_type=preferred_element_type,
out_sharding=out_sharding,
)
result_batch_dim = batching.shape_as_bdim(
result_stack_dim,
_dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers))
return batched_out, result_batch_dim
@ -3919,7 +3918,8 @@ _dot_general_batch_rule = functools.partial(
_dot_general_batch_unpack_dims,
dot_general,
)
batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
batching.fancy_primitive_batchers[dot_general_p] = _dot_general_batch_rule
batching.skippable_batchers[dot_general_p] = lambda _: ()
pe.padding_rules[dot_general_p] = _dot_general_padding_rule
core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule
batching.ragged_prop_rules[dot_general_p] = _dot_general_ragged_prop_rule
@ -4261,6 +4261,7 @@ def _ragged_dot_invoke_prim(
def _ragged_dot_batch_rule(
axis_data,
batched_args,
batch_dims,
*,
@ -4274,6 +4275,7 @@ def _ragged_dot_batch_rule(
_ragged_dot_batch_unpack_args,
_ragged_dot_batch_unpack_dims,
invoke,
axis_data,
batched_args,
batch_dims,
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
@ -4288,7 +4290,8 @@ ragged_dot_p = standard_primitive(_ragged_dot_shape_rule,
ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p))
ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule
ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule
batching.primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule
batching.fancy_primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule
batching.skippable_batchers[ragged_dot_p] = lambda _: ()
def _ragged_dot_impl(
lhs: Array,
@ -4393,7 +4396,7 @@ def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape,
return ([expand_dims(_reduce_sum(ct, axes), unit_dims)] +
[None] * len(dyn_shape))
def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
def _broadcast_in_dim_batch_rule(axis_data, batched_args, batch_dims, shape,
broadcast_dimensions, sharding):
# `dyn_shape` is the dynamic portion of the target shape. `shape`
# is the target shape, with `None` for dynamic sections.
@ -4407,13 +4410,11 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
if operand_bdim is not None:
if isinstance(operand_bdim, RaggedAxis):
stacked_axis = operand_bdim.stacked_axis
else:
stacked_axis = operand_bdim
new_operand = batching.moveaxis(operand, stacked_axis, 0)
if isinstance(operand_bdim, RaggedAxis):
stacked_size = operand_bdim.size
else:
stacked_axis = operand_bdim
stacked_size = operand.shape[stacked_axis]
new_operand = batching.moveaxis(operand, stacked_axis, 0)
new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
else:
new_operand = operand
@ -4440,12 +4441,10 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
assert len(sizes) == stacked_size, msg
dyn_limits.append(bound)
new_shape = (stacked_size,) + _merge_dyn_shape(shape, dyn_limits)
if sharding is not None:
if sharding.mesh._are_all_axes_auto or sharding.mesh._are_all_axes_manual:
sharding = None
else:
raise NotImplementedError('Implement sharding support for '
'broadcast_in_dim_batch_rule')
sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0)
result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions,
sharding=sharding)
out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None]
@ -4569,7 +4568,8 @@ broadcast_in_dim_p = standard_primitive(
broadcast_in_dim_p.def_abstract_eval(_broadcast_in_dim_abstract_eval)
ad.primitive_jvps[broadcast_in_dim_p] = _broadcast_in_dim_jvp_rule
ad.primitive_transposes[broadcast_in_dim_p] = _broadcast_in_dim_transpose_rule
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
batching.fancy_primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
batching.skippable_batchers[broadcast_in_dim_p] = lambda _: ()
pe.forwarding_rules[broadcast_in_dim_p] = _broadcast_in_dim_fwd_rule
pe.custom_partial_eval_rules[broadcast_in_dim_p] = _broadcast_in_dim_partial_eval
pe.custom_staging_rules[broadcast_in_dim_p] = _broadcast_in_dim_staging_rule
@ -5130,18 +5130,17 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding):
sharding=t_s),
np.argsort(dimensions))]
def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions,
sharding):
if sharding is not None:
if sharding.mesh._are_all_axes_manual or sharding.mesh._are_all_axes_auto:
sharding = None
else:
raise NotImplementedError('reshape batch sharding support')
def _reshape_batch_rule(axis_data, batched_args, batch_dims, *, new_sizes,
dimensions, sharding):
operand, = batched_args
bdim, = batch_dims
operand = batching.moveaxis(operand, bdim, 0)
if dimensions is not None:
dimensions = (0,) + tuple(np.add(1, dimensions))
if sharding is not None:
sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0)
out = reshape(operand, operand.shape[:1] + new_sizes, dimensions,
sharding=sharding)
return out, 0
@ -5171,7 +5170,8 @@ def _reshape_staging_rule(
reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule,
'reshape', sharding_rule=_reshape_sharding_rule)
ad.deflinear2(reshape_p, _reshape_transpose_rule)
batching.primitive_batchers[reshape_p] = _reshape_batch_rule
batching.fancy_primitive_batchers[reshape_p] = _reshape_batch_rule
batching.skippable_batchers[reshape_p] = lambda _: ()
mlir.register_lowering(reshape_p, _reshape_lower)
core.custom_typechecks[reshape_p] = _reshape_typecheck_rule
pe.custom_staging_rules[reshape_p] = _reshape_staging_rule

@ -1604,7 +1604,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
'Please see the jax.Array migration guide for more information '
'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. '
f'Got arg shape: {arg.shape}, arg value: {arg}')
if not isinstance(arg_s, UnspecifiedValue) and arg_s.is_concrete:
if not isinstance(arg_s, UnspecifiedValue) and arg_s._is_concrete:
# jax.jit does not allow resharding across different memory kinds even
# if the argument is uncommitted. Use jax.device_put for those cases,
# either outside or inside jax.jit.
@ -2763,10 +2763,8 @@ def _mesh_cast_batcher(axis_data, vals_in, dims_in, dst_sharding):
assert axis_data.spmd_name is None
x, = vals_in
d, = dims_in
val = None
new_spec = PartitionSpec(*util.tuple_insert(dst_sharding.spec, d, val))
vmapped_dst_sharding = NamedSharding(dst_sharding.mesh, new_spec)
vmapped_dst_sharding = batching.get_sharding_for_vmap(
axis_data, dst_sharding, d)
y = mesh_cast_p.bind(x, dst_sharding=vmapped_dst_sharding)
return y, d
batching.fancy_primitive_batchers[mesh_cast_p] = _mesh_cast_batcher
@ -2814,6 +2812,17 @@ def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding):
return [mlir.lower_sharding_under_shit(ctx, x_node, aval_out, proto)]
mlir.register_lowering(reshard_p, _reshard_hlo_lowering)
def _reshard_batcher(axis_data, vals_in, dims_in, dst_sharding):
assert axis_data.spmd_name is None
x, = vals_in
d, = dims_in
vmapped_dst_sharding = batching.get_sharding_for_vmap(
axis_data, dst_sharding, d)
y = reshard_p.bind(x, dst_sharding=vmapped_dst_sharding)
return y, d
batching.fancy_primitive_batchers[reshard_p] = _reshard_batcher
batching.skippable_batchers[reshard_p] = lambda _: ()
# -------------------- auto and user mode -------------------------
def _get_new_mesh(axes: str | tuple[str, ...] | None,

@ -139,7 +139,7 @@ class Sharding:
# Default implementations below that all subclasses will inherit.
@property
def is_concrete(self) -> bool:
def _is_concrete(self) -> bool:
return True
@functools.cached_property

@ -392,7 +392,7 @@ class NamedSharding(jsharding.Sharding):
return not self.mesh.is_multi_process
@property
def is_concrete(self) -> bool:
def _is_concrete(self) -> bool:
if isinstance(self.mesh, mesh_lib.AbstractMesh):
return False
return True

@ -376,8 +376,9 @@ class AbstractRef(core.AbstractValue):
def _map_ref(size, axis, ref_aval):
return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval))
def _unmap_ref(size, axis, ref_aval):
return AbstractRef(core.unmapped_aval(size, axis, ref_aval.inner_aval))
def _unmap_ref(size, axis, explicit_mesh_axis, ref_aval):
return AbstractRef(core.unmapped_aval(
size, axis, ref_aval.inner_aval, explicit_mesh_axis))
core.aval_mapping_handlers[AbstractRef] = (_map_ref, _unmap_ref)

@ -1449,7 +1449,8 @@ def _shard_map_batch(
new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped
else ns for ns, d in zip(new_in_names, in_dims)]
new_size = trace.axis_data.size // prod(mesh.shape[n] for n in spmd_axis_name)
new_axis_data = batching.AxisData(trace.axis_data.name, new_size, trace.axis_data.spmd_name)
new_axis_data = batching.AxisData(trace.axis_data.name, new_size,
trace.axis_data.spmd_name, None)
else:
new_axis_data = trace.axis_data
fun, out_dims = batching.batch_subtrace(fun, trace.tag, new_axis_data, tuple(in_dims))

@ -4912,7 +4912,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
@parameterized.parameters([True, False])
@jtu.with_user_mesh((4,), ('x',))
def test_dot_general_out_sharding(self, jit, mesh):
def test_dot_general_out_sharding(self, use_jit, mesh):
np_inp1 = np.arange(16.).reshape(8, 2)
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None)))
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x')))
@ -4922,7 +4922,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out.aval.sharding.spec, P('x', None))
return jnp.sum(out)
if jit:
if use_jit:
f = jax.jit(f)
out = f(arr1, arr2)
@ -4939,7 +4939,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
self.assertEqual(out[0].sharding, arr1.sharding)
self.assertEqual(out[1].sharding, arr2.sharding)
if jit:
if use_jit:
jitted_grad = jax.jit(jax.grad(f, argnums=(0, 1)))
out = jitted_grad(arr1, arr2)
self.assertEqual(out[0].sharding, arr1.sharding)
@ -6226,6 +6226,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
out = jax.jit(jax.grad(g))(arr)
self.assertEqual(out.sharding, arr.sharding)
def f_vmap(x):
self.assertEqual(x.aval.sharding.spec, P('y'))
y = reshard(x, P(None))
self.assertEqual(y.aval.sharding.spec, P(None))
return y
out = jax.vmap(f_vmap)(arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
out = jax.jit(jax.vmap(jax.jit(f_vmap)))(arr)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
@jax.jit
def h(x):
with use_auto_axes('x'):
@ -6394,6 +6406,121 @@ class ShardingInTypesTest(jtu.JaxTestCase):
lowered_text = f.lower(aval, aval2).as_text()
self.assertNotIn("mhlo.sharding", lowered_text)
@parameterized.parameters(True, False)
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_mul_vmap(self, use_jit, mesh):
np_inp = np.arange(16.).reshape(8, 2)
s = NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
def f(x):
self.assertEqual(x.aval.sharding.spec, P(s.spec[1]))
x = x * 2
self.assertEqual(x.aval.sharding.spec, P(s.spec[1]))
x = x * x
self.assertEqual(x.aval.sharding.spec, P(s.spec[1]))
return x
if use_jit:
f = jax.jit(f)
f = jax.vmap(f)
out = f(arr)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))
out = jax.jit(f)(arr)
self.assertEqual(out.sharding, s)
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))
def g(x):
return jnp.sum(f(x))
out = jax.grad(g)(arr)
self.assertEqual(out.sharding, arr.sharding)
out = jax.jit(jax.grad(g))(arr)
self.assertEqual(out.sharding, arr.sharding)
@parameterized.parameters(True, False)
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_dot_general_vmap(self, use_jit, mesh):
np_inp1 = np.arange(16.).reshape(4, 2, 2)
np_inp2 = np.arange(16.).reshape(2, 4, 2)
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None, 'y')))
arr2 = jax.device_put(np_inp2, NamedSharding(mesh, P(None, 'x', 'y')))
def f(x, y):
return jnp.einsum('xy,yz->xz', x, y, out_sharding=P(None, 'y'))
if use_jit:
f = jax.jit(f)
f = jax.vmap(f, in_axes=(0, 1), out_axes=2)
out = f(arr1, arr2)
self.assertEqual(out.shape, (2, 2, 4))
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_reshape_vmap(self, mesh):
np_inp = np.arange(16).reshape(2, 8)
arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, 'x')))
def f(x):
y = lax.reshape(x, (1, 2), sharding=P(None, 'y'))
y = y * 2
self.assertEqual(y.aval.sharding.spec, P(None, 'y'))
return y
out = jax.jit(jax.vmap(f, in_axes=1))(arr)
self.assertEqual(out.shape, (8, 1, 2))
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, 'y')))
@parameterized.parameters(True, False)
@jtu.with_user_mesh((2, 2), ('x', 'y'))
def test_shit_vmap_error_check(self, use_jit, mesh):
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None)))
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P(None, 'y')))
def f(x, y):
return x @ y
if use_jit:
f = jax.jit(f)
with self.assertRaisesRegex(
ValueError,
"Mapped away dimension of inputs passed to vmap should be sharded "
"the same"):
jax.vmap(f, in_axes=(0, 1))(arr, arr2)
with self.assertRaisesRegex(
ValueError,
'Mapped away dimension of inputs passed to vmap should be sharded the'
' same'):
jax.jit(jax.vmap(f, in_axes=(0, 1)))(arr, arr2)
with self.assertRaisesRegex(
ValueError,
"Only one of spmd_axis_name or arrays sharded on.*spmd_axis_name"):
jax.vmap(f, spmd_axis_name='y')(arr, arr)
@jtu.with_user_mesh((2,), ('x',))
def test_unmapped_last_vmap(self, mesh):
np_inp = np.arange(8)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x',)))
@partial(jax.vmap, out_axes=-1)
def f(x):
return np.zeros((4,))
out = f(arr)
self.assertEqual(out.shape, (4, 8))
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'x')))
@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):