Merge pull request #16072 from axch:stacked-piles

PiperOrigin-RevId: 538839592
This commit is contained in:
jax authors 2023-06-08 11:22:30 -07:00
commit 5c5baccb2e
6 changed files with 119 additions and 149 deletions

View File

@ -1666,6 +1666,9 @@ class DShapedArray(UnshapedArray):
weak_type = self.weak_type
return DShapedArray(shape, dtype, weak_type)
def _len(self, tracer):
return self.shape[0]
def __eq__(self, other):
return (type(self) is type(other)
and self.dtype == other.dtype and self.shape == other.shape
@ -1720,7 +1723,7 @@ class DArray:
slices = tuple(slice(int(d._data)) if type(d) is DArray and
type(d.dtype) is bint else slice(None) for d in self.shape)
data = self._data[slices]
return f'{dtypestr}[{shapestr}] with value:\n{data}'
return f'{dtypestr}[{shapestr}] with value: {data}'
def __hash__(self) -> int:
if not self.shape:
return hash((self._aval, int(self._data)))
@ -1729,6 +1732,8 @@ class DArray:
if isinstance(other, DArray) and self._aval == other._aval:
return self._data == other._data
return False
def __len__(self):
return self.shape[0]
pytype_aval_mappings[DArray] = \
lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type,

View File

@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations
import collections
import dataclasses
from functools import partial
from typing import (Any, Callable, Dict, Iterable, Optional, Sequence, Set,
@ -98,23 +97,30 @@ def _pile_unflatten(aval, x):
return Pile(aval, data)
register_pytree_node(Pile, _pile_flatten, _pile_unflatten)
def _pile_result(axis_size, axis, segment_lens, x):
def _pile_result(axis_size, stacked_axis, ragged_axis, segment_lens, x):
binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32')))
if stacked_axis != 0:
raise NotImplementedError # TODO Transpose x so the stacked axis is axis 0
shape = list(x.shape)
shape[axis] = IndexedAxisSize(binder, segment_lens)
del shape[0]
shape[ragged_axis-1] = IndexedAxisSize(binder, segment_lens)
elt_ty = core.DShapedArray(tuple(shape), x.dtype, x.weak_type)
return Pile(PileTy(binder, axis_size, elt_ty), x)
@dataclasses.dataclass(frozen=True)
class ConcatAxis:
axis: int
class RaggedAxis:
stacked_axis: int
# TODO(mattjj,axch): Generalize to multiple ragged dimensions
# e.g. `i:(Fin 3) => f32[lens1.i, lens2.i]`
ragged_axis: int
segment_lengths: Array
def _update_annotation(
f: lu.WrappedFun, orig_type: Optional[core.InputType],
axis_size: core.AxisSize, axis_name: AxisName,
explicit_in_dims: Sequence[Optional[Union[int, ConcatAxis]]],
explicit_in_dims: Sequence[Optional[Union[int, RaggedAxis]]],
segment_lens: Sequence[Array],
) -> lu.WrappedFun:
if orig_type is None: return f
@ -126,7 +132,7 @@ def _update_annotation(
# * drop other implicit binders, replacing DBIdx which refer to them with
# Name objects;
# * for each (aval, in_dim) pair: if int-valued in_dim, add batch axis (int
# size if `axis_size` is int, otherwise Name); if ConcatAxis-valued in_dim,
# size if `axis_size` is int, otherwise Name); if RaggedAxis-valued in_dim,
# add batch axis (int if corresponding segment_lengths is concrete, Name if
# not);
# * generate full in_type with implicit args too.
@ -141,14 +147,8 @@ def _update_annotation(
new_avals = [core.raise_to_shaped(core.get_aval(s)) for s in segment_lens]
sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size
for a, d in zip(avals, explicit_in_dims):
if isinstance(d, ConcatAxis):
s = segment_lens[d.segment_lengths.val]
if isinstance(core.get_aval(s), core.ConcreteArray):
shape = list(a.shape) # type: ignore
shape[d.axis] = int(s.sum()) # specialize on shape if we can
new_avals.append(a.update(shape=tuple(shape)))
else:
new_avals.append(a)
if isinstance(d, RaggedAxis):
raise NotImplementedError
else:
new_avals.append(core.unmapped_aval(sz, axis_name, d, a)) # type: ignore
@ -182,7 +182,7 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
raise TypeError("pile input without using pile_axis in_axes spec")
(d, ias), = ((i, sz) for i, sz in enumerate(x.aval.elt_ty.shape)
if type(sz) is IndexedAxisSize)
return BatchTracer(trace, x.data, ConcatAxis(d, ias.lengths)) # type: ignore
return BatchTracer(trace, x.data, RaggedAxis(0, d+1, ias.lengths)) # type: ignore
elif isinstance(spec, int) or spec is None:
spec = spec and canonicalize_axis(spec, len(np.shape(x)))
return (BatchTracer(trace, x, spec, source_info_util.current())
@ -198,11 +198,11 @@ def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
return handler(partial(from_elt, trace), axis_size, x, spec)
x_ = trace.full_raise(x)
val, bdim = x_.val, x_.batch_dim
if type(bdim) is ConcatAxis:
if type(bdim) is RaggedAxis:
if spec is not pile_axis:
# TODO(mattjj): improve this error message
raise TypeError("ragged output without using pile_axis out_axes spec")
return _pile_result(axis_size, bdim.axis, bdim.segment_lengths, val)
return _pile_result(axis_size, bdim.stacked_axis, bdim.ragged_axis, bdim.segment_lengths, val)
else:
return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
from_elt_handlers: Dict[Type, FromEltHandler] = {}
@ -253,10 +253,10 @@ not_mapped = None
class BatchTracer(Tracer):
__slots__ = ['val', 'batch_dim', 'source_info']
def __init__(self, trace, val, batch_dim: Union[NotMapped, int, ConcatAxis],
def __init__(self, trace, val, batch_dim: Union[NotMapped, int, RaggedAxis],
source_info: Optional[source_info_util.SourceInfo] = None):
if config.jax_enable_checks:
assert type(batch_dim) in (NotMapped, int, ConcatAxis)
assert type(batch_dim) in (NotMapped, int, RaggedAxis)
if type(batch_dim) is int:
aval = raise_to_shaped(core.get_aval(val))
assert 0 <= batch_dim < len(aval.shape) # type: ignore
@ -272,10 +272,15 @@ class BatchTracer(Tracer):
return aval
elif type(self.batch_dim) is int:
return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
elif type(self.batch_dim) is ConcatAxis:
shape = list(aval.shape)
elif type(self.batch_dim) is RaggedAxis:
new_aval = core.mapped_aval(
aval.shape[self.batch_dim.stacked_axis], self.batch_dim.stacked_axis, aval)
shape = list(new_aval.shape) # type: ignore
size_tracer = BatchTracer(self._trace, self.batch_dim.segment_lengths, 0)
shape[self.batch_dim.axis] = size_tracer
ragged_axis = self.batch_dim.ragged_axis
if self.batch_dim.stacked_axis < self.batch_dim.ragged_axis:
ragged_axis -= 1
shape[ragged_axis] = size_tracer
return core.DShapedArray(shape=tuple(shape), dtype=aval.dtype,
weak_type=aval.weak_type)
@ -297,7 +302,7 @@ class BatchTracer(Tracer):
def get_referent(self):
if self.batch_dim is None or type(self.batch_dim) is int:
return core.get_referent(self.val)
else: # TODO(mattjj): could handle the ConcatAxis case?
else: # TODO(mattjj): could handle the RaggedAxis case?
return self
class BatchTrace(Trace):
@ -376,12 +381,9 @@ class BatchTrace(Trace):
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
for x, d in zip(vals, dims) if d is not not_mapped)
axis_size, = core.dedup_referents(sizes)
segment_lens, dims = unpack_concat_axes(dims)
f_, dims_out = batch_subtrace(f, self.main, tuple(dims))
f_ = _update_annotation(f_, f.in_type, axis_size, self.axis_name, dims,
segment_lens)
vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params)
vals_out, dims_out = reassemble_concat_axes(vals_out, dims_out())
f_ = _update_annotation(f_, f.in_type, axis_size, self.axis_name, dims, [])
vals_out = call_primitive.bind(f_, *vals, **params)
src = source_info_util.current()
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)]
@ -607,34 +609,12 @@ def vtile(f_flat: lu.WrappedFun,
def batch_subtrace(main, in_dims, *in_vals):
trace = main.with_cur_sublevel()
in_dims = in_dims() if callable(in_dims) else in_dims
in_vals, in_dims = reassemble_concat_axes(in_vals, in_dims)
in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
if dim is not None else x for x, dim in zip(in_vals, in_dims)]
outs = yield in_tracers, {}
out_tracers = map(trace.full_raise, outs)
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
segment_lens, out_dims = unpack_concat_axes(out_dims)
yield (*segment_lens, *out_vals), out_dims
def unpack_concat_axes(dims):
if not any(type(d) is ConcatAxis for d in dims):
return [], dims
concat_axis_map = collections.OrderedDict()
def convert(d: ConcatAxis) -> ConcatAxis:
_, dbidx = concat_axis_map.setdefault(
id(core.get_referent(d.segment_lengths)),
(d.segment_lengths, pe.DBIdx(len(concat_axis_map))))
return ConcatAxis(d.axis, dbidx)
new_dims = [convert(d) if isinstance(d, ConcatAxis) else d for d in dims]
segment_lens = [s for s, _ in concat_axis_map.values()]
return segment_lens, new_dims
def reassemble_concat_axes(vals, dims):
idxs = {d.segment_lengths.val for d in dims if isinstance(d, ConcatAxis)}
dims = [ConcatAxis(d.axis, vals[d.segment_lengths.val])
if isinstance(d, ConcatAxis) else d for d in dims]
vals = [x for i, x in enumerate(vals) if i not in idxs]
return vals, dims
yield out_vals, out_dims
### API for batching jaxprs
@ -880,40 +860,43 @@ def _handle_scalar_broadcasting(nd, x, d):
else:
return jax.lax.expand_dims(x, tuple(range(np.ndim(x), nd)))
def defreducer(prim):
primitive_batchers[prim] = partial(reducer_batcher, prim)
def defreducer(prim, ident):
primitive_batchers[prim] = partial(reducer_batcher, prim, ident)
def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
def reducer_batcher(prim, ident, batched_args, batch_dims, axes, **params):
def out_axis(axes, axis):
return int(list(np.delete(np.arange(operand.ndim), axes)).index(axis))
operand, = batched_args
bdim, = batch_dims
if isinstance(bdim, int):
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
bdim_out = out_axis(axes, bdim)
if 'input_shape' in params:
params = dict(params, input_shape=operand.shape)
return prim.bind(operand, axes=axes, **params), bdim_out
elif isinstance(bdim, ConcatAxis):
if bdim.axis in axes:
other_axes = [i for i in axes if i != bdim.axis]
if other_axes:
operand = prim.bind(operand, axes=other_axes, **params)
c_axis = bdim.axis - sum(d < bdim.axis for d in other_axes)
operand = bdim_at_front(operand, c_axis, operand.shape[c_axis])
return segment_sum(operand, bdim.segment_lengths), 0
elif isinstance(bdim, RaggedAxis):
assert ident is not None, "TODO Ragged batching a reduction requires an identity"
axes = tuple(np.where(np.less(axes, bdim.stacked_axis), axes, np.add(axes, 1)))
bdim_out = out_axis(axes, bdim.stacked_axis)
if bdim.ragged_axis in axes:
operand = mask_ragged_axis(operand, ident, bdim)
result = prim.bind(operand, axes=axes, **params)
return result, bdim_out
else:
raise NotImplementedError # TODO(mattjj)
result = prim.bind(operand, axes=axes, **params)
return result, RaggedAxis(bdim_out, out_axis(axes, bdim.ragged_axis), bdim.segment_lengths)
else:
assert False
# TODO(mattjj): replace with jax.lax.ops.segment_sum (once it's easier to trace
# under dynamic shapes)
def segment_sum(operand, segment_lens):
scat_idx = jax.numpy.cumsum(segment_lens) - segment_lens
segment_ids = jax.numpy.cumsum(
jax.numpy.zeros(operand.shape[0], 'int32').at[scat_idx].set(1)) - 1
out = jax.numpy.zeros((len(segment_lens), *operand.shape[1:]),
operand.dtype).at[segment_ids].add(operand)
return out
def mask_ragged_axis(operand, ident, axis_spec):
value = ident(operand.dtype)
positions = jax.lax.broadcasted_iota('int32', operand.shape, axis_spec.ragged_axis)
# TODO(mattjj, axch) cant get ._data, need to convert it
lengths = jax.lax.convert_element_type(axis_spec.segment_lengths._data, 'int32')
limits = jax.lax.broadcast_in_dim(
lengths, operand.shape, [axis_spec.stacked_axis])
mask = positions < limits
return jax.lax.select(mask, operand, jax.lax.broadcast(value, operand.shape))
### general utilities for manipulating axes on jaxpr types (not vmappables)

View File

@ -51,7 +51,7 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax._src.interpreters import xla
from jax._src.interpreters.batching import ConcatAxis
from jax._src.interpreters.batching import RaggedAxis
from jax._src.lax import slicing
from jax._src.lax.utils import (
_input_dtype,
@ -2255,10 +2255,12 @@ def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type):
def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type):
if operand.dtype != new_dtype:
if dtypes.is_opaque_dtype(operand.dtype):
if (dtypes.is_opaque_dtype(operand.dtype) and
not isinstance(operand.dtype, core.bint)):
raise ValueError(
f"Cannot call convert_element_type on dtype {dtype_to_string(operand.dtype)}")
if dtypes.is_opaque_dtype(new_dtype):
if (dtypes.is_opaque_dtype(new_dtype) and
not isinstance(new_dtype, core.bint)):
raise ValueError(
f"Cannot convert_element_type to dtype={dtype_to_string(new_dtype)}")
return new_dtype
@ -2595,21 +2597,16 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
lhs, rhs = batched_args
lbd, rbd = batch_dims
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
if (type(lbd) is type(rbd) is ConcatAxis and
lbd.axis in lhs_contract and rbd.axis in rhs_contract):
# first handle any other part of the dot with these as batch dims
lhs_contract_ = [d for d in lhs_contract if d != lbd.axis]
rhs_contract_ = [d for d in rhs_contract if d != rbd.axis]
lhs_batch_ = (lbd.axis, *lhs_batch)
rhs_batch_ = (rbd.axis, *rhs_batch)
new_dnums = ((lhs_contract_, rhs_contract_), (lhs_batch_, rhs_batch_))
out = dot_general(lhs, rhs, new_dnums, precision=precision,
preferred_element_type=preferred_element_type)
# now a segment sum along that batch axis
return batching.segment_sum(out, lbd.segment_lengths), 0
left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd
right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
(lhs.ndim, rhs.ndim), batch_dims, dimension_numbers)
(lhs.ndim, rhs.ndim), (left_stack_dim, right_stack_dim), dimension_numbers)
# TODO Should probably check that any ragged dimensions have corresponding
# sizes, because otherwise the dot product is technically undefined.
if type(lbd) is RaggedAxis:
lhs = batching.mask_ragged_axis(lhs, _get_sum_identity, lbd)
if type(rbd) is RaggedAxis:
rhs = batching.mask_ragged_axis(rhs, _get_sum_identity, rbd)
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
precision=precision,
preferred_element_type=preferred_element_type)
@ -2635,34 +2632,13 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
lhs_contract = bump_dims(lhs_contract, lbd)
rhs_contract = bump_dims(rhs_contract, rbd)
result_batch_dim = 0
elif rbd is None and type(lbd) is ConcatAxis and lbd.axis not in lhs_contract:
if lbd.axis in lhs_batch:
axis = int(np.sum(np.less(lhs_batch, lbd.axis)))
else:
lhs_tensor = [d for d in range(lhs_ndim)
if d not in lhs_batch and d not in lhs_contract]
axis = len(lhs_batch) + int(np.sum(np.less(lhs_tensor, lbd.axis)))
result_batch_dim = ConcatAxis(axis, lbd.segment_lengths)
elif lbd is None and type(rbd) is ConcatAxis and rbd.axis not in rhs_contract:
if rbd.axis in rhs_batch:
axis = int(np.sum(np.less(rhs_batch, rbd.axis)))
else:
rhs_tensor = [d for d in range(rhs_ndim)
if d not in rhs_batch and d not in rhs_contract]
axis = (lhs_ndim - len(lhs_contract) +
int(sum(np.less(rhs_tensor, rbd.axis))))
result_batch_dim = ConcatAxis(axis, rbd.segment_lengths)
elif (type(lbd) is int and
(rbd is None or type(rbd) is ConcatAxis and
rbd.axis not in rhs_contract)):
elif (type(lbd) is int and rbd is None):
lhs_tensor = [d for d in range(lhs_ndim)
if d not in lhs_batch and d not in lhs_contract]
result_batch_dim = len(lhs_batch) + int(sum(np.less(lhs_tensor, lbd)))
lhs_batch = bump_dims(lhs_batch, lbd)
lhs_contract = bump_dims(lhs_contract, lbd)
elif (type(rbd) is int and
(lbd is None or type(lbd) is ConcatAxis and
lbd.axis not in lhs_contract)):
elif (type(rbd) is int and lbd is None):
rhs_tensor = [d for d in range(rhs_ndim)
if d not in rhs_batch and d not in rhs_contract]
result_batch_dim = (lhs_ndim - len(lhs_contract) +
@ -2848,12 +2824,13 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
dyn_shape_bdims[0] is not None):
(d,), (d_bdim,) = dyn_shape, dyn_shape_bdims # NotImplementedError above
assert d_bdim == 0 # must be scalar in the program to be batched
new_shape = _merge_dyn_shape(shape, (int(d.sum()),))
bound = d.dtype.bound
new_shape = (len(d),) + _merge_dyn_shape(shape, (bound,))
out = broadcast_in_dim(operand, new_shape, broadcast_dimensions)
idx, = (i for i, s in enumerate(shape) if s is None)
return out, batching.ConcatAxis(idx, d)
return out, batching.RaggedAxis(0, idx+1, d)
else:
raise NotImplementedError # TODO(mattjj)
raise NotImplementedError # TODO(mattjj,axch)
def _broadcast_in_dim_fwd_rule(eqn):
v, *dyn = eqn.invars
@ -3710,7 +3687,7 @@ reduce_sum_p = standard_primitive(
_reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
'reduce_sum')
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
batching.defreducer(reduce_sum_p)
batching.defreducer(reduce_sum_p, _get_sum_identity)
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,
_get_sum_identity)
@ -3734,7 +3711,7 @@ reduce_prod_p = standard_primitive(
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'),
'reduce_prod')
ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule
batching.defreducer(reduce_prod_p)
batching.defreducer(reduce_prod_p, _get_prod_identity)
pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, _reduce_prod,
_get_prod_identity)
@ -3756,7 +3733,7 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
reduce_max_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
'reduce_max')
ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
batching.defreducer(reduce_max_p)
batching.defreducer(reduce_max_p, _get_max_identity)
pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max,
_get_max_identity)
@ -3764,7 +3741,7 @@ pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max,
reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
'reduce_min')
ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule)
batching.defreducer(reduce_min_p)
batching.defreducer(reduce_min_p, _get_min_identity)
pe.padding_rules[reduce_min_p] = partial(_reducer_padding, _reduce_min,
_get_min_identity)
@ -3810,12 +3787,12 @@ def _compute_argminmax(value_comparator, get_identity,
argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
'argmin', weak_type_rule=_strip_weak_type)
batching.defreducer(argmin_p)
batching.defreducer(argmin_p, _get_min_identity)
ad.defjvp_zero(argmin_p)
argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
'argmax', weak_type_rule=_strip_weak_type)
batching.defreducer(argmax_p)
batching.defreducer(argmax_p, _get_max_identity)
ad.defjvp_zero(argmax_p)
mlir.register_lowering(argmin_p, mlir.cache_lowering(mlir.lower_fun(
@ -3835,19 +3812,19 @@ def _reduce_logical_shape_rule(operand, *, axes):
reduce_or_p = standard_primitive(
_reduce_logical_shape_rule, _input_dtype, 'reduce_or',
weak_type_rule=_strip_weak_type)
batching.defreducer(reduce_or_p)
batching.defreducer(reduce_or_p, _get_bitwise_or_identity)
reduce_and_p = standard_primitive(
_reduce_logical_shape_rule, _input_dtype, 'reduce_and',
weak_type_rule=_strip_weak_type)
batching.defreducer(reduce_and_p)
batching.defreducer(reduce_and_p, _get_bitwise_and_identity)
reduce_xor_p = standard_primitive(
_reduce_logical_shape_rule, _input_dtype, 'reduce_xor',
weak_type_rule=_strip_weak_type)
batching.defreducer(reduce_xor_p)
batching.defreducer(reduce_xor_p, _get_bitwise_or_identity)
def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
@ -4555,9 +4532,12 @@ mlir.register_lowering(iota_p, _iota_lower)
def _iota_batching_rule(in_vals, in_dims, *, dtype, shape, dimension):
(segment_lengths,), (ax,) = in_vals, in_dims
shapes = [_merge_dyn_shape(shape, (d,)) for d in segment_lengths]
iotas = [broadcasted_iota(dtype, s, dimension) for s in shapes]
return concatenate(iotas, dimension), batching.ConcatAxis(ax, segment_lengths)
assert ax == 0
bound = segment_lengths.dtype.bound
ragged_axis, = [i for i, dim in enumerate(shape) if dim is None]
shape = (len(segment_lengths),) + _merge_dyn_shape(shape, (bound,))
iota = broadcasted_iota(dtype, shape, dimension+1)
return iota, batching.RaggedAxis(ax, ragged_axis+1, segment_lengths)
batching.primitive_batchers[iota_p] = _iota_batching_rule
def _iota_pp_rule(eqn, context, settings):

View File

@ -850,7 +850,7 @@ def _shard_map_batch(
if all(bdim is batching.not_mapped for bdim in in_dims):
return prim.bind(fun, *in_vals, mesh=mesh, in_names=in_names,
out_names_thunk=out_names_thunk, check_rep=check_rep)
if any(isinstance(d, batching.ConcatAxis) for d in in_dims):
if any(isinstance(d, batching.RaggedAxis) for d in in_dims):
raise NotImplementedError
fun, out_dims = batching.batch_subtrace(fun, trace.main, tuple(in_dims))
new_in_names = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] # type: ignore

View File

@ -21,7 +21,7 @@ from jax._src.interpreters.batching import (
BatchTrace as BatchTrace,
BatchTracer as BatchTracer,
BatchingRule as BatchingRule,
ConcatAxis as ConcatAxis,
RaggedAxis as RaggedAxis,
Elt as Elt,
FromEltHandler as FromEltHandler,
GetIdx as GetIdx,
@ -62,15 +62,12 @@ from jax._src.interpreters.batching import (
not_mapped as not_mapped,
pile_axis as pile_axis,
primitive_batchers as primitive_batchers,
reassemble_concat_axes as reassemble_concat_axes,
reducer_batcher as reducer_batcher,
register_vmappable as register_vmappable,
segment_sum as segment_sum,
spec_types as spec_types,
spmd_axis_primitive_batchers as spmd_axis_primitive_batchers,
to_elt as to_elt,
to_elt_handlers as to_elt_handlers,
unpack_concat_axes as unpack_concat_axes,
unregister_vmappable as unregister_vmappable,
vectorized_batcher as vectorized_batcher,
vmappables as vmappables,

View File

@ -1490,42 +1490,47 @@ class DynamicShapeExecutionTest(jtu.JaxTestCase):
self.assertEqual(y.shape, (sz, 4))
self.assertAllClose(y._data, x)
@unittest.skip("Test does not work with jax.Array")
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow")
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow",
jax_disable_jit=True, jax_traceback_filtering='off')
class PileTest(jtu.JaxTestCase):
def test_internal_pile(self):
xs = jax.vmap(lambda n: jnp.arange(n).sum())(jnp.array([3, 1, 4]))
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
xs = jax.vmap(lambda n: jax.lax.iota('int32', n).sum())(ins)
self.assertAllClose(xs, jnp.array([3, 0, 6]), check_dtypes=False)
def test_make_pile_from_dynamic_shape(self):
# We may not want to support returning piles from vmapped functions (instead
# preferring to have a separate API which allows piles). But for now it
# makes for a convenient way to construct piles for the other tests!
p = jax.vmap(partial(jnp.arange, dtype='int32'), out_axes=batching.pile_axis
)(jnp.array([3, 1, 4]))
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
p = jax.vmap(partial(jnp.arange, dtype='int32'),
out_axes=batching.pile_axis)(ins)
self.assertIsInstance(p, batching.Pile)
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[\[3 1 4\]\.Var[0-9]+\]')
data = jnp.concatenate([jnp.arange(3), jnp.arange(1), jnp.arange(4)])
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]')
data = jax.lax.broadcasted_iota('int32', (3, 5), 1)
self.assertAllClose(p.data, data, check_dtypes=False)
def test_pile_map_eltwise(self):
p = jax.vmap(partial(jnp.arange, dtype='int32'), out_axes=batching.pile_axis
)(jnp.array([3, 1, 4]))
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
p = jax.vmap(partial(jnp.arange, dtype='int32'),
out_axes=batching.pile_axis)(ins)
p = pile_map(lambda x: x ** 2)(p)
self.assertIsInstance(p, batching.Pile)
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[\[3 1 4\]\.Var[0-9]+\]')
data = jnp.concatenate([jnp.arange(3), jnp.arange(1), jnp.arange(4)]) ** 2
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]')
data = jax.lax.broadcasted_iota('int32', (3, 5), 1) ** 2
self.assertAllClose(p.data, data, check_dtypes=False)
def test_pile_map_vector_dot(self):
p = jax.vmap(jnp.arange, out_axes=batching.pile_axis)(jnp.array([3, 1, 4]))
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
p = jax.vmap(partial(jnp.arange, dtype='int32'),
out_axes=batching.pile_axis)(ins)
y = pile_map(jnp.dot)(p, p)
self.assertIsInstance(y, batching.Pile)
self.assertAllClose(y.data, jnp.array([5, 0, 14]))
self.assertAllClose(y.data, jnp.array([5, 0, 14], dtype='int32'))
def test_pile_map_matrix_dot(self):
sizes = jnp.array([3, 1, 4])
sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
p1 = jax.vmap(lambda n: jnp.ones((7, n)), out_axes=batching.pile_axis
)(sizes)
p2 = jax.vmap(lambda n: jnp.ones((n, 7)), out_axes=batching.pile_axis