[pallas] Improve the error localization

* Add the source location information for the index map function to
    `BlockMapping`.
  * Removed the `compute_index` wrapper around the index_map, so that
    we can get the location information for the index_map, not the wrapper.
  * Added source location to the errors related to index map functions.
  * Added an error if the index map returns something other than integer
    scalars.
  * Construct BlockSpec origins for arguments using JAX helper functions
    to get argument names
  * Removed redundant API error tests from tpu_pallas_test.py
This commit is contained in:
George Necula 2024-07-18 15:33:40 +02:00
parent cc212457d2
commit 6d53aaf7d0
9 changed files with 160 additions and 123 deletions

View File

@ -6,6 +6,8 @@ see {ref}`pallas-changelog`.
<!--
Remember to align the itemized text with the first line of an item within a list.
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
-->
## jax 0.4.32

View File

@ -11,7 +11,17 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c
Remember to align the itemized text with the first line of an item within a list.
-->
## Released with JAX 0.4.31
## Released with jax 0.4.32
* Changes
* Deprecations
* New functionality:
* Improved error messages for mistakes in the signature of the index map functions,
to include the name and source location of the index map.
## Released with jax 0.4.31 (July 29, 2024)
* Changes
* {class}`jax.experimental.pallas.BlockSpec` now expects `block_shape` to

View File

@ -2301,8 +2301,7 @@ class DebugInfo(NamedTuple):
def debug_info(fn: Callable, in_tree: PyTreeDef | None,
out_tree_thunk: Callable[[], PyTreeDef] | None,
has_kwargs: bool, traced_for: str) -> DebugInfo:
try: sig = inspect.signature(fn)
except (ValueError, TypeError): sig = None
sig = api_util.fun_signature(fn)
src_info = fun_sourceinfo(fn)
return DebugInfo(src_info, sig, in_tree, out_tree_thunk, has_kwargs,
traced_for)

View File

