mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[pallas] Improve error messages for input_output_aliasing
Currently, we get MLIR verification errors when the inputs and outputs declared to be aliased do not have matching shapes and dtypes. We add a nicer error message that localizes the inputs and outputs in the corresponding PyTrees. Interestingly, if the output index is out of bounds, there is no MLIR verification error. This seems to be a bug in the StableHLO verification code. Currently, in interpreter mode we get a mix of internal assertion errors when there are errors in input_output_aliasing.
This commit is contained in:
parent
e3347700bb
commit
f960c287c4
@ -1044,7 +1044,8 @@ def pallas_call(
|
||||
The default value for `out_specs` specifies the whole array,
|
||||
e.g., as ``pl.BlockSpec(x.shape, lambda *indices: indices)``.
|
||||
input_output_aliases: a dictionary mapping the index of some inputs to
|
||||
the index of the output that aliases them.
|
||||
the index of the output that aliases them. These indices are in the
|
||||
flattened inputs and outputs.
|
||||
interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the
|
||||
grid whose body is the kernel lowered as a JAX function. This does not
|
||||
require a TPU or a GPU, and is the only way to run Pallas kernels on CPU.
|
||||
@ -1086,6 +1087,27 @@ def pallas_call(
|
||||
raise ValueError(
|
||||
"The kernel function in a pallas_call should return None. "
|
||||
f"Found a PyTree: {f_out_tree}")
|
||||
for i_idx, o_idx in input_output_aliases.items():
|
||||
if i_idx not in range(len(flat_in_avals)):
|
||||
raise ValueError(
|
||||
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' with "
|
||||
f"input index {i_idx} outside the range "
|
||||
f"[0, {len(flat_in_avals)})")
|
||||
if o_idx not in range(len(flat_out_avals)):
|
||||
raise ValueError(
|
||||
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' with "
|
||||
f"output index {o_idx} outside the range "
|
||||
f"[0, {len(flat_out_avals)})")
|
||||
in_aval = flat_in_avals[i_idx]
|
||||
out_aval = flat_out_avals[o_idx]
|
||||
if in_aval.shape != out_aval.shape or in_aval.dtype != out_aval.dtype:
|
||||
raise ValueError(
|
||||
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
|
||||
f"referring to input{tree_util.keystr(in_paths[i_idx])} with "
|
||||
f"abstract value {in_aval} "
|
||||
f"and to output{tree_util.keystr(out_paths[o_idx])} with "
|
||||
f"a different abstract value {out_aval}.")
|
||||
|
||||
out_flat = pallas_call_p.bind(
|
||||
*dynamic_grid_bounds, *consts, *flat_args,
|
||||
jaxpr=jaxpr, name=name,
|
||||
|
@ -567,6 +567,44 @@ class ApiErrorTest(PallasBaseTest):
|
||||
"array shape"):
|
||||
f(a)
|
||||
|
||||
def test_pallas_call_input_output_aliases_errors(self):
|
||||
x = np.arange(8 * 128, dtype=np.int32).reshape((8, 128))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"input_output_aliases contains the mapping '2:0' with input index 2 "
|
||||
"outside the range .*"):
|
||||
self.pallas_call(lambda x_ref, y_ref, o1_ref: None,
|
||||
out_shape=[x],
|
||||
input_output_aliases={2: 0})(x, x)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"input_output_aliases contains the mapping '1:1' with output index 1 "
|
||||
"outside the range .*"):
|
||||
self.pallas_call(lambda x_ref, y_ref, o1_ref: None,
|
||||
out_shape=[x],
|
||||
input_output_aliases={1: 1})(x, x)
|
||||
|
||||
y = np.concatenate([x, x], axis=0)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"input_output_aliases contains the mapping '1:0' referring to "
|
||||
"input\\[1\\] with abstract value .*int32\\[16,128\\].* "
|
||||
"output\\[0\\] with a different abstract value .*int32\\[8,128\\]"):
|
||||
self.pallas_call(lambda x_ref, y_ref, o1_ref: None,
|
||||
out_shape=[x],
|
||||
input_output_aliases={1: 0})(x, y)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"input_output_aliases contains the mapping '1:0' referring to "
|
||||
"input\\[1\\] with abstract value .*int32\\[8,128\\].* "
|
||||
"output\\[0\\] with a different abstract value .*float32\\[8,128\\]"):
|
||||
self.pallas_call(lambda x_ref, y_ref, o1_ref: None,
|
||||
out_shape=[jax.ShapeDtypeStruct(x.shape, jnp.float32)],
|
||||
input_output_aliases={1: 0})(x, x)
|
||||
|
||||
|
||||
class ApiErrorInterpreterTest(ApiErrorTest):
|
||||
INTERPRET = True
|
||||
|
Loading…
x
Reference in New Issue
Block a user