mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
[Pallas] Allow input-output-aliasing in interpret mode on TPU
PiperOrigin-RevId: 581391427
This commit is contained in:
parent
d8c94ee281
commit
ed6fbd0166
@ -43,9 +43,6 @@ def pallas_call_tpu_lowering_rule(
|
||||
mosaic_params: dict[str, Any] | None = None,
|
||||
**compiler_params: Any):
|
||||
"""Lowers a pallas_call to a Mosaic TPU custom call."""
|
||||
if input_output_aliases:
|
||||
raise NotImplementedError(
|
||||
"`input_output_aliases` not supported on TPU backend.")
|
||||
if interpret:
|
||||
return mlir.lower_fun(pallas_call_p.impl, multiple_results=True)(
|
||||
ctx, *in_nodes, jaxpr=jaxpr, name=name, out_shapes=out_shapes,
|
||||
@ -54,6 +51,9 @@ def pallas_call_tpu_lowering_rule(
|
||||
interpret=interpret, debug=debug,
|
||||
input_output_aliases=input_output_aliases,
|
||||
grid_mapping=grid_mapping, **compiler_params)
|
||||
if input_output_aliases:
|
||||
raise NotImplementedError(
|
||||
"`input_output_aliases` not supported on TPU backend.")
|
||||
if debug:
|
||||
print(jaxpr)
|
||||
mesh = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user