[pallas:gpu] Fix swap Triton lowering.

PiperOrigin-RevId: 573141426
This commit is contained in:
Chris Jones 2023-10-13 01:33:41 -07:00 committed by jax authors
parent 0da5828d03
commit 2bc2e173cb
3 changed files with 68 additions and 10 deletions

View File

@ -375,20 +375,30 @@ ad.primitive_jvps[swap_p] = _swap_jvp
def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
del out_avals # Unused.
ref, idx, val, _ = args_tree.unflatten(args_flat)
ref, idx, val, mask = args_tree.unflatten(args_flat)
if all((isinstance(s, Slice) or not s.shape) for s in idx.indices):
indices = idx.indices
scalar_dims = [not isinstance(s, Slice) and s.shape == () for s in indices]
scalar_dims = [
i
for i, s in enumerate(indices)
if not isinstance(s, Slice) and not s.shape
]
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
val_indexer = tuple(None if scalar else slice(None) for scalar in scalar_dims)
val = val[val_indexer]
out = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
out = jnp.squeeze(out, scalar_dims)
if mask is not None:
out_ = out
out = jnp.where(mask, out, val)
val = jnp.where(mask, val, out_)
val = jnp.expand_dims(val, scalar_dims)
x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts)
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
out = out_ones[out_indexer]
elif all(not isinstance(s, Slice) for s in idx.indices):
out = ref[idx.indices]
if mask is not None:
out_ = out
out = jnp.where(mask, out, val)
val = jnp.where(mask, val, out_)
x_new = ref.at[idx.indices].set(val)
else:
raise NotImplementedError

View File

@ -740,7 +740,9 @@ def _get_lowering_rule(
ptr = _compute_pointers_from_indices(
ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder
)
return tl.load(ptr, _builder=ctx.builder)
val = tl.load(ptr, _builder=ctx.builder)
# `tl.load` of a `*int1` returns a tensor with type `int8`, so fix the type.
return val.to(ptr.dtype.element_ty, _builder=ctx.builder)
triton_lowering_rules[sp.get_p] = _get_lowering_rule
@ -761,7 +763,7 @@ def _masked_load_lowering_rule(
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape, ctx.builder
)
return tl.load(
val = tl.load(
ptr,
mask=mask,
other=other,
@ -770,6 +772,8 @@ def _masked_load_lowering_rule(
eviction_policy=eviction_policy,
_builder=ctx.builder,
)
# `tl.load` of a `*int1` returns a tensor with type `int8`, so fix the type.
return val.to(ptr.dtype.element_ty, _builder=ctx.builder)
triton_lowering_rules[primitives.load_p] = _masked_load_lowering_rule
@ -819,13 +823,16 @@ def _masked_swap_lowering_rule(
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape, ctx.builder
)
return tl.store(
other = None if mask is None else value
old_value = tl.load(ptr, mask=mask, other=other, _builder=ctx.builder)
tl.store(
ptr,
value,
mask=mask,
eviction_policy=eviction_policy,
_builder=ctx.builder,
)
return old_value
triton_lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule

View File

@ -436,6 +436,47 @@ class PallasCallTest(PallasTest):
x = random.normal(key, (m, n))
np.testing.assert_allclose(load(x), x + 1., atol=1e-5, rtol=1e-5)
def test_swap(self):
m, n = 16, 32
@functools.partial(
self.pallas_call,
out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2,
grid=1,
input_output_aliases={0: 0, 1: 1},
)
def swap(_, _2, x_ref, y_ref):
x = x_ref[:]
y = pl.swap(y_ref, (slice(None),), x)
x_ref[:] = y
x = random.normal(random.PRNGKey(0), (m, n))
y = random.normal(random.PRNGKey(1), (m, n))
out = swap(x, y)
np.testing.assert_array_equal(out[0], y)
np.testing.assert_array_equal(out[1], x)
def test_masked_swap(self):
m, n = 16, 32
@functools.partial(
self.pallas_call,
out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2,
grid=1,
input_output_aliases={0: 0, 1: 1},
)
def masked_swap(_, _2, mask_ref, x_ref, y_ref):
x = x_ref[:]
y = pl.swap(y_ref, (slice(None),), x, mask=mask_ref[:])
x_ref[:] = y
x = random.normal(random.PRNGKey(0), (m, n))
y = random.normal(random.PRNGKey(1), (m, n))
mask = random.bernoulli(random.PRNGKey(2), shape=(m, n))
out = masked_swap(x, y, mask)
np.testing.assert_array_equal(out[0], jnp.where(mask, y, x))
np.testing.assert_array_equal(out[1], jnp.where(mask, x, y))
def test_unused_ref(self):
m, n = 16, 32
@functools.partial(