[dynamic-shapes] add basic vmap-of-indexing support

The main changes here are only indirectly related to gather: we just had to
update some other rules (e.g. for comparison, and squeeze) for a simple
dynamic-batch-shape gather to work.

I also skipped two tests and deleted some old dynamic shape slicing logic
because we want to handle that differently. We didn't have to do that removal
in this PR, but it's just convenient given I'm looking at indexing again.
This commit is contained in:
Matthew Johnson 2022-07-07 16:44:00 -07:00
parent 49672cd2bc
commit 58826507cc
6 changed files with 72 additions and 57 deletions

View File

@ -110,6 +110,9 @@ class IreeBuffer(xla_client.DeviceArrayBase):
def _value(self):
return np.asarray(self)
def copy_to_host_async(self):
return self
class IreeExecutable:
def __init__(self, client, devices, module_object, function_name):

View File

@ -629,8 +629,14 @@ def concatenate(operands: Sequence[Array], dimension: int) -> Array:
Returns:
An array containing the concatenation.
"""
from jax.experimental import array
if len(operands) == 0:
raise ValueError("concatenate requires a non-empty sequences of arrays")
if len(operands) == 1:
op, = operands
if isinstance(op, (core.Tracer, device_array.DeviceArray, array.Array)):
return op
return concatenate_p.bind(*operands, dimension=dimension)
@ -2249,8 +2255,12 @@ mlir.register_lowering(shift_right_logical_p,
partial(_nary_lower_mhlo, mhlo.ShiftRightLogicalOp))
def _compare_lower_mhlo(direction: str, ctx, x, y):
x_aval, y_aval = ctx.avals_in
aval_out, = ctx.avals_out
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
if config.jax_dynamic_shapes:
substitute = partial(_substitute_axis_sizes_in_aval, ctx.axis_size_env)
avals_in = map(substitute, avals_in)
aval_out = substitute(aval_out)
x_aval, y_aval = avals_in
x, y = broadcast_mhlo(aval_out.update(dtype=x_aval.dtype), ctx.avals_in,
(x, y))
if dtypes.issubdtype(x_aval.dtype, np.inexact):
@ -2757,8 +2767,9 @@ def _broadcast_in_dim_transpose_rule(ct, operand, *dyn_shape,
return ([expand_dims(_reduce_sum(ct, axes), unit_dims)] +
[None] * len(dyn_shape))
def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *, shape,
def _broadcast_in_dim_batch_rule(batched_args, batch_dims, *dyn_shape, shape,
broadcast_dimensions):
if dyn_shape: raise NotImplementedError # TODO(mattjj)
operand, = batched_args
bdim, = batch_dims
new_operand = batching.moveaxis(operand, bdim, 0)
@ -3157,7 +3168,16 @@ batching.primitive_batchers[squeeze_p] = _squeeze_batch_rule
def _squeeze_lower(ctx, operand, *, dimensions):
del dimensions # Implied by the output aval.
aval_out, = ctx.avals_out
return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), operand).results
if config.jax_dynamic_shapes:
substitute = partial(_substitute_axis_sizes_in_aval, ctx.axis_size_env)
aval_out = substitute(aval_out)
if any(isinstance(d, ir.Value) for d in aval_out.shape):
return mhlo.DynamicReshapeOp(
mlir.aval_to_ir_type(aval_out), operand,
mlir.shape_tensor(aval_out.shape),
).results
else:
return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), operand).results
mlir.register_lowering(squeeze_p, _squeeze_lower)

View File

@ -287,7 +287,6 @@ def gather(operand: Array, start_indices: Array,
fill_value=fill_value)
class ScatterDimensionNumbers(NamedTuple):
"""
Describes the dimension number arguments to an `XLA's Scatter operator
@ -1161,7 +1160,8 @@ def _gather_shape_rule(operand, indices, *, dimension_numbers,
expanded_indices_shape.pop(index_vector_dim)
indices_shape = iter(expanded_indices_shape)
slice_sizes = iter(np.delete(slice_sizes, collapsed_slice_dims))
slice_sizes = (s for i, s in enumerate(slice_sizes)
if i not in collapsed_slice_dims)
return tuple(next(slice_sizes) if i in offset_dims
else next(indices_shape) for i in range(output_shape_rank))
@ -1250,7 +1250,7 @@ def _gather_batching_rule(batched_args, batch_dims, *, dimension_numbers,
elif operand_bdim is None and indices_bdim is not None:
indices = batching.moveaxis(indices, indices_bdim, 0)
offset_dims = tuple(np.add(1, dimension_numbers.offset_dims))
offset_dims = tuple(1 + d for d in dimension_numbers.offset_dims)
dnums = GatherDimensionNumbers(
offset_dims=offset_dims,
collapsed_slice_dims=dimension_numbers.collapsed_slice_dims,
@ -1308,12 +1308,10 @@ gather_p = standard_primitive(
_gather_shape_rule, _gather_dtype_rule, 'gather',
weak_type_rule=_argnum_weak_type(0))
ad.defjvp(gather_p, _gather_jvp_rule, None)
ad.primitive_transposes[gather_p] = _gather_transpose_rule
batching.primitive_batchers[gather_p] = _gather_batching_rule
def _gather_lower(ctx, operand, indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
@ -2054,39 +2052,3 @@ def _dynamic_slice_indices(operand, start_indices: Any):
d = lax.convert_element_type(core.dimension_as_value(d), _dtype(i))
result.append(lax.select(i < 0, i + d, i))
return result
# TODO(mattjj): getslice is a prototype for dynamic shapes, revise or remove it
def _getslice(x, lo, hi):
return getslice_p.bind(x, lo, hi)
getslice_p = core.Primitive('getslice')
@getslice_p.def_impl
def getslice_impl(x, lo, hi):
return x[lo:hi]
def _getslice_staging_rule(trace, x, lo, hi):
size = lax.make_bint(lax.clamp(0, hi - lo, x.shape[0]), x.shape[0])
aval = core.DShapedArray((size,), x.dtype, x.weak_type)
source_info = source_info_util.current()
out_tracer = pe.DynamicJaxprTracer(trace, aval, source_info)
invars = map(trace.getvar, [x, lo, hi])
eqn = pe.new_jaxpr_eqn(invars, [trace.makevar(out_tracer)],
getslice_p, {}, source_info)
trace.frame.eqns.append(eqn)
return out_tracer
pe.custom_staging_rules[getslice_p] = _getslice_staging_rule
def _getslice_padding_rule(in_avals, out_avals, x, lo, hi):
xx = lax.concatenate([x, x], 0)
return [dynamic_slice_in_dim(xx, lo, x.shape[0])]
pe.padding_rules[getslice_p] = _getslice_padding_rule
def _getslice_lower(ctx, x, lo, hi):
aval_out, = ctx.avals_out
return mhlo.RealDynamicSliceOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.shape_tensor([lo]), mlir.shape_tensor([hi]), mlir.shape_tensor([1])
).results
mlir.register_lowering(getslice_p, _getslice_lower)

View File

@ -51,7 +51,6 @@ from jax._src.api_util import _ensure_index_tuple
from jax._src.lax.lax import (_array_copy, _sort_lt_comparator,
_sort_le_comparator)
from jax._src.lax import lax as lax_internal
from jax._src.lax.slicing import _getslice
from jax._src.numpy.ndarray import ndarray
from jax._src.numpy.reductions import ( # noqa: F401
_ensure_optional_axes, _reduction_dims,
@ -3620,14 +3619,16 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
(start, stop, step) != (0, n, 1)):
return lax.slice_in_dim(arr, start, stop, step)
# TODO(mattjj,dougalm): expand dynamic shape indexing support
if (jax.config.jax_dynamic_shapes and type(idx) is slice and idx.step is None
and (isinstance(idx.start, core.Tracer) or isinstance(idx.stop, core.Tracer))
and arr.shape):
start = 0 if idx.start is None else idx.start
stop = arr.shape[0] if idx.stop is None else idx.stop
return _getslice(arr, start, stop)
if jax.config.jax_dynamic_shapes and arr.ndim > 0:
try: aval = core.get_aval(idx)
except: pass
else:
if (isinstance(aval, core.DShapedArray) and aval.shape == () and
dtypes.issubdtype(aval.dtype, np.integer) and
not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
isinstance(arr.shape[0], int)):
return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,

View File

@ -1692,7 +1692,6 @@ _SPECIAL_DIMENSION_HANDLERS: Dict[type, DimensionHandler] = {}
def _get_special_dim_handler(dim: DimSize) -> Optional[DimensionHandler]:
if isinstance(dim, Tracer) and not config.jax_dynamic_shapes:
return None
# TODO: look up DynamicJaxprTracer
return _SPECIAL_DIMENSION_HANDLERS.get(type(dim))
def _dim_handler_and_canonical(*dlist: DimSize) -> Tuple[DimensionHandler, Tuple[DimSize, ...]]:
@ -1801,7 +1800,9 @@ def dimension_as_value(d: DimSize):
return handler.as_value(*ds)
def _canonicalize_dimension(dim: DimSize) -> DimSize:
if is_special_dim_size(dim):
if isinstance(dim, Tracer) and config.jax_dynamic_shapes:
return dim
elif is_special_dim_size(dim):
return dim
else:
return operator.index(dim)

View File

@ -754,7 +754,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
self.assertAllClose(y, jnp.arange(2 * 4), check_dtypes=False)
self.assertEqual(count, 1)
@jtu.skip_on_devices('iree') # TODO(mattjj): update getslice, no bints
@unittest.skip("revising slicing logic")
def test_slicing_basic(self):
f = jax.jit(lambda x, n: jnp.sum(x[:n]))
# TODO(mattjj): revise getslice, add typecheck rule for it, enable checks
@ -765,7 +765,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
# TODO(mattjj,dougalm,phawkins): debug iree failure, "failed to legalize
# operation 'mhlo.while' that was explicitly marked illegal"
@jtu.skip_on_devices('iree')
@unittest.skip("revising slicing logic")
def test_scan_basic(self):
def cumsum(x):
def body(i, _):
@ -1299,6 +1299,34 @@ class DynamicShapeTest(jtu.JaxTestCase):
f, = jaxpr.outvars
self.assertEqual(f.aval.shape, (a,))
def test_vmap_of_indexing_basic(self):
x = jnp.arange(3.)
def f(idxs):
return jax.vmap(lambda i: x[i])(idxs)
idxs = jnp.arange(3)
jaxpr = jax.make_jaxpr(f, abstracted_axes=('n',))(idxs).jaxpr
# { lambda a:f32[3]; b:i32[] c:i32[b]. let
# d:bool[b] = lt c 0
# e:i32[b] = add c 3
# f:i32[b] = select_n d c e
# g:i32[b,1] = broadcast_in_dim[broadcast_dimensions=(0,) shape=(None, 1)] f b
# h:f32[b,1] = gather[
# dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(), start_index_map=(0,))
# fill_value=None
# indices_are_sorted=False
# mode=GatherScatterMode.PROMISE_IN_BOUNDS
# slice_sizes=(1,)
# unique_indices=False
# ] a g
# i:f32[b] = squeeze[dimensions=(1,)] h
# in (i,) }
b, _ = jaxpr.invars
e, = (e for e in jaxpr.eqns if str(e.primitive) == 'gather')
h, = e.outvars
self.assertEqual(h.aval.shape, (b, 1))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())