@ -56,6 +56,8 @@ TupleGrid = tuple[GridElement, ...]
Grid = Union[NamedGrid, TupleGrid]
StaticGrid = tuple[int, ...]
GridMappingGrid = tuple[int | DynamicGridDim, ...]
SrcInfoStr = str # function_name at filename:linenumber
OriginStr = str # The origin of a block spec, e.g. input[2]["field"]
# Pytrees of jax.ShapeDtypeStruct
ShapeDtypeStructTree = tuple[jax.ShapeDtypeStruct, ...]
@ -247,14 +249,6 @@ class BlockSpec:
self.indexing_mode = indexing_mode
def compute_index(bs: BlockSpec, *args):
assert bs.index_map is not None
out = bs.index_map(*args)
if not isinstance(out, tuple):
out = (out,)
return out
class NoBlockSpec:
def __repr__(self):
return "NoBlockSpec"
@ -274,9 +268,10 @@ class BlockMapping:
block_shape: tuple[Mapped | int, ...]
block_aval: AbstractMemoryRef # The block ref aval
index_map_jaxpr: jax_core.ClosedJaxpr
index_map_src_info: SrcInfoStr
indexing_mode: IndexingMode
array_shape_dtype: jax.ShapeDtypeStruct # The whole array
origin: str # The origin, e.g. input[2]["field"]
origin: OriginStr
def check_invariants(self) -> None:
if not config.enable_checks.value: return
@ -501,7 +496,7 @@ def _is_valid_grid_dim(dim: int | jax.Array) -> bool:
def _convert_block_spec_to_block_mapping(
block_spec: BlockSpec,
path: tree_util.KeyPath,
origin: OriginStr,
array_aval: jax_core.ShapedArray,
*,
# Inputs for the index_map
@ -509,15 +504,13 @@ def _convert_block_spec_to_block_mapping(
index_map_tree: tree_util.PyTreeDef,
grid: GridMappingGrid,
mapped_dims: tuple[int, ...],
what: str, # Used to localize error messages, e.g., {what}{path}
) -> BlockMapping:
origin = f"{what}{tree_util.keystr(path)}"
if block_spec is no_block_spec:
block_spec = BlockSpec(None, None)
if block_spec.index_map is None:
index_map_func = lambda *args: (0,) * len(array_aval.shape)
else:
index_map_func = functools.partial(compute_index, block_spec)
index_map_func = block_spec.index_map
if block_spec.block_shape is None:
block_shape = array_aval.shape
else:
@ -525,8 +518,9 @@ def _convert_block_spec_to_block_mapping(
if len(array_aval.shape) != len(block_shape):
raise ValueError(
f"Block shape for {origin} (= {block_shape}) "
f"must have the same number of dimensions as the array shape {array_aval.shape}"
)
"must have the same number of dimensions as the "
f"array shape {array_aval.shape}.")
unmapped_block_shape = tuple(s for s in block_shape if s is not None)
block_aval = AbstractMemoryRef(array_aval.update(shape=unmapped_block_shape),
block_spec.memory_space)
@ -535,31 +529,44 @@ def _convert_block_spec_to_block_mapping(
raise ValueError(
"shape polymorphism for Pallas does not support "
"dynamically-shaped blocks. "
f"{origin} has block_shape: {block_aval.shape}")
f"Block spec for {origin} has block_shape: {block_aval.shape}")
flat_index_map_fun, _ = api_util.flatten_fun(lu.wrap_init(index_map_func),
index_map_tree)
flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun(
lu.wrap_init(index_map_func), index_map_tree)
debug = pe.debug_info(index_map_func, index_map_tree, index_map_out_tree_thunk,
False, "pallas_call index_map")
index_map_src_info = debug.func_src_info or "<unknown>"
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun,
index_map_avals)
index_map_avals,
debug_info=debug)
mapped_block_shape = tuple(
mapped if s is None else s for s in block_shape)
if len(out_avals) != len(mapped_block_shape):
mapped if s is None else s for s in block_shape)
if len(out_avals) != len(block_shape):
raise ValueError(
# TODO(necula): show the name and location of the index map function
f"Index map for {origin} must return "
f"{len(block_aval.shape)} values to match block shape {mapped_block_shape}. "
f"Currently returning {len(out_avals)} values."
)
f"Index map function {index_map_src_info} for "
f"{origin} must return "
f"{len(block_shape)} values to match {block_shape=}. "
f"Currently returning {len(out_avals)} values.")
for i, ov in enumerate(out_avals):
if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]:
raise ValueError(
f"Index map function {index_map_src_info} for "
f"{origin} must return integer scalars. Output[{i}] has type "
f"{ov}.")
if consts:
raise NotImplementedError(
# TODO(necula): show the name and location of the index map function
f"Index map for {origin} captures constants: "
f"{consts}")
raise ValueError(
f"Index map function {index_map_src_info} for "
f"{origin} must not capture constants: {consts}")
mapping = BlockMapping(
block_shape=mapped_block_shape,
block_aval=block_aval,
index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts),
index_map_src_info=index_map_src_info,
indexing_mode=block_spec.indexing_mode,
array_shape_dtype=jax.ShapeDtypeStruct(array_aval.shape, array_aval.dtype),
origin=origin,
@ -632,10 +639,10 @@ def get_grid_mapping(
grid_spec: GridSpec,
in_avals: Sequence[jax_core.AbstractValue],
in_tree: tree_util.PyTreeDef,
in_paths: Sequence[tree_util.KeyPath],
in_origins: Sequence[OriginStr],
out_avals: Sequence[jax_core.AbstractValue],
out_tree: tree_util.PyTreeDef,
out_paths: Sequence[tree_util.KeyPath],
out_origins: Sequence[OriginStr],
) -> tuple[tuple[jax_core.AbstractValue, ...],
GridMapping]:
assert all(i is None or isinstance(i, int) for i in grid_spec.grid)
@ -700,10 +707,9 @@ def get_grid_mapping(
index_map_tree=index_map_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="inputs",
),
flat_in_specs,
in_paths[num_flat_scalar_prefetch:],
in_origins[num_flat_scalar_prefetch:],
in_avals,
)
@ -723,10 +729,9 @@ def get_grid_mapping(
index_map_tree=index_map_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="outputs",
),
flat_out_specs,
out_paths,
out_origins,
out_avals,
)
grid_mapping = GridMapping(

View File

@ -276,7 +276,7 @@ class BufferedRef:
@property
def compute_index(self):
return lambda *args: pallas_core.compute_index(self.spec, *args)
return self.spec.index_map
@property
def memory_space(self):

View File

@ -21,9 +21,9 @@ import itertools
from typing import Any
import jax
from jax import api_util
from jax import lax
from jax._src import ad_util
from jax._src import api_util
from jax._src import checkify
from jax._src import config
from jax._src import core as jax_core
@ -797,15 +797,15 @@ def pallas_call_checkify_rule(error: checkify.Error,
# for the new error inputs and outputs.
error_block_specs = [pallas_core.BlockSpec(None, None)] * len(shaped_err_avals)
error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0])
error_origins = tuple(f"errrors[{tree_util.keystr(p)}" for p in error_paths)
error_block_mappings = map(
partial(
pallas_core._convert_block_spec_to_block_mapping,
index_map_avals=grid_mapping.index_map_avals,
index_map_tree=grid_mapping.index_map_tree,
grid=grid_mapping.grid,
mapped_dims=grid_mapping.vmapped_dims,
what="error"),
error_block_specs, error_paths, shaped_err_avals)
mapped_dims=grid_mapping.vmapped_dims),
error_block_specs, error_origins, shaped_err_avals)
input_block_mappings, output_block_mappings = split_list(
grid_mapping.block_mappings, [num_kernel_inputs,])
grid_mapping_with_error = grid_mapping.replace(
@ -837,6 +837,7 @@ checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
@weakref_lru_cache
def _trace_kernel_to_jaxpr(fun: Callable,
fun_src_info: pallas_core.SrcInfoStr,
grid_mapping: GridMapping,
kernel_avals: tuple[pallas_core.AbstractMemRef, ...],
kernel_in_tree: tree_util.PyTreeDef,
@ -863,13 +864,12 @@ def _trace_kernel_to_jaxpr(fun: Callable,
for c_idx, c in enumerate(consts):
const_block_mapping = pallas_core._convert_block_spec_to_block_mapping(
pallas_core.BlockSpec(None, None),
path=(tree_util.SequenceKey(c_idx),),
origin=f"consts[{c_idx}]",
array_aval=jax_core.ShapedArray(c.shape, c.dtype),
index_map_avals=grid_mapping.index_map_avals,
index_map_tree=grid_mapping.index_map_tree,
grid=grid_mapping.grid,
mapped_dims=(),
what="consts",
)
const_block_mappings.append(const_block_mapping)
@ -880,8 +880,9 @@ def _trace_kernel_to_jaxpr(fun: Callable,
kernel_out_tree = out_tree_thunk()
if kernel_out_tree != tree_util.tree_structure(None):
raise ValueError(
"The kernel function in a pallas_call should return None. "
f"Found a PyTree: {kernel_out_tree}")
f"The kernel function {fun_src_info} in a "
f"pallas_call should return None. "
f"It returns a PyTree: {kernel_out_tree}")
return grid_mapping, jaxpr, consts
def _extract_function_name(f: Callable, name: str | None) -> str:
@ -979,7 +980,7 @@ jax_core.custom_typechecks[pallas_call_p] = _pallas_call_typecheck_rule
def pallas_call(
f: Callable[..., None],
kernel: Callable[..., None],
out_shape: Any,
*,
grid_spec: GridSpec | None = None,
@ -997,7 +998,7 @@ def pallas_call(
See `Pallas Quickstart <https://jax.readthedocs.io/en/latest/pallas/quickstart.html>`_.
Args:
f: the kernel function, that receives a Ref for each input and output.
kernel: the kernel function, that receives a Ref for each input and output.
The shape of the Refs are given by the ``block_shape`` in the
corresponding ``in_specs`` and ``out_specs``.
out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape
@ -1034,7 +1035,7 @@ def pallas_call(
invoke the Pallas kernel.
"""
name = _extract_function_name(f, name)
name = _extract_function_name(kernel, name)
if compiler_params is None:
compiler_params = {}
@ -1072,14 +1073,31 @@ def pallas_call(
for a in flat_args)
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
for v in flat_out_shapes)
kernel_fun_sig = api_util.fun_signature(kernel)
arg_names = None
kernel_src_info: pallas_core.SrcInfoStr = "<unknown>"
if kernel_fun_sig:
kernel_debug_info = api_util.debug_info(
"pallas_call kernel",
api_util.fun_sourceinfo(kernel),
kernel_fun_sig,
[1] * len(kernel_fun_sig.parameters), {}, (), ())
if kernel_debug_info:
arg_names = kernel_debug_info.arg_names
kernel_src_info = kernel_debug_info.func_src_info
in_origins = tuple(in_path_to_input_origin(p, arg_names)
for p in in_paths)
out_origins = tuple(f"outputs{tree_util.keystr(p)}" for p in out_paths)
# TODO(necula): check that input_output_aliases is well-formed: no duplicates, etc.
kernel_avals, grid_mapping = pallas_core.get_grid_mapping(
grid_spec,
flat_in_avals, in_tree, in_paths,
flat_out_avals, out_tree, out_paths)
flat_in_avals, in_tree, in_origins,
flat_out_avals, out_tree, out_origins)
flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals)
grid_mapping, jaxpr, consts = _trace_kernel_to_jaxpr(
f, grid_mapping, tuple(flat_kernel_avals), kernel_in_tree,
kernel, kernel_src_info,
grid_mapping, tuple(flat_kernel_avals), kernel_in_tree,
interpret=interpret)
for i_idx, o_idx in input_output_aliases.items():
if i_idx not in range(len(flat_in_avals)):
@ -1116,6 +1134,20 @@ def pallas_call(
return wrapped
def in_path_to_input_origin(in_path: tree_util.KeyPath,
arg_names: tuple[str, ...] | None) -> pallas_core.OriginStr:
"""Converts `args[k]<rest>` into `arg_k_name<rest>`."""
if arg_names is None:
return f"args{tree_util.keystr(in_path)}"
if len(in_path) == 0:
return "args"
arg_idx, *rest_path = in_path
if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len(arg_names):
return arg_names[arg_idx.idx] + tree_util.keystr(tuple(rest_path))
else:
return f"args{tree_util.keystr(tuple(in_path))}"
# We import the TPU backend at the top level because it defines flags. Note that
# we can only do that at the bottom of this file, beacuse it also depends on
# this module already being initialized.

View File

@ -810,7 +810,7 @@ def run_scoped(f: Callable[..., Any], *types, **kw_types) -> Any:
"""Call the function with allocated references.
Args:
f: The function that generatest the jaxpr.
f: The function that generates the jaxpr.
*types: The types of the function's positional arguments.
**kw_types: The types of the function's keyword arguments.
"""

View File

@ -671,7 +671,7 @@ class PallasCallInterpreterTest(PallasCallTest):
class ApiErrorTest(PallasBaseTest):
def test_pallas_kernel_args_mismatch(self):
def test_pallas_call_kernel_args_mismatch(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref: None, # Missing o_ref
out_shape=a)
@ -687,11 +687,22 @@ class ApiErrorTest(PallasBaseTest):
def test_pallas_call_error_kernel_returns_something(self, returns):
a = np.arange(256, dtype=np.int32)
# The kernel should not return anything
f = self.pallas_call(lambda x_ref, o1_ref, o2_ref: returns,
def my_kernel(x_ref, o1_ref, o2_ref):
return returns
f = self.pallas_call(my_kernel,
out_shape=(a, a))
with self.assertRaisesRegex(
ValueError,
"The kernel function in a pallas_call should return None"):
"The kernel function my_kernel at .*pallas_test.py:.* in a pallas_call should return None"):
f(a)
def test_pallas_call_kernel_with_no_signature_returns_something(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda *args: 0, # Returns 0
out_shape=a)
with self.assertRaisesRegex(
ValueError,
"The kernel function .* at .*pallas_test.py:.* in a pallas_call should return None"):
f(a)
def test_pallas_call_in_specs_not_a_sequence(self):
@ -729,12 +740,46 @@ class ApiErrorTest(PallasBaseTest):
def test_pallas_call_index_map_wrong_number_of_results(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
def my_index_map():
return 0, 0
f = self.pallas_call(lambda x_ref, o_ref: None,
out_shape=a,
in_specs=[pl.BlockSpec((4,), lambda: (0, 0))])
in_specs=[pl.BlockSpec((4,), my_index_map)])
with self.assertRaisesRegex(
ValueError,
"Index map for inputs\\[0\\] must return 1 values to match .*Currently returning 2 values."):
"Index map function my_index_map at .*/pallas_test.py:.* for "
"x_ref must return 1 values to match .*"
"Currently returning 2 values."):
f(a)
def test_pallas_call_index_map_wrong_return_type(self):
a = np.arange(256, dtype=np.int32)
def my_index_map(i):
return 5.
f = self.pallas_call(lambda x_ref, o_ref: None,
out_shape=a,
grid=(1,),
in_specs=[pl.BlockSpec((4,), my_index_map)])
with self.assertRaisesRegex(
ValueError,
"Index map function my_index_map at .*/pallas_test.py:.* for "
"x_ref must return integer scalars. Output\\[0\\] has "
"type .*float"):
f(a)
def test_pallas_call_index_map_wrong_return_shape(self):
a = np.arange(256, dtype=np.int32)
def my_index_map(i):
return jnp.arange(4, dtype=np.int32)
f = self.pallas_call(lambda x_ref, o_ref: None,
out_shape=a,
grid=(1,),
in_specs=[pl.BlockSpec((4,), my_index_map)])
with self.assertRaisesRegex(
ValueError,
"Index map function my_index_map at .*/pallas_test.py:.* for "
"x_ref must return integer scalars. Output\\[0\\] has "
"type .*int32\\[4\\]"):
f(a)
def test_pallas_call_index_map_captures_consts(self):
@ -742,10 +787,12 @@ class ApiErrorTest(PallasBaseTest):
index_map_result = np.array([0], dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
out_shape=a,
in_specs=[pl.BlockSpec((4,), lambda: index_map_result)])
grid=(1,),
in_specs=[pl.BlockSpec((4,),
lambda i: jnp.array(index_map_result)[i])])
with self.assertRaisesRegex(
NotImplementedError,
"Index map for inputs\\[0\\] captures constants"):
ValueError,
"Index map function .* for x_ref must not capture constants:"):
f(a)
def test_pallas_call_out_specs_mismatch_shape(self):
@ -767,7 +814,7 @@ class ApiErrorTest(PallasBaseTest):
in_specs=[pl.BlockSpec((1, 1), lambda: (0, 0))])
with self.assertRaisesRegex(
ValueError,
"Block shape for inputs\\[0\\] .* must have the same number of dimensions as the "
"Block shape for x_ref .* must have the same number of dimensions as the "
"array shape"):
f(a)

View File

@ -208,64 +208,6 @@ class PallasCallScalarPrefetchTest(PallasBaseTest):
self.assertIsInstance(res, tuple) # Even though we asked for a list!
self.assertAllClose(res[0][0], x)
def test_block_spec_with_wrong_block_shape_errors(self):
def body(x_ref, o_ref):
o_ref[...] = x_ref[...]
x = jnp.ones((16, 128))
with self.assertRaisesRegex(
ValueError,
'Block shape .* must have the same number of dimensions as the array shape .*'):
_ = self.pallas_call(
body,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[pl.BlockSpec((128,), lambda i: (i, 0))], # WRONG
out_specs=pl.BlockSpec((8, 128,), lambda i: (i, 0)),
grid=(2,),
),
out_shape=x,
)(x)
def test_block_spec_with_index_map_that_accepts_wrong_number_of_args_errors(self):
def body(x_ref, o_ref):
o_ref[...] = x_ref[...]
x = jnp.ones((16, 128))
with self.assertRaisesRegex(
TypeError,
'missing 1 required positional argument: \'j\''):
_ = self.pallas_call(
body,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[pl.BlockSpec((8, 128,), lambda i, j: (i, 0))], # WRONG
out_specs=pl.BlockSpec((8, 128,), lambda i: (i, 0),),
grid=(2,),
),
out_shape=x,
)(x)
def test_block_spec_with_index_map_returns_wrong_number_of_values_errors(self):
def body(x_ref, o_ref):
o_ref[...] = x_ref[...]
x = jnp.ones((16, 128))
with self.assertRaisesRegex(
ValueError,
r'Index map for inputs\[0\] must return 2 values to match block shape \(8, 128\).'
' Currently returning 1 values.'):
_ = self.pallas_call(
body,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[pl.BlockSpec((8, 128,), lambda i: (i,))], # WRONG
out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)),
grid=(2,),
),
out_shape=x,
)(x)
def test_vmap_scalar_prefetch(self):
def body(_, x_ref, o_ref):
o_ref[...] = x_ref[...]