mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
defer to custom eltype for slice lowering rule
We already handled dynamic slice, but plain slice is eltype-polymorphic too.
This commit is contained in:
parent
8a1b4785de
commit
7955799ae3
@ -800,6 +800,10 @@ batching.primitive_batchers[slice_p] = _slice_batching_rule
|
||||
|
||||
def _slice_lower(ctx, x, *, start_indices, limit_indices, strides):
|
||||
strides = strides or [1] * len(start_indices)
|
||||
aval_out, = ctx.avals_out
|
||||
if type(aval_out.dtype) in core.custom_eltypes:
|
||||
return aval_out.dtype.slice_mlir(ctx, x, start_indices, limit_indices,
|
||||
strides)
|
||||
return mhlo.SliceOp(x,
|
||||
mlir.dense_int_elements(start_indices),
|
||||
mlir.dense_int_elements(limit_indices),
|
||||
|
@ -3007,6 +3007,16 @@ class FooTy:
|
||||
def empty_mlir(ctx):
|
||||
return mlir.ir_constants(np.zeros((2,), dtype=np.dtype('uint32')))
|
||||
|
||||
@staticmethod
|
||||
def slice_mlir(ctx, x, start_indices, limit_indices, strides):
|
||||
start_indices = (*start_indices, 0)
|
||||
limit_indices = (*limit_indices, 2)
|
||||
strides = (*strides, 1)
|
||||
return mhlo.SliceOp(x,
|
||||
mlir.dense_int_elements(start_indices),
|
||||
mlir.dense_int_elements(limit_indices),
|
||||
mlir.dense_int_elements(strides)).results
|
||||
|
||||
@staticmethod
|
||||
def dynamic_slice_mlir(ctx, x, start_indices, slice_sizes):
|
||||
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
|
||||
@ -3251,6 +3261,18 @@ class CustomElementTypesTest(jtu.JaxTestCase):
|
||||
expected = jnp.broadcast_to(3 * 4 * 5, (3, 5, 4)).astype('float32')
|
||||
self.assertAllClose(ys, expected)
|
||||
|
||||
def test_slice(self):
|
||||
ks = jax.jit(lambda: make((3, 4)))()
|
||||
ys = jax.jit(lambda x: lax.slice_in_dim(x, 1, 3))(ks)
|
||||
self.assertIsInstance(ys, FooArray)
|
||||
self.assertEqual(ys.shape, (2, 4))
|
||||
|
||||
def test_dynamic_slice(self):
|
||||
ks = jax.jit(lambda: make((3, 4)))()
|
||||
ys = jax.jit(lambda x, i: lax.dynamic_slice_in_dim(x, i, 2))(ks, 1)
|
||||
self.assertIsInstance(ys, FooArray)
|
||||
self.assertEqual(ys.shape, (2, 4))
|
||||
|
||||
def test_transpose(self):
|
||||
ks = jax.jit(lambda: make((3, 4)))()
|
||||
ys = jax.jit(lambda x: x.T)(ks)
|
||||
|
Loading…
x
Reference in New Issue
Block a user