Merge pull request #13108 from mattjj:djax-vmap2

PiperOrigin-RevId: 486534673
This commit is contained in:
jax authors 2022-11-06 17:27:53 -08:00
commit 8d59b0d47a
6 changed files with 371 additions and 94 deletions

View File

@ -45,6 +45,7 @@ from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters.batching import ConcatAxis
import jax._src.pretty_printer as pp
from jax._src import util
from jax._src.util import (cache, prod, safe_zip, safe_map, canonicalize_axis,
@ -2597,6 +2598,21 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
precision,
preferred_element_type: Optional[DTypeLike]):
lhs, rhs = batched_args
lbd, rbd = batch_dims
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
if (type(lbd) is type(rbd) is ConcatAxis and
lbd.axis in lhs_contract and rbd.axis in rhs_contract):
# first handle any other part of the dot with these as batch dims
lhs_contract_ = [d for d in lhs_contract if d != lbd.axis]
rhs_contract_ = [d for d in rhs_contract if d != rbd.axis]
lhs_batch_ = (lbd.axis, *lhs_batch)
rhs_batch_ = (rbd.axis, *rhs_batch)
new_dnums = ((lhs_contract_, rhs_contract_), (lhs_batch_, rhs_batch_))
out = dot_general(lhs, rhs, new_dnums, precision=precision,
preferred_element_type=preferred_element_type)
# now a segment sum along that batch axis
return batching.segment_sum(out, lbd.segment_lengths), 0
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
(lhs.ndim, rhs.ndim), batch_dims, dimension_numbers)
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
@ -2617,31 +2633,52 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
def bump_dims(dims, b):
return tuple(np.add(dims, np.greater_equal(dims, b)))
if lbd is not None and rbd is not None:
if type(lbd) is type(rbd) is int:
# adding a batch dimension
lhs_batch = (lbd,) + bump_dims(lhs_batch, lbd)
rhs_batch = (rbd,) + bump_dims(rhs_batch, rbd)
lhs_contract = bump_dims(lhs_contract, lbd)
rhs_contract = bump_dims(rhs_contract, rbd)
result_batch_dim = 0
else:
# adding a tensor product dimension
if lbd is not None:
other = tuple(d for d in range(lhs_ndim)
if d not in lhs_batch and d not in lhs_contract)
result_batch_dim = (len(lhs_batch) + sum(np.less(other, lbd)))
lhs_batch = bump_dims(lhs_batch, lbd)
lhs_contract = bump_dims(lhs_contract, lbd)
elif rbd is None and type(lbd) is ConcatAxis and lbd.axis not in lhs_contract:
if lbd.axis in lhs_batch:
axis = int(np.sum(np.less(lhs_batch, lbd.axis)))
else:
other = tuple(d for d in range(rhs_ndim)
if d not in rhs_batch and d not in rhs_contract)
result_batch_dim = (lhs_ndim - len(lhs_contract) +
sum(np.less(other, rbd)))
rhs_batch = bump_dims(rhs_batch, rbd)
rhs_contract = bump_dims(rhs_contract, rbd)
lhs_tensor = [d for d in range(lhs_ndim)
if d not in lhs_batch and d not in lhs_contract]
axis = len(lhs_batch) + int(np.sum(np.less(lhs_tensor, lbd.axis)))
result_batch_dim = ConcatAxis(axis, lbd.segment_lengths)
elif lbd is None and type(rbd) is ConcatAxis and rbd.axis not in rhs_contract:
if rbd.axis in rhs_batch:
axis = int(np.sum(np.less(rhs_batch, rbd.axis)))
else:
rhs_tensor = [d for d in range(rhs_ndim)
if d not in rhs_batch and d not in rhs_contract]
axis = (lhs_ndim - len(lhs_contract) +
int(sum(np.less(rhs_tensor, rbd.axis))))
result_batch_dim = ConcatAxis(axis, rbd.segment_lengths)
elif (type(lbd) is int and
(rbd is None or type(rbd) is ConcatAxis and
rbd.axis not in rhs_contract)):
lhs_tensor = [d for d in range(lhs_ndim)
if d not in lhs_batch and d not in lhs_contract]
result_batch_dim = len(lhs_batch) + int(sum(np.less(lhs_tensor, lbd)))
lhs_batch = bump_dims(lhs_batch, lbd)
lhs_contract = bump_dims(lhs_contract, lbd)
elif (type(rbd) is int and
(lbd is None or type(lbd) is ConcatAxis and
lbd.axis not in lhs_contract)):
rhs_tensor = [d for d in range(rhs_ndim)
if d not in rhs_batch and d not in rhs_contract]
result_batch_dim = (lhs_ndim - len(lhs_contract) +
int(sum(np.less(rhs_tensor, rbd))))
rhs_batch = bump_dims(rhs_batch, rbd)
rhs_contract = bump_dims(rhs_contract, rbd)
else:
assert False
new_dimension_numbers = ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch))
return new_dimension_numbers, int(result_batch_dim)
return new_dimension_numbers, result_batch_dim
def _dot_general_padding_rule(in_avals, out_avals, lhs, rhs, *,
dimension_numbers, **params):
@ -2782,15 +2819,27 @@ def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape,
return ([expand_dims(_reduce_sum(ct, axes), unit_dims)] +
[None] * len(dyn_shape))
def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *dyn_shape, shape,
def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
broadcast_dimensions):
if dyn_shape: raise NotImplementedError # TODO(mattjj)
operand, = batched_args
bdim, = batch_dims
new_operand = batching.moveaxis(operand, bdim, 0)
new_shape = (operand.shape[bdim],) + shape
new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
return broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions), 0
operand, *dyn_shape = batched_args
operand_bdim, *dyn_shape_bdims = batch_dims
if len(dyn_shape) > 1: raise NotImplementedError
if (operand_bdim is not None and
(not dyn_shape_bdims or dyn_shape_bdims[0] is None)):
new_operand = batching.moveaxis(operand, operand_bdim, 0)
new_shape = (operand.shape[operand_bdim],) + shape
new_broadcast_dimensions = (0,) + tuple(np.add(1, broadcast_dimensions))
return broadcast_in_dim(new_operand, new_shape, new_broadcast_dimensions), 0
elif (operand_bdim is None and dyn_shape_bdims and
dyn_shape_bdims[0] is not None):
(d,), (d_bdim,) = dyn_shape, dyn_shape_bdims # NotImplementedError above
assert d_bdim == 0 # must be scalar in the program to be batched
new_shape = _merge_dyn_shape(shape, (int(d.sum()),))
out = broadcast_in_dim(operand, new_shape, broadcast_dimensions)
idx, = (i for i, s in enumerate(shape) if s is None)
return out, batching.ConcatAxis(idx, d)
else:
raise NotImplementedError # TODO(mattjj)
def _broadcast_in_dim_fwd_rule(eqn):
v, *dyn = eqn.invars
@ -4471,6 +4520,13 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension):
mlir.i64_attr(dimension)).results
mlir.register_lowering(iota_p, _iota_lower)
def _iota_batching_rule(in_vals, in_dims, *, dtype, shape, dimension):
(segment_lengths,), (ax,) = in_vals, in_dims
shapes = [_merge_dyn_shape(shape, (d,)) for d in segment_lengths]
iotas = [broadcasted_iota(dtype, s, dimension) for s in shapes]
return concatenate(iotas, dimension), batching.ConcatAxis(ax, segment_lengths)
batching.primitive_batchers[iota_p] = _iota_batching_rule
def _iota_pp_rule(eqn, context, settings):
printed_params = {}
if len(eqn.params['shape']) > 1:

