[pallas] Improve error messages for input_output_aliasing

Currently, we get MLIR verification errors when the inputs
and outputs declared to be aliased do not have matching
shapes and dtypes. We add a nicer error message that localizes
the inputs and outputs in the corresponding PyTrees.

Interestingly, if the output index is out of bounds, there
is no MLIR verification error. This seems to be a bug
in the StableHLO verification code.

Currently, in interpreter mode we get a mix of internal
assertion errors when there are errors in input_output_aliasing.
This commit is contained in:
George Necula 2024-07-08 09:13:35 +03:00
parent e3347700bb
commit f960c287c4
2 changed files with 61 additions and 1 deletions

View File

@ -1044,7 +1044,8 @@ def pallas_call(
The default value for `out_specs` specifies the whole array,
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: indices)``.
input_output_aliases: a dictionary mapping the index of some inputs to
the index of the output that aliases them.
the index of the output that aliases them. These indices are in the
flattened inputs and outputs.
interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the
grid whose body is the kernel lowered as a JAX function. This does not
require a TPU or a GPU, and is the only way to run Pallas kernels on CPU.
@ -1086,6 +1087,27 @@ def pallas_call(
raise ValueError(
"The kernel function in a pallas_call should return None. "
f"Found a PyTree: {f_out_tree}")
for i_idx, o_idx in input_output_aliases.items():
if i_idx not in range(len(flat_in_avals)):
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' with "
f"input index {i_idx} outside the range "
f"[0, {len(flat_in_avals)})")
if o_idx not in range(len(flat_out_avals)):
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' with "
f"output index {o_idx} outside the range "
f"[0, {len(flat_out_avals)})")
in_aval = flat_in_avals[i_idx]
out_aval = flat_out_avals[o_idx]
if in_aval.shape != out_aval.shape or in_aval.dtype != out_aval.dtype:
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
f"referring to input{tree_util.keystr(in_paths[i_idx])} with "
f"abstract value {in_aval} "
f"and to output{tree_util.keystr(out_paths[o_idx])} with "
f"a different abstract value {out_aval}.")
out_flat = pallas_call_p.bind(
*dynamic_grid_bounds, *consts, *flat_args,
jaxpr=jaxpr, name=name,

View File

@ -567,6 +567,44 @@ class ApiErrorTest(PallasBaseTest):
"array shape"):
f(a)
def test_pallas_call_input_output_aliases_errors(self):
x = np.arange(8 * 128, dtype=np.int32).reshape((8, 128))
with self.assertRaisesRegex(
ValueError,
"input_output_aliases contains the mapping '2:0' with input index 2 "
"outside the range .*"):
self.pallas_call(lambda x_ref, y_ref, o1_ref: None,
out_shape=[x],
input_output_aliases={2: 0})(x, x)
with self.assertRaisesRegex(
ValueError,
"input_output_aliases contains the mapping '1:1' with output index 1 "
"outside the range .*"):
self.pallas_call(lambda x_ref, y_ref, o1_ref: None,
out_shape=[x],
input_output_aliases={1: 1})(x, x)
y = np.concatenate([x, x], axis=0)
with self.assertRaisesRegex(
ValueError,
"input_output_aliases contains the mapping '1:0' referring to "
"input\\[1\\] with abstract value .*int32\\[16,128\\].* "
"output\\[0\\] with a different abstract value .*int32\\[8,128\\]"):
self.pallas_call(lambda x_ref, y_ref, o1_ref: None,
out_shape=[x],
input_output_aliases={1: 0})(x, y)
with self.assertRaisesRegex(
ValueError,
"input_output_aliases contains the mapping '1:0' referring to "
"input\\[1\\] with abstract value .*int32\\[8,128\\].* "
"output\\[0\\] with a different abstract value .*float32\\[8,128\\]"):
self.pallas_call(lambda x_ref, y_ref, o1_ref: None,
out_shape=[jax.ShapeDtypeStruct(x.shape, jnp.float32)],
input_output_aliases={1: 0})(x, x)
class ApiErrorInterpreterTest(ApiErrorTest):
INTERPRET = True