mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[Pallas] Add interpret mode support for dynamic grid
PiperOrigin-RevId: 603818776
This commit is contained in:
parent
c4b6266049
commit
a7a6b40b55
@ -45,10 +45,6 @@ def pallas_call_tpu_lowering_rule(
|
||||
**compiler_params: Any):
|
||||
"""Lowers a pallas_call to a Mosaic TPU custom call."""
|
||||
if interpret:
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError(
|
||||
"Dynamic grid bounds not supported in interpret mode."
|
||||
)
|
||||
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
||||
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
|
||||
in_shapes=in_shapes,
|
||||
|
@ -16,7 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
from functools import reduce
|
||||
|
||||
from typing import Any, Callable
|
||||
from collections.abc import Sequence
|
||||
@ -89,24 +89,37 @@ def _uninitialized_value(shape, dtype):
|
||||
return jnp.full(shape, jnp.iinfo(dtype).min, dtype)
|
||||
raise NotImplementedError(dtype)
|
||||
|
||||
def _get_next_indices(grid, indices):
|
||||
next_indices = []
|
||||
carry = True
|
||||
for dim_size, index in reversed(list(zip(grid, indices))):
|
||||
i = jnp.where(carry, index + 1, index)
|
||||
carry = dim_size == i
|
||||
next_indices.append(jnp.where(carry, 0, i))
|
||||
return tuple(reversed(next_indices))
|
||||
|
||||
def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
|
||||
interpret, debug: bool,
|
||||
in_shapes,
|
||||
input_output_aliases: tuple[tuple[int, int], ...],
|
||||
grid_mapping: GridMapping,
|
||||
**compiler_params: Any):
|
||||
if grid_mapping.num_dynamic_grid_bounds:
|
||||
raise NotImplementedError("interpret with dynamic grid bounds unsupported")
|
||||
dynamic_grid_args, args = split_list( # type: ignore
|
||||
args, [grid_mapping.num_dynamic_grid_bounds]
|
||||
)
|
||||
if interpret:
|
||||
# If we're in interpreter mode, we *scan* over the grid and eval the
|
||||
# discharged jaxpr. This should reproduce exactly what compiling to Triton
|
||||
# will do.
|
||||
grid = grid_mapping.static_grid
|
||||
dynamic_grid_args_iter = iter(dynamic_grid_args)
|
||||
grid = tuple(
|
||||
a if a is not None else next(dynamic_grid_args_iter)
|
||||
for a in grid_mapping.grid
|
||||
)
|
||||
assert next(dynamic_grid_args_iter, None) is None
|
||||
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
|
||||
if debug:
|
||||
print(discharged_jaxpr)
|
||||
loop_indices = jnp.array(list(it.product(*(range(g) for g in grid))),
|
||||
dtype=jnp.int32)
|
||||
oi_map = {v: k for k, v in input_output_aliases}
|
||||
out = []
|
||||
for i, out_shape in enumerate(out_shapes):
|
||||
@ -135,12 +148,18 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
|
||||
for a in scratch_avals]
|
||||
carry = [*args, *out, *scratch_values]
|
||||
num_carry = len(args) + len(out)
|
||||
grid_start_indices = (jnp.int32(0),) * len(grid)
|
||||
if grid:
|
||||
num_iterations = reduce(jnp.multiply, grid)
|
||||
else:
|
||||
# Base case is always one iteration when grid is ()
|
||||
num_iterations = 1
|
||||
def cond(carry):
|
||||
return carry[0] < loop_indices.shape[0]
|
||||
i, *_ = carry
|
||||
return i < num_iterations
|
||||
def body(carry):
|
||||
i, *carry = carry
|
||||
i, loop_idx, *carry = carry
|
||||
carry, scratch = split_list(carry, [num_carry])
|
||||
loop_idx = loop_indices[i]
|
||||
start_indices = [
|
||||
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
|
||||
for bm in grid_mapping.block_mappings]
|
||||
@ -177,8 +196,10 @@ def _pallas_call_impl(*args, jaxpr, name, out_shapes, which_linear,
|
||||
blocks, out_scratch = split_list(blocks, [num_carry])
|
||||
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
|
||||
carry, blocks, is_indexing_dim)
|
||||
return (i + 1, *carry, *out_scratch)
|
||||
(_, *carry) = lax.while_loop(cond, body, (jnp.int32(0), *carry))
|
||||
return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch)
|
||||
(_, _, *carry) = lax.while_loop(
|
||||
cond, body, (jnp.int32(0), grid_start_indices, *carry)
|
||||
)
|
||||
_, out, _ = split_list(carry, [len(args), len(out)])
|
||||
return out
|
||||
return xla.apply_primitive(pallas_call_p, *args, jaxpr=jaxpr, name=name,
|
||||
|
@ -44,15 +44,17 @@ partial = functools.partial
|
||||
class PallasTPUTest(jtu.JaxTestCase):
|
||||
interpret: bool = False
|
||||
|
||||
|
||||
class PallasCallScalarPrefetchTest(PallasTPUTest):
|
||||
interpret: bool = False
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not self.interpret and jtu.device_under_test() != 'tpu':
|
||||
self.skipTest('Only interpret mode supported on non-TPU')
|
||||
|
||||
def pallas_call(self, *args, **kwargs):
|
||||
return pl.pallas_call(*args, **kwargs, interpret=self.interpret)
|
||||
|
||||
|
||||
class PallasCallScalarPrefetchTest(PallasTPUTest):
|
||||
|
||||
def test_trivial_scalar_prefetch(self):
|
||||
def body(_, x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
@ -288,6 +290,107 @@ class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest):
|
||||
interpret: bool = True
|
||||
|
||||
|
||||
class PallasCallDynamicGridTest(PallasTPUTest):
|
||||
|
||||
def test_dynamic_grid(self):
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
||||
def kernel(y_ref):
|
||||
@pl.when(pl.program_id(0) == 0)
|
||||
def _init():
|
||||
y_ref[...] = jnp.zeros_like(y_ref)
|
||||
y_ref[...] += 1
|
||||
|
||||
@jax.jit
|
||||
def dynamic_kernel(steps):
|
||||
return self.pallas_call(
|
||||
kernel,
|
||||
grid=(steps * 2,),
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
|
||||
out_shape=result_ty,
|
||||
)()
|
||||
np.testing.assert_array_equal(
|
||||
dynamic_kernel(jnp.int32(4)), np.full(shape, 8.0, np.float32)
|
||||
)
|
||||
|
||||
def test_vmap_trivial_dynamic_grid(self):
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
||||
def kernel(x_ref, y_ref):
|
||||
@pl.when(pl.program_id(0) == 0)
|
||||
def _init():
|
||||
y_ref[...] = x_ref[...]
|
||||
y_ref[...] += 1
|
||||
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def dynamic_kernel(steps, x):
|
||||
return self.pallas_call(
|
||||
kernel,
|
||||
grid=(steps * 2,),
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
|
||||
out_shape=result_ty,
|
||||
)(x)
|
||||
x = jnp.arange(8 * 128., dtype=jnp.float32).reshape((1, *shape))
|
||||
np.testing.assert_array_equal(
|
||||
dynamic_kernel(jnp.array([4], jnp.int32), x), x + 8.0
|
||||
)
|
||||
|
||||
def test_vmap_nontrivial_dynamic_grid(self):
|
||||
# Dynamic grid doesn't support vmapping over multiple distinct grid values
|
||||
# at the moment.
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
||||
def kernel(y_ref):
|
||||
@pl.when(pl.program_id(0) == 0)
|
||||
def _init():
|
||||
y_ref[...] = jnp.zeros_like(y_ref)
|
||||
y_ref[...] += 1
|
||||
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def dynamic_kernel(steps):
|
||||
return self.pallas_call(
|
||||
kernel,
|
||||
grid=(steps * 2,),
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
|
||||
out_shape=result_ty,
|
||||
)()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
dynamic_kernel(jnp.array([4, 8], jnp.int32))
|
||||
|
||||
def test_vmap_dynamic_grid(self):
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
||||
def kernel(x_ref, y_ref):
|
||||
@pl.when(pl.program_id(0) == 0)
|
||||
def _init():
|
||||
y_ref[...] = x_ref[...]
|
||||
y_ref[...] += jnp.float32(1.)
|
||||
|
||||
@jax.jit
|
||||
def dynamic_kernel(x, steps):
|
||||
return self.pallas_call(
|
||||
kernel,
|
||||
grid=(steps * 2,),
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
|
||||
out_shape=result_ty,
|
||||
)(x)
|
||||
x = jnp.arange(4 * 8 * 128., dtype=jnp.float32).reshape((4, *shape))
|
||||
np.testing.assert_array_equal(
|
||||
jax.jit(jax.vmap(dynamic_kernel, in_axes=(0, None)))(x, jnp.int32(4)),
|
||||
x + 8,
|
||||
)
|
||||
|
||||
|
||||
class PallasCallInterpretDynamicGridTest(PallasCallDynamicGridTest):
|
||||
interpret: bool = True
|
||||
|
||||
|
||||
class PallasCallDMATest(parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -1091,98 +1194,6 @@ class PallasCallTest(PallasTPUTest):
|
||||
kernel, out_shape=x, mosaic_params=dict(vmem_limit_bytes=int(2**18))
|
||||
)(x)
|
||||
|
||||
def test_dynamic_grid(self):
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
||||
def kernel(y_ref):
|
||||
@pl.when(pl.program_id(0) == 0)
|
||||
def _init():
|
||||
y_ref[...] = jnp.zeros_like(y_ref)
|
||||
y_ref[...] += 1
|
||||
|
||||
@jax.jit
|
||||
def dynamic_kernel(steps):
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
grid=(steps * 2,),
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
|
||||
out_shape=result_ty,
|
||||
)()
|
||||
np.testing.assert_array_equal(
|
||||
dynamic_kernel(4), np.full(shape, 8.0, np.float32)
|
||||
)
|
||||
|
||||
def test_vmap_trivial_dynamic_grid(self):
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
||||
def kernel(x_ref, y_ref):
|
||||
@pl.when(pl.program_id(0) == 0)
|
||||
def _init():
|
||||
y_ref[...] = x_ref[...]
|
||||
y_ref[...] += 1
|
||||
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def dynamic_kernel(steps, x):
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
grid=(steps * 2,),
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
|
||||
out_shape=result_ty,
|
||||
)(x)
|
||||
x = jnp.arange(8 * 128.).reshape((1, *shape))
|
||||
np.testing.assert_array_equal(dynamic_kernel(jnp.array([4]), x), x + 8.0)
|
||||
|
||||
def test_vmap_nontrivial_dynamic_grid(self):
|
||||
# Dynamic grid doesn't support vmapping over multiple distinct grid values
|
||||
# at the moment.
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
||||
def kernel(y_ref):
|
||||
@pl.when(pl.program_id(0) == 0)
|
||||
def _init():
|
||||
y_ref[...] = jnp.zeros_like(y_ref)
|
||||
y_ref[...] += 1
|
||||
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def dynamic_kernel(steps):
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
grid=(steps * 2,),
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
|
||||
out_shape=result_ty,
|
||||
)()
|
||||
with self.assertRaises(NotImplementedError):
|
||||
dynamic_kernel(jnp.array([4, 8]))
|
||||
|
||||
def test_vmap_dynamic_grid(self):
|
||||
shape = (8, 128)
|
||||
result_ty = jax.ShapeDtypeStruct(shape, jnp.float32)
|
||||
|
||||
def kernel(x_ref, y_ref):
|
||||
@pl.when(pl.program_id(0) == 0)
|
||||
def _init():
|
||||
y_ref[...] = x_ref[...]
|
||||
y_ref[...] += 1.
|
||||
|
||||
@jax.jit
|
||||
def dynamic_kernel(x, steps):
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
grid=(steps * 2,),
|
||||
out_specs=pl.BlockSpec(lambda i: (0, 0), shape),
|
||||
out_shape=result_ty,
|
||||
)(x)
|
||||
x = jnp.arange(4 * 8 * 128.).reshape((4, *shape))
|
||||
np.testing.assert_array_equal(
|
||||
jax.jit(jax.vmap(dynamic_kernel, in_axes=(0, None)))(x, 4),
|
||||
x + 8,
|
||||
)
|
||||
|
||||
|
||||
class PallasUXTest(PallasTPUTest):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user