[pallas] Improve some error messages and add API tests.

We make the following improvements:

  * pytree structural disequality messages now attempt to localize the
    mismatch using tree_util.KeyPath.
  * we generate a simpler error message for when `in_specs` is not
    a sequence, instead of the current PyTreeDef mismatch error.
  * we generate an error message for when the index map function
    in a BlockSpec returns an unexpected number of results.
  * added error localization to the existing shape polymorphism
    check that the block shapes are static.
  * We check that the kernel function returns None. Without this
    we used to get `body_fun output and input must have same type structure`
    in the interpreter, `assert len(jaxpr.outvars) == 0` on GPU,
    and `INTERNAL: Mosaic failed to compile TPU kernel: has 1 operands, but enclosing function (@main) returns 0`
    on TPU.
  * we check that the rank of the block_shape matches the rank of
    the overall array. Without this we used to get a `safe_zip`
    error. We also carry the pytree paths to localize the error.

To simplify the generation of the error messages we added a helper
function `tree_util.equality_errors_pytreedef`, which is just like
`tree_util.equality_errors` but takes `PyTreeDef` inputs rather than
PyTrees. We then used this new helper function in `pjit.py` and `stages.py`.
This commit is contained in:
George Necula 2024-07-02 00:40:13 -07:00 committed by George Necula
parent f0e36d5083
commit a4a9499a40
10 changed files with 339 additions and 82 deletions

View File

