mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[pallas] Fix the handling of captured consts
There was an attempt to handle consts captured by the kernel, but it was incomplete and with errors: the calling convention was wrong, and the support for handling consts along with scalar prefetch and scratch values was incomplete. I expanded the tests: one in pallas_tests.py and two tests in tpu_pallas_test.py (to handle scalar prefetch, with and without scratch inputs). The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`. This is different from before (`*consts, *scalar_refs, *ins, ...`) so that it keeps the block arguments (consts, ins, outs) together and makes it easier to write the lowering. I will follow up with a cleanup PR for the handling of grid_mapping. Here I attempted to minimize the changes.
This commit is contained in:
parent
b3469a61d1
commit
b7105ccd19
@ -245,7 +245,7 @@ class BlockMapping:
|
||||
index_map_jaxpr: jax_core.ClosedJaxpr
|
||||
indexing_mode: IndexingMode
|
||||
|
||||
def compute_start_indices(self, loop_idx, *args):
|
||||
def compute_start_indices_interpret(self, loop_idx, *args):
|
||||
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(
|
||||
self.index_map_jaxpr.jaxpr, self.index_map_jaxpr.consts
|
||||
)
|
||||
@ -344,6 +344,10 @@ def _convert_block_spec_to_block_mapping(
|
||||
f"{len(aval.shape)} values to match {block_shape=}. "
|
||||
f"Currently returning {len(out_avals)} values."
|
||||
)
|
||||
if consts:
|
||||
raise NotImplementedError(
|
||||
f"Index map for {what}{tree_util.keystr(path)} captures constants: "
|
||||
f"{consts}")
|
||||
return BlockMapping(
|
||||
block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode
|
||||
)
|
||||
|
@ -274,7 +274,7 @@ class MosaicGridMapping:
|
||||
self.mapped_dims = grid_mapping.mapped_dims
|
||||
num_scalar_prefetch = grid_mapping.num_index_operands
|
||||
num_scratch = grid_mapping.num_scratch_operands
|
||||
# jaxpr has signature [*scalar_prefetch, *in_ops *out_ops, *scratch]
|
||||
# jaxpr has signature [*scalar_prefetch, *consts, *in_ops, *out_ops, *scratch]
|
||||
num_operands = (
|
||||
len(self.jaxpr.invars)
|
||||
- num_scalar_prefetch
|
||||
@ -411,8 +411,12 @@ def lower_jaxpr_to_module(
|
||||
grid_mapping.num_index_operands:-grid_mapping.num_scratch_operands]
|
||||
else:
|
||||
invars = invars[grid_mapping.num_index_operands:]
|
||||
# invars now = *consts, *ins, *outs
|
||||
avals = tuple(v.aval for v in invars)
|
||||
block_operand_shapes = (
|
||||
*[jax.ShapeDtypeStruct(v.aval.shape,
|
||||
v.aval.dtype)
|
||||
for v in invars[:grid_mapping.num_constant_operands]],
|
||||
*in_shapes[grid_mapping.num_index_operands:],
|
||||
*out_shapes,
|
||||
)
|
||||
@ -425,8 +429,6 @@ def lower_jaxpr_to_module(
|
||||
raise NotImplementedError(
|
||||
"BlockSpecs are required on TPU when grid is specified"
|
||||
)
|
||||
if bm.index_map_jaxpr.consts:
|
||||
raise NotImplementedError("Index map jaxpr with consts not supported.")
|
||||
# ANY operands don't support windowing and require empty window_params.
|
||||
if aval.memory_space == tpu_core.TPUMemorySpace.ANY:
|
||||
# We may not require windowing if our block_shape matches the original
|
||||
|
@ -76,6 +76,7 @@ def pallas_call_tpu_lowering_rule(
|
||||
compiler_params: dict[str, Any]):
|
||||
"""Lowers a pallas_call to a Mosaic TPU custom call."""
|
||||
if interpret:
|
||||
# TODO(necula): is this branch still needed?
|
||||
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
||||
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
|
||||
in_shapes=in_shapes,
|
||||
|
@ -42,6 +42,7 @@ def pallas_call_lowering(
|
||||
compiler_params: dict[str, Any],
|
||||
):
|
||||
if interpret:
|
||||
# TODO(necula): is this still needed?
|
||||
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
||||
ctx,
|
||||
*args,
|
||||
@ -68,7 +69,10 @@ def pallas_call_lowering(
|
||||
if debug:
|
||||
print(jaxpr)
|
||||
print(grid_mapping)
|
||||
|
||||
if grid_mapping.num_constant_operands:
|
||||
raise NotImplementedError(
|
||||
"captured consts not supported in the Mosaic GPU backend"
|
||||
)
|
||||
lowering_result = lowering.lower_jaxpr_to_module(
|
||||
grid_mapping,
|
||||
in_shapes,
|
||||
|
@ -187,12 +187,13 @@ def _pallas_call_impl_interpret(
|
||||
)
|
||||
assert next(dynamic_grid_args_iter, None) is None
|
||||
with grid_mapping.trace_env():
|
||||
discharged_jaxpr, consts = state_discharge.discharge_state(jaxpr, ())
|
||||
discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ())
|
||||
if debug:
|
||||
print(discharged_jaxpr)
|
||||
out = _initialize_output_vals(out_shapes, args, input_output_aliases)
|
||||
scalars, args = split_list(args, [grid_mapping.num_index_operands]) # type: ignore
|
||||
# invars: [*scalar_prefetch, *inputs, *outputs, *scratch]
|
||||
# invars: [*scalar_prefetch, *consts, *inputs, *outputs, *scratch]
|
||||
# args now contains: *consts, *inputs, *outputs
|
||||
num_invars = len(jaxpr.invars)
|
||||
num_inputs_outputs = (
|
||||
num_invars
|
||||
@ -243,6 +244,10 @@ def _pallas_call_impl_interpret(
|
||||
else:
|
||||
# Base case is always one iteration when grid is ()
|
||||
num_iterations = 1
|
||||
|
||||
# The scan carry: (i, loop_idx, *consts, *ins, *outs, *scratch)
|
||||
# i:int32 is the interation index
|
||||
# loop_idx: tuple[int32] are the program ids for each grid axis
|
||||
def cond(carry):
|
||||
i, *_ = carry
|
||||
return i < num_iterations
|
||||
@ -256,7 +261,7 @@ def _pallas_call_impl_interpret(
|
||||
carry, scratch = split_list(carry, [num_inout])
|
||||
with pallas_core.grid_env(local_grid_env):
|
||||
start_indices = [
|
||||
None if bm is None else bm.compute_start_indices(loop_idx, *scalars)
|
||||
None if bm is None else bm.compute_start_indices_interpret(loop_idx, *scalars)
|
||||
for bm in grid_mapping.block_mappings]
|
||||
blocks = map(_maybe_dynamic_slice, start_indices, block_shapes, carry,
|
||||
is_indexing_dim)
|
||||
@ -269,13 +274,14 @@ def _pallas_call_impl_interpret(
|
||||
len(blocks),
|
||||
len(scratch_values),
|
||||
)
|
||||
blocks = jax.core.eval_jaxpr(discharged_jaxpr, consts, *scalars,
|
||||
blocks = jax.core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars,
|
||||
*blocks, *scratch)
|
||||
blocks = blocks[grid_mapping.num_index_operands:]
|
||||
blocks, out_scratch = split_list(blocks, [num_inout])
|
||||
carry = map(_maybe_dynamic_update_slice, start_indices, block_shapes,
|
||||
carry, blocks, is_indexing_dim)
|
||||
return (i + 1, _get_next_indices(grid, loop_idx), *carry, *out_scratch)
|
||||
|
||||
(_, _, *carry) = lax.while_loop(
|
||||
cond, body, (jnp.int32(0), grid_start_indices, *carry)
|
||||
)
|
||||
@ -604,13 +610,13 @@ def _pallas_call_batching_rule(
|
||||
# Ordinarily, adding support for scalar prefetch in vmap would involve
|
||||
# modifying the block specs in a nontrivial way. However, if we are only
|
||||
# vmapping over 1-sized dimensions, we can just get rid of the dimensions
|
||||
# and pretend we were never vmapping over them at all.
|
||||
# and pretend we were never vmapped over them at all.
|
||||
if all(
|
||||
bdim is batching.not_mapped or arg.shape[bdim] == 1
|
||||
for arg, bdim in zip(scalar_args, scalar_bdims)
|
||||
):
|
||||
scalar_args = safe_map(_maybe_squeeze_out_bdim, scalar_args, scalar_bdims)
|
||||
scalar_bdims = [None] * len(scalar_args)
|
||||
scalar_bdims = [batching.not_mapped] * len(scalar_args)
|
||||
args = (*scalar_args, *args)
|
||||
dims = (*scalar_bdims, *bdims)
|
||||
else:
|
||||
@ -648,6 +654,7 @@ def _pallas_call_batching_rule(
|
||||
all_dims = list(dims) + [0] * len(out_shapes)
|
||||
|
||||
num_index_operands = grid_mapping.num_index_operands
|
||||
num_constant_operands = grid_mapping.num_constant_operands
|
||||
num_scratch_operands = grid_mapping.num_scratch_operands
|
||||
|
||||
# Only add a batch dimension for the avals that actually have a grid mapping.
|
||||
@ -661,11 +668,16 @@ def _pallas_call_batching_rule(
|
||||
block_mappings,
|
||||
)
|
||||
|
||||
# TODO(necula): should fix in_shapes to include the consts
|
||||
dims_no_consts = (
|
||||
dims[:num_index_operands] +
|
||||
dims[num_index_operands + num_constant_operands:]
|
||||
)
|
||||
batched_in_shapes = tuple(
|
||||
jax.ShapeDtypeStruct(x.shape if dim is batching.not_mapped else
|
||||
tuple_insert(x.shape, dim, axis_size),
|
||||
x.dtype)
|
||||
for x, dim in zip(in_shapes, dims))
|
||||
for x, dim in zip(in_shapes, dims_no_consts))
|
||||
batched_out_shapes = tuple(
|
||||
jax.ShapeDtypeStruct(tuple_insert(x.shape, 0, axis_size), x.dtype)
|
||||
for x in out_shapes)
|
||||
@ -900,11 +912,37 @@ def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
|
||||
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun,
|
||||
jaxpr_flat_avals, debug)
|
||||
if consts:
|
||||
jaxpr = state_utils.hoist_consts_to_refs(jaxpr)
|
||||
# Pad ``block_mappings`` to account for the hoisted constants.
|
||||
# The constants will be right after the index operands and just before
|
||||
# the real inputs and outputs.
|
||||
jaxpr = state_utils.hoist_consts_to_refs(
|
||||
jaxpr,
|
||||
index=grid_mapping.num_index_operands,
|
||||
make_abstract_ref=lambda aval: pallas_core.AbstractMemoryRef(aval, None))
|
||||
num_constant_operands = len(consts)
|
||||
# TODO(necula): refactor grid_mapping to remove this code duplication
|
||||
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
|
||||
if grid_mapping.num_index_operands:
|
||||
grid_avals += flat_in_avals[:grid_mapping.num_index_operands] # type: ignore
|
||||
# Create args, kwargs pytree def
|
||||
grid_tree = tree_util.tree_structure((tuple(grid_avals), {}))
|
||||
const_block_mappings = []
|
||||
for c_idx, c in enumerate(consts):
|
||||
const_block_mapping = pallas_core._convert_block_spec_to_block_mapping(
|
||||
grid_avals,
|
||||
pallas_core.BlockSpec(None, None),
|
||||
path=(tree_util.SequenceKey(c_idx),),
|
||||
aval=jax_core.ShapedArray(c.shape, c.dtype),
|
||||
in_tree=grid_tree,
|
||||
grid=grid_mapping.grid,
|
||||
mapped_dims=(),
|
||||
what="consts",
|
||||
)
|
||||
const_block_mappings.append(const_block_mapping)
|
||||
|
||||
grid_mapping = grid_mapping.replace(
|
||||
block_mappings=(*grid_mapping.block_mappings, *[None] * len(consts)),
|
||||
num_constant_operands=len(consts),
|
||||
block_mappings=(*const_block_mappings, *grid_mapping.block_mappings),
|
||||
num_constant_operands=num_constant_operands,
|
||||
)
|
||||
return grid_mapping, jaxpr, consts, out_tree_thunk()
|
||||
|
||||
@ -1105,8 +1143,9 @@ def pallas_call(
|
||||
f"and to output{tree_util.keystr(out_paths[o_idx])} with "
|
||||
f"a different abstract value {out_aval}.")
|
||||
|
||||
index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands])
|
||||
out_flat = pallas_call_p.bind(
|
||||
*dynamic_grid_bounds, *consts, *flat_args,
|
||||
*dynamic_grid_bounds, *index_args, *consts, *rest_args,
|
||||
jaxpr=jaxpr, name=name,
|
||||
in_shapes=tuple(jax.ShapeDtypeStruct(a.shape, a.dtype)
|
||||
for a in flat_args),
|
||||
|
@ -251,7 +251,7 @@ def _new_ir_context() -> ir.Context:
|
||||
|
||||
def lower_jaxpr_to_triton_module(
|
||||
jaxpr: jax_core.Jaxpr,
|
||||
in_shapes,
|
||||
in_out_shapes,
|
||||
grid_mapping: GridMapping,
|
||||
name: str,
|
||||
platform: str
|
||||
@ -301,6 +301,10 @@ def lower_jaxpr_to_triton_module(
|
||||
functools.partial(_eval_index_map, ctx, program_ids),
|
||||
grid_mapping.block_mappings,
|
||||
)
|
||||
consts_shapes = [
|
||||
jax.ShapeDtypeStruct(v.aval.shape, v.aval.dtype)
|
||||
for v in jaxpr.invars[grid_mapping.num_index_operands:grid_mapping.num_index_operands + grid_mapping.num_constant_operands]
|
||||
]
|
||||
block_infos = [
|
||||
BlockInfo(
|
||||
jax.ShapeDtypeStruct(shape_dtype.shape, shape_dtype.dtype),
|
||||
@ -310,7 +314,7 @@ def lower_jaxpr_to_triton_module(
|
||||
if block_mapping is not None
|
||||
else None
|
||||
for shape_dtype, block_mapping, start_idx in zip(
|
||||
(*in_shapes, *[()] * grid_mapping.num_constant_operands),
|
||||
(*consts_shapes, *in_out_shapes),
|
||||
grid_mapping.block_mappings,
|
||||
start_indices,
|
||||
)
|
||||
|
@ -54,6 +54,7 @@ def pallas_call_lowering(
|
||||
compiler_params: dict[str, Any],
|
||||
):
|
||||
if interpret:
|
||||
# TODO(necula): is this branch still needed?
|
||||
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
||||
ctx,
|
||||
*in_nodes,
|
||||
@ -72,6 +73,10 @@ def pallas_call_lowering(
|
||||
raise NotImplementedError(
|
||||
"dynamic grid bounds not supported in the Triton backend"
|
||||
)
|
||||
if grid_mapping.num_index_operands:
|
||||
raise NotImplementedError(
|
||||
"scalar prefetch not implemented in the Triton backend"
|
||||
)
|
||||
triton_params = compiler_params.get("triton", compiler_params)
|
||||
num_warps = triton_params.pop("num_warps", 4)
|
||||
[lowering_platform] = ctx.platforms or ctx.module_context.platforms
|
||||
|
@ -13,6 +13,8 @@
|
||||
# limitations under the License.
|
||||
"""Utilities for tracing stateful functions."""
|
||||
|
||||
from typing import Callable
|
||||
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src import core
|
||||
from jax._src import linear_util as lu
|
||||
@ -24,13 +26,20 @@ map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
|
||||
def hoist_consts_to_refs(jaxpr: core.Jaxpr, *, index: int = 0) -> core.Jaxpr:
|
||||
def hoist_consts_to_refs(
|
||||
jaxpr: core.Jaxpr,
|
||||
*,
|
||||
index: int = 0,
|
||||
make_abstract_ref: Callable[[core.AbstractValue], AbstractRef] = lambda aval: AbstractRef(aval)
|
||||
) -> core.Jaxpr:
|
||||
"""Hoists the constants in the given jaxpr into invars.
|
||||
|
||||
Args:
|
||||
jaxpr: The jaxpr.
|
||||
index: The index where the invars for the constants should be inserted.
|
||||
By default, the new invars are inserted *before* any existing invars.
|
||||
make_abstract_ref: a callable to construct an AbstractRef, or subtype
|
||||
thereof, from a constant AbstractValue.
|
||||
|
||||
Returns:
|
||||
A new jaxpr where the constants were hoisted into invars as ``Ref``s.
|
||||
@ -42,7 +51,7 @@ def hoist_consts_to_refs(jaxpr: core.Jaxpr, *, index: int = 0) -> core.Jaxpr:
|
||||
isinstance(var.aval, AbstractRef) for var in jaxpr.constvars
|
||||
]
|
||||
const_avals = [
|
||||
var.aval if is_ref else AbstractRef(var.aval)
|
||||
var.aval if is_ref else make_abstract_ref(var.aval)
|
||||
for is_ref, var in zip(is_const_ref, jaxpr.constvars)
|
||||
]
|
||||
in_avals = [var.aval for var in jaxpr.invars]
|
||||
|
@ -150,7 +150,6 @@ class PallasBaseTest(jtu.JaxTestCase):
|
||||
|
||||
class PallasCallTest(PallasBaseTest):
|
||||
|
||||
|
||||
def test_add_one(self):
|
||||
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
|
||||
# TODO: assertion failures on CPU in 64-bit mode
|
||||
@ -468,19 +467,22 @@ class PallasCallTest(PallasBaseTest):
|
||||
|
||||
def test_hoisted_consts(self):
|
||||
# See https://github.com/google/jax/issues/21557.
|
||||
if jtu.test_device_matches(["tpu"]) and not self.INTERPRET:
|
||||
self.skipTest("On TPU the test works only in interpret mode")
|
||||
x = jnp.zeros(32)
|
||||
indices = jnp.arange(4).reshape((2, 2))
|
||||
# to_store will be hoisted as a constant. Choose distinct shapes from in/outs.
|
||||
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
|
||||
x = np.arange(16 * 128, dtype=np.float32).reshape((16, 128))
|
||||
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype),
|
||||
grid=(2,),
|
||||
in_specs=[pl.BlockSpec((8, 128), lambda i: (i, 0))],
|
||||
out_specs=pl.BlockSpec((32, 128), lambda i: (i, 0)),
|
||||
)
|
||||
def kernel(src, dst):
|
||||
dst[indices] = src[indices]
|
||||
dst[0:1] = to_store
|
||||
|
||||
jax.block_until_ready(kernel(x))
|
||||
res = kernel(x)
|
||||
self.assertAllClose(res[0:1], to_store)
|
||||
|
||||
def test_vector_slicing(self):
|
||||
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
|
||||
@ -744,6 +746,17 @@ class ApiErrorTest(PallasBaseTest):
|
||||
"Index map for input\\[0\\] must return 1 values to match .*Currently returning 2 values."):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_index_map_captures_consts(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
index_map_result = np.array([0], dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref, o1_ref: None,
|
||||
out_shape=a,
|
||||
in_specs=[pl.BlockSpec((4,), lambda: index_map_result)])
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Index map for input\\[0\\] captures constants"):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_out_specs_mismatch_shape(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref, o1_ref: None,
|
||||
|
@ -37,7 +37,7 @@ config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.with_config(jax_traceback_filtering="off")
|
||||
class PallasTest(jtu.JaxTestCase):
|
||||
class PallasBaseTest(jtu.JaxTestCase):
|
||||
INTERPRET = False
|
||||
|
||||
def setUp(self):
|
||||
@ -58,7 +58,7 @@ class PallasTest(jtu.JaxTestCase):
|
||||
return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET)
|
||||
|
||||
|
||||
class PallasCallVmapTest(PallasTest):
|
||||
class PallasCallVmapTest(PallasBaseTest):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -130,6 +130,26 @@ class PallasCallVmapTest(PallasTest):
|
||||
out_ref = jnp.arange(1, 9).reshape((4, 2))
|
||||
np.testing.assert_allclose(out, out_ref)
|
||||
|
||||
def test_vmap_with_hoisted_consts(self):
|
||||
# to_store will be hoisted as a constant. Choose distinct shapes from in/outs.
|
||||
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
|
||||
x = np.arange(4 * 16 * 128, dtype=np.float32).reshape((4, 16, 128))
|
||||
|
||||
@jax.vmap
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype),
|
||||
grid=(2,),
|
||||
in_specs=[pl.BlockSpec((8, 128), lambda i: (i, 0))],
|
||||
out_specs=pl.BlockSpec((32, 128), lambda i: (i, 0)),
|
||||
)
|
||||
def kernel(src, dst):
|
||||
dst[0:1] = to_store
|
||||
|
||||
res = kernel(x)
|
||||
for i in range(x.shape[0]):
|
||||
self.assertAllClose(res[i, 0:1], to_store)
|
||||
|
||||
def test_vmap_of_kernel_with_input_output_aliases(self):
|
||||
@functools.partial(
|
||||
self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.int32),
|
||||
|
@ -17,6 +17,7 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import io
|
||||
import math
|
||||
import re
|
||||
import sys
|
||||
from absl.testing import absltest
|
||||
@ -71,7 +72,6 @@ class PallasBaseTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
class PallasCallScalarPrefetchTest(PallasBaseTest):
|
||||
|
||||
def test_trivial_scalar_prefetch(self):
|
||||
def body(_, x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
@ -115,6 +115,56 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
|
||||
)(s, x)
|
||||
np.testing.assert_array_equal(out, x)
|
||||
|
||||
@jtu.parameterized_filterable(
|
||||
kwargs=[
|
||||
dict(scratch=scratch, vmap=vmap)
|
||||
for scratch in [True, False]
|
||||
for vmap in [True, False]
|
||||
]
|
||||
)
|
||||
def test_scalar_prefetch_hoisted_const(self, *, scratch: bool, vmap: bool):
|
||||
if jtu.test_device_matches(["cpu"]) and jax.config.x64_enabled:
|
||||
self.skipTest("TODO: dslice(start, 1) raises error about slice inputs being int32 and int64")
|
||||
# to_store will be hoisted as constants. Choose distinct shapes from in/outs.
|
||||
to_store = np.arange(128, dtype=np.float32).reshape((1, 128))
|
||||
if vmap:
|
||||
x_shape = (4, 16, 128)
|
||||
else:
|
||||
x_shape = (16, 128)
|
||||
x = np.arange(math.prod(x_shape), dtype=np.float32).reshape(x_shape)
|
||||
|
||||
def f(x):
|
||||
s = jnp.array([1, 0], jnp.int32) # iteration 0 -> 1, iteration 1 -> 0
|
||||
@functools.partial(
|
||||
self.pallas_call,
|
||||
out_shape=jax.ShapeDtypeStruct((64, 128), x.dtype),
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=1,
|
||||
grid=(2,),
|
||||
in_specs=[pl.BlockSpec((8, 128),
|
||||
lambda i, s_ref: (pl.load(s_ref, (i,)), 0))],
|
||||
out_specs=pl.BlockSpec((32, 128),
|
||||
lambda i, s_ref: (pl.load(s_ref, i), 0)),
|
||||
scratch_shapes=([pltpu.SemaphoreType.REGULAR((3,))] if scratch
|
||||
else []),
|
||||
),
|
||||
)
|
||||
def kernel(s_ref, src, dst, *scratch_refs):
|
||||
store_idx = s_ref[pl.program_id(0)]
|
||||
pl.store(dst, (pl.dslice(store_idx, 1), slice(None)), to_store)
|
||||
return kernel(s, x)
|
||||
|
||||
if vmap:
|
||||
f = jax.vmap(f)
|
||||
res = f(x)
|
||||
if vmap:
|
||||
for i in range(x.shape[0]):
|
||||
self.assertAllClose(res[i, 0:1], to_store)
|
||||
self.assertAllClose(res[i, 33:34], to_store)
|
||||
else:
|
||||
self.assertAllClose(res[0:1], to_store)
|
||||
self.assertAllClose(res[33:34], to_store)
|
||||
|
||||
def test_block_spec_with_wrong_block_shape_errors(self):
|
||||
def body(x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
|
Loading…
x
Reference in New Issue
Block a user