From f960c287c4aee736971a6c6d4cf102de4272b9ce Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 8 Jul 2024 09:13:35 +0300 Subject: [PATCH] [pallas] Improve error messages for input_output_aliasing Currently, we get MLIR verification errors when the inputs and outputs declared to be aliased do not have matching shapes and dtypes. We add a nicer error message that localizes the inputs and outputs in the corresponding PyTrees. Interestingly, if the output index is out of bounds, there is no MLIR verification error. This seems to be a bug in the StableHLO verification code. Currently, in interpreter mode we get a mix of internal assertion errors when there are errors in input_output_aliasing. --- jax/_src/pallas/pallas_call.py | 24 ++++++++++++++++++++- tests/pallas/pallas_test.py | 38 ++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 66678076e..736d52bd3 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1044,7 +1044,8 @@ def pallas_call( The default value for `out_specs` specifies the whole array, e.g., as ``pl.BlockSpec(x.shape, lambda *indices: indices)``. input_output_aliases: a dictionary mapping the index of some inputs to - the index of the output that aliases them. + the index of the output that aliases them. These indices are in the + flattened inputs and outputs. interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the grid whose body is the kernel lowered as a JAX function. This does not require a TPU or a GPU, and is the only way to run Pallas kernels on CPU. @@ -1086,6 +1087,27 @@ def pallas_call( raise ValueError( "The kernel function in a pallas_call should return None. " f"Found a PyTree: {f_out_tree}") + for i_idx, o_idx in input_output_aliases.items(): + if i_idx not in range(len(flat_in_avals)): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' with " + f"input index {i_idx} outside the range " + f"[0, {len(flat_in_avals)})") + if o_idx not in range(len(flat_out_avals)): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' with " + f"output index {o_idx} outside the range " + f"[0, {len(flat_out_avals)})") + in_aval = flat_in_avals[i_idx] + out_aval = flat_out_avals[o_idx] + if in_aval.shape != out_aval.shape or in_aval.dtype != out_aval.dtype: + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"referring to input{tree_util.keystr(in_paths[i_idx])} with " + f"abstract value {in_aval} " + f"and to output{tree_util.keystr(out_paths[o_idx])} with " + f"a different abstract value {out_aval}.") + out_flat = pallas_call_p.bind( *dynamic_grid_bounds, *consts, *flat_args, jaxpr=jaxpr, name=name, diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 5ac052834..a7e6ff7de 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -567,6 +567,44 @@ class ApiErrorTest(PallasBaseTest): "array shape"): f(a) + def test_pallas_call_input_output_aliases_errors(self): + x = np.arange(8 * 128, dtype=np.int32).reshape((8, 128)) + + with self.assertRaisesRegex( + ValueError, + "input_output_aliases contains the mapping '2:0' with input index 2 " + "outside the range .*"): + self.pallas_call(lambda x_ref, y_ref, o1_ref: None, + out_shape=[x], + input_output_aliases={2: 0})(x, x) + + with self.assertRaisesRegex( + ValueError, + "input_output_aliases contains the mapping '1:1' with output index 1 " + "outside the range .*"): + self.pallas_call(lambda x_ref, y_ref, o1_ref: None, + out_shape=[x], + input_output_aliases={1: 1})(x, x) + + y = np.concatenate([x, x], axis=0) + with self.assertRaisesRegex( + ValueError, + "input_output_aliases contains the mapping '1:0' referring to " + "input\\[1\\] with abstract value .*int32\\[16,128\\].* " + "output\\[0\\] with a different abstract value .*int32\\[8,128\\]"): + self.pallas_call(lambda x_ref, y_ref, o1_ref: None, + out_shape=[x], + input_output_aliases={1: 0})(x, y) + + with self.assertRaisesRegex( + ValueError, + "input_output_aliases contains the mapping '1:0' referring to " + "input\\[1\\] with abstract value .*int32\\[8,128\\].* " + "output\\[0\\] with a different abstract value .*float32\\[8,128\\]"): + self.pallas_call(lambda x_ref, y_ref, o1_ref: None, + out_shape=[jax.ShapeDtypeStruct(x.shape, jnp.float32)], + input_output_aliases={1: 0})(x, x) + class ApiErrorInterpreterTest(ApiErrorTest): INTERPRET = True