mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #13108 from mattjj:djax-vmap2
PiperOrigin-RevId: 486534673
This commit is contained in:
commit
8d59b0d47a
@ -45,6 +45,7 @@ from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters.batching import ConcatAxis
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src import util
|
||||
from jax._src.util import (cache, prod, safe_zip, safe_map, canonicalize_axis,
|
||||
@ -2597,6 +2598,21 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
precision,
|
||||
preferred_element_type: Optional[DTypeLike]):
|
||||
lhs, rhs = batched_args
|
||||
lbd, rbd = batch_dims
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
if (type(lbd) is type(rbd) is ConcatAxis and
|
||||
lbd.axis in lhs_contract and rbd.axis in rhs_contract):
|
||||
# first handle any other part of the dot with these as batch dims
|
||||
lhs_contract_ = [d for d in lhs_contract if d != lbd.axis]
|
||||
rhs_contract_ = [d for d in rhs_contract if d != rbd.axis]
|
||||
lhs_batch_ = (lbd.axis, *lhs_batch)
|
||||
rhs_batch_ = (rbd.axis, *rhs_batch)
|
||||
new_dnums = ((lhs_contract_, rhs_contract_), (lhs_batch_, rhs_batch_))
|
||||
out = dot_general(lhs, rhs, new_dnums, precision=precision,
|
||||
preferred_element_type=preferred_element_type)
|
||||
# now a segment sum along that batch axis
|
||||
return batching.segment_sum(out, lbd.segment_lengths), 0
|
||||
|
||||
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
|
||||
(lhs.ndim, rhs.ndim), batch_dims, dimension_numbers)
|
||||
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
|
||||
@ -2617,31 +2633,52 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
|
||||
def bump_dims(dims, b):
|
||||
return tuple(np.add(dims, np.greater_equal(dims, b)))
|
||||
|
||||
if lbd is not None and rbd is not None:
|
||||
if type(lbd) is type(rbd) is int:
|
||||
# adding a batch dimension
|
||||
lhs_batch = (lbd,) + bump_dims(lhs_batch, lbd)
|
||||
rhs_batch = (rbd,) + bump_dims(rhs_batch, rbd)
|
||||
lhs_contract = bump_dims(lhs_contract, lbd)
|
||||
rhs_contract = bump_dims(rhs_contract, rbd)
|
||||
result_batch_dim = 0
|
||||
else:
|
||||
# adding a tensor product dimension
|
||||
if lbd is not None:
|
||||
other = tuple(d for d in range(lhs_ndim)
|
||||
if d not in lhs_batch and d not in lhs_contract)
|
||||
result_batch_dim = (len(lhs_batch) + sum(np.less(other, lbd)))
|
||||
lhs_batch = bump_dims(lhs_batch, lbd)
|
||||
lhs_contract = bump_dims(lhs_contract, lbd)
|
||||
elif rbd is None and type(lbd) is ConcatAxis and lbd.axis not in lhs_contract:
|
||||
if lbd.axis in lhs_batch:
|
||||
axis = int(np.sum(np.less(lhs_batch, lbd.axis)))
|
||||
else:
|
||||
other = tuple(d for d in range(rhs_ndim)
|
||||
if d not in rhs_batch and d not in rhs_contract)
|
||||
result_batch_dim = (lhs_ndim - len(lhs_contract) +
|
||||
sum(np.less(other, rbd)))
|
||||
rhs_batch = bump_dims(rhs_batch, rbd)
|
||||
rhs_contract = bump_dims(rhs_contract, rbd)
|
||||
lhs_tensor = [d for d in range(lhs_ndim)
|
||||
if d not in lhs_batch and d not in lhs_contract]
|
||||
axis = len(lhs_batch) + int(np.sum(np.less(lhs_tensor, lbd.axis)))
|
||||
result_batch_dim = ConcatAxis(axis, lbd.segment_lengths)
|
||||
elif lbd is None and type(rbd) is ConcatAxis and rbd.axis not in rhs_contract:
|
||||
if rbd.axis in rhs_batch:
|
||||
axis = int(np.sum(np.less(rhs_batch, rbd.axis)))
|
||||
else:
|
||||
rhs_tensor = [d for d in range(rhs_ndim)
|
||||
if d not in rhs_batch and d not in rhs_contract]
|
||||
axis = (lhs_ndim - len(lhs_contract) +
|
||||
int(sum(np.less(rhs_tensor, rbd.axis))))
|
||||
result_batch_dim = ConcatAxis(axis, rbd.segment_lengths)
|
||||
elif (type(lbd) is int and
|
||||
(rbd is None or type(rbd) is ConcatAxis and
|
||||
rbd.axis not in rhs_contract)):
|
||||
lhs_tensor = [d for d in range(lhs_ndim)
|
||||
if d not in lhs_batch and d not in lhs_contract]
|
||||
result_batch_dim = len(lhs_batch) + int(sum(np.less(lhs_tensor, lbd)))
|
||||
lhs_batch = bump_dims(lhs_batch, lbd)
|
||||
lhs_contract = bump_dims(lhs_contract, lbd)
|
||||
elif (type(rbd) is int and
|
||||
(lbd is None or type(lbd) is ConcatAxis and
|
||||
lbd.axis not in lhs_contract)):
|
||||
rhs_tensor = [d for d in range(rhs_ndim)
|
||||
if d not in rhs_batch and d not in rhs_contract]
|
||||
result_batch_dim = (lhs_ndim - len(lhs_contract) +
|
||||
int(sum(np.less(rhs_tensor, rbd))))
|
||||
rhs_batch = bump_dims(rhs_batch, rbd)
|
||||
rhs_contract = bump_dims(rhs_contract, rbd)
|
||||
else:
|
||||
assert False
|
||||
|
||||
new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
|
||||
return new_dimension_numbers, int(result_batch_dim)
|
||||
return new_dimension_numbers, result_batch_dim
|
||||
|
||||
def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *,
|
||||
dimension_numbers, **params):
|
||||
@ -2782,15 +2819,27 @@ def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape,
|
||||
return ([expand_dims(_reduce_sum(ct, axes), unit_dims)] +
|
||||
[None] * len(dyn_shape))
|
||||
|
||||
def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *dyn_shape, shape,
|
||||
def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
|
||||
broadcast_dimensions):
|
||||
if dyn_shape: raise NotImplementedError # TODO(mattjj)
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
new_operand = batching.moveaxis(operand, bdim, 0)
|
||||
new_shape = (operand.shape[bdim],) + shape
|
||||
new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
|
||||
return broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions), 0
|
||||
operand, *dyn_shape = batched_args
|
||||
operand_bdim, *dyn_shape_bdims = batch_dims
|
||||
if len(dyn_shape) > 1: raise NotImplementedError
|
||||
if (operand_bdim is not None and
|
||||
(not dyn_shape_bdims or dyn_shape_bdims[0] is None)):
|
||||
new_operand = batching.moveaxis(operand, operand_bdim, 0)
|
||||
new_shape = (operand.shape[operand_bdim],) + shape
|
||||
new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
|
||||
return broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions), 0
|
||||
elif (operand_bdim is None and dyn_shape_bdims and
|
||||
dyn_shape_bdims[0] is not None):
|
||||
(d,), (d_bdim,) = dyn_shape, dyn_shape_bdims # NotImplementedError above
|
||||
assert d_bdim == 0 # must be scalar in the program to be batched
|
||||
new_shape = _merge_dyn_shape(shape, (int(d.sum()),))
|
||||
out = broadcast_in_dim(operand, new_shape, broadcast_dimensions)
|
||||
idx, = (i for i, s in enumerate(shape) if s is None)
|
||||
return out, batching.ConcatAxis(idx, d)
|
||||
else:
|
||||
raise NotImplementedError # TODO(mattjj)
|
||||
|
||||
def _broadcast_in_dim_fwd_rule(eqn):
|
||||
v, *dyn = eqn.invars
|
||||
@ -4471,6 +4520,13 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension):
|
||||
mlir.i64_attr(dimension)).results
|
||||
mlir.register_lowering(iota_p, _iota_lower)
|
||||
|
||||
def _iota_batching_rule(in_vals, in_dims, *, dtype, shape, dimension):
|
||||
(segment_lengths,), (ax,) = in_vals, in_dims
|
||||
shapes = [_merge_dyn_shape(shape, (d,)) for d in segment_lengths]
|
||||
iotas = [broadcasted_iota(dtype, s, dimension) for s in shapes]
|
||||
return concatenate(iotas, dimension), batching.ConcatAxis(ax, segment_lengths)
|
||||
batching.primitive_batchers[iota_p] = _iota_batching_rule
|
||||
|
||||
def _iota_pp_rule(eqn, context, settings):
|
||||
printed_params = {}
|
||||
if len(eqn.params['shape']) > 1:
|
||||
|
34
jax/core.py
34
jax/core.py
@ -2195,35 +2195,41 @@ def unmapped_aval(size: AxisSize, axis_name, axis: Optional[int],
|
||||
else:
|
||||
raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}")
|
||||
|
||||
def _map_shaped_array(size: int, axis: Optional[int], aval: ShapedArray
|
||||
) -> ShapedArray:
|
||||
|
||||
def _map_shaped_array(
|
||||
size: int, axis: Optional[int], aval: ShapedArray) -> ShapedArray:
|
||||
assert axis is None or aval.shape[axis] == size
|
||||
# TODO: Extend the named shape
|
||||
if axis is None: return aval
|
||||
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
|
||||
named_shape=aval.named_shape, weak_type=aval.weak_type)
|
||||
|
||||
def _unmap_shaped_array(size: int, axis_name, axis: Optional[int],
|
||||
aval: ShapedArray) -> ShapedArray:
|
||||
def _unmap_shaped_array(
|
||||
size: int, axis_name: AxisName, axis: Optional[int], aval: ShapedArray
|
||||
) -> ShapedArray:
|
||||
named_shape = dict(aval.named_shape)
|
||||
# TODO: Make this mandatory
|
||||
named_shape.pop(axis_name, None)
|
||||
named_shape.pop(axis_name, None) # TODO: make this mandatory
|
||||
if axis is None: return aval.update(named_shape=named_shape)
|
||||
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
|
||||
named_shape=named_shape, weak_type=aval.weak_type)
|
||||
elif type(axis) is int:
|
||||
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
|
||||
named_shape=named_shape, weak_type=aval.weak_type)
|
||||
else: raise TypeError(axis)
|
||||
|
||||
def _map_dshaped_array(size: AxisSize, axis: Optional[int],
|
||||
aval: DShapedArray) -> DShapedArray:
|
||||
def _map_dshaped_array(
|
||||
size: AxisSize, axis: Optional[int], aval: DShapedArray) -> DShapedArray:
|
||||
if axis is None: return aval
|
||||
return DShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
|
||||
aval.weak_type)
|
||||
|
||||
def _unmap_dshaped_array(
|
||||
size: AxisSize, axis_name, axis: Optional[int],
|
||||
aval: DShapedArray) -> DShapedArray:
|
||||
size: AxisSize, axis_name: AxisName, axis: Optional[int], aval: DShapedArray
|
||||
) -> DShapedArray:
|
||||
if axis is None: return aval
|
||||
return DShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
|
||||
weak_type=aval.weak_type)
|
||||
elif type(axis) is int:
|
||||
return DShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
|
||||
weak_type=aval.weak_type)
|
||||
else:
|
||||
raise TypeError(axis)
|
||||
|
||||
AvalMapHandlerPair = Tuple[Callable, Callable]
|
||||
aval_mapping_handlers: Dict[Type, AvalMapHandlerPair] = {
|
||||
|
@ -11,7 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Dict, Hashable, Iterable, Optional, Sequence,
|
||||
Set, Tuple, Type, Union)
|
||||
@ -23,7 +26,8 @@ from jax.config import config
|
||||
from jax import core
|
||||
from jax.core import raise_to_shaped, Trace, Tracer
|
||||
from jax._src import source_info_util
|
||||
from jax._src.tree_util import tree_unflatten, tree_flatten
|
||||
from jax._src.tree_util import (tree_unflatten, tree_flatten,
|
||||
register_pytree_node)
|
||||
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_p, Zero)
|
||||
from jax import linear_util as lu
|
||||
@ -33,31 +37,122 @@ from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, wrap_name,
|
||||
weakref_lru_cache)
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
Array = Any
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
|
||||
# Piles
|
||||
|
||||
# i:(Fin 3) => f32[[3, 1, 4].i]
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class PileTy:
|
||||
binder: core.Var
|
||||
length: Union[int, Tracer, core.Var]
|
||||
elt_ty: core.DShapedArray
|
||||
def __repr__(self) -> str:
|
||||
return f'Var{id(self.binder)}:{self.length} => {self.elt_ty}'
|
||||
replace = dataclasses.replace
|
||||
|
||||
# [3, 1, 4].i
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class IndexedAxisSize:
|
||||
idx: core.Var
|
||||
lengths: Union[Array, core.Var, Tracer]
|
||||
def __repr__(self) -> str:
|
||||
return f'{str(self.lengths)}.Var{id(self.idx)}'
|
||||
replace = dataclasses.replace
|
||||
|
||||
# Pile(aval=a:3 => f32[[3 1 4].a],
|
||||
# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Pile:
|
||||
aval: PileTy
|
||||
data: Array
|
||||
|
||||
def _pile_flatten(pile):
|
||||
lengths = []
|
||||
new_shape = [lengths.append(d.lengths) or d.replace(lengths=len(lengths))
|
||||
if type(d) is IndexedAxisSize else d
|
||||
for d in pile.aval.elt_ty.shape]
|
||||
elt_ty = pile.aval.elt_ty.update(shape=tuple(new_shape))
|
||||
aval = pile.aval.replace(elt_ty=elt_ty)
|
||||
return (lengths, pile.data), aval
|
||||
|
||||
def _pile_unflatten(aval, x):
|
||||
lengths, data = x
|
||||
new_shape = [d.replace(lengths=lengths[d.lengths - 1])
|
||||
if type(d) is IndexedAxisSize else d
|
||||
for d in aval.elt_ty.shape]
|
||||
elt_ty = aval.elt_ty.update(shape=tuple(new_shape))
|
||||
aval = aval.replace(elt_ty=elt_ty)
|
||||
return Pile(aval, data)
|
||||
|
||||
register_pytree_node(Pile, _pile_flatten, _pile_unflatten)
|
||||
|
||||
def _pile_result(axis_size, axis, segment_lens, x):
|
||||
binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32')))
|
||||
shape = list(x.shape)
|
||||
shape[axis] = IndexedAxisSize(binder, segment_lens)
|
||||
elt_ty = core.DShapedArray(tuple(shape), x.dtype, x.weak_type)
|
||||
return Pile(PileTy(binder, axis_size, elt_ty), x)
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ConcatAxis:
|
||||
axis: int
|
||||
segment_lengths: Array
|
||||
|
||||
|
||||
def _update_annotation(
|
||||
f: lu.WrappedFun, orig_type: Optional[core.InputType],
|
||||
axis_size: core.AxisSize, axis_name: core.AxisName,
|
||||
explicit_in_dims: Sequence[Optional[int]]) -> lu.WrappedFun:
|
||||
explicit_in_dims: Sequence[Optional[Union[int, ConcatAxis]]],
|
||||
segment_lens: Sequence[Array],
|
||||
) -> lu.WrappedFun:
|
||||
if orig_type is None: return f
|
||||
# By convention, `explicit_in_dims` only accounts for explicit arguments.
|
||||
assert len(explicit_in_dims) == sum(explicit for _, explicit in orig_type)
|
||||
# We add a batch dim to each mapped argument type. If `axis_size` is dynamic
|
||||
# (i.e. a Tracer) the added batch dim size is a DBIdx and we add a new leading
|
||||
# implicit argument and increment all other DBIdx.
|
||||
new_arg = isinstance(axis_size, Tracer)
|
||||
sz = core.DBIdx(0) if new_arg else axis_size
|
||||
def unmap(d, a):
|
||||
if isinstance(a, core.DShapedArray):
|
||||
a = a.update(shape=tuple(core.DBIdx(d.val + new_arg)
|
||||
if type(d) is core.DBIdx else d for d in a.shape))
|
||||
return core.unmapped_aval(sz, axis_name, d, a)
|
||||
in_dims = iter(explicit_in_dims)
|
||||
in_type = [(unmap(next(in_dims), a), explicit) if explicit else (a, explicit)
|
||||
for a, explicit in orig_type]
|
||||
if new_arg: in_type = [(axis_size.aval, False), *in_type] # type: ignore
|
||||
return lu.annotate(f, tuple(in_type))
|
||||
# We need to:
|
||||
# * if `axis_size` is dynamic, add a new implicit binder (type) for it;
|
||||
# * for each element of `segment_lengths`, add a new explicit binder for it;
|
||||
# * drop other implicit binders, replacing DBIdx which refer to them with
|
||||
# Name objects;
|
||||
# * for each (aval, in_dim) pair: if int-valued in_dim, add batch axis (int
|
||||
# size if `axis_size` is int, otherwise Name); if ConcatAxis-valued in_dim,
|
||||
# add batch axis (int if corresponding segment_lengths is concrete, Name if
|
||||
# not);
|
||||
# * generate full in_type with implicit args too.
|
||||
|
||||
class Name:
|
||||
def __init__(self, a): self.a = a
|
||||
names = [Name(a) for a, _ in orig_type]
|
||||
avals = [a.update(shape=tuple(names[d.val] if type(d) is pe.DBIdx else d # type: ignore
|
||||
for d in a.shape))
|
||||
if type(a) is core.DShapedArray else a for a, e in orig_type if e]
|
||||
|
||||
new_avals = [core.raise_to_shaped(core.get_aval(s)) for s in segment_lens]
|
||||
sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size
|
||||
for a, d in zip(avals, explicit_in_dims):
|
||||
if isinstance(d, ConcatAxis):
|
||||
s = segment_lens[d.segment_lengths.val]
|
||||
if isinstance(core.get_aval(s), core.ConcreteArray):
|
||||
shape = list(a.shape) # type: ignore
|
||||
shape[d.axis] = int(s.sum()) # specialize on shape if we can
|
||||
new_avals.append(a.update(shape=tuple(shape)))
|
||||
else:
|
||||
new_avals.append(a)
|
||||
else:
|
||||
new_avals.append(core.unmapped_aval(sz, axis_name, d, a)) # type: ignore
|
||||
|
||||
mentioned = {d for a in new_avals if type(a) is core.DShapedArray
|
||||
for d in a.shape if type(d) is Name}
|
||||
expl_names = set(map(Name, new_avals))
|
||||
impl_names = mentioned - expl_names # type: ignore
|
||||
impl_part = [(n.a, False) for n in impl_names] # type: ignore
|
||||
name_map = {n: pe.DBIdx(i) for i, n in enumerate((*impl_names, *expl_names))}
|
||||
expl_part = [(a.update(shape=tuple(name_map.get(d, d) for d in a.shape))
|
||||
if type(a) is core.DShapedArray else a, True) for a in new_avals]
|
||||
return lu.annotate(f, (*impl_part, *expl_part))
|
||||
|
||||
### vmappable typeclass
|
||||
|
||||
@ -65,7 +160,6 @@ Vmappable = Any
|
||||
Elt = Any
|
||||
MapSpec = Any
|
||||
AxisSize = Any
|
||||
Array = Any
|
||||
GetIdx = Callable[[], Tracer] # TODO(mattjj): revise this laziness
|
||||
ToEltHandler = Callable[[Callable, GetIdx, Vmappable, MapSpec], Elt]
|
||||
FromEltHandler = Callable[[Callable, AxisSize, Elt, MapSpec], Vmappable]
|
||||
@ -75,10 +169,16 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
|
||||
handler = to_elt_handlers.get(type(x))
|
||||
if handler:
|
||||
return handler(partial(to_elt, trace, get_idx), get_idx, x, spec)
|
||||
else:
|
||||
elif type(x) is Pile:
|
||||
(d, ias), = ((i, sz) for i, sz in enumerate(x.aval.elt_ty.shape)
|
||||
if type(sz) is IndexedAxisSize)
|
||||
return BatchTracer(trace, x.data, ConcatAxis(d, ias.lengths)) # type: ignore
|
||||
elif isinstance(spec, int) or spec is None:
|
||||
spec = spec and canonicalize_axis(spec, len(np.shape(x)))
|
||||
return (BatchTracer(trace, x, spec, source_info_util.current())
|
||||
if spec is not None else x)
|
||||
else:
|
||||
assert False
|
||||
to_elt_handlers: Dict[Type, ToEltHandler] = {}
|
||||
|
||||
def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
|
||||
@ -86,8 +186,11 @@ def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
|
||||
handler = from_elt_handlers.get(type(x))
|
||||
if handler:
|
||||
return handler(partial(from_elt, trace), axis_size, x, spec)
|
||||
x_ = trace.full_raise(x)
|
||||
val, bdim = x_.val, x_.batch_dim
|
||||
if type(bdim) is ConcatAxis:
|
||||
return _pile_result(axis_size, bdim.axis, bdim.segment_lengths, val)
|
||||
else:
|
||||
x_ = trace.full_raise(x)
|
||||
return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
|
||||
from_elt_handlers: Dict[Type, FromEltHandler] = {}
|
||||
|
||||
@ -119,7 +222,7 @@ def unregister_vmappable(data_type: Type) -> None:
|
||||
del make_iota_handlers[axis_size_type]
|
||||
|
||||
def is_vmappable(x: Any) -> bool:
|
||||
return type(x) in vmappables
|
||||
return type(x) is Pile or type(x) in vmappables
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def flatten_fun_for_vmap(in_tree, *args_flat):
|
||||
@ -133,16 +236,17 @@ def flatten_fun_for_vmap(in_tree, *args_flat):
|
||||
NotMapped = type(None)
|
||||
not_mapped = None
|
||||
|
||||
|
||||
class BatchTracer(Tracer):
|
||||
__slots__ = ['val', 'batch_dim', 'source_info']
|
||||
|
||||
def __init__(self, trace, val, batch_dim: Optional[int],
|
||||
def __init__(self, trace, val, batch_dim: Union[NotMapped, int, ConcatAxis],
|
||||
source_info: Optional[source_info_util.SourceInfo] = None):
|
||||
if config.jax_enable_checks:
|
||||
assert type(batch_dim) in (int, NotMapped)
|
||||
assert type(batch_dim) in (NotMapped, int, ConcatAxis)
|
||||
if type(batch_dim) is int:
|
||||
aval = raise_to_shaped(core.get_aval(val))
|
||||
assert batch_dim is not_mapped or 0 <= batch_dim < len(aval.shape) # type: ignore
|
||||
assert 0 <= batch_dim < len(aval.shape) # type: ignore
|
||||
self._trace = trace
|
||||
self.val = val
|
||||
self.batch_dim = batch_dim
|
||||
@ -153,7 +257,14 @@ class BatchTracer(Tracer):
|
||||
aval = raise_to_shaped(core.get_aval(self.val))
|
||||
if self.batch_dim is not_mapped:
|
||||
return aval
|
||||
return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
|
||||
elif type(self.batch_dim) is int:
|
||||
return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
|
||||
elif type(self.batch_dim) is ConcatAxis:
|
||||
shape = list(aval.shape)
|
||||
size_tracer = BatchTracer(self._trace, self.batch_dim.segment_lengths, 0)
|
||||
shape[self.batch_dim.axis] = size_tracer
|
||||
return core.DShapedArray(shape=tuple(shape), dtype=aval.dtype,
|
||||
weak_type=aval.weak_type)
|
||||
|
||||
def full_lower(self):
|
||||
if self.batch_dim is not_mapped:
|
||||
@ -170,6 +281,12 @@ class BatchTracer(Tracer):
|
||||
def _contents(self):
|
||||
return [('val', self.val), ('batch_dim', self.batch_dim)]
|
||||
|
||||
def get_referent(self):
|
||||
if self.batch_dim is None or type(self.batch_dim) is int:
|
||||
return core.get_referent(self.val)
|
||||
else: # TODO(mattjj): could handle the ConcatAxis case?
|
||||
return self
|
||||
|
||||
class BatchTrace(Trace):
|
||||
|
||||
def __init__(self, *args, axis_name, spmd_axis_name = None):
|
||||
@ -206,8 +323,9 @@ class BatchTrace(Trace):
|
||||
if self.axis_name is core.no_axis_name:
|
||||
# If axis name is `no_axis_name` we can't find it via `core.axis_name` so
|
||||
# we reconstruct it from the information we have available
|
||||
axis_size, = core.dedup_referents(x.shape[d] for x, d in zip(vals, dims)
|
||||
if d is not not_mapped)
|
||||
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
|
||||
for x, d in zip(vals, dims) if d is not not_mapped)
|
||||
axis_size, = core.dedup_referents(sizes)
|
||||
return core.AxisEnvFrame(self.axis_name, axis_size, self.main)
|
||||
return core.axis_frame(self.axis_name)
|
||||
|
||||
@ -241,14 +359,17 @@ class BatchTrace(Trace):
|
||||
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
|
||||
if all(bdim is not_mapped for bdim in dims):
|
||||
return call_primitive.bind(f, *vals, **params)
|
||||
else:
|
||||
f_, dims_out = batch_subtrace(f, self.main, dims)
|
||||
axis_size, = core.dedup_referents(x.shape[d] for x, d in zip(vals, dims)
|
||||
if d is not not_mapped)
|
||||
f_ = _update_annotation(f_, f.in_type, axis_size, self.axis_name, dims)
|
||||
vals_out = call_primitive.bind(f_, *vals, **params)
|
||||
src = source_info_util.current()
|
||||
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out())]
|
||||
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
|
||||
for x, d in zip(vals, dims) if d is not not_mapped)
|
||||
axis_size, = core.dedup_referents(sizes)
|
||||
segment_lens, dims = unpack_concat_axes(dims)
|
||||
f_, dims_out = batch_subtrace(f, self.main, tuple(dims))
|
||||
f_ = _update_annotation(f_, f.in_type, axis_size, self.axis_name, dims,
|
||||
segment_lens)
|
||||
vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params)
|
||||
vals_out, dims_out = reassemble_concat_axes(vals_out, dims_out())
|
||||
src = source_info_util.current()
|
||||
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)]
|
||||
|
||||
def post_process_call(self, call_primitive, out_tracers, params):
|
||||
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
|
||||
@ -465,15 +586,36 @@ def vtile(f_flat: lu.WrappedFun,
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def batch_subtrace(main, in_dims, *in_vals):
|
||||
# used in e.g. process_call
|
||||
trace = main.with_cur_sublevel()
|
||||
in_dims = in_dims() if callable(in_dims) else in_dims
|
||||
in_vals, in_dims = reassemble_concat_axes(in_vals, in_dims)
|
||||
in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
|
||||
if dim is not None else x for x, dim in zip(in_vals, in_dims)]
|
||||
outs = yield in_tracers, {}
|
||||
out_tracers = map(trace.full_raise, outs)
|
||||
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
||||
yield out_vals, out_dims
|
||||
segment_lens, out_dims = unpack_concat_axes(out_dims)
|
||||
yield (*segment_lens, *out_vals), out_dims
|
||||
|
||||
def unpack_concat_axes(dims):
|
||||
if not any(type(d) is ConcatAxis for d in dims):
|
||||
return [], dims
|
||||
concat_axis_map = collections.OrderedDict()
|
||||
def convert(d: ConcatAxis) -> ConcatAxis:
|
||||
_, dbidx = concat_axis_map.setdefault(
|
||||
id(core.get_referent(d.segment_lengths)),
|
||||
(d.segment_lengths, pe.DBIdx(len(concat_axis_map))))
|
||||
return ConcatAxis(d.axis, dbidx)
|
||||
new_dims = [convert(d) if isinstance(d, ConcatAxis) else d for d in dims]
|
||||
segment_lens = [s for s, _ in concat_axis_map.values()]
|
||||
return segment_lens, new_dims
|
||||
|
||||
def reassemble_concat_axes(vals, dims):
|
||||
idxs = {d.segment_lengths.val for d in dims if isinstance(d, ConcatAxis)}
|
||||
dims = [ConcatAxis(d.axis, vals[d.segment_lengths.val])
|
||||
if isinstance(d, ConcatAxis) else d for d in dims]
|
||||
vals = [x for i, x in enumerate(vals) if i not in idxs]
|
||||
return vals, dims
|
||||
|
||||
|
||||
### API for batching jaxprs
|
||||
@ -668,11 +810,34 @@ def defreducer(prim):
|
||||
def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
|
||||
bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
|
||||
if 'input_shape' in params:
|
||||
params = dict(params, input_shape=operand.shape)
|
||||
return prim.bind(operand, axes=axes, **params), bdim_out
|
||||
if isinstance(bdim, int):
|
||||
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
|
||||
bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
|
||||
if 'input_shape' in params:
|
||||
params = dict(params, input_shape=operand.shape)
|
||||
return prim.bind(operand, axes=axes, **params), bdim_out
|
||||
elif isinstance(bdim, ConcatAxis):
|
||||
if bdim.axis in axes:
|
||||
other_axes = [i for i in axes if i != bdim.axis]
|
||||
if other_axes:
|
||||
operand = prim.bind(operand, axes=other_axes, **params)
|
||||
c_axis = bdim.axis - sum(d < bdim.axis for d in other_axes)
|
||||
operand = bdim_at_front(operand, c_axis, operand.shape[c_axis])
|
||||
return segment_sum(operand, bdim.segment_lengths), 0
|
||||
else:
|
||||
raise NotImplementedError # TODO(mattjj)
|
||||
else:
|
||||
assert False
|
||||
|
||||
# TODO(mattjj): replace with jax.lax.ops.segment_sum (once it's easier to trace
|
||||
# under dynamic shapes)
|
||||
def segment_sum(operand, segment_lens):
|
||||
scat_idx = jax.numpy.cumsum(segment_lens) - segment_lens
|
||||
segment_ids = jax.numpy.cumsum(
|
||||
jax.numpy.zeros(operand.shape[0], 'int32').at[scat_idx].set(1)) - 1
|
||||
out = jax.numpy.zeros((len(segment_lens), *operand.shape[1:]),
|
||||
operand.dtype).at[segment_ids].add(operand)
|
||||
return out
|
||||
|
||||
### general utilities for manipulating axes on jaxpr types (not vmappables)
|
||||
|
||||
|
@ -2224,14 +2224,13 @@ def _add_implicit_outputs(jaxpr: Jaxpr) -> Tuple[Jaxpr, OutputType]:
|
||||
|
||||
|
||||
class TracerAsName:
|
||||
tracer: DynamicJaxprTracer
|
||||
ref: Any
|
||||
def __init__(self, tracer):
|
||||
trace = core.thread_local_state.trace_state.trace_stack.dynamic
|
||||
self.tracer = trace.with_cur_sublevel().full_raise(tracer)
|
||||
self.ref = core.get_referent(tracer)
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, TracerAsName) and self.tracer is other.tracer
|
||||
return isinstance(other, TracerAsName) and self.ref is other.ref
|
||||
def __hash__(self):
|
||||
return id(self.tracer)
|
||||
return id(self.ref)
|
||||
|
||||
def _extract_implicit_args(
|
||||
trace: DynamicJaxprTrace, in_type: Sequence[Tuple[AbstractValue, bool]],
|
||||
|
@ -245,14 +245,17 @@ def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun:
|
||||
|
||||
def _check_input_type(in_type: core.InputType) -> None:
|
||||
# Check that in_type is syntactically well-formed
|
||||
assert (type(in_type) is tuple and all(type(e) is tuple for e in in_type) and
|
||||
all(isinstance(a, core.AbstractValue) and type(b) is bool
|
||||
and not isinstance(a, core.ConcreteArray) for a, b in in_type) and
|
||||
all(isinstance(d, (int, core.DBIdx, core.DArray))
|
||||
and (not isinstance(d, core.DArray) or
|
||||
type(d.dtype) is core.bint and not d.shape)
|
||||
for a, _ in in_type if type(a) is core.DShapedArray
|
||||
for d in a.shape))
|
||||
assert type(in_type) is tuple and all(type(e) is tuple for e in in_type)
|
||||
assert all(isinstance(a, core.AbstractValue) and type(b) is bool
|
||||
and not isinstance(a, core.ConcreteArray) for a, b in in_type)
|
||||
|
||||
def valid_size(d) -> bool:
|
||||
if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0:
|
||||
return True
|
||||
return (isinstance(d, (int, core.DBIdx, core.DArray)) and
|
||||
(not isinstance(d, core.DArray) or type(d) is core.bint and not d.shape))
|
||||
assert all(valid_size(d) for a, _ in in_type if type(a) is core.DShapedArray
|
||||
for d in a.shape)
|
||||
|
||||
# Check that all DBIdx point to positions to the left of the input on which
|
||||
# they appear.
|
||||
|
@ -22,7 +22,9 @@ from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import core, lax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax.interpreters import batching
|
||||
import jax._src.lib
|
||||
from jax._src import test_util as jtu
|
||||
import jax._src.util
|
||||
@ -1437,6 +1439,52 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
y = jnp.arange(3.0) + 1
|
||||
jax.make_jaxpr(f)(x, y) # doesn't crash
|
||||
|
||||
# TODO(https://github.com/google/jax/issues/12291): Enable jax.Array
|
||||
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow",
|
||||
jax_array=False)
|
||||
class PileTest(jtu.JaxTestCase):
|
||||
|
||||
def test_internal_pile(self):
|
||||
xs = jax.vmap(lambda n: jnp.arange(n).sum())(jnp.array([3, 1, 4]))
|
||||
self.assertAllClose(xs, jnp.array([3, 0, 6]), check_dtypes=False)
|
||||
|
||||
def test_make_pile_from_dynamic_shape(self):
|
||||
# We may not want to support returning piles from vmapped functions (instead
|
||||
# preferring to have a separate API which allows piles). But for now it
|
||||
# makes for a convenient way to construct piles for the other tests!
|
||||
p = jax.vmap(partial(jnp.arange, dtype='int32'))(jnp.array([3, 1, 4]))
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[\[3 1 4\]\.Var[0-9]+\]')
|
||||
data = jnp.concatenate([jnp.arange(3), jnp.arange(1), jnp.arange(4)])
|
||||
self.assertAllClose(p.data, data, check_dtypes=False)
|
||||
|
||||
def test_pile_map_eltwise(self):
|
||||
p = jax.vmap(partial(jnp.arange, dtype='int32'))(jnp.array([3, 1, 4]))
|
||||
p = pile_map(lambda x: x ** 2)(p)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[\[3 1 4\]\.Var[0-9]+\]')
|
||||
data = jnp.concatenate([jnp.arange(3), jnp.arange(1), jnp.arange(4)]) ** 2
|
||||
self.assertAllClose(p.data, data, check_dtypes=False)
|
||||
|
||||
def test_pile_map_vector_dot(self):
|
||||
p = jax.vmap(jnp.arange)(jnp.array([3, 1, 4]))
|
||||
y = pile_map(jnp.dot)(p, p)
|
||||
self.assertAllClose(y, jnp.array([5, 0, 14]))
|
||||
|
||||
def test_pile_map_matrix_dot(self):
|
||||
sizes = jnp.array([3, 1, 4])
|
||||
p1 = jax.vmap(lambda n: jnp.ones((7, n)))(sizes)
|
||||
p2 = jax.vmap(lambda n: jnp.ones((n, 7)))(sizes)
|
||||
y = pile_map(jnp.dot)(p1, p2)
|
||||
self.assertAllClose(y, np.tile(np.array([3, 1, 4])[:, None, None], (7, 7)),
|
||||
check_dtypes=False)
|
||||
|
||||
# TODO(mattjj): could just make this vmap, just need to adjust how we infer axis
|
||||
# sizes in api.py's _mapped_axis_size to handle piles. For another day...
|
||||
def pile_map(f):
|
||||
def mapped(*piles):
|
||||
return jax.vmap(f, axis_size=piles[0].aval.length)(*piles)
|
||||
return mapped
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user