mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #16072 from axch:stacked-piles
PiperOrigin-RevId: 538839592
This commit is contained in:
commit
5c5baccb2e
@ -1666,6 +1666,9 @@ class DShapedArray(UnshapedArray):
|
||||
weak_type = self.weak_type
|
||||
return DShapedArray(shape, dtype, weak_type)
|
||||
|
||||
def _len(self, tracer):
|
||||
return self.shape[0]
|
||||
|
||||
def __eq__(self, other):
|
||||
return (type(self) is type(other)
|
||||
and self.dtype == other.dtype and self.shape == other.shape
|
||||
@ -1720,7 +1723,7 @@ class DArray:
|
||||
slices = tuple(slice(int(d._data)) if type(d) is DArray and
|
||||
type(d.dtype) is bint else slice(None) for d in self.shape)
|
||||
data = self._data[slices]
|
||||
return f'{dtypestr}[{shapestr}] with value:\n{data}'
|
||||
return f'{dtypestr}[{shapestr}] with value: {data}'
|
||||
def __hash__(self) -> int:
|
||||
if not self.shape:
|
||||
return hash((self._aval, int(self._data)))
|
||||
@ -1729,6 +1732,8 @@ class DArray:
|
||||
if isinstance(other, DArray) and self._aval == other._aval:
|
||||
return self._data == other._data
|
||||
return False
|
||||
def __len__(self):
|
||||
return self.shape[0]
|
||||
|
||||
pytype_aval_mappings[DArray] = \
|
||||
lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type,
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import dataclasses
|
||||
from functools import partial
|
||||
from typing import (Any, Callable, Dict, Iterable, Optional, Sequence, Set,
|
||||
@ -98,23 +97,30 @@ def _pile_unflatten(aval, x):
|
||||
return Pile(aval, data)
|
||||
register_pytree_node(Pile, _pile_flatten, _pile_unflatten)
|
||||
|
||||
def _pile_result(axis_size, axis, segment_lens, x):
|
||||
def _pile_result(axis_size, stacked_axis, ragged_axis, segment_lens, x):
|
||||
binder = core.Var(0, '', core.ShapedArray((), np.dtype('int32')))
|
||||
if stacked_axis != 0:
|
||||
raise NotImplementedError # TODO Transpose x so the stacked axis is axis 0
|
||||
shape = list(x.shape)
|
||||
shape[axis] = IndexedAxisSize(binder, segment_lens)
|
||||
del shape[0]
|
||||
shape[ragged_axis-1] = IndexedAxisSize(binder, segment_lens)
|
||||
elt_ty = core.DShapedArray(tuple(shape), x.dtype, x.weak_type)
|
||||
return Pile(PileTy(binder, axis_size, elt_ty), x)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ConcatAxis:
|
||||
axis: int
|
||||
class RaggedAxis:
|
||||
stacked_axis: int
|
||||
# TODO(mattjj,axch): Generalize to multiple ragged dimensions
|
||||
# e.g. `i:(Fin 3) => f32[lens1.i, lens2.i]`
|
||||
ragged_axis: int
|
||||
segment_lengths: Array
|
||||
|
||||
|
||||
def _update_annotation(
|
||||
f: lu.WrappedFun, orig_type: Optional[core.InputType],
|
||||
axis_size: core.AxisSize, axis_name: AxisName,
|
||||
explicit_in_dims: Sequence[Optional[Union[int, ConcatAxis]]],
|
||||
explicit_in_dims: Sequence[Optional[Union[int, RaggedAxis]]],
|
||||
segment_lens: Sequence[Array],
|
||||
) -> lu.WrappedFun:
|
||||
if orig_type is None: return f
|
||||
@ -126,7 +132,7 @@ def _update_annotation(
|
||||
# * drop other implicit binders, replacing DBIdx which refer to them with
|
||||
# Name objects;
|
||||
# * for each (aval, in_dim) pair: if int-valued in_dim, add batch axis (int
|
||||
# size if `axis_size` is int, otherwise Name); if ConcatAxis-valued in_dim,
|
||||
# size if `axis_size` is int, otherwise Name); if RaggedAxis-valued in_dim,
|
||||
# add batch axis (int if corresponding segment_lengths is concrete, Name if
|
||||
# not);
|
||||
# * generate full in_type with implicit args too.
|
||||
@ -141,14 +147,8 @@ def _update_annotation(
|
||||
new_avals = [core.raise_to_shaped(core.get_aval(s)) for s in segment_lens]
|
||||
sz = Name(axis_size.aval) if isinstance(axis_size, Tracer) else axis_size
|
||||
for a, d in zip(avals, explicit_in_dims):
|
||||
if isinstance(d, ConcatAxis):
|
||||
s = segment_lens[d.segment_lengths.val]
|
||||
if isinstance(core.get_aval(s), core.ConcreteArray):
|
||||
shape = list(a.shape) # type: ignore
|
||||
shape[d.axis] = int(s.sum()) # specialize on shape if we can
|
||||
new_avals.append(a.update(shape=tuple(shape)))
|
||||
else:
|
||||
new_avals.append(a)
|
||||
if isinstance(d, RaggedAxis):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
new_avals.append(core.unmapped_aval(sz, axis_name, d, a)) # type: ignore
|
||||
|
||||
@ -182,7 +182,7 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
|
||||
raise TypeError("pile input without using pile_axis in_axes spec")
|
||||
(d, ias), = ((i, sz) for i, sz in enumerate(x.aval.elt_ty.shape)
|
||||
if type(sz) is IndexedAxisSize)
|
||||
return BatchTracer(trace, x.data, ConcatAxis(d, ias.lengths)) # type: ignore
|
||||
return BatchTracer(trace, x.data, RaggedAxis(0, d+1, ias.lengths)) # type: ignore
|
||||
elif isinstance(spec, int) or spec is None:
|
||||
spec = spec and canonicalize_axis(spec, len(np.shape(x)))
|
||||
return (BatchTracer(trace, x, spec, source_info_util.current())
|
||||
@ -198,11 +198,11 @@ def from_elt(trace: 'BatchTrace', axis_size: AxisSize, x: Elt, spec: MapSpec
|
||||
return handler(partial(from_elt, trace), axis_size, x, spec)
|
||||
x_ = trace.full_raise(x)
|
||||
val, bdim = x_.val, x_.batch_dim
|
||||
if type(bdim) is ConcatAxis:
|
||||
if type(bdim) is RaggedAxis:
|
||||
if spec is not pile_axis:
|
||||
# TODO(mattjj): improve this error message
|
||||
raise TypeError("ragged output without using pile_axis out_axes spec")
|
||||
return _pile_result(axis_size, bdim.axis, bdim.segment_lengths, val)
|
||||
return _pile_result(axis_size, bdim.stacked_axis, bdim.ragged_axis, bdim.segment_lengths, val)
|
||||
else:
|
||||
return matchaxis(trace.axis_name, axis_size, x_.batch_dim, spec, x_.val)
|
||||
from_elt_handlers: Dict[Type, FromEltHandler] = {}
|
||||
@ -253,10 +253,10 @@ not_mapped = None
|
||||
class BatchTracer(Tracer):
|
||||
__slots__ = ['val', 'batch_dim', 'source_info']
|
||||
|
||||
def __init__(self, trace, val, batch_dim: Union[NotMapped, int, ConcatAxis],
|
||||
def __init__(self, trace, val, batch_dim: Union[NotMapped, int, RaggedAxis],
|
||||
source_info: Optional[source_info_util.SourceInfo] = None):
|
||||
if config.jax_enable_checks:
|
||||
assert type(batch_dim) in (NotMapped, int, ConcatAxis)
|
||||
assert type(batch_dim) in (NotMapped, int, RaggedAxis)
|
||||
if type(batch_dim) is int:
|
||||
aval = raise_to_shaped(core.get_aval(val))
|
||||
assert 0 <= batch_dim < len(aval.shape) # type: ignore
|
||||
@ -272,10 +272,15 @@ class BatchTracer(Tracer):
|
||||
return aval
|
||||
elif type(self.batch_dim) is int:
|
||||
return core.mapped_aval(aval.shape[self.batch_dim], self.batch_dim, aval)
|
||||
elif type(self.batch_dim) is ConcatAxis:
|
||||
shape = list(aval.shape)
|
||||
elif type(self.batch_dim) is RaggedAxis:
|
||||
new_aval = core.mapped_aval(
|
||||
aval.shape[self.batch_dim.stacked_axis], self.batch_dim.stacked_axis, aval)
|
||||
shape = list(new_aval.shape) # type: ignore
|
||||
size_tracer = BatchTracer(self._trace, self.batch_dim.segment_lengths, 0)
|
||||
shape[self.batch_dim.axis] = size_tracer
|
||||
ragged_axis = self.batch_dim.ragged_axis
|
||||
if self.batch_dim.stacked_axis < self.batch_dim.ragged_axis:
|
||||
ragged_axis -= 1
|
||||
shape[ragged_axis] = size_tracer
|
||||
return core.DShapedArray(shape=tuple(shape), dtype=aval.dtype,
|
||||
weak_type=aval.weak_type)
|
||||
|
||||
@ -297,7 +302,7 @@ class BatchTracer(Tracer):
|
||||
def get_referent(self):
|
||||
if self.batch_dim is None or type(self.batch_dim) is int:
|
||||
return core.get_referent(self.val)
|
||||
else: # TODO(mattjj): could handle the ConcatAxis case?
|
||||
else: # TODO(mattjj): could handle the RaggedAxis case?
|
||||
return self
|
||||
|
||||
class BatchTrace(Trace):
|
||||
@ -376,12 +381,9 @@ class BatchTrace(Trace):
|
||||
sizes = (x.shape[d] if type(d) is int else len(d.segment_lengths)
|
||||
for x, d in zip(vals, dims) if d is not not_mapped)
|
||||
axis_size, = core.dedup_referents(sizes)
|
||||
segment_lens, dims = unpack_concat_axes(dims)
|
||||
f_, dims_out = batch_subtrace(f, self.main, tuple(dims))
|
||||
f_ = _update_annotation(f_, f.in_type, axis_size, self.axis_name, dims,
|
||||
segment_lens)
|
||||
vals_out = call_primitive.bind(f_, *segment_lens, *vals, **params)
|
||||
vals_out, dims_out = reassemble_concat_axes(vals_out, dims_out())
|
||||
f_ = _update_annotation(f_, f.in_type, axis_size, self.axis_name, dims, [])
|
||||
vals_out = call_primitive.bind(f_, *vals, **params)
|
||||
src = source_info_util.current()
|
||||
return [BatchTracer(self, v, d, src) for v, d in zip(vals_out, dims_out)]
|
||||
|
||||
@ -607,34 +609,12 @@ def vtile(f_flat: lu.WrappedFun,
|
||||
def batch_subtrace(main, in_dims, *in_vals):
|
||||
trace = main.with_cur_sublevel()
|
||||
in_dims = in_dims() if callable(in_dims) else in_dims
|
||||
in_vals, in_dims = reassemble_concat_axes(in_vals, in_dims)
|
||||
in_tracers = [BatchTracer(trace, x, dim, source_info_util.current())
|
||||
if dim is not None else x for x, dim in zip(in_vals, in_dims)]
|
||||
outs = yield in_tracers, {}
|
||||
out_tracers = map(trace.full_raise, outs)
|
||||
out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
|
||||
segment_lens, out_dims = unpack_concat_axes(out_dims)
|
||||
yield (*segment_lens, *out_vals), out_dims
|
||||
|
||||
def unpack_concat_axes(dims):
|
||||
if not any(type(d) is ConcatAxis for d in dims):
|
||||
return [], dims
|
||||
concat_axis_map = collections.OrderedDict()
|
||||
def convert(d: ConcatAxis) -> ConcatAxis:
|
||||
_, dbidx = concat_axis_map.setdefault(
|
||||
id(core.get_referent(d.segment_lengths)),
|
||||
(d.segment_lengths, pe.DBIdx(len(concat_axis_map))))
|
||||
return ConcatAxis(d.axis, dbidx)
|
||||
new_dims = [convert(d) if isinstance(d, ConcatAxis) else d for d in dims]
|
||||
segment_lens = [s for s, _ in concat_axis_map.values()]
|
||||
return segment_lens, new_dims
|
||||
|
||||
def reassemble_concat_axes(vals, dims):
|
||||
idxs = {d.segment_lengths.val for d in dims if isinstance(d, ConcatAxis)}
|
||||
dims = [ConcatAxis(d.axis, vals[d.segment_lengths.val])
|
||||
if isinstance(d, ConcatAxis) else d for d in dims]
|
||||
vals = [x for i, x in enumerate(vals) if i not in idxs]
|
||||
return vals, dims
|
||||
yield out_vals, out_dims
|
||||
|
||||
|
||||
### API for batching jaxprs
|
||||
@ -880,40 +860,43 @@ def _handle_scalar_broadcasting(nd, x, d):
|
||||
else:
|
||||
return jax.lax.expand_dims(x, tuple(range(np.ndim(x), nd)))
|
||||
|
||||
def defreducer(prim):
|
||||
primitive_batchers[prim] = partial(reducer_batcher, prim)
|
||||
def defreducer(prim, ident):
|
||||
primitive_batchers[prim] = partial(reducer_batcher, prim, ident)
|
||||
|
||||
def reducer_batcher(prim, batched_args, batch_dims, axes, **params):
|
||||
def reducer_batcher(prim, ident, batched_args, batch_dims, axes, **params):
|
||||
def out_axis(axes, axis):
|
||||
return int(list(np.delete(np.arange(operand.ndim), axes)).index(axis))
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
if isinstance(bdim, int):
|
||||
axes = tuple(np.where(np.less(axes, bdim), axes, np.add(axes, 1)))
|
||||
bdim_out = int(list(np.delete(np.arange(operand.ndim), axes)).index(bdim))
|
||||
bdim_out = out_axis(axes, bdim)
|
||||
if 'input_shape' in params:
|
||||
params = dict(params, input_shape=operand.shape)
|
||||
return prim.bind(operand, axes=axes, **params), bdim_out
|
||||
elif isinstance(bdim, ConcatAxis):
|
||||
if bdim.axis in axes:
|
||||
other_axes = [i for i in axes if i != bdim.axis]
|
||||
if other_axes:
|
||||
operand = prim.bind(operand, axes=other_axes, **params)
|
||||
c_axis = bdim.axis - sum(d < bdim.axis for d in other_axes)
|
||||
operand = bdim_at_front(operand, c_axis, operand.shape[c_axis])
|
||||
return segment_sum(operand, bdim.segment_lengths), 0
|
||||
elif isinstance(bdim, RaggedAxis):
|
||||
assert ident is not None, "TODO Ragged batching a reduction requires an identity"
|
||||
axes = tuple(np.where(np.less(axes, bdim.stacked_axis), axes, np.add(axes, 1)))
|
||||
bdim_out = out_axis(axes, bdim.stacked_axis)
|
||||
if bdim.ragged_axis in axes:
|
||||
operand = mask_ragged_axis(operand, ident, bdim)
|
||||
result = prim.bind(operand, axes=axes, **params)
|
||||
return result, bdim_out
|
||||
else:
|
||||
raise NotImplementedError # TODO(mattjj)
|
||||
result = prim.bind(operand, axes=axes, **params)
|
||||
return result, RaggedAxis(bdim_out, out_axis(axes, bdim.ragged_axis), bdim.segment_lengths)
|
||||
else:
|
||||
assert False
|
||||
|
||||
# TODO(mattjj): replace with jax.lax.ops.segment_sum (once it's easier to trace
|
||||
# under dynamic shapes)
|
||||
def segment_sum(operand, segment_lens):
|
||||
scat_idx = jax.numpy.cumsum(segment_lens) - segment_lens
|
||||
segment_ids = jax.numpy.cumsum(
|
||||
jax.numpy.zeros(operand.shape[0], 'int32').at[scat_idx].set(1)) - 1
|
||||
out = jax.numpy.zeros((len(segment_lens), *operand.shape[1:]),
|
||||
operand.dtype).at[segment_ids].add(operand)
|
||||
return out
|
||||
def mask_ragged_axis(operand, ident, axis_spec):
|
||||
value = ident(operand.dtype)
|
||||
positions = jax.lax.broadcasted_iota('int32', operand.shape, axis_spec.ragged_axis)
|
||||
# TODO(mattjj, axch) cant get ._data, need to convert it
|
||||
lengths = jax.lax.convert_element_type(axis_spec.segment_lengths._data, 'int32')
|
||||
limits = jax.lax.broadcast_in_dim(
|
||||
lengths, operand.shape, [axis_spec.stacked_axis])
|
||||
mask = positions < limits
|
||||
return jax.lax.select(mask, operand, jax.lax.broadcast(value, operand.shape))
|
||||
|
||||
### general utilities for manipulating axes on jaxpr types (not vmappables)
|
||||
|
||||
|
@ -51,7 +51,7 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.interpreters.batching import ConcatAxis
|
||||
from jax._src.interpreters.batching import RaggedAxis
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.lax.utils import (
|
||||
_input_dtype,
|
||||
@ -2255,10 +2255,12 @@ def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type):
|
||||
|
||||
def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type):
|
||||
if operand.dtype != new_dtype:
|
||||
if dtypes.is_opaque_dtype(operand.dtype):
|
||||
if (dtypes.is_opaque_dtype(operand.dtype) and
|
||||
not isinstance(operand.dtype, core.bint)):
|
||||
raise ValueError(
|
||||
f"Cannot call convert_element_type on dtype {dtype_to_string(operand.dtype)}")
|
||||
if dtypes.is_opaque_dtype(new_dtype):
|
||||
if (dtypes.is_opaque_dtype(new_dtype) and
|
||||
not isinstance(new_dtype, core.bint)):
|
||||
raise ValueError(
|
||||
f"Cannot convert_element_type to dtype={dtype_to_string(new_dtype)}")
|
||||
return new_dtype
|
||||
@ -2595,21 +2597,16 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
lhs, rhs = batched_args
|
||||
lbd, rbd = batch_dims
|
||||
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
|
||||
if (type(lbd) is type(rbd) is ConcatAxis and
|
||||
lbd.axis in lhs_contract and rbd.axis in rhs_contract):
|
||||
# first handle any other part of the dot with these as batch dims
|
||||
lhs_contract_ = [d for d in lhs_contract if d != lbd.axis]
|
||||
rhs_contract_ = [d for d in rhs_contract if d != rbd.axis]
|
||||
lhs_batch_ = (lbd.axis, *lhs_batch)
|
||||
rhs_batch_ = (rbd.axis, *rhs_batch)
|
||||
new_dnums = ((lhs_contract_, rhs_contract_), (lhs_batch_, rhs_batch_))
|
||||
out = dot_general(lhs, rhs, new_dnums, precision=precision,
|
||||
preferred_element_type=preferred_element_type)
|
||||
# now a segment sum along that batch axis
|
||||
return batching.segment_sum(out, lbd.segment_lengths), 0
|
||||
|
||||
left_stack_dim = lbd.stacked_axis if type(lbd) is RaggedAxis else lbd
|
||||
right_stack_dim = rbd.stacked_axis if type(rbd) is RaggedAxis else rbd
|
||||
new_dimension_numbers, result_batch_dim = _dot_general_batch_dim_nums(
|
||||
(lhs.ndim, rhs.ndim), batch_dims, dimension_numbers)
|
||||
(lhs.ndim, rhs.ndim), (left_stack_dim, right_stack_dim), dimension_numbers)
|
||||
# TODO Should probably check that any ragged dimensions have corresponding
|
||||
# sizes, because otherwise the dot product is technically undefined.
|
||||
if type(lbd) is RaggedAxis:
|
||||
lhs = batching.mask_ragged_axis(lhs, _get_sum_identity, lbd)
|
||||
if type(rbd) is RaggedAxis:
|
||||
rhs = batching.mask_ragged_axis(rhs, _get_sum_identity, rbd)
|
||||
batched_out = dot_general(lhs, rhs, new_dimension_numbers,
|
||||
precision=precision,
|
||||
preferred_element_type=preferred_element_type)
|
||||
@ -2635,34 +2632,13 @@ def _dot_general_batch_dim_nums(ndims, batch_dims, dimension_numbers):
|
||||
lhs_contract = bump_dims(lhs_contract, lbd)
|
||||
rhs_contract = bump_dims(rhs_contract, rbd)
|
||||
result_batch_dim = 0
|
||||
elif rbd is None and type(lbd) is ConcatAxis and lbd.axis not in lhs_contract:
|
||||
if lbd.axis in lhs_batch:
|
||||
axis = int(np.sum(np.less(lhs_batch, lbd.axis)))
|
||||
else:
|
||||
lhs_tensor = [d for d in range(lhs_ndim)
|
||||
if d not in lhs_batch and d not in lhs_contract]
|
||||
axis = len(lhs_batch) + int(np.sum(np.less(lhs_tensor, lbd.axis)))
|
||||
result_batch_dim = ConcatAxis(axis, lbd.segment_lengths)
|
||||
elif lbd is None and type(rbd) is ConcatAxis and rbd.axis not in rhs_contract:
|
||||
if rbd.axis in rhs_batch:
|
||||
axis = int(np.sum(np.less(rhs_batch, rbd.axis)))
|
||||
else:
|
||||
rhs_tensor = [d for d in range(rhs_ndim)
|
||||
if d not in rhs_batch and d not in rhs_contract]
|
||||
axis = (lhs_ndim - len(lhs_contract) +
|
||||
int(sum(np.less(rhs_tensor, rbd.axis))))
|
||||
result_batch_dim = ConcatAxis(axis, rbd.segment_lengths)
|
||||
elif (type(lbd) is int and
|
||||
(rbd is None or type(rbd) is ConcatAxis and
|
||||
rbd.axis not in rhs_contract)):
|
||||
elif (type(lbd) is int and rbd is None):
|
||||
lhs_tensor = [d for d in range(lhs_ndim)
|
||||
if d not in lhs_batch and d not in lhs_contract]
|
||||
result_batch_dim = len(lhs_batch) + int(sum(np.less(lhs_tensor, lbd)))
|
||||
lhs_batch = bump_dims(lhs_batch, lbd)
|
||||
lhs_contract = bump_dims(lhs_contract, lbd)
|
||||
elif (type(rbd) is int and
|
||||
(lbd is None or type(lbd) is ConcatAxis and
|
||||
lbd.axis not in lhs_contract)):
|
||||
elif (type(rbd) is int and lbd is None):
|
||||
rhs_tensor = [d for d in range(rhs_ndim)
|
||||
if d not in rhs_batch and d not in rhs_contract]
|
||||
result_batch_dim = (lhs_ndim - len(lhs_contract) +
|
||||
@ -2848,12 +2824,13 @@ def _broadcast_in_dim_batch_rule(batched_args, batch_dims, shape,
|
||||
dyn_shape_bdims[0] is not None):
|
||||
(d,), (d_bdim,) = dyn_shape, dyn_shape_bdims # NotImplementedError above
|
||||
assert d_bdim == 0 # must be scalar in the program to be batched
|
||||
new_shape = _merge_dyn_shape(shape, (int(d.sum()),))
|
||||
bound = d.dtype.bound
|
||||
new_shape = (len(d),) + _merge_dyn_shape(shape, (bound,))
|
||||
out = broadcast_in_dim(operand, new_shape, broadcast_dimensions)
|
||||
idx, = (i for i, s in enumerate(shape) if s is None)
|
||||
return out, batching.ConcatAxis(idx, d)
|
||||
return out, batching.RaggedAxis(0, idx+1, d)
|
||||
else:
|
||||
raise NotImplementedError # TODO(mattjj)
|
||||
raise NotImplementedError # TODO(mattjj,axch)
|
||||
|
||||
def _broadcast_in_dim_fwd_rule(eqn):
|
||||
v, *dyn = eqn.invars
|
||||
@ -3710,7 +3687,7 @@ reduce_sum_p = standard_primitive(
|
||||
_reduce_sum_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_sum'),
|
||||
'reduce_sum')
|
||||
ad.deflinear2(reduce_sum_p, _reduce_sum_transpose_rule)
|
||||
batching.defreducer(reduce_sum_p)
|
||||
batching.defreducer(reduce_sum_p, _get_sum_identity)
|
||||
pe.padding_rules[reduce_sum_p] = partial(_reducer_padding, _reduce_sum,
|
||||
_get_sum_identity)
|
||||
|
||||
@ -3734,7 +3711,7 @@ reduce_prod_p = standard_primitive(
|
||||
_reduce_op_shape_rule, partial(_reduce_number_dtype_rule, 'reduce_prod'),
|
||||
'reduce_prod')
|
||||
ad.primitive_jvps[reduce_prod_p] = _reduce_prod_jvp_rule
|
||||
batching.defreducer(reduce_prod_p)
|
||||
batching.defreducer(reduce_prod_p, _get_prod_identity)
|
||||
pe.padding_rules[reduce_prod_p] = partial(_reducer_padding, _reduce_prod,
|
||||
_get_prod_identity)
|
||||
|
||||
@ -3756,7 +3733,7 @@ def _reduce_chooser_jvp_rule(g, ans, operand, *, axes):
|
||||
reduce_max_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
|
||||
'reduce_max')
|
||||
ad.defjvp2(reduce_max_p, _reduce_chooser_jvp_rule)
|
||||
batching.defreducer(reduce_max_p)
|
||||
batching.defreducer(reduce_max_p, _get_max_identity)
|
||||
pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max,
|
||||
_get_max_identity)
|
||||
|
||||
@ -3764,7 +3741,7 @@ pe.padding_rules[reduce_max_p] = partial(_reducer_padding, _reduce_max,
|
||||
reduce_min_p = standard_primitive(_reduce_op_shape_rule, _input_dtype,
|
||||
'reduce_min')
|
||||
ad.defjvp2(reduce_min_p, _reduce_chooser_jvp_rule)
|
||||
batching.defreducer(reduce_min_p)
|
||||
batching.defreducer(reduce_min_p, _get_min_identity)
|
||||
pe.padding_rules[reduce_min_p] = partial(_reducer_padding, _reduce_min,
|
||||
_get_min_identity)
|
||||
|
||||
@ -3810,12 +3787,12 @@ def _compute_argminmax(value_comparator, get_identity,
|
||||
|
||||
argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
|
||||
'argmin', weak_type_rule=_strip_weak_type)
|
||||
batching.defreducer(argmin_p)
|
||||
batching.defreducer(argmin_p, _get_min_identity)
|
||||
ad.defjvp_zero(argmin_p)
|
||||
|
||||
argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
|
||||
'argmax', weak_type_rule=_strip_weak_type)
|
||||
batching.defreducer(argmax_p)
|
||||
batching.defreducer(argmax_p, _get_max_identity)
|
||||
ad.defjvp_zero(argmax_p)
|
||||
|
||||
mlir.register_lowering(argmin_p, mlir.cache_lowering(mlir.lower_fun(
|
||||
@ -3835,19 +3812,19 @@ def _reduce_logical_shape_rule(operand, *, axes):
|
||||
reduce_or_p = standard_primitive(
|
||||
_reduce_logical_shape_rule, _input_dtype, 'reduce_or',
|
||||
weak_type_rule=_strip_weak_type)
|
||||
batching.defreducer(reduce_or_p)
|
||||
batching.defreducer(reduce_or_p, _get_bitwise_or_identity)
|
||||
|
||||
|
||||
reduce_and_p = standard_primitive(
|
||||
_reduce_logical_shape_rule, _input_dtype, 'reduce_and',
|
||||
weak_type_rule=_strip_weak_type)
|
||||
batching.defreducer(reduce_and_p)
|
||||
batching.defreducer(reduce_and_p, _get_bitwise_and_identity)
|
||||
|
||||
|
||||
reduce_xor_p = standard_primitive(
|
||||
_reduce_logical_shape_rule, _input_dtype, 'reduce_xor',
|
||||
weak_type_rule=_strip_weak_type)
|
||||
batching.defreducer(reduce_xor_p)
|
||||
batching.defreducer(reduce_xor_p, _get_bitwise_or_identity)
|
||||
|
||||
|
||||
def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
|
||||
@ -4555,9 +4532,12 @@ mlir.register_lowering(iota_p, _iota_lower)
|
||||
|
||||
def _iota_batching_rule(in_vals, in_dims, *, dtype, shape, dimension):
|
||||
(segment_lengths,), (ax,) = in_vals, in_dims
|
||||
shapes = [_merge_dyn_shape(shape, (d,)) for d in segment_lengths]
|
||||
iotas = [broadcasted_iota(dtype, s, dimension) for s in shapes]
|
||||
return concatenate(iotas, dimension), batching.ConcatAxis(ax, segment_lengths)
|
||||
assert ax == 0
|
||||
bound = segment_lengths.dtype.bound
|
||||
ragged_axis, = [i for i, dim in enumerate(shape) if dim is None]
|
||||
shape = (len(segment_lengths),) + _merge_dyn_shape(shape, (bound,))
|
||||
iota = broadcasted_iota(dtype, shape, dimension+1)
|
||||
return iota, batching.RaggedAxis(ax, ragged_axis+1, segment_lengths)
|
||||
batching.primitive_batchers[iota_p] = _iota_batching_rule
|
||||
|
||||
def _iota_pp_rule(eqn, context, settings):
|
||||
|
@ -850,7 +850,7 @@ def _shard_map_batch(
|
||||
if all(bdim is batching.not_mapped for bdim in in_dims):
|
||||
return prim.bind(fun, *in_vals, mesh=mesh, in_names=in_names,
|
||||
out_names_thunk=out_names_thunk, check_rep=check_rep)
|
||||
if any(isinstance(d, batching.ConcatAxis) for d in in_dims):
|
||||
if any(isinstance(d, batching.RaggedAxis) for d in in_dims):
|
||||
raise NotImplementedError
|
||||
fun, out_dims = batching.batch_subtrace(fun, trace.main, tuple(in_dims))
|
||||
new_in_names = [{ax + (d is not batching.not_mapped and d <= ax): names[ax] # type: ignore
|
||||
|
@ -21,7 +21,7 @@ from jax._src.interpreters.batching import (
|
||||
BatchTrace as BatchTrace,
|
||||
BatchTracer as BatchTracer,
|
||||
BatchingRule as BatchingRule,
|
||||
ConcatAxis as ConcatAxis,
|
||||
RaggedAxis as RaggedAxis,
|
||||
Elt as Elt,
|
||||
FromEltHandler as FromEltHandler,
|
||||
GetIdx as GetIdx,
|
||||
@ -62,15 +62,12 @@ from jax._src.interpreters.batching import (
|
||||
not_mapped as not_mapped,
|
||||
pile_axis as pile_axis,
|
||||
primitive_batchers as primitive_batchers,
|
||||
reassemble_concat_axes as reassemble_concat_axes,
|
||||
reducer_batcher as reducer_batcher,
|
||||
register_vmappable as register_vmappable,
|
||||
segment_sum as segment_sum,
|
||||
spec_types as spec_types,
|
||||
spmd_axis_primitive_batchers as spmd_axis_primitive_batchers,
|
||||
to_elt as to_elt,
|
||||
to_elt_handlers as to_elt_handlers,
|
||||
unpack_concat_axes as unpack_concat_axes,
|
||||
unregister_vmappable as unregister_vmappable,
|
||||
vectorized_batcher as vectorized_batcher,
|
||||
vmappables as vmappables,
|
||||
|
@ -1490,42 +1490,47 @@ class DynamicShapeExecutionTest(jtu.JaxTestCase):
|
||||
self.assertEqual(y.shape, (sz, 4))
|
||||
self.assertAllClose(y._data, x)
|
||||
|
||||
@unittest.skip("Test does not work with jax.Array")
|
||||
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow")
|
||||
@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_rank_promotion="allow",
|
||||
jax_disable_jit=True, jax_traceback_filtering='off')
|
||||
class PileTest(jtu.JaxTestCase):
|
||||
|
||||
def test_internal_pile(self):
|
||||
xs = jax.vmap(lambda n: jnp.arange(n).sum())(jnp.array([3, 1, 4]))
|
||||
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
xs = jax.vmap(lambda n: jax.lax.iota('int32', n).sum())(ins)
|
||||
self.assertAllClose(xs, jnp.array([3, 0, 6]), check_dtypes=False)
|
||||
|
||||
def test_make_pile_from_dynamic_shape(self):
|
||||
# We may not want to support returning piles from vmapped functions (instead
|
||||
# preferring to have a separate API which allows piles). But for now it
|
||||
# makes for a convenient way to construct piles for the other tests!
|
||||
p = jax.vmap(partial(jnp.arange, dtype='int32'), out_axes=batching.pile_axis
|
||||
)(jnp.array([3, 1, 4]))
|
||||
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
p = jax.vmap(partial(jnp.arange, dtype='int32'),
|
||||
out_axes=batching.pile_axis)(ins)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[\[3 1 4\]\.Var[0-9]+\]')
|
||||
data = jnp.concatenate([jnp.arange(3), jnp.arange(1), jnp.arange(4)])
|
||||
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]')
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 5), 1)
|
||||
self.assertAllClose(p.data, data, check_dtypes=False)
|
||||
|
||||
def test_pile_map_eltwise(self):
|
||||
p = jax.vmap(partial(jnp.arange, dtype='int32'), out_axes=batching.pile_axis
|
||||
)(jnp.array([3, 1, 4]))
|
||||
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
p = jax.vmap(partial(jnp.arange, dtype='int32'),
|
||||
out_axes=batching.pile_axis)(ins)
|
||||
p = pile_map(lambda x: x ** 2)(p)
|
||||
self.assertIsInstance(p, batching.Pile)
|
||||
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[\[3 1 4\]\.Var[0-9]+\]')
|
||||
data = jnp.concatenate([jnp.arange(3), jnp.arange(1), jnp.arange(4)]) ** 2
|
||||
self.assertRegex(str(p.aval), r'Var[0-9]+:3 => i32\[bint\{≤5\}\[3\] with value: \[3 1 4\]\.Var[0-9]+\]')
|
||||
data = jax.lax.broadcasted_iota('int32', (3, 5), 1) ** 2
|
||||
self.assertAllClose(p.data, data, check_dtypes=False)
|
||||
|
||||
def test_pile_map_vector_dot(self):
|
||||
p = jax.vmap(jnp.arange, out_axes=batching.pile_axis)(jnp.array([3, 1, 4]))
|
||||
ins = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
p = jax.vmap(partial(jnp.arange, dtype='int32'),
|
||||
out_axes=batching.pile_axis)(ins)
|
||||
y = pile_map(jnp.dot)(p, p)
|
||||
self.assertIsInstance(y, batching.Pile)
|
||||
self.assertAllClose(y.data, jnp.array([5, 0, 14]))
|
||||
self.assertAllClose(y.data, jnp.array([5, 0, 14], dtype='int32'))
|
||||
|
||||
def test_pile_map_matrix_dot(self):
|
||||
sizes = jnp.array([3, 1, 4])
|
||||
sizes = lax.convert_element_type(jnp.array([3, 1, 4]), core.bint(5))
|
||||
p1 = jax.vmap(lambda n: jnp.ones((7, n)), out_axes=batching.pile_axis
|
||||
)(sizes)
|
||||
p2 = jax.vmap(lambda n: jnp.ones((n, 7)), out_axes=batching.pile_axis
|
||||
|
Loading…
x
Reference in New Issue
Block a user