mirror of
https://github.com/ROCm/jax.git
synced 2025-04-20 05:46:06 +00:00
[sharding_in_types] Add vmap + explicit sharding support. The main changes are:
* Track `explicit_mesh_axis` on `AxisData`. * Modify `unmapped_aval` to the the above explicit mesh axis and insert it into the right place in the sharding so out_shardings are correct. * Make `matchaxis` also handle shardings correctly * All mapped dimensions should be sharded the same way * spmd_axis_name and explicit sharded arrays cannot be used together * `out_shardings` parameter on `dot_general`, `broadcast_in_dim`, `reshape`, `reshard` and `mesh_cast` is handled correctly in presence of vmap. This should eventually help us get rid of `spmd_axis_name` from `vmap`. PiperOrigin-RevId: 721007659
This commit is contained in:
parent
20843643ab
commit
dcb28f1218
jax
_src
experimental
tests
@ -815,8 +815,7 @@ def vmap(fun: F,
|
||||
out_axes: Any = 0,
|
||||
axis_name: AxisName | None = None,
|
||||
axis_size: int | None = None,
|
||||
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None
|
||||
) -> F:
|
||||
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None) -> F:
|
||||
"""Vectorizing map. Creates a function which maps ``fun`` over argument axes.
|
||||
|
||||
Args:
|
||||
@ -989,8 +988,15 @@ def vmap(fun: F,
|
||||
in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
|
||||
axis_size_ = (axis_size if axis_size is not None else
|
||||
_mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap"))
|
||||
explicit_mesh_axis = _mapped_axis_spec(args_flat, in_axes_flat)
|
||||
if spmd_axis_name is not None and explicit_mesh_axis is not None:
|
||||
raise ValueError(
|
||||
"Only one of spmd_axis_name or arrays sharded on `Explicit` mesh"
|
||||
f" axis type is allowed. Got {spmd_axis_name=} and"
|
||||
f" arrays sharded on {explicit_mesh_axis=}")
|
||||
try:
|
||||
axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name)
|
||||
axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name,
|
||||
explicit_mesh_axis)
|
||||
out_flat = batching.batch(
|
||||
flat_fun, axis_data, in_axes_flat,
|
||||
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
|
||||
@ -1007,6 +1013,28 @@ def vmap(fun: F,
|
||||
|
||||
return cast(F, vmap_f)
|
||||
|
||||
def _mapped_axis_spec(args_flat, in_axes_flat):
|
||||
if not config.sharding_in_types.value:
|
||||
return None
|
||||
|
||||
def _get_spec(arg, i):
|
||||
try:
|
||||
# Duck type arrays like BCOO arrays can be passed to vmap.
|
||||
return shaped_abstractify(arg).sharding.spec[i]
|
||||
except TypeError:
|
||||
return None
|
||||
|
||||
temp_spec = None
|
||||
for arg, i in zip(args_flat, in_axes_flat):
|
||||
if i is not None:
|
||||
spec = _get_spec(arg, i)
|
||||
if temp_spec is not None and temp_spec != spec:
|
||||
raise ValueError(
|
||||
"Mapped away dimension of inputs passed to vmap should be sharded"
|
||||
f" the same. Got inconsistent axis specs: {temp_spec} vs {spec}")
|
||||
temp_spec = spec
|
||||
return temp_spec
|
||||
|
||||
def _mapped_axis_size(fn, tree, vals, dims, name):
|
||||
if not vals:
|
||||
args, kwargs = tree_unflatten(tree, vals)
|
||||
|
@ -2417,10 +2417,10 @@ def mapped_aval(size: AxisSize, axis: int | None,
|
||||
raise TypeError(f"no mapping handler for {aval} of type {type(aval)}")
|
||||
|
||||
def unmapped_aval(size: AxisSize, axis: int | None,
|
||||
aval: AbstractValue) -> AbstractValue:
|
||||
aval: AbstractValue, explicit_mesh_axis=None) -> AbstractValue:
|
||||
_, handler = aval_mapping_handlers.get(type(aval), (None, None))
|
||||
if handler is not None:
|
||||
return handler(size, axis, aval)
|
||||
return handler(size, axis, explicit_mesh_axis, aval)
|
||||
else:
|
||||
raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}")
|
||||
|
||||
@ -2436,10 +2436,12 @@ def _map_shaped_array(
|
||||
weak_type=aval.weak_type, sharding=sharding)
|
||||
|
||||
def _unmap_shaped_array(
|
||||
size: int, axis: int | None, aval: ShapedArray) -> ShapedArray:
|
||||
size: int, axis: int | None, explicit_mesh_axis, aval: ShapedArray
|
||||
) -> ShapedArray:
|
||||
if axis is None: return aval
|
||||
elif type(axis) is int:
|
||||
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis, None))
|
||||
sharding = (aval.sharding.with_spec(tuple_insert(aval.sharding.spec, axis,
|
||||
explicit_mesh_axis))
|
||||
if config.sharding_in_types.value else None)
|
||||
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
|
||||
weak_type=aval.weak_type, sharding=sharding)
|
||||
@ -2452,7 +2454,7 @@ def _map_dshaped_array(
|
||||
aval.weak_type)
|
||||
|
||||
def _unmap_dshaped_array(
|
||||
size: AxisSize, axis: int | None, aval: DShapedArray
|
||||
size: AxisSize, axis: int | None, explicit_mesh_axis, aval: DShapedArray
|
||||
) -> DShapedArray:
|
||||
if axis is None: return aval
|
||||
elif type(axis) is int:
|
||||
@ -2465,7 +2467,7 @@ AvalMapHandlerPair = tuple[Callable, Callable]
|
||||
aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
|
||||
DShapedArray: (_map_dshaped_array, _unmap_dshaped_array),
|
||||
ShapedArray: (_map_shaped_array, _unmap_shaped_array),
|
||||
AbstractToken: (lambda _, __, a: a, lambda _, __, a: a)
|
||||
AbstractToken: (lambda _, __, a: a, lambda _, __, ____, a: a)
|
||||
}
|
||||
|
||||
# When a mapped function is given no axis name, we generate a name object based
|
||||
|
@ -217,7 +217,7 @@ def maybe_bdim_at_front(x, bdim):
|
||||
# axes instead of accepting and matching a given spec of output axes. Assumes
|
||||
# `f` is pytree-flattened
|
||||
def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size):
|
||||
axis_data = batching.AxisData(axis_name, axis_size, None)
|
||||
axis_data = batching.AxisData(axis_name, axis_size, None, None)
|
||||
tag = core.TraceTag()
|
||||
f, out_axes = batching.batch_subtrace(f, tag, axis_data, in_axes)
|
||||
outs = f.call_wrapped(*args)
|
||||
|
@ -26,6 +26,8 @@ from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src.partition_spec import PartitionSpec as P
|
||||
from jax._src.sharding_impls import NamedSharding
|
||||
from jax._src.ad_util import (Zero, instantiate, SymbolicZero,
|
||||
replace_rule_output_symbolic_zeros,
|
||||
add_jaxvals, add_jaxvals_p)
|
||||
@ -36,7 +38,7 @@ from jax._src.tree_util import (tree_unflatten, tree_flatten,
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (unzip2, safe_map, safe_zip, split_list,
|
||||
canonicalize_axis, moveaxis, as_hashable_function,
|
||||
curry, memoize, weakref_lru_cache)
|
||||
curry, memoize, weakref_lru_cache, tuple_insert)
|
||||
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
@ -241,6 +243,7 @@ Vmappable = Any
|
||||
Elt = Any
|
||||
MapSpec = Any
|
||||
AxisSize = Any
|
||||
MeshAxis = Any
|
||||
GetIdx = Callable[[], Tracer] # TODO(mattjj): revise this laziness
|
||||
ToEltHandler = Callable[[Callable, GetIdx, Vmappable, MapSpec], Elt]
|
||||
FromEltHandler = Callable[[Callable, AxisSize, Elt, MapSpec], Vmappable]
|
||||
@ -277,12 +280,12 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
|
||||
|
||||
to_elt_handlers: dict[type, ToEltHandler] = {}
|
||||
|
||||
def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int,
|
||||
x: Elt, spec: MapSpec) -> Vmappable:
|
||||
def from_elt(trace: BatchTrace, axis_size: AxisSize, mesh_axis: MeshAxis,
|
||||
i: int, x: Elt, spec: MapSpec) -> Vmappable:
|
||||
handler = from_elt_handlers.get(type(x))
|
||||
if handler:
|
||||
def _cont(axis_size, elt, axis):
|
||||
return from_elt(trace, axis_size, i, elt, axis)
|
||||
return from_elt(trace, axis_size, mesh_axis, i, elt, axis)
|
||||
return handler(_cont, axis_size, x, spec)
|
||||
val, bdim = trace.to_batch_info(x)
|
||||
if type(bdim) is RaggedAxis:
|
||||
@ -292,7 +295,8 @@ def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int,
|
||||
return _jumble_result(axis_size, bdim.stacked_axis, bdim.ragged_axes, val)
|
||||
else:
|
||||
try:
|
||||
return matchaxis(trace.axis_data.name, axis_size, bdim, spec, val)
|
||||
return matchaxis(trace.axis_data.name, axis_size, mesh_axis,
|
||||
bdim, spec, val)
|
||||
except SpecMatchError:
|
||||
raise SpecMatchError(i, x.batch_dim, spec) from None
|
||||
from_elt_handlers: dict[type, FromEltHandler] = {}
|
||||
@ -441,7 +445,17 @@ class BatchTracer(Tracer):
|
||||
class AxisData:
|
||||
name : Any
|
||||
size : Any
|
||||
# Only one of spmd_axis_name and explicit_mesh_axis is set.
|
||||
spmd_name : Any
|
||||
explicit_mesh_axis: Any
|
||||
|
||||
|
||||
def get_sharding_for_vmap(axis_data, orig_sharding, axis):
|
||||
if orig_sharding.mesh.empty:
|
||||
return None
|
||||
val = axis_data.explicit_mesh_axis
|
||||
new_spec = P(*tuple_insert(orig_sharding.spec, axis, val))
|
||||
return NamedSharding(orig_sharding.mesh, new_spec)
|
||||
|
||||
|
||||
class BatchTrace(Trace):
|
||||
@ -472,7 +486,8 @@ class BatchTrace(Trace):
|
||||
return p.bind_with_trace(self.parent_trace, vals_in, params)
|
||||
else:
|
||||
with core.set_current_trace(self.parent_trace):
|
||||
val_out, dim_out = fancy_primitive_batchers[p](self.axis_data, vals_in, dims_in, **params)
|
||||
val_out, dim_out = fancy_primitive_batchers[p](
|
||||
self.axis_data, vals_in, dims_in, **params)
|
||||
elif args_not_mapped:
|
||||
# no-op shortcut
|
||||
return p.bind_with_trace(self.parent_trace, vals_in, params)
|
||||
@ -605,8 +620,9 @@ def _batch_inner(f: Callable, axis_data, out_dim_dests, tag, in_dims, *in_vals):
|
||||
outs = f(*in_tracers)
|
||||
|
||||
out_dim_dests = out_dim_dests() if callable(out_dim_dests) else out_dim_dests
|
||||
out_vals = map(partial(from_elt, trace, axis_data.size), range(len(outs)),
|
||||
outs, out_dim_dests)
|
||||
out_vals = map(partial(from_elt, trace, axis_data.size,
|
||||
axis_data.explicit_mesh_axis),
|
||||
range(len(outs)), outs, out_dim_dests)
|
||||
|
||||
return out_vals, trace
|
||||
|
||||
@ -639,7 +655,7 @@ def vtile(f_flat: lu.WrappedFun,
|
||||
outputs_flat = f(*map(tile_axis(tile_size=tile_size_), args_flat, in_axes_flat))
|
||||
return map(untile_axis, outputs_flat, out_axes_flat)
|
||||
|
||||
axis_data = AxisData(axis_name, tile_size, None)
|
||||
axis_data = AxisData(axis_name, tile_size, None, None)
|
||||
return _map_to_tile(batch(f_flat, axis_data, in_axes_flat, out_axes_flat))
|
||||
|
||||
### API for batching functions with jaxpr type inputs and outputs
|
||||
@ -751,7 +767,8 @@ def _batch_jaxpr2(
|
||||
handle_ragged(closed_jaxpr.in_avals, dim, aval)
|
||||
if isinstance(dim, RaggedAxis) else (dim, aval)
|
||||
for dim, aval in zip(in_axes, closed_jaxpr.in_avals)])
|
||||
avals_in2 = [core.unmapped_aval(axis_data.size, b, aval)
|
||||
avals_in2 = [core.unmapped_aval(axis_data.size, b, aval,
|
||||
axis_data.explicit_mesh_axis)
|
||||
if b is not not_mapped else aval
|
||||
for aval, b in unsafe_zip(avals_in, in_axes2)]
|
||||
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in2)
|
||||
@ -788,7 +805,9 @@ def _batch_jaxpr_axes(closed_jaxpr, axis_data, in_axes, out_axes_dest):
|
||||
f, out_axes = _batch_jaxpr_inner(f, axis_data)
|
||||
f, out_batched = _match_axes_jaxpr(f, axis_data, out_axes_dest, out_axes)
|
||||
f = _batch_jaxpr_outer(f, axis_data, in_axes)
|
||||
avals_in = [core.unmapped_aval(axis_data.size, b, aval) if b is not not_mapped
|
||||
avals_in = [core.unmapped_aval(axis_data.size, b, aval,
|
||||
axis_data.explicit_mesh_axis)
|
||||
if b is not not_mapped
|
||||
else aval for aval, b in unsafe_zip(closed_jaxpr.in_avals, in_axes)]
|
||||
jaxpr_out, _, consts, () = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
||||
return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
|
||||
@ -821,7 +840,8 @@ def _match_axes_jaxpr(f, store, axis_data, out_axes_dest, out_axes, trace, in_ax
|
||||
if len(out_axes_dest) != len(out_axes):
|
||||
out_axis_dest, = out_axes_dest
|
||||
out_axes_dest = [out_axis_dest] * len(out_axes)
|
||||
out_vals = map(partial(matchaxis, axis_data.name, axis_data.size),
|
||||
out_vals = map(partial(matchaxis, axis_data.name, axis_data.size,
|
||||
axis_data.explicit_mesh_axis),
|
||||
out_axes, out_axes_dest, out_vals)
|
||||
out_batched = [dst is not None for dst in out_axes_dest]
|
||||
store.store(out_batched)
|
||||
@ -853,6 +873,7 @@ zero_if_mapped = ZeroIfMapped()
|
||||
@lu.transformation_with_aux2
|
||||
def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
|
||||
size = axis_data.size
|
||||
mesh_axis = axis_data.explicit_mesh_axis
|
||||
with core.take_current_trace() as parent_trace:
|
||||
trace = BatchTrace(parent_trace, tag, axis_data)
|
||||
in_tracers = [val if dim is None else
|
||||
@ -869,9 +890,9 @@ def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
|
||||
out_primals, out_tangents = split_list(out_vals, [len(out_vals) // 2])
|
||||
out_primal_bds, out_tangent_bds = split_list(out_dims, [len(out_vals) // 2])
|
||||
out_dims = map(_merge_bdims, out_primal_bds, out_tangent_bds)
|
||||
out_primals = map(partial(matchaxis, trace.axis_data.name, size),
|
||||
out_primals = map(partial(matchaxis, trace.axis_data.name, size, mesh_axis),
|
||||
out_primal_bds, out_dims, out_primals)
|
||||
out_tangents = map(partial(matchaxis, trace.axis_data.name, size),
|
||||
out_tangents = map(partial(matchaxis, trace.axis_data.name, size, mesh_axis),
|
||||
out_tangent_bds, out_dims, out_tangents)
|
||||
store.store(out_dims * 2)
|
||||
return out_primals + out_tangents
|
||||
@ -879,6 +900,7 @@ def batch_custom_jvp_subtrace(f, store, tag, axis_data, in_dims, *in_vals):
|
||||
def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests):
|
||||
axis_size = axis_data.size
|
||||
axis_name = axis_data.name
|
||||
mesh_axis = axis_data.explicit_mesh_axis
|
||||
def new_bwd(*args):
|
||||
in_dims_ = in_dims() if callable(in_dims) else in_dims
|
||||
args = [SymbolicZero(core.mapped_aval(axis_size, dim, x.aval))
|
||||
@ -887,19 +909,22 @@ def batch_custom_vjp_bwd(bwd, tag, axis_data, in_dims, out_dim_dests):
|
||||
in_dims_ = [None if type(x) is SymbolicZero else d
|
||||
for x, d in zip(args, in_dims_)]
|
||||
bwd_, out_dims_thunk = batch_subtrace(lu.wrap_init(bwd), tag, axis_data, in_dims_)
|
||||
bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, out_dims_thunk,
|
||||
out_dim_dests)
|
||||
bwd_ = _match_axes_and_sum(bwd_, axis_size, axis_name, mesh_axis,
|
||||
out_dims_thunk, out_dim_dests)
|
||||
return bwd_.call_wrapped(*args)
|
||||
return new_bwd
|
||||
|
||||
@lu.transformation2
|
||||
def _match_axes_and_sum(f, axis_size, axis_name, out_dims_thunk, out_dim_dests, *in_vals):
|
||||
def _match_axes_and_sum(f, axis_size, axis_name, mesh_axis, out_dims_thunk,
|
||||
out_dim_dests, *in_vals):
|
||||
# this is like _match_axes, but we do reduce-sums as needed
|
||||
out_vals = f(*in_vals)
|
||||
return map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, axis_name,
|
||||
sum_match=True), out_dims_thunk(), out_dim_dests, out_vals)
|
||||
return map(partial(_matchaxis_symbolic_zeros, axis_name, axis_size, mesh_axis,
|
||||
axis_name, sum_match=True),
|
||||
out_dims_thunk(), out_dim_dests, out_vals)
|
||||
|
||||
def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False):
|
||||
def _matchaxis_symbolic_zeros(axis_name, sz, mesh_axis, name, src, dst, x,
|
||||
sum_match=False):
|
||||
# Just like `matchaxis`, but handles symbolic zeros using ad_util.py
|
||||
# TODO(mattjj): dedup with matchaxis
|
||||
if isinstance(x, (Zero, SymbolicZero)):
|
||||
@ -907,15 +932,15 @@ def _matchaxis_symbolic_zeros(axis_name, sz, name, src, dst, x, sum_match=False)
|
||||
return x
|
||||
elif type(src) == type(dst) == int:
|
||||
aval = core.mapped_aval(sz, src, x.aval)
|
||||
return Zero(core.unmapped_aval(sz, dst, aval))
|
||||
return Zero(core.unmapped_aval(sz, dst, aval, mesh_axis))
|
||||
elif src is not_mapped and dst is not not_mapped:
|
||||
return Zero(core.unmapped_aval(sz, dst, x.aval))
|
||||
return Zero(core.unmapped_aval(sz, dst, x.aval, mesh_axis))
|
||||
elif dst is not_mapped and sum_match:
|
||||
return Zero(core.mapped_aval(sz, src, x.aval))
|
||||
else:
|
||||
raise ValueError((axis_name, x, src, dst))
|
||||
else:
|
||||
return matchaxis(axis_name, sz, src, dst, x, sum_match=sum_match)
|
||||
return matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=sum_match)
|
||||
|
||||
|
||||
### utilities for defining primitives' batching rules
|
||||
@ -1057,13 +1082,19 @@ def move_stacked_axis(operand, bdim, dst):
|
||||
|
||||
### general utilities for manipulating axes on jaxpr types (not vmappables)
|
||||
|
||||
def broadcast(x, sz, axis):
|
||||
def broadcast(x, sz, axis, mesh_axis=None):
|
||||
shape = list(np.shape(x))
|
||||
shape.insert(axis, sz)
|
||||
broadcast_dims = tuple(np.delete(np.arange(len(shape)), axis))
|
||||
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims)
|
||||
if config.sharding_in_types.value:
|
||||
x_aval = core.get_aval(x)
|
||||
new_spec = P(*tuple_insert(x_aval.sharding.spec, axis, mesh_axis))
|
||||
sharding = x_aval.sharding.with_spec(new_spec)
|
||||
else:
|
||||
sharding = None
|
||||
return jax.lax.broadcast_in_dim(x, shape, broadcast_dims, sharding=sharding)
|
||||
|
||||
def matchaxis(axis_name, sz, src, dst, x, sum_match=False):
|
||||
def matchaxis(axis_name, sz, mesh_axis, src, dst, x, sum_match=False):
|
||||
if dst == jumble_axis:
|
||||
x = bdim_at_front(x, src, sz)
|
||||
elt_ty = x.aval.update(shape=x.shape[1:])
|
||||
@ -1080,7 +1111,7 @@ def matchaxis(axis_name, sz, src, dst, x, sum_match=False):
|
||||
elif type(src) == type(dst) == int:
|
||||
return moveaxis(x, src, dst)
|
||||
elif src is not_mapped and dst is not not_mapped:
|
||||
return broadcast(x, sz, canonicalize_axis(dst, np.ndim(x) + 1))
|
||||
return broadcast(x, sz, canonicalize_axis(dst, np.ndim(x) + 1), mesh_axis)
|
||||
elif dst is not_mapped and sum_match:
|
||||
return x.sum(src)
|
||||
else:
|
||||
|
@ -291,7 +291,8 @@ def _for_vmap(axis_data, args, dims, *,
|
||||
batched = map(operator.or_, batched, out_batched)
|
||||
else:
|
||||
raise Exception("Invalid fixpoint")
|
||||
args = [batching.broadcast(x, axis_data.size, 0) if now_bat and not was_bat
|
||||
args = [batching.broadcast(x, axis_data.size, 0, axis_data.explicit_mesh_axis)
|
||||
if now_bat and not was_bat
|
||||
else batching.moveaxis(x, d, 0) if now_bat else x
|
||||
for x, d, was_bat, now_bat in zip(args, dims, init_batched, batched)]
|
||||
batched_jaxpr_, _ = batching.batch_jaxpr(
|
||||
|
@ -951,7 +951,8 @@ def _scan_batching_rule(axis_data, args,
|
||||
consts_bdims, init_bdims, xs_bdims = split_list(dims, [num_consts, num_carry])
|
||||
new_consts = [batching.moveaxis(x, d, 0) if d is not batching.not_mapped and d != 0
|
||||
else x for x, d in zip(consts, consts_bdims)]
|
||||
new_init = [batching.broadcast(x, axis_data.size, 0) if now_batched and not was_batched
|
||||
new_init = [batching.broadcast(x, axis_data.size, 0, axis_data.explicit_mesh_axis)
|
||||
if now_batched and not was_batched
|
||||
else batching.moveaxis(x, d, 0) if now_batched else x
|
||||
for x, d, was_batched, now_batched in
|
||||
zip(init, init_bdims, init_batched, carry_batched)]
|
||||
@ -1492,7 +1493,8 @@ def _while_loop_batching_rule(axis_data, args, dims, cond_nconsts, cond_jaxpr,
|
||||
new_init = []
|
||||
for x, old_axis, new_axis in zip(init, init_dims, carry_dims):
|
||||
if old_axis is batching.not_mapped and new_axis is not batching.not_mapped:
|
||||
new_init.append(batching.broadcast(x, axis_data.size, new_axis))
|
||||
new_init.append(batching.broadcast(x, axis_data.size, new_axis,
|
||||
axis_data.explicit_mesh_axis))
|
||||
elif old_axis is batching.not_mapped and new_axis is batching.not_mapped:
|
||||
new_init.append(x)
|
||||
else:
|
||||
|
@ -465,7 +465,8 @@ def _linear_solve_batching_rule(axis_data, args, dims, const_lengths, jaxprs):
|
||||
]
|
||||
# Broadcast out b if necessary
|
||||
new_b = [
|
||||
batching.broadcast(x, axis_data.size, 0) if now_bat and not was_bat else
|
||||
batching.broadcast(x, axis_data.size, 0, axis_data.explicit_mesh_axis)
|
||||
if now_bat and not was_bat else
|
||||
batching.moveaxis(x, d, 0) if now_bat and d != 0 else x
|
||||
for x, d, was_bat, now_bat in zip(b, b_dims, orig_b_bat, b_bat)
|
||||
]
|
||||
|
@ -64,7 +64,6 @@ from jax._src.lax.utils import (
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.mesh import get_abstract_mesh
|
||||
from jax._src.sharding_impls import (PmapSharding, NamedSharding,
|
||||
PartitionSpec as P, canonicalize_sharding)
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape
|
||||
@ -1932,13 +1931,13 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
|
||||
dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
|
||||
fill_value = _convert_element_type(fill_value, dtype, weak_type)
|
||||
if (sharding is not None and not isinstance(sharding, PmapSharding) and
|
||||
isinstance(fill_value, array.ArrayImpl) and sharding.is_concrete):
|
||||
isinstance(fill_value, array.ArrayImpl) and sharding._is_concrete):
|
||||
broadcast_shape = sharding.shard_shape(shape)
|
||||
shard = broadcast(fill_value, broadcast_shape)
|
||||
return array.make_array_from_callback(shape, sharding, lambda _: shard)
|
||||
|
||||
if (config.sharding_in_types.value and sharding is not None and
|
||||
not sharding.is_concrete):
|
||||
not sharding._is_concrete):
|
||||
return broadcast(fill_value, shape, sharding=sharding)
|
||||
else:
|
||||
return broadcast(fill_value, shape)
|
||||
@ -3185,7 +3184,7 @@ def _convert_element_type_sharding_rule(operand, *, new_dtype, weak_type,
|
||||
sharding):
|
||||
if sharding is None:
|
||||
return operand.sharding
|
||||
if sharding.is_concrete:
|
||||
if sharding._is_concrete:
|
||||
if isinstance(sharding, NamedSharding):
|
||||
return NamedSharding(sharding.mesh.abstract_mesh, sharding.spec)
|
||||
else:
|
||||
@ -3274,7 +3273,7 @@ convert_element_type_p = Primitive('convert_element_type')
|
||||
def _convert_element_type_bind_with_trace(trace, args, params):
|
||||
sharding = params['sharding']
|
||||
operand = core.Primitive.bind_with_trace(convert_element_type_p, trace, args, params)
|
||||
if sharding is not None and sharding.is_concrete:
|
||||
if sharding is not None and sharding._is_concrete:
|
||||
with core.set_current_trace(trace):
|
||||
operand = pjit.with_sharding_constraint(operand, sharding)
|
||||
return operand
|
||||
@ -3703,6 +3702,7 @@ def _dot_batch_rule(
|
||||
unpack_args,
|
||||
unpack_dims,
|
||||
invoke_prim,
|
||||
axis_data,
|
||||
batched_args,
|
||||
batch_dims,
|
||||
*,
|
||||
@ -3737,13 +3737,15 @@ def _dot_batch_rule(
|
||||
rhs_shape = batching.bdim_as_shape(rbd, rhs.shape)
|
||||
else:
|
||||
rhs_shape = np.shape(rhs)
|
||||
|
||||
result_batch_dim = batching.shape_as_bdim(
|
||||
result_stack_dim,
|
||||
_dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers))
|
||||
|
||||
if out_sharding is not None:
|
||||
cur_mesh = get_abstract_mesh()
|
||||
if cur_mesh._are_all_axes_auto or cur_mesh._are_all_axes_manual:
|
||||
out_sharding = None
|
||||
else:
|
||||
raise NotImplementedError("vmap with out_sharding is not supported. "
|
||||
"Please open an issue.")
|
||||
out_sharding = batching.get_sharding_for_vmap(
|
||||
axis_data, out_sharding, result_batch_dim)
|
||||
|
||||
batched_out = invoke_prim(
|
||||
lhs,
|
||||
rhs,
|
||||
@ -3752,9 +3754,6 @@ def _dot_batch_rule(
|
||||
preferred_element_type=preferred_element_type,
|
||||
out_sharding=out_sharding,
|
||||
)
|
||||
result_batch_dim = batching.shape_as_bdim(
|
||||
result_stack_dim,
|
||||
_dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers))
|
||||
return batched_out, result_batch_dim
|
||||
|
||||
|
||||
@ -3919,7 +3918,8 @@ _dot_general_batch_rule = functools.partial(
|
||||
_dot_general_batch_unpack_dims,
|
||||
dot_general,
|
||||
)
|
||||
batching.primitive_batchers[dot_general_p] = _dot_general_batch_rule
|
||||
batching.fancy_primitive_batchers[dot_general_p] = _dot_general_batch_rule
|
||||
batching.skippable_batchers[dot_general_p] = lambda _: ()
|
||||
pe.padding_rules[dot_general_p] = _dot_general_padding_rule
|
||||
core.pp_eqn_rules[dot_general_p] = _dot_general_pp_rule
|
||||
batching.ragged_prop_rules[dot_general_p] = _dot_general_ragged_prop_rule
|
||||
@ -4261,6 +4261,7 @@ def _ragged_dot_invoke_prim(
|
||||
|
||||
|
||||
def _ragged_dot_batch_rule(
|
||||
axis_data,
|
||||
batched_args,
|
||||
batch_dims,
|
||||
*,
|
||||
@ -4274,6 +4275,7 @@ def _ragged_dot_batch_rule(
|
||||
_ragged_dot_batch_unpack_args,
|
||||
_ragged_dot_batch_unpack_dims,
|
||||
invoke,
|
||||
axis_data,
|
||||
batched_args,
|
||||
batch_dims,
|
||||
dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS,
|
||||
@ -4288,7 +4290,8 @@ ragged_dot_p = standard_primitive(_ragged_dot_shape_rule,
|
||||
ragged_dot_p.def_impl(partial(dispatch.apply_primitive, ragged_dot_p))
|
||||
ad.primitive_jvps[ragged_dot_p] = _ragged_dot_jvp_rule
|
||||
ad.primitive_transposes[ragged_dot_p] = _ragged_dot_transpose_rule
|
||||
batching.primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule
|
||||
batching.fancy_primitive_batchers[ragged_dot_p] = _ragged_dot_batch_rule
|
||||
batching.skippable_batchers[ragged_dot_p] = lambda _: ()
|
||||
|
||||
def _ragged_dot_impl(
|
||||
lhs: Array,
|
||||
@ -4393,7 +4396,7 @@ def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape,
|
||||
return ([expand_dims(_reduce_sum(ct, axes), unit_dims)] +
|
||||
[None] * len(dyn_shape))
|
||||
|
||||
def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
|
||||
def _broadcast_in_dim_batch_rule(axis_data, batched_args, batch_dims, shape,
|
||||
broadcast_dimensions, sharding):
|
||||
# `dyn_shape` is the dynamic portion of the target shape. `shape`
|
||||
# is the target shape, with `None` for dynamic sections.
|
||||
@ -4407,13 +4410,11 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
|
||||
if operand_bdim is not None:
|
||||
if isinstance(operand_bdim, RaggedAxis):
|
||||
stacked_axis = operand_bdim.stacked_axis
|
||||
else:
|
||||
stacked_axis = operand_bdim
|
||||
new_operand = batching.moveaxis(operand, stacked_axis, 0)
|
||||
if isinstance(operand_bdim, RaggedAxis):
|
||||
stacked_size = operand_bdim.size
|
||||
else:
|
||||
stacked_axis = operand_bdim
|
||||
stacked_size = operand.shape[stacked_axis]
|
||||
new_operand = batching.moveaxis(operand, stacked_axis, 0)
|
||||
new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
|
||||
else:
|
||||
new_operand = operand
|
||||
@ -4440,12 +4441,10 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
|
||||
assert len(sizes) == stacked_size, msg
|
||||
dyn_limits.append(bound)
|
||||
new_shape = (stacked_size,) + _merge_dyn_shape(shape, dyn_limits)
|
||||
|
||||
if sharding is not None:
|
||||
if sharding.mesh._are_all_axes_auto or sharding.mesh._are_all_axes_manual:
|
||||
sharding = None
|
||||
else:
|
||||
raise NotImplementedError('Implement sharding support for '
|
||||
'broadcast_in_dim_batch_rule')
|
||||
sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0)
|
||||
|
||||
result = broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions,
|
||||
sharding=sharding)
|
||||
out_ragged_axes = [idx+1 for idx, s in enumerate(shape) if s is None]
|
||||
@ -4569,7 +4568,8 @@ broadcast_in_dim_p = standard_primitive(
|
||||
broadcast_in_dim_p.def_abstract_eval(_broadcast_in_dim_abstract_eval)
|
||||
ad.primitive_jvps[broadcast_in_dim_p] = _broadcast_in_dim_jvp_rule
|
||||
ad.primitive_transposes[broadcast_in_dim_p] = _broadcast_in_dim_transpose_rule
|
||||
batching.primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
|
||||
batching.fancy_primitive_batchers[broadcast_in_dim_p] = _broadcast_in_dim_batch_rule
|
||||
batching.skippable_batchers[broadcast_in_dim_p] = lambda _: ()
|
||||
pe.forwarding_rules[broadcast_in_dim_p] = _broadcast_in_dim_fwd_rule
|
||||
pe.custom_partial_eval_rules[broadcast_in_dim_p] = _broadcast_in_dim_partial_eval
|
||||
pe.custom_staging_rules[broadcast_in_dim_p] = _broadcast_in_dim_staging_rule
|
||||
@ -5130,18 +5130,17 @@ def _reshape_transpose_rule(t, operand, *, new_sizes, dimensions, sharding):
|
||||
sharding=t_s),
|
||||
np.argsort(dimensions))]
|
||||
|
||||
def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions,
|
||||
sharding):
|
||||
if sharding is not None:
|
||||
if sharding.mesh._are_all_axes_manual or sharding.mesh._are_all_axes_auto:
|
||||
sharding = None
|
||||
else:
|
||||
raise NotImplementedError('reshape batch sharding support')
|
||||
def _reshape_batch_rule(axis_data, batched_args, batch_dims, *, new_sizes,
|
||||
dimensions, sharding):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
operand = batching.moveaxis(operand, bdim, 0)
|
||||
if dimensions is not None:
|
||||
dimensions = (0,) + tuple(np.add(1, dimensions))
|
||||
|
||||
if sharding is not None:
|
||||
sharding = batching.get_sharding_for_vmap(axis_data, sharding, 0)
|
||||
|
||||
out = reshape(operand, operand.shape[:1] + new_sizes, dimensions,
|
||||
sharding=sharding)
|
||||
return out, 0
|
||||
@ -5171,7 +5170,8 @@ def _reshape_staging_rule(
|
||||
reshape_p = standard_primitive(_reshape_shape_rule, _reshape_dtype_rule,
|
||||
'reshape', sharding_rule=_reshape_sharding_rule)
|
||||
ad.deflinear2(reshape_p, _reshape_transpose_rule)
|
||||
batching.primitive_batchers[reshape_p] = _reshape_batch_rule
|
||||
batching.fancy_primitive_batchers[reshape_p] = _reshape_batch_rule
|
||||
batching.skippable_batchers[reshape_p] = lambda _: ()
|
||||
mlir.register_lowering(reshape_p, _reshape_lower)
|
||||
core.custom_typechecks[reshape_p] = _reshape_typecheck_rule
|
||||
pe.custom_staging_rules[reshape_p] = _reshape_staging_rule
|
||||
|
@ -1604,7 +1604,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
|
||||
'Please see the jax.Array migration guide for more information '
|
||||
'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. '
|
||||
f'Got arg shape: {arg.shape}, arg value: {arg}')
|
||||
if not isinstance(arg_s, UnspecifiedValue) and arg_s.is_concrete:
|
||||
if not isinstance(arg_s, UnspecifiedValue) and arg_s._is_concrete:
|
||||
# jax.jit does not allow resharding across different memory kinds even
|
||||
# if the argument is uncommitted. Use jax.device_put for those cases,
|
||||
# either outside or inside jax.jit.
|
||||
@ -2763,10 +2763,8 @@ def _mesh_cast_batcher(axis_data, vals_in, dims_in, dst_sharding):
|
||||
assert axis_data.spmd_name is None
|
||||
x, = vals_in
|
||||
d, = dims_in
|
||||
|
||||
val = None
|
||||
new_spec = PartitionSpec(*util.tuple_insert(dst_sharding.spec, d, val))
|
||||
vmapped_dst_sharding = NamedSharding(dst_sharding.mesh, new_spec)
|
||||
vmapped_dst_sharding = batching.get_sharding_for_vmap(
|
||||
axis_data, dst_sharding, d)
|
||||
y = mesh_cast_p.bind(x, dst_sharding=vmapped_dst_sharding)
|
||||
return y, d
|
||||
batching.fancy_primitive_batchers[mesh_cast_p] = _mesh_cast_batcher
|
||||
@ -2814,6 +2812,17 @@ def _reshard_hlo_lowering(ctx, x_node, *, dst_sharding):
|
||||
return [mlir.lower_sharding_under_shit(ctx, x_node, aval_out, proto)]
|
||||
mlir.register_lowering(reshard_p, _reshard_hlo_lowering)
|
||||
|
||||
def _reshard_batcher(axis_data, vals_in, dims_in, dst_sharding):
|
||||
assert axis_data.spmd_name is None
|
||||
x, = vals_in
|
||||
d, = dims_in
|
||||
vmapped_dst_sharding = batching.get_sharding_for_vmap(
|
||||
axis_data, dst_sharding, d)
|
||||
y = reshard_p.bind(x, dst_sharding=vmapped_dst_sharding)
|
||||
return y, d
|
||||
batching.fancy_primitive_batchers[reshard_p] = _reshard_batcher
|
||||
batching.skippable_batchers[reshard_p] = lambda _: ()
|
||||
|
||||
# -------------------- auto and user mode -------------------------
|
||||
|
||||
def _get_new_mesh(axes: str | tuple[str, ...] | None,
|
||||
|
@ -139,7 +139,7 @@ class Sharding:
|
||||
# Default implementations below that all subclasses will inherit.
|
||||
|
||||
@property
|
||||
def is_concrete(self) -> bool:
|
||||
def _is_concrete(self) -> bool:
|
||||
return True
|
||||
|
||||
@functools.cached_property
|
||||
|
@ -392,7 +392,7 @@ class NamedSharding(jsharding.Sharding):
|
||||
return not self.mesh.is_multi_process
|
||||
|
||||
@property
|
||||
def is_concrete(self) -> bool:
|
||||
def _is_concrete(self) -> bool:
|
||||
if isinstance(self.mesh, mesh_lib.AbstractMesh):
|
||||
return False
|
||||
return True
|
||||
|
@ -376,8 +376,9 @@ class AbstractRef(core.AbstractValue):
|
||||
def _map_ref(size, axis, ref_aval):
|
||||
return AbstractRef(core.mapped_aval(size, axis, ref_aval.inner_aval))
|
||||
|
||||
def _unmap_ref(size, axis, ref_aval):
|
||||
return AbstractRef(core.unmapped_aval(size, axis, ref_aval.inner_aval))
|
||||
def _unmap_ref(size, axis, explicit_mesh_axis, ref_aval):
|
||||
return AbstractRef(core.unmapped_aval(
|
||||
size, axis, ref_aval.inner_aval, explicit_mesh_axis))
|
||||
|
||||
core.aval_mapping_handlers[AbstractRef] = (_map_ref, _unmap_ref)
|
||||
|
||||
|
@ -1449,7 +1449,8 @@ def _shard_map_batch(
|
||||
new_in_names = [{**ns, d:spmd_axis_name} if d is not batching.not_mapped
|
||||
else ns for ns, d in zip(new_in_names, in_dims)]
|
||||
new_size = trace.axis_data.size // prod(mesh.shape[n] for n in spmd_axis_name)
|
||||
new_axis_data = batching.AxisData(trace.axis_data.name, new_size, trace.axis_data.spmd_name)
|
||||
new_axis_data = batching.AxisData(trace.axis_data.name, new_size,
|
||||
trace.axis_data.spmd_name, None)
|
||||
else:
|
||||
new_axis_data = trace.axis_data
|
||||
fun, out_dims = batching.batch_subtrace(fun, trace.tag, new_axis_data, tuple(in_dims))
|
||||
|
@ -4912,7 +4912,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.parameters([True, False])
|
||||
@jtu.with_user_mesh((4,), ('x',))
|
||||
def test_dot_general_out_sharding(self, jit, mesh):
|
||||
def test_dot_general_out_sharding(self, use_jit, mesh):
|
||||
np_inp1 = np.arange(16.).reshape(8, 2)
|
||||
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None)))
|
||||
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x')))
|
||||
@ -4922,7 +4922,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out.aval.sharding.spec, P('x', None))
|
||||
return jnp.sum(out)
|
||||
|
||||
if jit:
|
||||
if use_jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
out = f(arr1, arr2)
|
||||
@ -4939,7 +4939,7 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out[0].sharding, arr1.sharding)
|
||||
self.assertEqual(out[1].sharding, arr2.sharding)
|
||||
|
||||
if jit:
|
||||
if use_jit:
|
||||
jitted_grad = jax.jit(jax.grad(f, argnums=(0, 1)))
|
||||
out = jitted_grad(arr1, arr2)
|
||||
self.assertEqual(out[0].sharding, arr1.sharding)
|
||||
@ -6226,6 +6226,18 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
out = jax.jit(jax.grad(g))(arr)
|
||||
self.assertEqual(out.sharding, arr.sharding)
|
||||
|
||||
def f_vmap(x):
|
||||
self.assertEqual(x.aval.sharding.spec, P('y'))
|
||||
y = reshard(x, P(None))
|
||||
self.assertEqual(y.aval.sharding.spec, P(None))
|
||||
return y
|
||||
|
||||
out = jax.vmap(f_vmap)(arr)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
|
||||
|
||||
out = jax.jit(jax.vmap(jax.jit(f_vmap)))(arr)
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
|
||||
|
||||
@jax.jit
|
||||
def h(x):
|
||||
with use_auto_axes('x'):
|
||||
@ -6394,6 +6406,121 @@ class ShardingInTypesTest(jtu.JaxTestCase):
|
||||
lowered_text = f.lower(aval, aval2).as_text()
|
||||
self.assertNotIn("mhlo.sharding", lowered_text)
|
||||
|
||||
@parameterized.parameters(True, False)
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_mul_vmap(self, use_jit, mesh):
|
||||
np_inp = np.arange(16.).reshape(8, 2)
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
arr = jax.device_put(np_inp, s)
|
||||
|
||||
def f(x):
|
||||
self.assertEqual(x.aval.sharding.spec, P(s.spec[1]))
|
||||
x = x * 2
|
||||
self.assertEqual(x.aval.sharding.spec, P(s.spec[1]))
|
||||
x = x * x
|
||||
self.assertEqual(x.aval.sharding.spec, P(s.spec[1]))
|
||||
return x
|
||||
|
||||
if use_jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
f = jax.vmap(f)
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))
|
||||
|
||||
out = jax.jit(f)(arr)
|
||||
self.assertEqual(out.sharding, s)
|
||||
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))
|
||||
|
||||
def g(x):
|
||||
return jnp.sum(f(x))
|
||||
|
||||
out = jax.grad(g)(arr)
|
||||
self.assertEqual(out.sharding, arr.sharding)
|
||||
|
||||
out = jax.jit(jax.grad(g))(arr)
|
||||
self.assertEqual(out.sharding, arr.sharding)
|
||||
|
||||
@parameterized.parameters(True, False)
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_dot_general_vmap(self, use_jit, mesh):
|
||||
np_inp1 = np.arange(16.).reshape(4, 2, 2)
|
||||
np_inp2 = np.arange(16.).reshape(2, 4, 2)
|
||||
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None, 'y')))
|
||||
arr2 = jax.device_put(np_inp2, NamedSharding(mesh, P(None, 'x', 'y')))
|
||||
|
||||
def f(x, y):
|
||||
return jnp.einsum('xy,yz->xz', x, y, out_sharding=P(None, 'y'))
|
||||
|
||||
if use_jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
f = jax.vmap(f, in_axes=(0, 1), out_axes=2)
|
||||
|
||||
out = f(arr1, arr2)
|
||||
self.assertEqual(out.shape, (2, 2, 4))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))
|
||||
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_reshape_vmap(self, mesh):
|
||||
np_inp = np.arange(16).reshape(2, 8)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P(None, 'x')))
|
||||
|
||||
def f(x):
|
||||
y = lax.reshape(x, (1, 2), sharding=P(None, 'y'))
|
||||
y = y * 2
|
||||
self.assertEqual(y.aval.sharding.spec, P(None, 'y'))
|
||||
return y
|
||||
|
||||
out = jax.jit(jax.vmap(f, in_axes=1))(arr)
|
||||
self.assertEqual(out.shape, (8, 1, 2))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, 'y')))
|
||||
|
||||
@parameterized.parameters(True, False)
|
||||
@jtu.with_user_mesh((2, 2), ('x', 'y'))
|
||||
def test_shit_vmap_error_check(self, use_jit, mesh):
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', None)))
|
||||
arr2 = jax.device_put(np_inp.T, NamedSharding(mesh, P(None, 'y')))
|
||||
|
||||
def f(x, y):
|
||||
return x @ y
|
||||
|
||||
if use_jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Mapped away dimension of inputs passed to vmap should be sharded "
|
||||
"the same"):
|
||||
jax.vmap(f, in_axes=(0, 1))(arr, arr2)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Mapped away dimension of inputs passed to vmap should be sharded the'
|
||||
' same'):
|
||||
jax.jit(jax.vmap(f, in_axes=(0, 1)))(arr, arr2)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Only one of spmd_axis_name or arrays sharded on.*spmd_axis_name"):
|
||||
jax.vmap(f, spmd_axis_name='y')(arr, arr)
|
||||
|
||||
@jtu.with_user_mesh((2,), ('x',))
|
||||
def test_unmapped_last_vmap(self, mesh):
|
||||
np_inp = np.arange(8)
|
||||
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x',)))
|
||||
|
||||
@partial(jax.vmap, out_axes=-1)
|
||||
def f(x):
|
||||
return np.zeros((4,))
|
||||
|
||||
out = f(arr)
|
||||
self.assertEqual(out.shape, (4, 8))
|
||||
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'x')))
|
||||
|
||||
|
||||
@jtu.pytest_mark_if_available('multiaccelerator')
|
||||
class PJitErrorTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user