[Pallas] Allow input-output-aliasing in interpret mode on TPU

PiperOrigin-RevId: 581391427
This commit is contained in:
Sharad Vikram 2023-11-10 15:50:05 -08:00 committed by jax authors
parent d8c94ee281
commit ed6fbd0166

View File

@ -43,9 +43,6 @@ def pallas_call_tpu_lowering_rule(
mosaic_params: dict[str, Any] | None = None,
**compiler_params: Any):
"""Lowers a pallas_call to a Mosaic TPU custom call."""
if input_output_aliases:
raise NotImplementedError(
"`input_output_aliases` not supported on TPU backend.")
if interpret:
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
@ -54,6 +51,9 @@ def pallas_call_tpu_lowering_rule(
interpret=interpret, debug=debug,
input_output_aliases=input_output_aliases,
grid_mapping=grid_mapping, **compiler_params)
if input_output_aliases:
raise NotImplementedError(
"`input_output_aliases` not supported on TPU backend.")
if debug:
print(jaxpr)
mesh = None