@ -296,10 +296,12 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid:
def _convert_block_spec_to_block_mapping(
in_avals: Sequence[jax_core.ShapedArray],
block_spec: BlockSpec,
path: tree_util.KeyPath,
aval: jax_core.ShapedArray,
in_tree: Any,
grid: GridMappingGrid,
mapped_dims: tuple[int, ...],
what: str, # Used to localize error messages, e.g., {what}{path}
) -> BlockMapping | None:
if block_spec is no_block_spec:
return None
@ -313,7 +315,13 @@ def _convert_block_spec_to_block_mapping(
mapped if s is None else s for s in block_shape)
flat_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index), in_tree)
with tracing_grid_env(grid, mapped_dims):
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals)
if len(out_avals) != len(block_shape):
raise ValueError(
f"Index map for {what}{tree_util.keystr(path)} must return "
f"{len(aval.shape)} values to match {block_shape=}. "
f"Currently returning {len(out_avals)} values."
)
return BlockMapping(
block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode
)
@ -327,42 +335,51 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None
return ref.update(inner_aval=ref.inner_aval.update(shape=shape))
def _check_static_ref_shape(ref: state.AbstractRef) -> state.AbstractRef:
shape = ref.shape
if not jax_core.is_constant_shape(shape):
# TODO(necula): thread the tree labels so that we can localize the error
raise ValueError("shape polymorphism for Pallas does not support "
f"dynamically-shaped blocks. Found block_shape: {shape}")
return ref
def _get_ref_avals(grid, in_avals, in_specs, out_avals, out_specs):
def _get_memory_space(spec):
def _get_ref_avals(grid,
in_avals: Sequence[jax_core.ShapedArray],
in_specs: Sequence[BlockSpec],
in_paths: Sequence[tree_util.KeyPath],
out_avals: Sequence[jax_core.ShapedArray],
out_specs: Sequence[BlockSpec],
out_paths: Sequence[tree_util.KeyPath]):
def make_ref_aval(aval: jax_core.ShapedArray,
spec: BlockSpec,
path: tree_util.KeyPath,
what: str) -> state.AbstractRef:
if spec is no_block_spec:
return None
return spec.memory_space
memory_space = None
block_shape = None
else:
memory_space = spec.memory_space
block_shape = spec.block_shape
ref_aval = AbstractMemoryRef(aval, memory_space)
if block_shape is not None:
if len(ref_aval.shape) != len(block_shape):
raise ValueError(
f"Block shape for {what}{tree_util.keystr(path)} (= {block_shape}) "
f"must have the same number of dimensions as the array shape {ref_aval.shape}"
)
trimmed_block_shape = tuple(s for s in block_shape if s is not None)
ref_aval = ref_aval.update(
inner_aval=ref_aval.inner_aval.update(shape=trimmed_block_shape))
if not jax_core.is_constant_shape(ref_aval.shape):
raise ValueError(
"shape polymorphism for Pallas does not support "
"dynamically-shaped blocks. "
f"{what}{tree_util.keystr(path)} has block_shape: {ref_aval.shape}")
return ref_aval
in_ref_avals = [
AbstractMemoryRef(aval, _get_memory_space(in_spec))
for aval, in_spec in zip(in_avals, in_specs)
make_ref_aval(aval, in_spec, in_path, "input")
for aval, in_spec, in_path in zip(in_avals, in_specs, in_paths)
]
out_ref_avals = [
AbstractMemoryRef(aval, _get_memory_space(out_spec))
for aval, out_spec in zip(out_avals, out_specs)
make_ref_aval(aval, out_spec, out_path, "output")
for aval, out_spec, out_path in zip(out_avals, out_specs, out_paths)
]
if grid is None:
in_specs = [None] * len(in_avals)
out_specs = [None] * len(out_avals)
tiled_in_ref_avals = [
_check_static_ref_shape(aval if in_spec is no_block_spec
else _tile_ref(aval, in_spec.block_shape))
for aval, in_spec in zip(in_ref_avals, in_specs)
]
tiled_out_ref_avals = [
_check_static_ref_shape(aval if out_spec is no_block_spec
else _tile_ref(aval, out_spec.block_shape))
for aval, out_spec in zip(out_ref_avals, out_specs)
]
return in_specs, tiled_in_ref_avals, out_specs, tiled_out_ref_avals
return in_specs, in_ref_avals, out_specs, out_ref_avals
class NoBlockSpec:
pass
@ -386,6 +403,8 @@ class GridSpec:
# Be more lenient for in/out_specs
if isinstance(in_specs, list):
in_specs = tuple(in_specs)
elif in_specs is not no_block_spec and not isinstance(in_specs, Sequence):
raise ValueError(f"`in_specs` must be a tuple or a list. Found: {in_specs}")
if isinstance(out_specs, list):
out_specs = tuple(out_specs)
@ -410,20 +429,20 @@ class GridSpec:
flat_in_specs = self.in_specs
if self.in_specs_tree != in_tree:
raise ValueError(
"Pytree specs for arguments and `in_specs` must match: "
f"{in_tree} vs. {self.in_specs_tree}")
pytreedef_mismatch_err_msg("`in_specs`", self.in_specs_tree,
"inputs", in_tree))
if self.out_specs is no_block_spec:
flat_out_specs = [no_block_spec] * len(out_avals)
else:
flat_out_specs = self.out_specs
if self.out_specs_tree != out_tree:
raise ValueError(
"Pytree specs for `out_shape` and `out_specs` must match: "
f"{out_tree} vs. {self.out_specs_tree}")
pytreedef_mismatch_err_msg("`out_specs`", self.out_specs_tree,
"`out_shape`", out_tree))
return flat_in_specs, flat_out_specs
def get_grid_mapping(
self, in_avals, in_tree, out_avals, out_tree
self, in_avals, in_tree, in_paths, out_avals, out_tree, out_paths
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
assert all(i is None or isinstance(i, int) for i in self.grid)
grid_mapping_grid = tuple(
@ -432,8 +451,8 @@ class GridSpec:
flat_in_specs, flat_out_specs = self._get_in_out_specs(
in_avals, in_tree, out_avals, out_tree)
in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals(
self.grid, in_avals, flat_in_specs, out_avals,
flat_out_specs)
self.grid, in_avals, flat_in_specs, in_paths,
out_avals, flat_out_specs, out_paths)
grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid)
# Create args, kwargs pytree def
grid_tree = tree_util.tree_structure((tuple(grid_avals), {}))
@ -444,8 +463,10 @@ class GridSpec:
in_tree=grid_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="input",
),
in_specs,
in_paths,
in_ref_avals,
)
out_block_mappings = map(
@ -455,8 +476,10 @@ class GridSpec:
in_tree=grid_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="output",
),
out_specs,
out_paths,
out_ref_avals,
)
grid_mapping = GridMapping(
@ -480,3 +503,18 @@ class GridSpec:
static_self = copy.copy(self)
static_self.grid = static_grid # type: ignore
return static_self, dynamic_bounds
def pytreedef_mismatch_err_msg(
what1: str, tree1: tree_util.PyTreeDef,
what2: str, tree2: tree_util.PyTreeDef) -> str:
errs = list(tree_util.equality_errors_pytreedef(tree1, tree2))
msg = []
msg.append(
f"Pytree for {what1} and {what2} do not match. "
f"There are {len(errs)} mismatches, including:")
for path, thing1, thing2, explanation in errs:
where = f"at {tree_util.keystr(path)}, " if path else ""
msg.append(
f" * {where}{what1} is a {thing1} but"
f" {what2} is a {thing2}, so {explanation}")
return "\n".join(msg)

View File

@ -170,7 +170,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
self.scratch_shapes = tuple(scratch_shapes)
def get_grid_mapping(
self, in_avals, in_tree, out_avals, out_tree
self, in_avals, in_tree, in_paths, out_avals, out_tree, out_paths
) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]:
assert all(i is None or isinstance(i, int) for i in self.grid)
grid_mapping_grid = tuple(
@ -189,8 +189,8 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
in_avals, in_avals_tree, out_avals, out_tree)
in_specs, in_ref_avals, out_specs, out_ref_avals = (
pallas_core._get_ref_avals(
self.grid, in_avals, flat_in_specs,
out_avals, flat_out_specs))
self.grid, in_avals, flat_in_specs, in_paths[num_flat_scalar_prefetch:],
out_avals, flat_out_specs, out_paths))
scalar_ref_avals = [
AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype),
TPUMemorySpace.SMEM)
@ -207,8 +207,10 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
in_tree=index_map_in_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="input",
),
in_specs,
in_paths[num_flat_scalar_prefetch:],
in_ref_avals,
)
out_block_mappings = map(
@ -218,8 +220,10 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec):
in_tree=index_map_in_tree,
grid=grid_mapping_grid,
mapped_dims=(),
what="output",
),
out_specs,
out_paths,
out_ref_avals,
)
grid_mapping = GridMapping(

View File

@ -22,10 +22,10 @@ import string
from typing import Any
import jax
from jax import core as jax_core
from jax import lax
from jax import tree_util
from jax._src import ad_util
from jax._src import core as jax_core
from jax._src import custom_derivatives
from jax._src import debugging
from jax._src import dtypes
@ -204,6 +204,11 @@ def ir_constant(x, mlir_type=None):
lowering_rules = {}
skip_mlir_conversions = set()
def _get_aval_physical_dtype_shape(aval):
dtype_physical_shape = jax_core.physical_aval(aval).shape[
len(aval.shape) :
]
return dtype_physical_shape
def _get_arg_type(
aval,
@ -427,6 +432,7 @@ def lower_jaxpr_to_module(
mlir_func = lower_jaxpr_to_transform_func(
ctx,
bm.index_map_jaxpr.jaxpr,
aval,
name=func_name,
mosaic_grid_mapping=mosaic_grid_mapping,
)
@ -434,6 +440,9 @@ def lower_jaxpr_to_module(
block_shape = [
1 if b is pl_core.mapped else b for b in bm.block_shape
]
# If we have an extended dtype, we need to add the block shape for the
# remaining physical dtype.
block_shape += list(_get_aval_physical_dtype_shape(aval.inner_aval))
window_shape = ir.DenseI64ArrayAttr.get(block_shape)
block_params = dict(
window_bounds=window_shape,
@ -469,6 +478,7 @@ def lower_jaxpr_to_module(
def lower_jaxpr_to_transform_func(
ctx: ir.Context,
jaxpr: jax_core.Jaxpr,
aval: jax_core.AbstractValue,
*,
name: str,
mosaic_grid_mapping: MosaicGridMapping,
@ -503,8 +513,16 @@ def lower_jaxpr_to_transform_func(
mesh_context=mesh_context,
traceback_caches=mlir.TracebackCaches(),
)
return jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices,
out = jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices,
*scalar_prefetch)
assert isinstance(aval, state.AbstractRef), aval
# If we have an extended dtype, we need to add 0s for the block indices
# for the remaining physical dtype.
out += [
ir_constant(0, mlir_type=_dtype_to_ir_type(jnp.dtype("int32")))
] * len(_get_aval_physical_dtype_shape(aval.inner_aval))
return out
body_func.__name__ = name
body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func)
try:

View File

@ -23,7 +23,6 @@ from typing import Any
import jax
from jax import api_util
from jax import lax
from jax import tree_util
from jax._src import ad_util
from jax._src import checkify
from jax._src import config
@ -31,6 +30,7 @@ from jax._src import core as jax_core
from jax._src import effects
from jax._src import linear_util as lu
from jax._src import state
from jax._src import tree_util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -45,6 +45,7 @@ from jax._src.util import (
safe_zip,
split_list,
tuple_insert,
unzip2,
weakref_lru_cache,
)
import jax.numpy as jnp
@ -848,6 +849,7 @@ def pallas_call_checkify_rule(error: checkify.Error,
# for the new error inputs and outputs.
scalar_avals = map(checkify.get_shaped_aval, scalars)
error_block_specs = [no_block_spec] * num_err_vals
error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0])
grid_avals = [
jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid)
# TODO(justinfu): Place these in device-specific scalar memory.
@ -862,8 +864,9 @@ def pallas_call_checkify_rule(error: checkify.Error,
(*grid_avals, *scalar_ref_avals),
in_tree=grid_tree,
grid=grid_mapping.grid,
mapped_dims=grid_mapping.mapped_dims),
error_block_specs, error_memref_aval)
mapped_dims=grid_mapping.mapped_dims,
what="error"),
error_block_specs, error_paths, error_memref_aval)
input_block_mappings, output_block_mappings = split_list(
grid_mapping.block_mappings, [num_kernel_inputs,])
grid_mapping_with_error = grid_mapping.replace(
@ -893,10 +896,16 @@ def pallas_call_checkify_rule(error: checkify.Error,
checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule
@weakref_lru_cache
def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, flat_in_avals,
flat_out_avals, in_tree, out_tree, interpret: bool):
avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree,
flat_out_avals, out_tree)
def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec,
flat_in_avals: Sequence[jax_core.AbstractValue],
flat_out_avals: Sequence[jax_core.AbstractValue],
in_tree: tree_util.PyTreeDef,
in_paths: Sequence[tree_util.KeyPath],
out_tree: tree_util.PyTreeDef,
out_paths: Sequence[tree_util.KeyPath],
interpret: bool):
avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree, in_paths,
flat_out_avals, out_tree, out_paths)
if interpret:
avals = jax.tree_util.tree_map(_logical_aval_to_interpret_mode_aval, avals)
jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals)
@ -1058,19 +1067,25 @@ def pallas_call(
grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds()
if isinstance(out_shape, list):
out_shape = tuple(out_shape)
flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape)
flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype)
flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape)
out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths)
flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) # type: ignore
for x in flat_out_shapes]
@jax.jit
def wrapped(*args):
flat_args, in_tree = tree_util.tree_flatten(args)
flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args)
in_paths, flat_args = unzip2(flat_args_with_paths)
flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a))
for a in flat_args)
flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype)
for v in flat_out_shapes)
grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr(
f, grid_spec, flat_in_avals, flat_out_avals, in_tree,
out_tree, interpret=interpret)
grid_mapping, jaxpr, consts, f_out_tree = _trace_to_jaxpr(
f, grid_spec, flat_in_avals, flat_out_avals, in_tree, in_paths,
out_tree, out_paths, interpret=interpret)
if f_out_tree != tree_util.tree_flatten(None)[1]:
raise ValueError(
"The kernel function in a pallas_call should return None. "
f"Found a PyTree: {f_out_tree}")
out_flat = pallas_call_p.bind(
*dynamic_grid_bounds, *consts, *flat_args,
jaxpr=jaxpr, name=name,

View File

@ -1206,11 +1206,7 @@ def explain_tracing_cache_miss(
p(f" never seen input pytree{in_tree_str}")
dont_match = [t for t, *_ in seen_keys if t != in_tree]
closest_tree = min(dont_match, key=lambda t: abs(t.num_leaves - in_tree.num_leaves))
# TODO(mattjj): make equality_errors not print type name, avoid metaclass
leaf = type('LeafMeta', (type,), dict(__repr__=lambda _: 'leaf'))('Leaf', (), {})()
this_dummy = tree_unflatten(in_tree, [leaf] * in_tree.num_leaves)
close_dummy = tree_unflatten(closest_tree, [leaf] * closest_tree.num_leaves) # type: ignore
errs = list(tree_util.equality_errors(this_dummy, close_dummy))
errs = list(tree_util.equality_errors_pytreedef(in_tree, closest_tree)) # type: ignore[arg-type]
p(f" closest seen input pytree has {len(errs)} mismatches, including:")
for path, thing1, thing2, explanation in errs:
fst, *path = path # type: ignore

View File

@ -43,7 +43,6 @@ from jax._src import config
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import tree_util
from jax._src.tree_util import tree_unflatten, keystr
from jax._src import util
from jax._src.sharding_impls import is_unspecified_or_auto
from jax._src.layout import Layout
@ -590,11 +589,7 @@ class Compiled(Stage):
f"keyword arguments, but called with keyword arguments: {kws}")
args_flat, in_tree = tree_util.tree_flatten((args, kwargs))
if in_tree != params.in_tree:
leaf = PytreeLeaf()
this_dummy = tree_unflatten(in_tree, [leaf] * in_tree.num_leaves)
other_dummy = tree_unflatten(
params.in_tree, [leaf] * params.in_tree.num_leaves)
errs = list(tree_util.equality_errors(this_dummy, other_dummy))
errs = list(tree_util.equality_errors_pytreedef(in_tree, params.in_tree))
msg = []
msg.append(
"Function compiled with input pytree does not match the input pytree"
@ -603,7 +598,7 @@ class Compiled(Stage):
fst, *rest = path
base = ['args', 'kwargs'][fst.idx]
msg.append(
f" * at {base}{keystr(tuple(rest))}, seen {thing2} but now"
f" * at {base}{tree_util.keystr(tuple(rest))}, seen {thing2} but now"
f" given {thing1}, so {explanation}")
raise TypeError('\n'.join(msg))
try:
@ -641,10 +636,6 @@ class Compiled(Stage):
return self._call(*args, **kwargs)
class PytreeLeaf:
def __repr__(self): return "pytree leaf"
class Lowered(Stage):
"""Lowering of a function specialized to argument types and values.

View File

@ -621,7 +621,7 @@ def equality_errors(
"""Helper to describe structural differences between two pytrees.
Args:
tree1, tree2: pytrees to compare.
tree1, tree2: pytrees known to have different structure.
Usage:
@ -636,6 +636,15 @@ def equality_errors(
"""
yield from _equality_errors((), tree1, tree2, is_leaf)
def equality_errors_pytreedef(
tree1: PyTreeDef,
tree2: PyTreeDef) -> Iterable[tuple[KeyPath, str, str, str]]:
"""Like `equality_errors` but invoked on PyTreeDef."""
# TODO(mattjj): make equality_errors not print type name, avoid metaclass
leaf = type("LeafMeta", (type,), dict(__repr__=lambda _: "pytree leaf"))("Leaf", (), {})()
return equality_errors(tree_unflatten(tree1, [leaf] * tree1.num_leaves),
tree_unflatten(tree2, [leaf] * tree2.num_leaves))
# TODO(mattjj): maybe share some logic with _prefix_error?
def _equality_errors(path, t1, t2, is_leaf):
# If both are leaves, this isn't a structure equality error.

View File

@ -16,6 +16,7 @@ import contextlib
import functools
import itertools
import os
import re
import sys
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
@ -228,6 +229,27 @@ class PallasCallTest(PallasTest):
for i in range(5):
np.testing.assert_allclose(index(x, i), x[i])
def test_pallas_call_no_outputs(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref: None, ())
self.assertAllClose((), f(a))
def test_pallas_call_out_shape_is_singleton_tuple(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
out_shape=(a,))
res = f(a)
self.assertIsInstance(res, tuple)
self.assertLen(res, 1)
def test_pallas_call_out_shape_is_list(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
out_shape=[a])
res = f(a)
# TODO(necula): we normalize out_shape to a tuple, we shouldn't.
self.assertIsInstance(res, tuple)
def test_hoisted_consts(self):
# See https://github.com/google/jax/issues/21557.
x = jnp.zeros(32)
@ -441,6 +463,112 @@ class PallasCallInterpreterTest(PallasCallTest):
INTERPRET = True
class ApiErrorTest(PallasTest):
def test_pallas_kernel_args_mismatch(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref: None, # Missing o_ref
out_shape=a)
with self.assertRaisesRegex(
TypeError,
"takes 1 positional argument but 2 were given"):
f(a)
@parameterized.named_parameters(
("array", 0),
("empty_tuple", ())
)
def test_pallas_call_error_kernel_returns_something(self, returns):
a = np.arange(256, dtype=np.int32)
# The kernel should not return anything
f = self.pallas_call(lambda x_ref, o1_ref, o2_ref: returns,
out_shape=(a, a))
with self.assertRaisesRegex(
ValueError,
"The kernel function in a pallas_call should return None"):
f(a)
def test_pallas_call_in_specs_not_a_sequence(self):
a = np.arange(256, dtype=np.int32)
with self.assertRaisesRegex(
ValueError,
"`in_specs` must be a tuple or a list"):
_ = self.pallas_call(lambda x_ref, o1_ref: None,
out_shape=a,
in_specs=pl.BlockSpec((4,), lambda: 0))
def test_pallas_call_in_specs_mismatch_inputs(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
out_shape=a,
in_specs=[pl.BlockSpec((4,), lambda: 0),
pl.BlockSpec((4,), lambda: 0)])
with self.assertRaisesRegex(
ValueError,
re.compile("Pytree for `in_specs` and inputs do not match. "
"There are 1 mismatches, including:"
".* at \\[1\\], `in_specs` is a pytree leaf but "
"inputs is a.*", re.DOTALL)):
f(a, dict(a=a))
def test_pallas_call_index_map_wrong_number_of_arguments(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
out_shape=a,
in_specs=[pl.BlockSpec((4,), lambda i, j: 0)])
with self.assertRaisesRegex(
TypeError,
"missing 2 required positional arguments: 'i' and 'j'"):
f(a)
def test_pallas_call_index_map_wrong_number_of_results(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
out_shape=a,
in_specs=[pl.BlockSpec((4,), lambda: (0, 0))])
with self.assertRaisesRegex(
ValueError,
"Index map for input\\[0\\] must return 1 values to match .*Currently returning 2 values."):
f(a)
def test_pallas_call_out_specs_mismatch_shape(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
out_shape=[a, a],
out_specs=[pl.BlockSpec((6,), lambda i: i)])
with self.assertRaisesRegex(
ValueError,
re.compile("Pytree for `out_specs` and `out_shape` do not match. There are 1 mismatches, including:"
".* `out_specs` is a tuple of length 1 but `out_shape` is a tuple of length 2.*", re.DOTALL)):
f(a)
def test_pallas_call_block_shape_ndim_mismatch(self):
a = np.arange(256, dtype=np.int32)
f = self.pallas_call(lambda x_ref, o1_ref: None,
out_shape=[a],
in_specs=[pl.BlockSpec((1, 1), lambda: (0, 0))])
with self.assertRaisesRegex(
ValueError,
"Block shape for input\\[0\\] .* must have the same number of dimensions as the "
"array shape"):
f(a)
f = self.pallas_call(lambda x_ref, o1_ref: None,
out_shape=[a],
out_specs=[pl.BlockSpec((1, 1), lambda: 0)])
with self.assertRaisesRegex(
ValueError,
"Block shape for output\\[0\\] .* must have the same number of dimensions as the "
"array shape"):
f(a)
class ApiErrorInterpreterTest(ApiErrorTest):
INTERPRET = True
class PallasControlFlowTest(PallasTest):
def setUp(self):

View File

@ -115,6 +115,67 @@ class PallasCallScalarPrefetchTest(PallasTPUTest):
)(s, x)
np.testing.assert_array_equal(out, x)
def test_block_spec_with_wrong_block_shape_errors(self):
def body(x_ref, o_ref):
o_ref[...] = x_ref[...]
x = jnp.ones((16, 128))
with self.assertRaisesRegex(
ValueError,
'Block shape .* must have the same number of dimensions as the array shape .*'):
_ = pl.pallas_call(
body,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[pl.BlockSpec((128,), lambda i: (i, 0))], # WRONG
out_specs=pl.BlockSpec((8, 128,), lambda i: (i, 0)),
grid=(2,),
),
out_shape=x,
interpret=self.interpret,
)(x)
def test_block_spec_with_index_map_that_accepts_wrong_number_of_args_errors(self):
def body(x_ref, o_ref):
o_ref[...] = x_ref[...]
x = jnp.ones((16, 128))
with self.assertRaisesRegex(
TypeError,
'missing 1 required positional argument: \'j\''):
_ = pl.pallas_call(
body,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[pl.BlockSpec((8, 128,), lambda i, j: (i, 0))], # WRONG
out_specs=pl.BlockSpec((8, 128,), lambda i: (i, 0),),
grid=(2,),
),
out_shape=x,
interpret=self.interpret
)(x)
def test_block_spec_with_index_map_returns_wrong_number_of_values_errors(self):
def body(x_ref, o_ref):
o_ref[...] = x_ref[...]
x = jnp.ones((16, 128))
with self.assertRaisesRegex(
ValueError,
r'Index map for input\[0\] must return 2 values to match block_shape=\(8, 128\).'
' Currently returning 1 values.'):
_ = pl.pallas_call(
body,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=0,
in_specs=[pl.BlockSpec((8, 128,), lambda i: (i,))], # WRONG
out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)),
grid=(2,),
),
out_shape=x,
interpret=self.interpret,
)(x)
def test_vmap_scalar_prefetch(self):
def body(_, x_ref, o_ref):
o_ref[...] = x_ref[...]
@ -138,8 +199,7 @@ class PallasCallScalarPrefetchTest(PallasTPUTest):
out_specs=pl.BlockSpec(
(x.shape[0] // 8, x.shape[1]), lambda i, _: (i, 0)
),
grid=8,
),
grid=8),
interpret=self.interpret,
)(s, x)
np.testing.assert_allclose(
@ -363,7 +423,7 @@ class PallasCallDynamicGridTest(PallasTPUTest):
num_programs = pl.num_programs(0)
self.assertIsInstance(num_programs, int)
self.assertEqual(num_programs, 2)
return 0
return 0, 0
pl.pallas_call(
kernel,
in_specs=[pl.BlockSpec((8, 128), x_index_map)],

View File

@ -184,7 +184,7 @@ class BlockInvarianceTest(parameterized.TestCase):
def make_kernel_body(index_map):
def body(key_ref, o_ref):
key = key_ref[0, 0]
key = key_ref[...]
samples = plrandom.sample_block(
jax.random.uniform,
key,
@ -199,9 +199,7 @@ class BlockInvarianceTest(parameterized.TestCase):
global_key = jax_random.key(0, impl="pallas_tpu")
o_shape = jnp.ones((64, 512), dtype=jnp.float32)
key_spec = pl.BlockSpec(
(1, 1), lambda i, j: (0, 0), memory_space=pltpu.TPUMemorySpace.SMEM
)
key_spec = pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM)
out_spec = pl.BlockSpec((16, 128), lambda i, j: (i, j))
result_16x128 = pl.pallas_call(
make_kernel_body(index_map=lambda i, j: (i, j)),