mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
[pallas:gpu] Fix swap
Triton lowering.
PiperOrigin-RevId: 573141426
This commit is contained in:
parent
0da5828d03
commit
2bc2e173cb
@ -375,20 +375,30 @@ ad.primitive_jvps[swap_p] = _swap_jvp
|
||||
|
||||
def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_):
|
||||
del out_avals # Unused.
|
||||
ref, idx, val, _ = args_tree.unflatten(args_flat)
|
||||
ref, idx, val, mask = args_tree.unflatten(args_flat)
|
||||
if all((isinstance(s, Slice) or not s.shape) for s in idx.indices):
|
||||
indices = idx.indices
|
||||
scalar_dims = [not isinstance(s, Slice) and s.shape == () for s in indices]
|
||||
scalar_dims = [
|
||||
i
|
||||
for i, s in enumerate(indices)
|
||||
if not isinstance(s, Slice) and not s.shape
|
||||
]
|
||||
slice_starts = [s.start if isinstance(s, Slice) else s for s in indices]
|
||||
slice_sizes = tuple(s.size if isinstance(s, Slice) else 1 for s in indices)
|
||||
val_indexer = tuple(None if scalar else slice(None) for scalar in scalar_dims)
|
||||
val = val[val_indexer]
|
||||
out = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
|
||||
out = jnp.squeeze(out, scalar_dims)
|
||||
if mask is not None:
|
||||
out_ = out
|
||||
out = jnp.where(mask, out, val)
|
||||
val = jnp.where(mask, val, out_)
|
||||
val = jnp.expand_dims(val, scalar_dims)
|
||||
x_new = lax.dynamic_update_slice(ref, val, start_indices=slice_starts)
|
||||
out_ones = lax.dynamic_slice(ref, slice_starts, slice_sizes=slice_sizes)
|
||||
out_indexer = tuple(0 if scalar else slice(None) for scalar in scalar_dims)
|
||||
out = out_ones[out_indexer]
|
||||
elif all(not isinstance(s, Slice) for s in idx.indices):
|
||||
out = ref[idx.indices]
|
||||
if mask is not None:
|
||||
out_ = out
|
||||
out = jnp.where(mask, out, val)
|
||||
val = jnp.where(mask, val, out_)
|
||||
x_new = ref.at[idx.indices].set(val)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@ -740,7 +740,9 @@ def _get_lowering_rule(
|
||||
ptr = _compute_pointers_from_indices(
|
||||
ptr, ref_block_info, idx, avals_in[0].shape, ctx.builder
|
||||
)
|
||||
return tl.load(ptr, _builder=ctx.builder)
|
||||
val = tl.load(ptr, _builder=ctx.builder)
|
||||
# `tl.load` of a `*int1` returns a tensor with type `int8`, so fix the type.
|
||||
return val.to(ptr.dtype.element_ty, _builder=ctx.builder)
|
||||
|
||||
|
||||
triton_lowering_rules[sp.get_p] = _get_lowering_rule
|
||||
@ -761,7 +763,7 @@ def _masked_load_lowering_rule(
|
||||
ptr = _compute_pointers_from_indices(
|
||||
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape, ctx.builder
|
||||
)
|
||||
return tl.load(
|
||||
val = tl.load(
|
||||
ptr,
|
||||
mask=mask,
|
||||
other=other,
|
||||
@ -770,6 +772,8 @@ def _masked_load_lowering_rule(
|
||||
eviction_policy=eviction_policy,
|
||||
_builder=ctx.builder,
|
||||
)
|
||||
# `tl.load` of a `*int1` returns a tensor with type `int8`, so fix the type.
|
||||
return val.to(ptr.dtype.element_ty, _builder=ctx.builder)
|
||||
|
||||
|
||||
triton_lowering_rules[primitives.load_p] = _masked_load_lowering_rule
|
||||
@ -819,13 +823,16 @@ def _masked_swap_lowering_rule(
|
||||
ptr = _compute_pointers_from_indices(
|
||||
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape, ctx.builder
|
||||
)
|
||||
return tl.store(
|
||||
other = None if mask is None else value
|
||||
old_value = tl.load(ptr, mask=mask, other=other, _builder=ctx.builder)
|
||||
tl.store(
|
||||
ptr,
|
||||
value,
|
||||
mask=mask,
|
||||
eviction_policy=eviction_policy,
|
||||
_builder=ctx.builder,
|
||||
)
|
||||
return old_value
|
||||
|
||||
|
||||
triton_lowering_rules[primitives.swap_p] = _masked_swap_lowering_rule
|
||||
|
@ -436,6 +436,47 @@ class PallasCallTest(PallasTest):
|
||||
x = random.normal(key, (m, n))
|
||||
np.testing.assert_allclose(load(x), x + 1., atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_swap(self):
|
||||
m, n = 16, 32
|
||||
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2,
|
||||
grid=1,
|
||||
input_output_aliases={0: 0, 1: 1},
|
||||
)
|
||||
def swap(_, _2, x_ref, y_ref):
|
||||
x = x_ref[:]
|
||||
y = pl.swap(y_ref, (slice(None),), x)
|
||||
x_ref[:] = y
|
||||
|
||||
x = random.normal(random.PRNGKey(0), (m, n))
|
||||
y = random.normal(random.PRNGKey(1), (m, n))
|
||||
out = swap(x, y)
|
||||
np.testing.assert_array_equal(out[0], y)
|
||||
np.testing.assert_array_equal(out[1], x)
|
||||
|
||||
def test_masked_swap(self):
|
||||
m, n = 16, 32
|
||||
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2,
|
||||
grid=1,
|
||||
input_output_aliases={0: 0, 1: 1},
|
||||
)
|
||||
def masked_swap(_, _2, mask_ref, x_ref, y_ref):
|
||||
x = x_ref[:]
|
||||
y = pl.swap(y_ref, (slice(None),), x, mask=mask_ref[:])
|
||||
x_ref[:] = y
|
||||
|
||||
x = random.normal(random.PRNGKey(0), (m, n))
|
||||
y = random.normal(random.PRNGKey(1), (m, n))
|
||||
mask = random.bernoulli(random.PRNGKey(2), shape=(m, n))
|
||||
out = masked_swap(x, y, mask)
|
||||
np.testing.assert_array_equal(out[0], jnp.where(mask, y, x))
|
||||
np.testing.assert_array_equal(out[1], jnp.where(mask, x, y))
|
||||
|
||||
def test_unused_ref(self):
|
||||
m, n = 16, 32
|
||||
@functools.partial(
|
||||
|
Loading…
x
Reference in New Issue
Block a user