mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[pallas] Improve the error localization
* Add the source location information for the index map function to `BlockMapping`. * Removed the `compute_index` wrapper around the index_map, so that we can get the location information for the index_map, not the wrapper. * Added source location to the errors related to index map functions. * Added an error if the index map returns something other than integer scalars. * Construct BlockSpec origins for arguments using JAX helper functions to get argument names * Removed redundant API error tests from tpu_pallas_test.py
This commit is contained in:
parent
cc212457d2
commit
6d53aaf7d0
@ -6,6 +6,8 @@ see {ref}`pallas-changelog`.
|
||||
|
||||
<!--
|
||||
Remember to align the itemized text with the first line of an item within a list.
|
||||
|
||||
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
|
||||
-->
|
||||
|
||||
## jax 0.4.32
|
||||
|
@ -11,7 +11,17 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c
|
||||
Remember to align the itemized text with the first line of an item within a list.
|
||||
-->
|
||||
|
||||
## Released with JAX 0.4.31
|
||||
## Released with jax 0.4.32
|
||||
|
||||
* Changes
|
||||
|
||||
* Deprecations
|
||||
|
||||
* New functionality:
|
||||
* Improved error messages for mistakes in the signature of the index map functions,
|
||||
to include the name and source location of the index map.
|
||||
|
||||
## Released with jax 0.4.31 (July 29, 2024)
|
||||
|
||||
* Changes
|
||||
* {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to
|
||||
|
@ -2301,8 +2301,7 @@ class DebugInfo(NamedTuple):
|
||||
def debug_info(fn: Callable, in_tree: PyTreeDef | None,
|
||||
out_tree_thunk: Callable[[], PyTreeDef] | None,
|
||||
has_kwargs: bool, traced_for: str) -> DebugInfo:
|
||||
try: sig = inspect.signature(fn)
|
||||
except (ValueError, TypeError): sig = None
|
||||
sig = api_util.fun_signature(fn)
|
||||
src_info = fun_sourceinfo(fn)
|
||||
return DebugInfo(src_info, sig, in_tree, out_tree_thunk, has_kwargs,
|
||||
traced_for)
|
||||
|
@ -56,6 +56,8 @@ TupleGrid = tuple[GridElement, ...]
|
||||
Grid = Union[NamedGrid, TupleGrid]
|
||||
StaticGrid = tuple[int, ...]
|
||||
GridMappingGrid = tuple[int | DynamicGridDim, ...]
|
||||
SrcInfoStr = str # function_name at filename:linenumber
|
||||
OriginStr = str # The origin of a block spec, e.g. input[2]["field"]
|
||||
|
||||
# Pytrees of jax.ShapeDtypeStruct
|
||||
ShapeDtypeStructTree = tuple[jax.ShapeDtypeStruct, ...]
|
||||
@ -247,14 +249,6 @@ class BlockSpec:
|
||||
self.indexing_mode = indexing_mode
|
||||
|
||||
|
||||
def compute_index(bs: BlockSpec, *args):
|
||||
assert bs.index_map is not None
|
||||
out = bs.index_map(*args)
|
||||
if not isinstance(out, tuple):
|
||||
out = (out,)
|
||||
return out
|
||||
|
||||
|
||||
class NoBlockSpec:
|
||||
def __repr__(self):
|
||||
return "NoBlockSpec"
|
||||
@ -274,9 +268,10 @@ class BlockMapping:
|
||||
block_shape: tuple[Mapped | int, ...]
|
||||
block_aval: AbstractMemoryRef # The block ref aval
|
||||
index_map_jaxpr: jax_core.ClosedJaxpr
|
||||
index_map_src_info: SrcInfoStr
|
||||
indexing_mode: IndexingMode
|
||||
array_shape_dtype: jax.ShapeDtypeStruct # The whole array
|
||||
origin: str # The origin, e.g. input[2]["field"]
|
||||
origin: OriginStr
|
||||
|
||||
def check_invariants(self) -> None:
|
||||
if not config.enable_checks.value: return
|
||||
@ -501,7 +496,7 @@ def _is_valid_grid_dim(dim: int | jax.Array) -> bool:
|
||||
|
||||
def _convert_block_spec_to_block_mapping(
|
||||
block_spec: BlockSpec,
|
||||
path: tree_util.KeyPath,
|
||||
origin: OriginStr,
|
||||
array_aval: jax_core.ShapedArray,
|
||||
*,
|
||||
# Inputs for the index_map
|
||||
@ -509,15 +504,13 @@ def _convert_block_spec_to_block_mapping(
|
||||
index_map_tree: tree_util.PyTreeDef,
|
||||
grid: GridMappingGrid,
|
||||
mapped_dims: tuple[int, ...],
|
||||
what: str, # Used to localize error messages, e.g., {what}{path}
|
||||
) -> BlockMapping:
|
||||
origin = f"{what}{tree_util.keystr(path)}"
|
||||
if block_spec is no_block_spec:
|
||||
block_spec = BlockSpec(None, None)
|
||||
if block_spec.index_map is None:
|
||||
index_map_func = lambda *args: (0,) * len(array_aval.shape)
|
||||
else:
|
||||
index_map_func = functools.partial(compute_index, block_spec)
|
||||
index_map_func = block_spec.index_map
|
||||
if block_spec.block_shape is None:
|
||||
block_shape = array_aval.shape
|
||||
else:
|
||||
@ -525,8 +518,9 @@ def _convert_block_spec_to_block_mapping(
|
||||
if len(array_aval.shape) != len(block_shape):
|
||||
raise ValueError(
|
||||
f"Block shape for {origin} (= {block_shape}) "
|
||||
f"must have the same number of dimensions as the array shape {array_aval.shape}"
|
||||
)
|
||||
"must have the same number of dimensions as the "
|
||||
f"array shape {array_aval.shape}.")
|
||||
|
||||
unmapped_block_shape = tuple(s for s in block_shape if s is not None)
|
||||
block_aval = AbstractMemoryRef(array_aval.update(shape=unmapped_block_shape),
|
||||
block_spec.memory_space)
|
||||
@ -535,31 +529,44 @@ def _convert_block_spec_to_block_mapping(
|
||||
raise ValueError(
|
||||
"shape polymorphism for Pallas does not support "
|
||||
"dynamically-shaped blocks. "
|
||||
f"{origin} has block_shape: {block_aval.shape}")
|
||||
f"Block spec for {origin} has block_shape: {block_aval.shape}")
|
||||
|
||||
flat_index_map_fun, _ = api_util.flatten_fun(lu.wrap_init(index_map_func),
|
||||
index_map_tree)
|
||||
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
|
||||
lu.wrap_init(index_map_func), index_map_tree)
|
||||
debug = pe.debug_info(index_map_func, index_map_tree, index_map_out_tree_thunk,
|
||||
False, "pallas_call index_map")
|
||||
index_map_src_info = debug.func_src_info or "<unknown>"
|
||||
with tracing_grid_env(grid, mapped_dims):
|
||||
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun,
|
||||
index_map_avals)
|
||||
index_map_avals,
|
||||
debug_info=debug)
|
||||
mapped_block_shape = tuple(
|
||||
mapped if s is None else s for s in block_shape)
|
||||
if len(out_avals) != len(mapped_block_shape):
|
||||
mapped if s is None else s for s in block_shape)
|
||||
if len(out_avals) != len(block_shape):
|
||||
raise ValueError(
|
||||
# TODO(necula): show the name and location of the index map function
|
||||
f"Index map for {origin} must return "
|
||||
f"{len(block_aval.shape)} values to match block shape {mapped_block_shape}. "
|
||||
f"Currently returning {len(out_avals)} values."
|
||||
)
|
||||
f"Index map function {index_map_src_info} for "
|
||||
f"{origin} must return "
|
||||
f"{len(block_shape)} values to match {block_shape=}. "
|
||||
f"Currently returning {len(out_avals)} values.")
|
||||
for i, ov in enumerate(out_avals):
|
||||
if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]:
|
||||
raise ValueError(
|
||||
f"Index map function {index_map_src_info} for "
|
||||
f"{origin} must return integer scalars. Output[{i}] has type "
|
||||
f"{ov}.")
|
||||
|
||||
|
||||
if consts:
|
||||
raise NotImplementedError(
|
||||
# TODO(necula): show the name and location of the index map function
|
||||
f"Index map for {origin} captures constants: "
|
||||
f"{consts}")
|
||||
raise ValueError(
|
||||
f"Index map function {index_map_src_info} for "
|
||||
f"{origin} must not capture constants: {consts}")
|
||||
|
||||
|
||||
mapping = BlockMapping(
|
||||
block_shape=mapped_block_shape,
|
||||
block_aval=block_aval,
|
||||
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
|
||||
index_map_src_info=index_map_src_info,
|
||||
indexing_mode=block_spec.indexing_mode,
|
||||
array_shape_dtype=jax.ShapeDtypeStruct(array_aval.shape, array_aval.dtype),
|
||||
origin=origin,
|
||||
@ -632,10 +639,10 @@ def get_grid_mapping(
|
||||
grid_spec: GridSpec,
|
||||
in_avals: Sequence[jax_core.AbstractValue],
|
||||
in_tree: tree_util.PyTreeDef,
|
||||
in_paths: Sequence[tree_util.KeyPath],
|
||||
in_origins: Sequence[OriginStr],
|
||||
out_avals: Sequence[jax_core.AbstractValue],
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
out_paths: Sequence[tree_util.KeyPath],
|
||||
out_origins: Sequence[OriginStr],
|
||||
) -> tuple[tuple[jax_core.AbstractValue, ...],
|
||||
GridMapping]:
|
||||
assert all(i is None or isinstance(i, int) for i in grid_spec.grid)
|
||||
@ -700,10 +707,9 @@ def get_grid_mapping(
|
||||
index_map_tree=index_map_tree,
|
||||
grid=grid_mapping_grid,
|
||||
mapped_dims=(),
|
||||
what="inputs",
|
||||
),
|
||||
flat_in_specs,
|
||||
in_paths[num_flat_scalar_prefetch:],
|
||||
in_origins[num_flat_scalar_prefetch:],
|
||||
in_avals,
|
||||
)
|
||||
|
||||
@ -723,10 +729,9 @@ def get_grid_mapping(
|
||||
index_map_tree=index_map_tree,
|
||||
grid=grid_mapping_grid,
|
||||
mapped_dims=(),
|
||||
what="outputs",
|
||||
),
|
||||
flat_out_specs,
|
||||
out_paths,
|
||||
out_origins,
|
||||
out_avals,
|
||||
)
|
||||
grid_mapping = GridMapping(
|
||||
|
@ -276,7 +276,7 @@ class BufferedRef:
|
||||
|
||||
@property
|
||||
def compute_index(self):
|
||||
return lambda *args: pallas_core.compute_index(self.spec, *args)
|
||||
return self.spec.index_map
|
||||
|
||||
@property
|
||||
def memory_space(self):
|
||||
|
@ -21,9 +21,9 @@ import itertools
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
from jax import api_util
|
||||
from jax import lax
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import checkify
|
||||
from jax._src import config
|
||||
from jax._src import core as jax_core
|
||||
@ -797,15 +797,15 @@ def pallas_call_checkify_rule(error: checkify.Error,
|
||||
# for the new error inputs and outputs.
|
||||
error_block_specs = [pallas_core.BlockSpec(None, None)] * len(shaped_err_avals)
|
||||
error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0])
|
||||
error_origins = tuple(f"errrors[{tree_util.keystr(p)}" for p in error_paths)
|
||||
error_block_mappings = map(
|
||||
partial(
|
||||
pallas_core._convert_block_spec_to_block_mapping,
|
||||
index_map_avals=grid_mapping.index_map_avals,
|
||||
index_map_tree=grid_mapping.index_map_tree,
|
||||
grid=grid_mapping.grid,
|
||||
mapped_dims=grid_mapping.vmapped_dims,
|
||||
what="error"),
|
||||
error_block_specs, error_paths, shaped_err_avals)
|
||||
mapped_dims=grid_mapping.vmapped_dims),
|
||||
error_block_specs, error_origins, shaped_err_avals)
|
||||
input_block_mappings, output_block_mappings = split_list(
|
||||
grid_mapping.block_mappings, [num_kernel_inputs,])
|
||||
grid_mapping_with_error = grid_mapping.replace(
|
||||
@ -837,6 +837,7 @@ checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
|
||||
|
||||
@weakref_lru_cache
|
||||
def _trace_kernel_to_jaxpr(fun: Callable,
|
||||
fun_src_info: pallas_core.SrcInfoStr,
|
||||
grid_mapping: GridMapping,
|
||||
kernel_avals: tuple[pallas_core.AbstractMemRef, ...],
|
||||
kernel_in_tree: tree_util.PyTreeDef,
|
||||
@ -863,13 +864,12 @@ def _trace_kernel_to_jaxpr(fun: Callable,
|
||||
for c_idx, c in enumerate(consts):
|
||||
const_block_mapping = pallas_core._convert_block_spec_to_block_mapping(
|
||||
pallas_core.BlockSpec(None, None),
|
||||
path=(tree_util.SequenceKey(c_idx),),
|
||||
origin=f"consts[{c_idx}]",
|
||||
array_aval=jax_core.ShapedArray(c.shape, c.dtype),
|
||||
index_map_avals=grid_mapping.index_map_avals,
|
||||
index_map_tree=grid_mapping.index_map_tree,
|
||||
grid=grid_mapping.grid,
|
||||
mapped_dims=(),
|
||||
what="consts",
|
||||
)
|
||||
const_block_mappings.append(const_block_mapping)
|
||||
|
||||
@ -880,8 +880,9 @@ def _trace_kernel_to_jaxpr(fun: Callable,
|
||||
kernel_out_tree = out_tree_thunk()
|
||||
if kernel_out_tree != tree_util.tree_structure(None):
|
||||
raise ValueError(
|
||||
"The kernel function in a pallas_call should return None. "
|
||||
f"Found a PyTree: {kernel_out_tree}")
|
||||
f"The kernel function {fun_src_info} in a "
|
||||
f"pallas_call should return None. "
|
||||
f"It returns a PyTree: {kernel_out_tree}")
|
||||
return grid_mapping, jaxpr, consts
|
||||
|
||||
def _extract_function_name(f: Callable, name: str | None) -> str:
|
||||
@ -979,7 +980,7 @@ jax_core.custom_typechecks[pallas_call_p] = _pallas_call_typecheck_rule
|
||||
|
||||
|
||||
def pallas_call(
|
||||
f: Callable[..., None],
|
||||
kernel: Callable[..., None],
|
||||
out_shape: Any,
|
||||
*,
|
||||
grid_spec: GridSpec | None = None,
|
||||
@ -997,7 +998,7 @@ def pallas_call(
|
||||
See `Pallas Quickstart <https://jax.readthedocs.io/en/latest/pallas/quickstart.html>`_.
|
||||
|
||||
Args:
|
||||
f: the kernel function, that receives a Ref for each input and output.
|
||||
kernel: the kernel function, that receives a Ref for each input and output.
|
||||
The shape of the Refs are given by the ``block_shape`` in the
|
||||
corresponding ``in_specs`` and ``out_specs``.
|
||||
out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape
|
||||
@ -1034,7 +1035,7 @@ def pallas_call(
|
||||
invoke the Pallas kernel.
|
||||
|
||||
"""
|
||||
name = _extract_function_name(f, name)
|
||||
name = _extract_function_name(kernel, name)
|
||||
if compiler_params is None:
|
||||
compiler_params = {}
|
||||
|
||||
@ -1072,14 +1073,31 @@ def pallas_call(
|
||||
for a in flat_args)
|
||||
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
|
||||
for v in flat_out_shapes)
|
||||
|
||||
kernel_fun_sig = api_util.fun_signature(kernel)
|
||||
arg_names = None
|
||||
kernel_src_info: pallas_core.SrcInfoStr = "<unknown>"
|
||||
if kernel_fun_sig:
|
||||
kernel_debug_info = api_util.debug_info(
|
||||
"pallas_call kernel",
|
||||
api_util.fun_sourceinfo(kernel),
|
||||
kernel_fun_sig,
|
||||
[1] * len(kernel_fun_sig.parameters), {}, (), ())
|
||||
if kernel_debug_info:
|
||||
arg_names = kernel_debug_info.arg_names
|
||||
kernel_src_info = kernel_debug_info.func_src_info
|
||||
in_origins = tuple(in_path_to_input_origin(p, arg_names)
|
||||
for p in in_paths)
|
||||
out_origins = tuple(f"outputs{tree_util.keystr(p)}" for p in out_paths)
|
||||
# TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc.
|
||||
kernel_avals, grid_mapping = pallas_core.get_grid_mapping(
|
||||
grid_spec,
|
||||
flat_in_avals, in_tree, in_paths,
|
||||
flat_out_avals, out_tree, out_paths)
|
||||
flat_in_avals, in_tree, in_origins,
|
||||
flat_out_avals, out_tree, out_origins)
|
||||
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals)
|
||||
grid_mapping, jaxpr, consts = _trace_kernel_to_jaxpr(
|
||||
f, grid_mapping, tuple(flat_kernel_avals), kernel_in_tree,
|
||||
kernel, kernel_src_info,
|
||||
grid_mapping, tuple(flat_kernel_avals), kernel_in_tree,
|
||||
interpret=interpret)
|
||||
for i_idx, o_idx in input_output_aliases.items():
|
||||
if i_idx not in range(len(flat_in_avals)):
|
||||
@ -1116,6 +1134,20 @@ def pallas_call(
|
||||
return wrapped
|
||||
|
||||
|
||||
def in_path_to_input_origin(in_path: tree_util.KeyPath,
|
||||
arg_names: tuple[str, ...] | None) -> pallas_core.OriginStr:
|
||||
"""Converts `args[k]<rest>` into `arg_k_name<rest>`."""
|
||||
if arg_names is None:
|
||||
return f"args{tree_util.keystr(in_path)}"
|
||||
if len(in_path) == 0:
|
||||
return "args"
|
||||
arg_idx, *rest_path = in_path
|
||||
if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len(arg_names):
|
||||
return arg_names[arg_idx.idx] + tree_util.keystr(tuple(rest_path))
|
||||
else:
|
||||
return f"args{tree_util.keystr(tuple(in_path))}"
|
||||
|
||||
|
||||
# We import the TPU backend at the top level because it defines flags. Note that
|
||||
# we can only do that at the bottom of this file, beacuse it also depends on
|
||||
# this module already being initialized.
|
||||
|
@ -810,7 +810,7 @@ def run_scoped(f: Callable[..., Any], *types, **kw_types) -> Any:
|
||||
"""Call the function with allocated references.
|
||||
|
||||
Args:
|
||||
f: The function that generatest the jaxpr.
|
||||
f: The function that generates the jaxpr.
|
||||
*types: The types of the function's positional arguments.
|
||||
**kw_types: The types of the function's keyword arguments.
|
||||
"""
|
||||
|
@ -671,7 +671,7 @@ class PallasCallInterpreterTest(PallasCallTest):
|
||||
|
||||
|
||||
class ApiErrorTest(PallasBaseTest):
|
||||
def test_pallas_kernel_args_mismatch(self):
|
||||
def test_pallas_call_kernel_args_mismatch(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref: None, # Missing o_ref
|
||||
out_shape=a)
|
||||
@ -687,11 +687,22 @@ class ApiErrorTest(PallasBaseTest):
|
||||
def test_pallas_call_error_kernel_returns_something(self, returns):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
# The kernel should not return anything
|
||||
f = self.pallas_call(lambda x_ref, o1_ref, o2_ref: returns,
|
||||
def my_kernel(x_ref, o1_ref, o2_ref):
|
||||
return returns
|
||||
f = self.pallas_call(my_kernel,
|
||||
out_shape=(a, a))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The kernel function in a pallas_call should return None"):
|
||||
"The kernel function my_kernel at .*pallas_test.py:.* in a pallas_call should return None"):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_kernel_with_no_signature_returns_something(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
f = self.pallas_call(lambda *args: 0, # Returns 0
|
||||
out_shape=a)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The kernel function .* at .*pallas_test.py:.* in a pallas_call should return None"):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_in_specs_not_a_sequence(self):
|
||||
@ -729,12 +740,46 @@ class ApiErrorTest(PallasBaseTest):
|
||||
|
||||
def test_pallas_call_index_map_wrong_number_of_results(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref, o1_ref: None,
|
||||
def my_index_map():
|
||||
return 0, 0
|
||||
f = self.pallas_call(lambda x_ref, o_ref: None,
|
||||
out_shape=a,
|
||||
in_specs=[pl.BlockSpec((4,), lambda: (0, 0))])
|
||||
in_specs=[pl.BlockSpec((4,), my_index_map)])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Index map for inputs\\[0\\] must return 1 values to match .*Currently returning 2 values."):
|
||||
"Index map function my_index_map at .*/pallas_test.py:.* for "
|
||||
"x_ref must return 1 values to match .*"
|
||||
"Currently returning 2 values."):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_index_map_wrong_return_type(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
def my_index_map(i):
|
||||
return 5.
|
||||
f = self.pallas_call(lambda x_ref, o_ref: None,
|
||||
out_shape=a,
|
||||
grid=(1,),
|
||||
in_specs=[pl.BlockSpec((4,), my_index_map)])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Index map function my_index_map at .*/pallas_test.py:.* for "
|
||||
"x_ref must return integer scalars. Output\\[0\\] has "
|
||||
"type .*float"):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_index_map_wrong_return_shape(self):
|
||||
a = np.arange(256, dtype=np.int32)
|
||||
def my_index_map(i):
|
||||
return jnp.arange(4, dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref, o_ref: None,
|
||||
out_shape=a,
|
||||
grid=(1,),
|
||||
in_specs=[pl.BlockSpec((4,), my_index_map)])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Index map function my_index_map at .*/pallas_test.py:.* for "
|
||||
"x_ref must return integer scalars. Output\\[0\\] has "
|
||||
"type .*int32\\[4\\]"):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_index_map_captures_consts(self):
|
||||
@ -742,10 +787,12 @@ class ApiErrorTest(PallasBaseTest):
|
||||
index_map_result = np.array([0], dtype=np.int32)
|
||||
f = self.pallas_call(lambda x_ref, o1_ref: None,
|
||||
out_shape=a,
|
||||
in_specs=[pl.BlockSpec((4,), lambda: index_map_result)])
|
||||
grid=(1,),
|
||||
in_specs=[pl.BlockSpec((4,),
|
||||
lambda i: jnp.array(index_map_result)[i])])
|
||||
with self.assertRaisesRegex(
|
||||
NotImplementedError,
|
||||
"Index map for inputs\\[0\\] captures constants"):
|
||||
ValueError,
|
||||
"Index map function .* for x_ref must not capture constants:"):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_out_specs_mismatch_shape(self):
|
||||
@ -767,7 +814,7 @@ class ApiErrorTest(PallasBaseTest):
|
||||
in_specs=[pl.BlockSpec((1, 1), lambda: (0, 0))])
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"Block shape for inputs\\[0\\] .* must have the same number of dimensions as the "
|
||||
"Block shape for x_ref .* must have the same number of dimensions as the "
|
||||
"array shape"):
|
||||
|
||||
f(a)
|
||||
|
@ -208,64 +208,6 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
|
||||
self.assertIsInstance(res, tuple) # Even though we asked for a list!
|
||||
self.assertAllClose(res[0][0], x)
|
||||
|
||||
def test_block_spec_with_wrong_block_shape_errors(self):
|
||||
def body(x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
|
||||
x = jnp.ones((16, 128))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'Block shape .* must have the same number of dimensions as the array shape .*'):
|
||||
_ = self.pallas_call(
|
||||
body,
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[pl.BlockSpec((128,), lambda i: (i, 0))], # WRONG
|
||||
out_specs=pl.BlockSpec((8, 128,), lambda i: (i, 0)),
|
||||
grid=(2,),
|
||||
),
|
||||
out_shape=x,
|
||||
)(x)
|
||||
|
||||
def test_block_spec_with_index_map_that_accepts_wrong_number_of_args_errors(self):
|
||||
def body(x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
|
||||
x = jnp.ones((16, 128))
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
'missing 1 required positional argument: \'j\''):
|
||||
_ = self.pallas_call(
|
||||
body,
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[pl.BlockSpec((8, 128,), lambda i, j: (i, 0))], # WRONG
|
||||
out_specs=pl.BlockSpec((8, 128,), lambda i: (i, 0),),
|
||||
grid=(2,),
|
||||
),
|
||||
out_shape=x,
|
||||
)(x)
|
||||
|
||||
def test_block_spec_with_index_map_returns_wrong_number_of_values_errors(self):
|
||||
def body(x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
|
||||
x = jnp.ones((16, 128))
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
r'Index map for inputs\[0\] must return 2 values to match block shape \(8, 128\).'
|
||||
' Currently returning 1 values.'):
|
||||
_ = self.pallas_call(
|
||||
body,
|
||||
grid_spec=pltpu.PrefetchScalarGridSpec(
|
||||
num_scalar_prefetch=0,
|
||||
in_specs=[pl.BlockSpec((8, 128,), lambda i: (i,))], # WRONG
|
||||
out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)),
|
||||
grid=(2,),
|
||||
),
|
||||
out_shape=x,
|
||||
)(x)
|
||||
|
||||
def test_vmap_scalar_prefetch(self):
|
||||
def body(_, x_ref, o_ref):
|
||||
o_ref[...] = x_ref[...]
|
||||
|
Loading…
x
Reference in New Issue
Block a user