diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index 66678076e..736d52bd3 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -1044,7 +1044,8 @@ def pallas_call( The default value for `out_specs` specifies the whole array, e.g., as ``pl.BlockSpec(x.shape, lambda *indices: indices)``. input_output_aliases: a dictionary mapping the index of some inputs to - the index of the output that aliases them. + the index of the output that aliases them. These indices are in the + flattened inputs and outputs. interpret: runs the ``pallas_call`` as a ``jax.jit`` of a scan over the grid whose body is the kernel lowered as a JAX function. This does not require a TPU or a GPU, and is the only way to run Pallas kernels on CPU. @@ -1086,6 +1087,27 @@ def pallas_call( raise ValueError( "The kernel function in a pallas_call should return None. " f"Found a PyTree: {f_out_tree}") + for i_idx, o_idx in input_output_aliases.items(): + if i_idx not in range(len(flat_in_avals)): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' with " + f"input index {i_idx} outside the range " + f"[0, {len(flat_in_avals)})") + if o_idx not in range(len(flat_out_avals)): + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' with " + f"output index {o_idx} outside the range " + f"[0, {len(flat_out_avals)})") + in_aval = flat_in_avals[i_idx] + out_aval = flat_out_avals[o_idx] + if in_aval.shape != out_aval.shape or in_aval.dtype != out_aval.dtype: + raise ValueError( + f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' " + f"referring to input{tree_util.keystr(in_paths[i_idx])} with " + f"abstract value {in_aval} " + f"and to output{tree_util.keystr(out_paths[o_idx])} with " + f"a different abstract value {out_aval}.") + out_flat = pallas_call_p.bind( *dynamic_grid_bounds, *consts, *flat_args, jaxpr=jaxpr, name=name, diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 5ac052834..a7e6ff7de 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -567,6 +567,44 @@ class ApiErrorTest(PallasBaseTest): "array shape"): f(a) + def test_pallas_call_input_output_aliases_errors(self): + x = np.arange(8 * 128, dtype=np.int32).reshape((8, 128)) + + with self.assertRaisesRegex( + ValueError, + "input_output_aliases contains the mapping '2:0' with input index 2 " + "outside the range .*"): + self.pallas_call(lambda x_ref, y_ref, o1_ref: None, + out_shape=[x], + input_output_aliases={2: 0})(x, x) + + with self.assertRaisesRegex( + ValueError, + "input_output_aliases contains the mapping '1:1' with output index 1 " + "outside the range .*"): + self.pallas_call(lambda x_ref, y_ref, o1_ref: None, + out_shape=[x], + input_output_aliases={1: 1})(x, x) + + y = np.concatenate([x, x], axis=0) + with self.assertRaisesRegex( + ValueError, + "input_output_aliases contains the mapping '1:0' referring to " + "input\\[1\\] with abstract value .*int32\\[16,128\\].* " + "output\\[0\\] with a different abstract value .*int32\\[8,128\\]"): + self.pallas_call(lambda x_ref, y_ref, o1_ref: None, + out_shape=[x], + input_output_aliases={1: 0})(x, y) + + with self.assertRaisesRegex( + ValueError, + "input_output_aliases contains the mapping '1:0' referring to " + "input\\[1\\] with abstract value .*int32\\[8,128\\].* " + "output\\[0\\] with a different abstract value .*float32\\[8,128\\]"): + self.pallas_call(lambda x_ref, y_ref, o1_ref: None, + out_shape=[jax.ShapeDtypeStruct(x.shape, jnp.float32)], + input_output_aliases={1: 0})(x, x) + class ApiErrorInterpreterTest(ApiErrorTest): INTERPRET = True