defer to custom eltype for slice lowering rule

We already handled dynamic slice, but plain slice is eltype-polymorphic too.
This commit is contained in:
Roy Frostig 2022-08-09 19:13:34 -07:00
parent 8a1b4785de
commit 7955799ae3
2 changed files with 26 additions and 0 deletions

View File

@ -800,6 +800,10 @@ batching.primitive_batchers[slice_p] = _slice_batching_rule
def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
strides = strides or [1] * len(start_indices)
aval_out, = ctx.avals_out
if type(aval_out.dtype) in core.custom_eltypes:
return aval_out.dtype.slice_mlir(ctx, x, start_indices, limit_indices,
strides)
return mhlo.SliceOp(x,
mlir.dense_int_elements(start_indices),
mlir.dense_int_elements(limit_indices),

View File

@ -3007,6 +3007,16 @@ class FooTy:
def empty_mlir(ctx):
return mlir.ir_constants(np.zeros((2,), dtype=np.dtype('uint32')))
@staticmethod
def slice_mlir(ctx, x, start_indices, limit_indices, strides):
start_indices = (*start_indices, 0)
limit_indices = (*limit_indices, 2)
strides = (*strides, 1)
return mhlo.SliceOp(x,
mlir.dense_int_elements(start_indices),
mlir.dense_int_elements(limit_indices),
mlir.dense_int_elements(strides)).results
@staticmethod
def dynamic_slice_mlir(ctx, x, start_indices, slice_sizes):
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
@ -3251,6 +3261,18 @@ class CustomElementTypesTest(jtu.JaxTestCase):
expected = jnp.broadcast_to(3 * 4 * 5, (3, 5, 4)).astype('float32')
self.assertAllClose(ys, expected)
def test_slice(self):
ks = jax.jit(lambda: make((3, 4)))()
ys = jax.jit(lambda x: lax.slice_in_dim(x, 1, 3))(ks)
self.assertIsInstance(ys, FooArray)
self.assertEqual(ys.shape, (2, 4))
def test_dynamic_slice(self):
ks = jax.jit(lambda: make((3, 4)))()
ys = jax.jit(lambda x, i: lax.dynamic_slice_in_dim(x, i, 2))(ks, 1)
self.assertIsInstance(ys, FooArray)
self.assertEqual(ys.shape, (2, 4))
def test_transpose(self):
ks = jax.jit(lambda: make((3, 4)))()
ys = jax.jit(lambda x: x.T)(ks)