diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index 9143acbd0..4692e7791 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -296,10 +296,12 @@ def _preprocess_grid(grid: Grid | int | None) -> Grid: def _convert_block_spec_to_block_mapping( in_avals: Sequence[jax_core.ShapedArray], block_spec: BlockSpec, + path: tree_util.KeyPath, aval: jax_core.ShapedArray, in_tree: Any, grid: GridMappingGrid, mapped_dims: tuple[int, ...], + what: str, # Used to localize error messages, e.g., {what}{path} ) -> BlockMapping | None: if block_spec is no_block_spec: return None @@ -313,7 +315,13 @@ def _convert_block_spec_to_block_mapping( mapped if s is None else s for s in block_shape) flat_fun, _ = api_util.flatten_fun(lu.wrap_init(compute_index), in_tree) with tracing_grid_env(grid, mapped_dims): - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals) + if len(out_avals) != len(block_shape): + raise ValueError( + f"Index map for {what}{tree_util.keystr(path)} must return " + f"{len(aval.shape)} values to match {block_shape=}. " + f"Currently returning {len(out_avals)} values." + ) return BlockMapping( block_shape, jax_core.ClosedJaxpr(jaxpr, consts), block_spec.indexing_mode ) @@ -327,42 +335,51 @@ def _tile_ref(ref: state.AbstractRef, block_shape: tuple[int, ...] | None return ref.update(inner_aval=ref.inner_aval.update(shape=shape)) -def _check_static_ref_shape(ref: state.AbstractRef) -> state.AbstractRef: - shape = ref.shape - if not jax_core.is_constant_shape(shape): - # TODO(necula): thread the tree labels so that we can localize the error - raise ValueError("shape polymorphism for Pallas does not support " - f"dynamically-shaped blocks. Found block_shape: {shape}") - return ref - - -def _get_ref_avals(grid, in_avals, in_specs, out_avals, out_specs): - def _get_memory_space(spec): +def _get_ref_avals(grid, + in_avals: Sequence[jax_core.ShapedArray], + in_specs: Sequence[BlockSpec], + in_paths: Sequence[tree_util.KeyPath], + out_avals: Sequence[jax_core.ShapedArray], + out_specs: Sequence[BlockSpec], + out_paths: Sequence[tree_util.KeyPath]): + def make_ref_aval(aval: jax_core.ShapedArray, + spec: BlockSpec, + path: tree_util.KeyPath, + what: str) -> state.AbstractRef: if spec is no_block_spec: - return None - return spec.memory_space + memory_space = None + block_shape = None + else: + memory_space = spec.memory_space + block_shape = spec.block_shape + + ref_aval = AbstractMemoryRef(aval, memory_space) + if block_shape is not None: + if len(ref_aval.shape) != len(block_shape): + raise ValueError( + f"Block shape for {what}{tree_util.keystr(path)} (= {block_shape}) " + f"must have the same number of dimensions as the array shape {ref_aval.shape}" + ) + trimmed_block_shape = tuple(s for s in block_shape if s is not None) + ref_aval = ref_aval.update( + inner_aval=ref_aval.inner_aval.update(shape=trimmed_block_shape)) + + if not jax_core.is_constant_shape(ref_aval.shape): + raise ValueError( + "shape polymorphism for Pallas does not support " + "dynamically-shaped blocks. " + f"{what}{tree_util.keystr(path)} has block_shape: {ref_aval.shape}") + return ref_aval + in_ref_avals = [ - AbstractMemoryRef(aval, _get_memory_space(in_spec)) - for aval, in_spec in zip(in_avals, in_specs) + make_ref_aval(aval, in_spec, in_path, "input") + for aval, in_spec, in_path in zip(in_avals, in_specs, in_paths) ] out_ref_avals = [ - AbstractMemoryRef(aval, _get_memory_space(out_spec)) - for aval, out_spec in zip(out_avals, out_specs) + make_ref_aval(aval, out_spec, out_path, "output") + for aval, out_spec, out_path in zip(out_avals, out_specs, out_paths) ] - if grid is None: - in_specs = [None] * len(in_avals) - out_specs = [None] * len(out_avals) - tiled_in_ref_avals = [ - _check_static_ref_shape(aval if in_spec is no_block_spec - else _tile_ref(aval, in_spec.block_shape)) - for aval, in_spec in zip(in_ref_avals, in_specs) - ] - tiled_out_ref_avals = [ - _check_static_ref_shape(aval if out_spec is no_block_spec - else _tile_ref(aval, out_spec.block_shape)) - for aval, out_spec in zip(out_ref_avals, out_specs) - ] - return in_specs, tiled_in_ref_avals, out_specs, tiled_out_ref_avals + return in_specs, in_ref_avals, out_specs, out_ref_avals class NoBlockSpec: pass @@ -386,6 +403,8 @@ class GridSpec: # Be more lenient for in/out_specs if isinstance(in_specs, list): in_specs = tuple(in_specs) + elif in_specs is not no_block_spec and not isinstance(in_specs, Sequence): + raise ValueError(f"`in_specs` must be a tuple or a list. Found: {in_specs}") if isinstance(out_specs, list): out_specs = tuple(out_specs) @@ -410,20 +429,20 @@ class GridSpec: flat_in_specs = self.in_specs if self.in_specs_tree != in_tree: raise ValueError( - "Pytree specs for arguments and `in_specs` must match: " - f"{in_tree} vs. {self.in_specs_tree}") + pytreedef_mismatch_err_msg("`in_specs`", self.in_specs_tree, + "inputs", in_tree)) if self.out_specs is no_block_spec: flat_out_specs = [no_block_spec] * len(out_avals) else: flat_out_specs = self.out_specs if self.out_specs_tree != out_tree: raise ValueError( - "Pytree specs for `out_shape` and `out_specs` must match: " - f"{out_tree} vs. {self.out_specs_tree}") + pytreedef_mismatch_err_msg("`out_specs`", self.out_specs_tree, + "`out_shape`", out_tree)) return flat_in_specs, flat_out_specs def get_grid_mapping( - self, in_avals, in_tree, out_avals, out_tree + self, in_avals, in_tree, in_paths, out_avals, out_tree, out_paths ) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: assert all(i is None or isinstance(i, int) for i in self.grid) grid_mapping_grid = tuple( @@ -432,8 +451,8 @@ class GridSpec: flat_in_specs, flat_out_specs = self._get_in_out_specs( in_avals, in_tree, out_avals, out_tree) in_specs, in_ref_avals, out_specs, out_ref_avals = _get_ref_avals( - self.grid, in_avals, flat_in_specs, out_avals, - flat_out_specs) + self.grid, in_avals, flat_in_specs, in_paths, + out_avals, flat_out_specs, out_paths) grid_avals = [jax_core.ShapedArray((), jnp.dtype("int32"))] * len(self.grid) # Create args, kwargs pytree def grid_tree = tree_util.tree_structure((tuple(grid_avals), {})) @@ -444,8 +463,10 @@ class GridSpec: in_tree=grid_tree, grid=grid_mapping_grid, mapped_dims=(), + what="input", ), in_specs, + in_paths, in_ref_avals, ) out_block_mappings = map( @@ -455,8 +476,10 @@ class GridSpec: in_tree=grid_tree, grid=grid_mapping_grid, mapped_dims=(), + what="output", ), out_specs, + out_paths, out_ref_avals, ) grid_mapping = GridMapping( @@ -480,3 +503,18 @@ class GridSpec: static_self = copy.copy(self) static_self.grid = static_grid # type: ignore return static_self, dynamic_bounds + +def pytreedef_mismatch_err_msg( + what1: str, tree1: tree_util.PyTreeDef, + what2: str, tree2: tree_util.PyTreeDef) -> str: + errs = list(tree_util.equality_errors_pytreedef(tree1, tree2)) + msg = [] + msg.append( + f"Pytree for {what1} and {what2} do not match. " + f"There are {len(errs)} mismatches, including:") + for path, thing1, thing2, explanation in errs: + where = f"at {tree_util.keystr(path)}, " if path else "" + msg.append( + f" * {where}{what1} is a {thing1} but" + f" {what2} is a {thing2}, so {explanation}") + return "\n".join(msg) diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index f4a794792..09979faf8 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -170,7 +170,7 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec): self.scratch_shapes = tuple(scratch_shapes) def get_grid_mapping( - self, in_avals, in_tree, out_avals, out_tree + self, in_avals, in_tree, in_paths, out_avals, out_tree, out_paths ) -> tuple[tuple[jax_core.AbstractValue, ...], GridMapping]: assert all(i is None or isinstance(i, int) for i in self.grid) grid_mapping_grid = tuple( @@ -189,8 +189,8 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec): in_avals, in_avals_tree, out_avals, out_tree) in_specs, in_ref_avals, out_specs, out_ref_avals = ( pallas_core._get_ref_avals( - self.grid, in_avals, flat_in_specs, - out_avals, flat_out_specs)) + self.grid, in_avals, flat_in_specs, in_paths[num_flat_scalar_prefetch:], + out_avals, flat_out_specs, out_paths)) scalar_ref_avals = [ AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype), TPUMemorySpace.SMEM) @@ -207,8 +207,10 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec): in_tree=index_map_in_tree, grid=grid_mapping_grid, mapped_dims=(), + what="input", ), in_specs, + in_paths[num_flat_scalar_prefetch:], in_ref_avals, ) out_block_mappings = map( @@ -218,8 +220,10 @@ class PrefetchScalarGridSpec(pallas_core.GridSpec): in_tree=index_map_in_tree, grid=grid_mapping_grid, mapped_dims=(), + what="output", ), out_specs, + out_paths, out_ref_avals, ) grid_mapping = GridMapping( diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index f47657a7a..1e011ee50 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -22,10 +22,10 @@ import string from typing import Any import jax -from jax import core as jax_core from jax import lax from jax import tree_util from jax._src import ad_util +from jax._src import core as jax_core from jax._src import custom_derivatives from jax._src import debugging from jax._src import dtypes @@ -204,6 +204,11 @@ def ir_constant(x, mlir_type=None): lowering_rules = {} skip_mlir_conversions = set() +def _get_aval_physical_dtype_shape(aval): + dtype_physical_shape = jax_core.physical_aval(aval).shape[ + len(aval.shape) : + ] + return dtype_physical_shape def _get_arg_type( aval, @@ -427,6 +432,7 @@ def lower_jaxpr_to_module( mlir_func = lower_jaxpr_to_transform_func( ctx, bm.index_map_jaxpr.jaxpr, + aval, name=func_name, mosaic_grid_mapping=mosaic_grid_mapping, ) @@ -434,6 +440,9 @@ def lower_jaxpr_to_module( block_shape = [ 1 if b is pl_core.mapped else b for b in bm.block_shape ] + # If we have an extended dtype, we need to add the block shape for the + # remaining physical dtype. + block_shape += list(_get_aval_physical_dtype_shape(aval.inner_aval)) window_shape = ir.DenseI64ArrayAttr.get(block_shape) block_params = dict( window_bounds=window_shape, @@ -469,6 +478,7 @@ def lower_jaxpr_to_module( def lower_jaxpr_to_transform_func( ctx: ir.Context, jaxpr: jax_core.Jaxpr, + aval: jax_core.AbstractValue, *, name: str, mosaic_grid_mapping: MosaicGridMapping, @@ -503,8 +513,16 @@ def lower_jaxpr_to_transform_func( mesh_context=mesh_context, traceback_caches=mlir.TracebackCaches(), ) - return jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices, - *scalar_prefetch) + out = jaxpr_subcomp(lowering_context, jaxpr, *jaxpr_indices, + *scalar_prefetch) + assert isinstance(aval, state.AbstractRef), aval + # If we have an extended dtype, we need to add 0s for the block indices + # for the remaining physical dtype. + out += [ + ir_constant(0, mlir_type=_dtype_to_ir_type(jnp.dtype("int32"))) + ] * len(_get_aval_physical_dtype_shape(aval.inner_aval)) + return out + body_func.__name__ = name body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) try: diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 41b2d2e61..66678076e 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -23,7 +23,6 @@ from typing import Any import jax from jax import api_util from jax import lax -from jax import tree_util from jax._src import ad_util from jax._src import checkify from jax._src import config @@ -31,6 +30,7 @@ from jax._src import core as jax_core from jax._src import effects from jax._src import linear_util as lu from jax._src import state +from jax._src import tree_util from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -45,6 +45,7 @@ from jax._src.util import ( safe_zip, split_list, tuple_insert, + unzip2, weakref_lru_cache, ) import jax.numpy as jnp @@ -848,6 +849,7 @@ def pallas_call_checkify_rule(error: checkify.Error, # for the new error inputs and outputs. scalar_avals = map(checkify.get_shaped_aval, scalars) error_block_specs = [no_block_spec] * num_err_vals + error_paths, _ = unzip2(tree_util.tree_flatten_with_path(error_block_specs)[0]) grid_avals = [ jax_core.ShapedArray((), jnp.dtype("int32"))] * len(grid_mapping.grid) # TODO(justinfu): Place these in device-specific scalar memory. @@ -862,8 +864,9 @@ def pallas_call_checkify_rule(error: checkify.Error, (*grid_avals, *scalar_ref_avals), in_tree=grid_tree, grid=grid_mapping.grid, - mapped_dims=grid_mapping.mapped_dims), - error_block_specs, error_memref_aval) + mapped_dims=grid_mapping.mapped_dims, + what="error"), + error_block_specs, error_paths, error_memref_aval) input_block_mappings, output_block_mappings = split_list( grid_mapping.block_mappings, [num_kernel_inputs,]) grid_mapping_with_error = grid_mapping.replace( @@ -893,10 +896,16 @@ def pallas_call_checkify_rule(error: checkify.Error, checkify.error_checks[pallas_call_p] = pallas_call_checkify_rule @weakref_lru_cache -def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, flat_in_avals, - flat_out_avals, in_tree, out_tree, interpret: bool): - avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree, - flat_out_avals, out_tree) +def _trace_to_jaxpr(fun: Callable, grid_spec: GridSpec, + flat_in_avals: Sequence[jax_core.AbstractValue], + flat_out_avals: Sequence[jax_core.AbstractValue], + in_tree: tree_util.PyTreeDef, + in_paths: Sequence[tree_util.KeyPath], + out_tree: tree_util.PyTreeDef, + out_paths: Sequence[tree_util.KeyPath], + interpret: bool): + avals, grid_mapping = grid_spec.get_grid_mapping(flat_in_avals, in_tree, in_paths, + flat_out_avals, out_tree, out_paths) if interpret: avals = jax.tree_util.tree_map(_logical_aval_to_interpret_mode_aval, avals) jaxpr_flat_avals, jaxpr_in_tree = tree_util.tree_flatten(avals) @@ -1058,19 +1067,25 @@ def pallas_call( grid_spec, dynamic_grid_bounds = grid_spec.unzip_dynamic_grid_bounds() if isinstance(out_shape, list): out_shape = tuple(out_shape) - flat_out_shapes, out_tree = tree_util.tree_flatten(out_shape) - flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) + flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape) + out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths) + flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) # type: ignore for x in flat_out_shapes] @jax.jit def wrapped(*args): - flat_args, in_tree = tree_util.tree_flatten(args) + flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args) + in_paths, flat_args = unzip2(flat_args_with_paths) flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a)) for a in flat_args) flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype) for v in flat_out_shapes) - grid_mapping, jaxpr, consts, _ = _trace_to_jaxpr( - f, grid_spec, flat_in_avals, flat_out_avals, in_tree, - out_tree, interpret=interpret) + grid_mapping, jaxpr, consts, f_out_tree = _trace_to_jaxpr( + f, grid_spec, flat_in_avals, flat_out_avals, in_tree, in_paths, + out_tree, out_paths, interpret=interpret) + if f_out_tree != tree_util.tree_flatten(None)[1]: + raise ValueError( + "The kernel function in a pallas_call should return None. " + f"Found a PyTree: {f_out_tree}") out_flat = pallas_call_p.bind( *dynamic_grid_bounds, *consts, *flat_args, jaxpr=jaxpr, name=name, diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0f21cb314..3e56119c2 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1206,11 +1206,7 @@ def explain_tracing_cache_miss( p(f" never seen input pytree{in_tree_str}") dont_match = [t for t, *_ in seen_keys if t != in_tree] closest_tree = min(dont_match, key=lambda t: abs(t.num_leaves - in_tree.num_leaves)) - # TODO(mattjj): make equality_errors not print type name, avoid metaclass - leaf = type('LeafMeta', (type,), dict(__repr__=lambda _: 'leaf'))('Leaf', (), {})() - this_dummy = tree_unflatten(in_tree, [leaf] * in_tree.num_leaves) - close_dummy = tree_unflatten(closest_tree, [leaf] * closest_tree.num_leaves) # type: ignore - errs = list(tree_util.equality_errors(this_dummy, close_dummy)) + errs = list(tree_util.equality_errors_pytreedef(in_tree, closest_tree)) # type: ignore[arg-type] p(f" closest seen input pytree has {len(errs)} mismatches, including:") for path, thing1, thing2, explanation in errs: fst, *path = path # type: ignore diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 874ef8834..34a573abd 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -43,7 +43,6 @@ from jax._src import config from jax._src import source_info_util from jax._src import traceback_util from jax._src import tree_util -from jax._src.tree_util import tree_unflatten, keystr from jax._src import util from jax._src.sharding_impls import is_unspecified_or_auto from jax._src.layout import Layout @@ -590,11 +589,7 @@ class Compiled(Stage): f"keyword arguments, but called with keyword arguments: {kws}") args_flat, in_tree = tree_util.tree_flatten((args, kwargs)) if in_tree != params.in_tree: - leaf = PytreeLeaf() - this_dummy = tree_unflatten(in_tree, [leaf] * in_tree.num_leaves) - other_dummy = tree_unflatten( - params.in_tree, [leaf] * params.in_tree.num_leaves) - errs = list(tree_util.equality_errors(this_dummy, other_dummy)) + errs = list(tree_util.equality_errors_pytreedef(in_tree, params.in_tree)) msg = [] msg.append( "Function compiled with input pytree does not match the input pytree" @@ -603,7 +598,7 @@ class Compiled(Stage): fst, *rest = path base = ['args', 'kwargs'][fst.idx] msg.append( - f" * at {base}{keystr(tuple(rest))}, seen {thing2} but now" + f" * at {base}{tree_util.keystr(tuple(rest))}, seen {thing2} but now" f" given {thing1}, so {explanation}") raise TypeError('\n'.join(msg)) try: @@ -641,10 +636,6 @@ class Compiled(Stage): return self._call(*args, **kwargs) -class PytreeLeaf: - def __repr__(self): return "pytree leaf" - - class Lowered(Stage): """Lowering of a function specialized to argument types and values. diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 32f59b1df..a33a43ab6 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -621,7 +621,7 @@ def equality_errors( """Helper to describe structural differences between two pytrees. Args: - tree1, tree2: pytrees to compare. + tree1, tree2: pytrees known to have different structure. Usage: @@ -636,6 +636,15 @@ def equality_errors( """ yield from _equality_errors((), tree1, tree2, is_leaf) +def equality_errors_pytreedef( + tree1: PyTreeDef, + tree2: PyTreeDef) -> Iterable[tuple[KeyPath, str, str, str]]: + """Like `equality_errors` but invoked on PyTreeDef.""" + # TODO(mattjj): make equality_errors not print type name, avoid metaclass + leaf = type("LeafMeta", (type,), dict(__repr__=lambda _: "pytree leaf"))("Leaf", (), {})() + return equality_errors(tree_unflatten(tree1, [leaf] * tree1.num_leaves), + tree_unflatten(tree2, [leaf] * tree2.num_leaves)) + # TODO(mattjj): maybe share some logic with _prefix_error? def _equality_errors(path, t1, t2, is_leaf): # If both are leaves, this isn't a structure equality error. diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 876cbc52c..d5c927b5b 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -16,6 +16,7 @@ import contextlib import functools import itertools import os +import re import sys os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" @@ -228,6 +229,27 @@ class PallasCallTest(PallasTest): for i in range(5): np.testing.assert_allclose(index(x, i), x[i]) + def test_pallas_call_no_outputs(self): + a = np.arange(256, dtype=np.int32) + f = self.pallas_call(lambda x_ref: None, ()) + self.assertAllClose((), f(a)) + + def test_pallas_call_out_shape_is_singleton_tuple(self): + a = np.arange(256, dtype=np.int32) + f = self.pallas_call(lambda x_ref, o1_ref: None, + out_shape=(a,)) + res = f(a) + self.assertIsInstance(res, tuple) + self.assertLen(res, 1) + + def test_pallas_call_out_shape_is_list(self): + a = np.arange(256, dtype=np.int32) + f = self.pallas_call(lambda x_ref, o1_ref: None, + out_shape=[a]) + res = f(a) + # TODO(necula): we normalize out_shape to a tuple, we shouldn't. + self.assertIsInstance(res, tuple) + def test_hoisted_consts(self): # See https://github.com/google/jax/issues/21557. x = jnp.zeros(32) @@ -441,6 +463,112 @@ class PallasCallInterpreterTest(PallasCallTest): INTERPRET = True +class ApiErrorTest(PallasTest): + + def test_pallas_kernel_args_mismatch(self): + a = np.arange(256, dtype=np.int32) + f = self.pallas_call(lambda x_ref: None, # Missing o_ref + out_shape=a) + with self.assertRaisesRegex( + TypeError, + "takes 1 positional argument but 2 were given"): + f(a) + + @parameterized.named_parameters( + ("array", 0), + ("empty_tuple", ()) + ) + def test_pallas_call_error_kernel_returns_something(self, returns): + a = np.arange(256, dtype=np.int32) + # The kernel should not return anything + f = self.pallas_call(lambda x_ref, o1_ref, o2_ref: returns, + out_shape=(a, a)) + with self.assertRaisesRegex( + ValueError, + "The kernel function in a pallas_call should return None"): + f(a) + + def test_pallas_call_in_specs_not_a_sequence(self): + a = np.arange(256, dtype=np.int32) + with self.assertRaisesRegex( + ValueError, + "`in_specs` must be a tuple or a list"): + _ = self.pallas_call(lambda x_ref, o1_ref: None, + out_shape=a, + in_specs=pl.BlockSpec((4,), lambda: 0)) + + def test_pallas_call_in_specs_mismatch_inputs(self): + a = np.arange(256, dtype=np.int32) + f = self.pallas_call(lambda x_ref, o1_ref: None, + out_shape=a, + in_specs=[pl.BlockSpec((4,), lambda: 0), + pl.BlockSpec((4,), lambda: 0)]) + with self.assertRaisesRegex( + ValueError, + re.compile("Pytree for `in_specs` and inputs do not match. " + "There are 1 mismatches, including:" + ".* at \\[1\\], `in_specs` is a pytree leaf but " + "inputs is a.*", re.DOTALL)): + f(a, dict(a=a)) + + def test_pallas_call_index_map_wrong_number_of_arguments(self): + a = np.arange(256, dtype=np.int32) + f = self.pallas_call(lambda x_ref, o1_ref: None, + out_shape=a, + in_specs=[pl.BlockSpec((4,), lambda i, j: 0)]) + with self.assertRaisesRegex( + TypeError, + "missing 2 required positional arguments: 'i' and 'j'"): + f(a) + + def test_pallas_call_index_map_wrong_number_of_results(self): + a = np.arange(256, dtype=np.int32) + f = self.pallas_call(lambda x_ref, o1_ref: None, + out_shape=a, + in_specs=[pl.BlockSpec((4,), lambda: (0, 0))]) + with self.assertRaisesRegex( + ValueError, + "Index map for input\\[0\\] must return 1 values to match .*Currently returning 2 values."): + f(a) + + def test_pallas_call_out_specs_mismatch_shape(self): + a = np.arange(256, dtype=np.int32) + f = self.pallas_call(lambda x_ref, o1_ref: None, + out_shape=[a, a], + out_specs=[pl.BlockSpec((6,), lambda i: i)]) + with self.assertRaisesRegex( + ValueError, + re.compile("Pytree for `out_specs` and `out_shape` do not match. There are 1 mismatches, including:" + ".* `out_specs` is a tuple of length 1 but `out_shape` is a tuple of length 2.*", re.DOTALL)): + f(a) + + + def test_pallas_call_block_shape_ndim_mismatch(self): + a = np.arange(256, dtype=np.int32) + f = self.pallas_call(lambda x_ref, o1_ref: None, + out_shape=[a], + in_specs=[pl.BlockSpec((1, 1), lambda: (0, 0))]) + with self.assertRaisesRegex( + ValueError, + "Block shape for input\\[0\\] .* must have the same number of dimensions as the " + "array shape"): + + f(a) + + f = self.pallas_call(lambda x_ref, o1_ref: None, + out_shape=[a], + out_specs=[pl.BlockSpec((1, 1), lambda: 0)]) + with self.assertRaisesRegex( + ValueError, + "Block shape for output\\[0\\] .* must have the same number of dimensions as the " + "array shape"): + f(a) + + +class ApiErrorInterpreterTest(ApiErrorTest): + INTERPRET = True + + class PallasControlFlowTest(PallasTest): def setUp(self): diff --git a/tests/pallas/tpu/pallas_call_test.py b/tests/pallas/tpu/pallas_call_test.py index 92e518206..b85533146 100644 --- a/tests/pallas/tpu/pallas_call_test.py +++ b/tests/pallas/tpu/pallas_call_test.py @@ -115,6 +115,67 @@ class PallasCallScalarPrefetchTest(PallasTPUTest): )(s, x) np.testing.assert_array_equal(out, x) + def test_block_spec_with_wrong_block_shape_errors(self): + def body(x_ref, o_ref): + o_ref[...] = x_ref[...] + + x = jnp.ones((16, 128)) + with self.assertRaisesRegex( + ValueError, + 'Block shape .* must have the same number of dimensions as the array shape .*'): + _ = pl.pallas_call( + body, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[pl.BlockSpec((128,), lambda i: (i, 0))], # WRONG + out_specs=pl.BlockSpec((8, 128,), lambda i: (i, 0)), + grid=(2,), + ), + out_shape=x, + interpret=self.interpret, + )(x) + + def test_block_spec_with_index_map_that_accepts_wrong_number_of_args_errors(self): + def body(x_ref, o_ref): + o_ref[...] = x_ref[...] + + x = jnp.ones((16, 128)) + with self.assertRaisesRegex( + TypeError, + 'missing 1 required positional argument: \'j\''): + _ = pl.pallas_call( + body, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[pl.BlockSpec((8, 128,), lambda i, j: (i, 0))], # WRONG + out_specs=pl.BlockSpec((8, 128,), lambda i: (i, 0),), + grid=(2,), + ), + out_shape=x, + interpret=self.interpret + )(x) + + def test_block_spec_with_index_map_returns_wrong_number_of_values_errors(self): + def body(x_ref, o_ref): + o_ref[...] = x_ref[...] + + x = jnp.ones((16, 128)) + with self.assertRaisesRegex( + ValueError, + r'Index map for input\[0\] must return 2 values to match block_shape=\(8, 128\).' + ' Currently returning 1 values.'): + _ = pl.pallas_call( + body, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[pl.BlockSpec((8, 128,), lambda i: (i,))], # WRONG + out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)), + grid=(2,), + ), + out_shape=x, + interpret=self.interpret, + )(x) + def test_vmap_scalar_prefetch(self): def body(_, x_ref, o_ref): o_ref[...] = x_ref[...] @@ -138,8 +199,7 @@ class PallasCallScalarPrefetchTest(PallasTPUTest): out_specs=pl.BlockSpec( (x.shape[0] // 8, x.shape[1]), lambda i, _: (i, 0) ), - grid=8, - ), + grid=8), interpret=self.interpret, )(s, x) np.testing.assert_allclose( @@ -363,7 +423,7 @@ class PallasCallDynamicGridTest(PallasTPUTest): num_programs = pl.num_programs(0) self.assertIsInstance(num_programs, int) self.assertEqual(num_programs, 2) - return 0 + return 0, 0 pl.pallas_call( kernel, in_specs=[pl.BlockSpec((8, 128), x_index_map)], diff --git a/tests/pallas/tpu/pallas_random_test.py b/tests/pallas/tpu/pallas_random_test.py index c8db82179..e3d43125c 100644 --- a/tests/pallas/tpu/pallas_random_test.py +++ b/tests/pallas/tpu/pallas_random_test.py @@ -184,7 +184,7 @@ class BlockInvarianceTest(parameterized.TestCase): def make_kernel_body(index_map): def body(key_ref, o_ref): - key = key_ref[0, 0] + key = key_ref[...] samples = plrandom.sample_block( jax.random.uniform, key, @@ -199,9 +199,7 @@ class BlockInvarianceTest(parameterized.TestCase): global_key = jax_random.key(0, impl="pallas_tpu") o_shape = jnp.ones((64, 512), dtype=jnp.float32) - key_spec = pl.BlockSpec( - (1, 1), lambda i, j: (0, 0), memory_space=pltpu.TPUMemorySpace.SMEM - ) + key_spec = pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM) out_spec = pl.BlockSpec((16, 128), lambda i, j: (i, j)) result_16x128 = pl.pallas_call( make_kernel_body(index_map=lambda i, j: (i, j)),