mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[dynamic-shapes] add basic vmap-of-indexing support
The main changes here are only indirectly related to gather: we just had to update some other rules (e.g. for comparison, and squeeze) for a simple dynamic-batch-shape gather to work. I also skipped two tests and deleted some old dynamic shape slicing logic because we want to handle that differently. We didn't have to do that removal in this PR, but it's just convenient given I'm looking at indexing again.
This commit is contained in:
parent
49672cd2bc
commit
58826507cc
@ -110,6 +110,9 @@ class IreeBuffer(xla_client.DeviceArrayBase):
|
||||
def _value(self):
|
||||
return np.asarray(self)
|
||||
|
||||
def copy_to_host_async(self):
|
||||
return self
|
||||
|
||||
class IreeExecutable:
|
||||
|
||||
def __init__(self, client, devices, module_object, function_name):
|
||||
|
@ -629,8 +629,14 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
|
||||
Returns:
|
||||
An array containing the concatenation.
|
||||
"""
|
||||
from jax.experimental import array
|
||||
|
||||
if len(operands) == 0:
|
||||
raise ValueError("concatenate requires a non-empty sequences of arrays")
|
||||
if len(operands) == 1:
|
||||
op, = operands
|
||||
if isinstance(op, (core.Tracer, device_array.DeviceArray, array.Array)):
|
||||
return op
|
||||
return concatenate_p.bind(*operands, dimension=dimension)
|
||||
|
||||
|
||||
@ -2249,8 +2255,12 @@ mlir.register_lowering(shift_right_logical_p,
|
||||
partial(_nary_lower_mhlo, mhlo.ShiftRightLogicalOp))
|
||||
|
||||
def _compare_lower_mhlo(direction: str, ctx, x, y):
|
||||
x_aval, y_aval = ctx.avals_in
|
||||
aval_out, = ctx.avals_out
|
||||
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
|
||||
if config.jax_dynamic_shapes:
|
||||
substitute = partial(_substitute_axis_sizes_in_aval, ctx.axis_size_env)
|
||||
avals_in = map(substitute, avals_in)
|
||||
aval_out = substitute(aval_out)
|
||||
x_aval, y_aval = avals_in
|
||||
x, y = broadcast_mhlo(aval_out.update(dtype=x_aval.dtype), ctx.avals_in,
|
||||
(x, y))
|
||||
if dtypes.issubdtype(x_aval.dtype, np.inexact):
|
||||
@ -2757,8 +2767,9 @@ def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape,
|
||||
return ([expand_dims(_reduce_sum(ct, axes), unit_dims)] +
|
||||
[None] * len(dyn_shape))
|
||||
|
||||
def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *, shape,
|
||||
def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *dyn_shape, shape,
|
||||
broadcast_dimensions):
|
||||
if dyn_shape: raise NotImplementedError # TODO(mattjj)
|
||||
operand, = batched_args
|
||||
bdim, = batch_dims
|
||||
new_operand = batching.moveaxis(operand, bdim, 0)
|
||||
@ -3157,7 +3168,16 @@ batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
|
||||
def _squeeze_lower(ctx, operand, *, dimensions):
|
||||
del dimensions # Implied by the output aval.
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), operand).results
|
||||
if config.jax_dynamic_shapes:
|
||||
substitute = partial(_substitute_axis_sizes_in_aval, ctx.axis_size_env)
|
||||
aval_out = substitute(aval_out)
|
||||
if any(isinstance(d, ir.Value) for d in aval_out.shape):
|
||||
return mhlo.DynamicReshapeOp(
|
||||
mlir.aval_to_ir_type(aval_out), operand,
|
||||
mlir.shape_tensor(aval_out.shape),
|
||||
).results
|
||||
else:
|
||||
return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), operand).results
|
||||
|
||||
mlir.register_lowering(squeeze_p, _squeeze_lower)
|
||||
|
||||
|
@ -287,7 +287,6 @@ def gather(operand: Array, start_indices: Array,
|
||||
fill_value=fill_value)
|
||||
|
||||
|
||||
|
||||
class ScatterDimensionNumbers(NamedTuple):
|
||||
"""
|
||||
Describes the dimension number arguments to an `XLA's Scatter operator
|
||||
@ -1161,7 +1160,8 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers,
|
||||
expanded_indices_shape.pop(index_vector_dim)
|
||||
indices_shape = iter(expanded_indices_shape)
|
||||
|
||||
slice_sizes = iter(np.delete(slice_sizes, collapsed_slice_dims))
|
||||
slice_sizes = (s for i, s in enumerate(slice_sizes)
|
||||
if i not in collapsed_slice_dims)
|
||||
return tuple(next(slice_sizes) if i in offset_dims
|
||||
else next(indices_shape) for i in range(output_shape_rank))
|
||||
|
||||
@ -1250,7 +1250,7 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
|
||||
|
||||
elif operand_bdim is None and indices_bdim is not None:
|
||||
indices = batching.moveaxis(indices, indices_bdim, 0)
|
||||
offset_dims = tuple(np.add(1, dimension_numbers.offset_dims))
|
||||
offset_dims = tuple(1 + d for d in dimension_numbers.offset_dims)
|
||||
dnums = GatherDimensionNumbers(
|
||||
offset_dims=offset_dims,
|
||||
collapsed_slice_dims=dimension_numbers.collapsed_slice_dims,
|
||||
@ -1308,12 +1308,10 @@ gather_p = standard_primitive(
|
||||
_gather_shape_rule, _gather_dtype_rule, 'gather',
|
||||
weak_type_rule=_argnum_weak_type(0))
|
||||
ad.defjvp(gather_p, _gather_jvp_rule, None)
|
||||
|
||||
ad.primitive_transposes[gather_p] = _gather_transpose_rule
|
||||
batching.primitive_batchers[gather_p] = _gather_batching_rule
|
||||
|
||||
|
||||
|
||||
def _gather_lower(ctx, operand, indices, *,
|
||||
dimension_numbers, slice_sizes, unique_indices,
|
||||
indices_are_sorted, mode, fill_value):
|
||||
@ -2054,39 +2052,3 @@ def _dynamic_slice_indices(operand, start_indices: Any):
|
||||
d = lax.convert_element_type(core.dimension_as_value(d), _dtype(i))
|
||||
result.append(lax.select(i < 0, i + d, i))
|
||||
return result
|
||||
|
||||
|
||||
# TODO(mattjj): getslice is a prototype for dynamic shapes, revise or remove it
|
||||
def _getslice(x, lo, hi):
|
||||
return getslice_p.bind(x, lo, hi)
|
||||
|
||||
getslice_p = core.Primitive('getslice')
|
||||
|
||||
@getslice_p.def_impl
|
||||
def getslice_impl(x, lo, hi):
|
||||
return x[lo:hi]
|
||||
|
||||
def _getslice_staging_rule(trace, x, lo, hi):
|
||||
size = lax.make_bint(lax.clamp(0, hi - lo, x.shape[0]), x.shape[0])
|
||||
aval = core.DShapedArray((size,), x.dtype, x.weak_type)
|
||||
source_info = source_info_util.current()
|
||||
out_tracer = pe.DynamicJaxprTracer(trace, aval, source_info)
|
||||
invars = map(trace.getvar, [x, lo, hi])
|
||||
eqn = pe.new_jaxpr_eqn(invars, [trace.makevar(out_tracer)],
|
||||
getslice_p, {}, source_info)
|
||||
trace.frame.eqns.append(eqn)
|
||||
return out_tracer
|
||||
pe.custom_staging_rules[getslice_p] = _getslice_staging_rule
|
||||
|
||||
def _getslice_padding_rule(in_avals, out_avals, x, lo, hi):
|
||||
xx = lax.concatenate([x, x], 0)
|
||||
return [dynamic_slice_in_dim(xx, lo, x.shape[0])]
|
||||
pe.padding_rules[getslice_p] = _getslice_padding_rule
|
||||
|
||||
def _getslice_lower(ctx, x, lo, hi):
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.RealDynamicSliceOp(
|
||||
mlir.aval_to_ir_type(aval_out), x,
|
||||
mlir.shape_tensor([lo]), mlir.shape_tensor([hi]), mlir.shape_tensor([1])
|
||||
).results
|
||||
mlir.register_lowering(getslice_p, _getslice_lower)
|
||||
|
@ -51,7 +51,6 @@ from jax._src.api_util import _ensure_index_tuple
|
||||
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
|
||||
_sort_le_comparator)
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax.slicing import _getslice
|
||||
from jax._src.numpy.ndarray import ndarray
|
||||
from jax._src.numpy.reductions import ( # noqa: F401
|
||||
_ensure_optional_axes, _reduction_dims,
|
||||
@ -3620,14 +3619,16 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
|
||||
(start, stop, step) != (0, n, 1)):
|
||||
return lax.slice_in_dim(arr, start, stop, step)
|
||||
|
||||
|
||||
# TODO(mattjj,dougalm): expand dynamic shape indexing support
|
||||
if (jax.config.jax_dynamic_shapes and type(idx) is slice and idx.step is None
|
||||
and (isinstance(idx.start, core.Tracer) or isinstance(idx.stop, core.Tracer))
|
||||
and arr.shape):
|
||||
start = 0 if idx.start is None else idx.start
|
||||
stop = arr.shape[0] if idx.stop is None else idx.stop
|
||||
return _getslice(arr, start, stop)
|
||||
if jax.config.jax_dynamic_shapes and arr.ndim > 0:
|
||||
try: aval = core.get_aval(idx)
|
||||
except: pass
|
||||
else:
|
||||
if (isinstance(aval, core.DShapedArray) and aval.shape == () and
|
||||
dtypes.issubdtype(aval.dtype, np.integer) and
|
||||
not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
|
||||
isinstance(arr.shape[0], int)):
|
||||
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
|
||||
|
||||
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
|
||||
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
|
||||
|
@ -1692,7 +1692,6 @@ _SPECIAL_DIMENSION_HANDLERS: Dict[type, DimensionHandler] = {}
|
||||
def _get_special_dim_handler(dim: DimSize) -> Optional[DimensionHandler]:
|
||||
if isinstance(dim, Tracer) and not config.jax_dynamic_shapes:
|
||||
return None
|
||||
# TODO: look up DynamicJaxprTracer
|
||||
return _SPECIAL_DIMENSION_HANDLERS.get(type(dim))
|
||||
|
||||
def _dim_handler_and_canonical(*dlist: DimSize) -> Tuple[DimensionHandler, Tuple[DimSize, ...]]:
|
||||
@ -1801,7 +1800,9 @@ def dimension_as_value(d: DimSize):
|
||||
return handler.as_value(*ds)
|
||||
|
||||
def _canonicalize_dimension(dim: DimSize) -> DimSize:
|
||||
if is_special_dim_size(dim):
|
||||
if isinstance(dim, Tracer) and config.jax_dynamic_shapes:
|
||||
return dim
|
||||
elif is_special_dim_size(dim):
|
||||
return dim
|
||||
else:
|
||||
return operator.index(dim)
|
||||
|
@ -754,7 +754,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(y, jnp.arange(2 * 4), check_dtypes=False)
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
@jtu.skip_on_devices('iree') # TODO(mattjj): update getslice, no bints
|
||||
@unittest.skip("revising slicing logic")
|
||||
def test_slicing_basic(self):
|
||||
f = jax.jit(lambda x, n: jnp.sum(x[:n]))
|
||||
# TODO(mattjj): revise getslice, add typecheck rule for it, enable checks
|
||||
@ -765,7 +765,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
|
||||
# TODO(mattjj,dougalm,phawkins): debug iree failure, "failed to legalize
|
||||
# operation 'mhlo.while' that was explicitly marked illegal"
|
||||
@jtu.skip_on_devices('iree')
|
||||
@unittest.skip("revising slicing logic")
|
||||
def test_scan_basic(self):
|
||||
def cumsum(x):
|
||||
def body(i, _):
|
||||
@ -1299,6 +1299,34 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
f, = jaxpr.outvars
|
||||
self.assertEqual(f.aval.shape, (a,))
|
||||
|
||||
def test_vmap_of_indexing_basic(self):
|
||||
x = jnp.arange(3.)
|
||||
|
||||
def f(idxs):
|
||||
return jax.vmap(lambda i: x[i])(idxs)
|
||||
|
||||
idxs = jnp.arange(3)
|
||||
jaxpr = jax.make_jaxpr(f, abstracted_axes=('n',))(idxs).jaxpr
|
||||
# { lambda a:f32[3]; b:i32[] c:i32[b]. let
|
||||
# d:bool[b] = lt c 0
|
||||
# e:i32[b] = add c 3
|
||||
# f:i32[b] = select_n d c e
|
||||
# g:i32[b,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None, 1)] f b
|
||||
# h:f32[b,1] = gather[
|
||||
# dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,))
|
||||
# fill_value=None
|
||||
# indices_are_sorted=False
|
||||
# mode=GatherScatterMode.PROMISE_IN_BOUNDS
|
||||
# slice_sizes=(1,)
|
||||
# unique_indices=False
|
||||
# ] a g
|
||||
# i:f32[b] = squeeze[dimensions=(1,)] h
|
||||
# in (i,) }
|
||||
b, _ = jaxpr.invars
|
||||
e, = (e for e in jaxpr.eqns if str(e.primitive) == 'gather')
|
||||
h, = e.outvars
|
||||
self.assertEqual(h.aval.shape, (b, 1))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user