View File

@ -2195,35 +2195,41 @@ def unmapped_aval(size: AxisSize, axis_name, axis: Optional[int],
else:
raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}")
def _map_shaped_array(size: int, axis: Optional[int], aval: ShapedArray
) -> ShapedArray:
def _map_shaped_array(
size: int, axis: Optional[int], aval: ShapedArray) -> ShapedArray:
assert axis is None or aval.shape[axis] == size
# TODO: Extend the named shape
if axis is None: return aval
return ShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
named_shape=aval.named_shape, weak_type=aval.weak_type)
def _unmap_shaped_array(size: int, axis_name, axis: Optional[int],
aval: ShapedArray) -> ShapedArray:
def _unmap_shaped_array(
size: int, axis_name: AxisName, axis: Optional[int], aval: ShapedArray
) -> ShapedArray:
named_shape = dict(aval.named_shape)
# TODO: Make this mandatory
named_shape.pop(axis_name, None)
named_shape.pop(axis_name, None) # TODO: make this mandatory
if axis is None: return aval.update(named_shape=named_shape)
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
named_shape=named_shape, weak_type=aval.weak_type)
elif type(axis) is int:
return ShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
named_shape=named_shape, weak_type=aval.weak_type)
else: raise TypeError(axis)
def _map_dshaped_array(size: AxisSize, axis: Optional[int],
aval: DShapedArray) -> DShapedArray:
def _map_dshaped_array(
size: AxisSize, axis: Optional[int], aval: DShapedArray) -> DShapedArray:
if axis is None: return aval
return DShapedArray(tuple_delete(aval.shape, axis), aval.dtype,
aval.weak_type)
def _unmap_dshaped_array(
size: AxisSize, axis_name, axis: Optional[int],
aval: DShapedArray) -> DShapedArray:
size: AxisSize, axis_name: AxisName, axis: Optional[int], aval: DShapedArray
) -> DShapedArray:
if axis is None: return aval
return DShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
weak_type=aval.weak_type)
elif type(axis) is int:
return DShapedArray(tuple_insert(aval.shape, axis, size), aval.dtype,
weak_type=aval.weak_type)
else:
raise TypeError(axis)
AvalMapHandlerPair = Tuple[Callable, Callable]
aval_mapping_handlers: Dict[Type, AvalMapHandlerPair] = {

View File

@ -11,7 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import collections
import dataclasses
from functools import partial
from typing import (Any, Callable, Dict, Hashable, Iterable, Optional, Sequence,
Set, Tuple, Type, Union)
@ -23,7 +26,8 @@ from jax.config import config
from jax import core
from jax.core import raise_to_shaped, Trace, Tracer
from jax._src import source_info_util
from jax._src.tree_util import tree_unflatten, tree_flatten
from jax._src.tree_util import (tree_unflatten, tree_flatten,
register_pytree_node)
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_p, Zero)
from jax import linear_util as lu
@ -33,31 +37,122 @@ from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, wrap_name,
weakref_lru_cache)
from jax.interpreters import partial_eval as pe
Array = Any
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
# Piles
# i:(Fin 3) => f32[[3, 1, 4].i]
@dataclasses.dataclass(frozen=True)
class PileTy:
binder: core.Var
length: Union[int, Tracer, core.Var]
elt_ty: core.DShapedArray
def __repr__(self) -> str:
return f'Var{id(self.binder)}:{self.length} => {self.elt_ty}'
replace = dataclasses.replace
# [3, 1, 4].i
@dataclasses.dataclass(frozen=True)
class IndexedAxisSize:
idx: core.Var
lengths: Union[Array, core.Var, Tracer]
def __repr__(self) -> str:
return f'{str(self.lengths)}.Var{id(self.idx)}'
replace = dataclasses.replace
# Pile(aval=a:3 => f32[[3 1 4].a],
# data=DeviceArray([0., 1., 2., 0., 0., 1., 2., 3.], dtype=float32))
@dataclasses.dataclass(frozen=True)
class Pile:
aval: PileTy
data: Array
def _pile_flatten(pile):
lengths = []
new_shape = [lengths.append(d.lengths) or d.replace(lengths=len(lengths))
if type(d) is IndexedAxisSize else d
for d in pile.aval.elt_ty.shape]
elt_ty = pile.aval.elt_ty.update(shape=tuple(new_shape))
aval = pile.aval.replace(elt_ty=elt_ty)
return (lengths, pile.data), aval
def _pile_unflatten(aval, x):
lengths, data = x
new_shape = [d.replace(lengths=lengths[d.lengths - 1])
if type(d) is IndexedAxisSize else d
for d in aval.elt_ty.shape]
elt_ty = aval.elt_ty.update(shape=tuple(new_shape))
aval = aval.replace(elt_ty=elt_ty)
return Pile(aval, data)
register_pytree_node(Pile, _pile_flatten, _pile_unflatten)
def _pile_result(axis_size, axis, segment_lens, x):
binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32')))
shape = list(x.shape)
shape[axis] = IndexedAxisSize(binder, segment_lens)
elt_ty = core.DShapedArray(tuple(shape), x.dtype, x.weak_type)
return Pile(PileTy(binder, axis_size, elt_ty), x)
@dataclasses.dataclass(frozen=True)
class ConcatAxis:
axis: int
segment_lengths: Array
def _update_annotation(
f: lu.WrappedFun, orig_type: Optional[core.InputType],
axis_size: core.AxisSize, axis_name: core.AxisName,
explicit_in_dims: Sequence[Optional[int]]) -> lu.WrappedFun:
explicit_in_dims: Sequence[Optional[Union[int, ConcatAxis]]],
segment_lens: Sequence[Array],
) -> lu.WrappedFun:
if orig_type is None: return f
# By convention, `explicit_in_dims` only accounts for explicit arguments.
assert len(explicit_in_dims) == sum(explicit for _, explicit in orig_type)
# We add a batch dim to each mapped argument type. If `axis_size` is dynamic
# (i.e. a Tracer) the added batch dim size is a DBIdx and we add a new leading
# implicit argument and increment all other DBIdx.
new_arg = isinstance(axis_size, Tracer)
sz = core.DBIdx(0) if new_arg else axis_size
def unmap(d, a):
if isinstance(a, core.DShapedArray):
a = a.update(shape=tuple(core.DBIdx(d.val + new_arg)
if type(d) is core.DBIdx else d for d in a.shape))
return core.unmapped_aval(sz, axis_name, d, a)
in_dims = iter(explicit_in_dims)
in_type = [(unmap(next(in_dims), a), explicit) if explicit else (a, explicit)
for a, explicit in orig_type]
if new_arg: in_type = [(axis_size.aval, False), *in_type] # type: ignore
return lu.annotate(f, tuple(in_type))
# We need to:
# * if `axis_size` is dynamic, add a new implicit binder (type) for it;
# * for each element of `segment_lengths`, add a new explicit binder for it;
# * drop other implicit binders, replacing DBIdx which refer to them with
# Name objects;
# * for each (aval, in_dim) pair: if int-valued in_dim, add batch axis (int
# size if `axis_size` is int, otherwise Name); if ConcatAxis-valued in_dim,
# add batch axis (int if corresponding segment_lengths is concrete, Name if
# not);
# * generate full in_type with implicit args too.
class Name:
def __init__(self, a): self.a = a
names = [Name(a) for a, _ in orig_type]
avals = [a.update(shape=tuple(names[d.val] if type(d) is pe.DBIdx else d # type: ignore
for d in a.shape))
if type(a) is core.DShapedArray else a for a, e in orig_type if e]
new_avals = [core.raise_to_shaped(core.get_aval(s)) for s in segment_lens]
sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size
for a, d in zip(avals, explicit_in_dims):
if isinstance(d, ConcatAxis):
s = segment_lens[d.segment_lengths.val]
if isinstance(core.get_aval(s), core.ConcreteArray):
shape = list(a.shape) # type: ignore
shape[d.axis] = int(s.sum()) # specialize on shape if we can
new_avals.append(a.update(shape=tuple(shape)))
else:
new_avals.append(a)
else:
new_avals.append(core.unmapped_aval(sz, axis_name, d, a)) # type: ignore
mentioned = {d for a in new_avals if type(a) is core.DShapedArray
for d in a.shape if type(d) is Name}
expl_names = set(map(Name, new_avals))
impl_names = mentioned - expl_names # type: ignore
impl_part = [(n.a, False) for n in impl_names] # type: ignore
name_map = {n: pe.DBIdx(i) for i, n in enumerate((*impl_names, *expl_names))}
expl_part = [(a.update(shape=tuple(name_map.get(d, d) for d in a.shape))
if type(a) is core.DShapedArray else a, True) for a in new_avals]
return lu.annotate(f, (*impl_part, *expl_part))
### vmappable typeclass
@ -65,7 +160,6 @@ Vmappable = Any
Elt = Any
MapSpec = Any
AxisSize = Any
Array = Any
GetIdx = Callable[[], Tracer] # TODO(mattjj): revise this laziness
ToEltHandler = Callable[[Callable, GetIdx, Vmappable, MapSpec], Elt]
FromEltHandler = Callable[[Callable, AxisSize, Elt, MapSpec], Vmappable]
@ -75,10 +169,16 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
handler = to_elt_handlers.get(type(x))
if handler:
return handler(partial(to_elt, trace, get_idx), get_idx, x, spec)
else:
elif type(x) is Pile:
(d, ias), = ((i, sz) for i, sz in enumerate(x.aval.elt_ty.shape)
if type(sz) is IndexedAxisSize)
return BatchTracer(trace, x.data, ConcatAxis(d, ias.lengths)) # type: ignore
elif isinstance(spec, int) or spec is None:
spec = spec and canonicalize_axis(spec, len(np.shape(x)))
return (BatchTracer(trace, x, spec, source_info_util.current())
if spec is not None else x)
else:
assert False
to_elt_handlers: Dict[Type, ToEltHandler] = {}
def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
@ -86,8 +186,11 @@ def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
handler = from_elt_handlers.get(type(x))
if handler:
return handler(partial(from_elt, trace), axis_size, x, spec)
x_ = trace.full_raise(x)
val, bdim = x_.val, x_.batch_dim
if type(bdim) is ConcatAxis:
return _pile_result(axis_size, bdim.axis, bdim.segment_lengths, val)
else:
x_ = trace.full_raise(x)
return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
from_elt_handlers: Dict[Type, FromEltHandler] = {}
@ -119,7 +222,7 @@ def unregister_vmappable(data_type: Type) -> None:
del make_iota_handlers[axis_size_type]
def is_vmappable(x: Any) -> bool:
return type(x) in vmappables
return type(x) is Pile or type(x) in vmappables
@lu.transformation_with_aux
def flatten_fun_for_vmap(in_tree, *args_flat):
@ -133,16 +236,17 @@ def flatten_fun_for_vmap(in_tree, *args_flat):
NotMapped = type(None)
not_mapped = None
class BatchTracer(Tracer):
__slots__ = ['val', 'batch_dim', 'source_info']
def __init__(self, trace, val, batch_dim: Optional[int],
def __init__(self, trace, val, batch_dim: Union[NotMapped, int, ConcatAxis],
source_info: Optional[source_info_util.SourceInfo] = None):
if config.jax_enable_checks:
assert type(batch_dim) in (int, NotMapped)
assert type(batch_dim) in (NotMapped, int, ConcatAxis)
if type(batch_dim) is int:
aval = raise_to_shaped(core.get_aval(val))
assert batch_dim is not_mapped or 0 <= batch_dim < len(aval.shape) # type: ignore
assert 0 <= batch_dim < len(aval.shape) # type: ignore
self._trace = trace
self.val = val
self.batch_dim = batch_dim
@ -153,7 +257,14 @@ class BatchTracer(Tracer):
aval = raise_to_shaped(core.get_aval(self.val))
if self.batch_dim is not_mapped:
return aval
return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
elif type(self.batch_dim) is int:
return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
elif type(self.batch_dim) is ConcatAxis:
shape = list(aval.shape)
size_tracer = BatchTracer(self._trace, self.batch_dim.segment_lengths, 0)
shape[self.batch_dim.axis] = size_tracer
return core.DShapedArray(shape=tuple(shape), dtype=aval.dtype,
weak_type=aval.weak_type)
def full_lower(self):
if self.batch_dim is not_mapped:
@ -170,6 +281,12 @@ class BatchTracer(Tracer):
def _contents(self):
return [('val', self.val), ('batch_dim', self.batch_dim)]
def get_referent(self):
if self.batch_dim is None or type(self.batch_dim) is int:
return core.get_referent(self.val)
else: # TODO(mattjj): could handle the ConcatAxis case?
return self
class BatchTrace(Trace):
def __init__(self, *args, axis_name, spmd_axis_name = None):
@ -206,8 +323,9 @@ class BatchTrace(Trace):
if self.axis_name is core.no_axis_name:
# If axis name is `no_axis_name` we can't find it via `core.axis_name` so
# we reconstruct it from the information we have available
axis_size, = core.dedup_referents(x.shape[d] for x, d in zip(vals, dims)
if d is not not_mapped)
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
for x, d in zip(vals, dims) if d is not not_mapped)
axis_size, = core.dedup_referents(sizes)
return core.AxisEnvFrame(self.axis_name, axis_size, self.main)
return core.axis_frame(self.axis_name)
@ -241,14 +359,17 @@ class BatchTrace(Trace):
vals, dims = unzip2((t.val, t.batch_dim) for t in tracers)
if all(bdim is not_mapped for bdim in dims):
return call_primitive.bind(f, *vals, **params)
else:
f_, dims_out = batch_subtrace(f, self.main, dims)
axis_size, = core.dedup_referents(x.shape[d] for x, d in zip(vals, dims)
if d is not not_mapped)
f_ = _update_annotation(f_, f.in_type, axis_size, self.axis_name, dims)
vals_out = call_primitive.bind(f_, *vals, **params)
src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out())]
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
for x, d in zip(vals, dims) if d is not not_mapped)
axis_size, = core.dedup_referents(sizes)
segment_lens, dims = unpack_concat_axes(dims)
f_, dims_out = batch_subtrace(f, self.main, tuple(dims))
f_ = _update_annotation(f_, f.in_type, axis_size, self.axis_name, dims,
segment_lens)
vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params)
vals_out, dims_out = reassemble_concat_axes(vals_out, dims_out())
src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)]
def post_process_call(self, call_primitive, out_tracers, params):
vals, dims, srcs = unzip3((t.val, t.batch_dim, t.source_info)
@ -465,15 +586,36 @@ def vtile(f_flat: lu.WrappedFun,
@lu.transformation_with_aux
def batch_subtrace(main, in_dims, *in_vals):
# used in e.g. process_call
trace = main.with_cur_sublevel()
in_dims = in_dims() if callable(in_dims) else in_dims
in_vals, in_dims = reassemble_concat_axes(in_vals, in_dims)
in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
if dim is not None else x for x, dim in zip(in_vals, in_dims)]
outs = yield in_tracers, {}
out_tracers = map(trace.full_raise, outs)
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
yield out_vals, out_dims
segment_lens, out_dims = unpack_concat_axes(out_dims)
yield (*segment_lens, *out_vals), out_dims
def unpack_concat_axes(dims):
if not any(type(d) is ConcatAxis for d in dims):
return [], dims
concat_axis_map = collections.OrderedDict()
def convert(d: ConcatAxis) -> ConcatAxis:
_, dbidx = concat_axis_map.setdefault(
id(core.get_referent(d.segment_lengths)),
(d.segment_lengths, pe.DBIdx(len(concat_axis_map))))
return ConcatAxis(d.axis, dbidx)
new_dims = [convert(d) if isinstance(d, ConcatAxis) else d for d in dims]
segment_lens = [s for s, _ in concat_axis_map.values()]
return segment_lens, new_dims
def reassemble_concat_axes(vals, dims):
idxs = {d.segment_lengths.val for d in dims if isinstance(d, ConcatAxis)}
dims = [ConcatAxis(d.axis, vals[d.segment_lengths.val])
if isinstance(d, ConcatAxis) else d for d in dims]
vals = [x for i, x in enumerate(vals) if i not in idxs]
return vals, dims
### API for batching jaxprs
@ -668,11 +810,34 @@ def defreducer(prim):
def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
operand, = batched_args
bdim, = batch_dims
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
if 'input_shape' in params:
params = dict(params, input_shape=operand.shape)
return prim.bind(operand, axes=axes, **params), bdim_out
if isinstance(bdim, int):
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
if 'input_shape' in params:
params = dict(params, input_shape=operand.shape)
return prim.bind(operand, axes=axes, **params), bdim_out
elif isinstance(bdim, ConcatAxis):
if bdim.axis in axes:
other_axes = [i for i in axes if i != bdim.axis]
if other_axes:
operand = prim.bind(operand, axes=other_axes, **params)
c_axis = bdim.axis - sum(d < bdim.axis for d in other_axes)
operand = bdim_at_front(operand, c_axis, operand.shape[c_axis])
return segment_sum(operand, bdim.segment_lengths), 0
else:
raise NotImplementedError # TODO(mattjj)
else:
assert False
# TODO(mattjj): replace with jax.lax.ops.segment_sum (once it's easier to trace
# under dynamic shapes)
def segment_sum(operand, segment_lens):
scat_idx = jax.numpy.cumsum(segment_lens) - segment_lens
segment_ids = jax.numpy.cumsum(
jax.numpy.zeros(operand.shape[0], 'int32').at[scat_idx].set(1)) - 1
out = jax.numpy.zeros((len(segment_lens), *operand.shape[1:]),
operand.dtype).at[segment_ids].add(operand)
return out
### general utilities for manipulating axes on jaxpr types (not vmappables)

View File

@ -2224,14 +2224,13 @@ def _add_implicit_outputs(jaxpr: Jaxpr) -> Tuple[Jaxpr, OutputType]:
class TracerAsName:
tracer: DynamicJaxprTracer
ref: Any
def __init__(self, tracer):
trace = core.thread_local_state.trace_state.trace_stack.dynamic
self.tracer = trace.with_cur_sublevel().full_raise(tracer)
self.ref = core.get_referent(tracer)
def __eq__(self, other):
return isinstance(other, TracerAsName) and self.tracer is other.tracer
return isinstance(other, TracerAsName) and self.ref is other.ref
def __hash__(self):
return id(self.tracer)
return id(self.ref)
def _extract_implicit_args(
trace: DynamicJaxprTrace, in_type: Sequence[Tuple[AbstractValue, bool]],

View File

@ -245,14 +245,17 @@ def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun:
def _check_input_type(in_type: core.InputType) -> None:
# Check that in_type is syntactically well-formed
assert (type(in_type) is tuple and all(type(e) is tuple for e in in_type) and
all(isinstance(a, core.AbstractValue) and type(b) is bool
and not isinstance(a, core.ConcreteArray) for a, b in in_type) and
all(isinstance(d, (int, core.DBIdx, core.DArray))
and (not isinstance(d, core.DArray) or
type(d.dtype) is core.bint and not d.shape)
for a, _ in in_type if type(a) is core.DShapedArray
for d in a.shape))
assert type(in_type) is tuple and all(type(e) is tuple for e in in_type)
assert all(isinstance(a, core.AbstractValue) and type(b) is bool
and not isinstance(a, core.ConcreteArray) for a, b in in_type)
def valid_size(d) -> bool:
if isinstance(d, core.DBIdx) and type(d.val) is int and d.val >= 0:
return True
return (isinstance(d, (int, core.DBIdx, core.DArray)) and
(not isinstance(d, core.DArray) or type(d) is core.bint and not d.shape))
assert all(valid_size(d) for a, _ in in_type if type(a) is core.DShapedArray
for d in a.shape)
# Check that all DBIdx point to positions to the left of the input on which
# they appear.

View File

@ -22,7 +22,9 @@ from absl.testing import absltest
import jax
import jax.numpy as jnp
from jax import core, lax
from jax import core
from jax import lax
from jax.interpreters import batching
import jax._src.lib
from jax._src import test_util as jtu
import jax._src.util
@ -1437,6 +1439,52 @@ class DynamicShapeTest(jtu.JaxTestCase):
y = jnp.arange(3.0) + 1
jax.make_jaxpr(f)(x, y) # doesn't crash
# TODO(https://github.com/google/jax/issues/12291): Enable jax.Array
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow",
jax_array=False)
class PileTest(jtu.JaxTestCase):
def test_internal_pile(self):
xs = jax.vmap(lambda n: jnp.arange(n).sum())(jnp.array([3, 1, 4]))
self.assertAllClose(xs, jnp.array([3, 0, 6]), check_dtypes=False)
def test_make_pile_from_dynamic_shape(self):
# We may not want to support returning piles from vmapped functions (instead
# preferring to have a separate API which allows piles). But for now it
# makes for a convenient way to construct piles for the other tests!
p = jax.vmap(partial(jnp.arange, dtype='int32'))(jnp.array([3, 1, 4]))
self.assertIsInstance(p, batching.Pile)
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[\[3 1 4\]\.Var[0-9]+\]')
data = jnp.concatenate([jnp.arange(3), jnp.arange(1), jnp.arange(4)])
self.assertAllClose(p.data, data, check_dtypes=False)
def test_pile_map_eltwise(self):
p = jax.vmap(partial(jnp.arange, dtype='int32'))(jnp.array([3, 1, 4]))
p = pile_map(lambda x: x ** 2)(p)
self.assertIsInstance(p, batching.Pile)
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[\[3 1 4\]\.Var[0-9]+\]')
data = jnp.concatenate([jnp.arange(3), jnp.arange(1), jnp.arange(4)]) ** 2
self.assertAllClose(p.data, data, check_dtypes=False)
def test_pile_map_vector_dot(self):
p = jax.vmap(jnp.arange)(jnp.array([3, 1, 4]))
y = pile_map(jnp.dot)(p, p)
self.assertAllClose(y, jnp.array([5, 0, 14]))
def test_pile_map_matrix_dot(self):
sizes = jnp.array([3, 1, 4])
p1 = jax.vmap(lambda n: jnp.ones((7, n)))(sizes)
p2 = jax.vmap(lambda n: jnp.ones((n, 7)))(sizes)
y = pile_map(jnp.dot)(p1, p2)
self.assertAllClose(y, np.tile(np.array([3, 1, 4])[:, None, None], (7, 7)),
check_dtypes=False)
# TODO(mattjj): could just make this vmap, just need to adjust how we infer axis
# sizes in api.py's _mapped_axis_size to handle piles. For another day...
def pile_map(f):
def mapped(*piles):
return jax.vmap(f, axis_size=piles[0].aval.length)(*piles)
return mapped
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())