Deprecate support for custom lowering rules that return tuple-wrapped ir.Values.

https://github.com/google/jax/pull/22211 forbade custom lowering rules from returning singleton tuples of ir.Value, but this appears to break downstream users, notably Transformer Engine. Instead, allow lowering rules to return singleton tuples and unwrap them if needed, but warn if this behavior is seen.

PiperOrigin-RevId: 650345051
This commit is contained in:
Peter Hawkins 2024-07-08 12:54:04 -07:00 committed by jax authors
parent 0d57c72644
commit 262a4f482c
2 changed files with 15 additions and 5 deletions

View File

@ -25,6 +25,9 @@ Remember to align the itemized text with the first line of an item within a list
* Removed a number of previously-deprecated internal APIs related to
polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`,
`dimension_as_value`, `definitely_equal`, and `symbolic_equal_dim`.
* HLO lowering rules should no longer wrap singleton ir.Values in tuples.
Instead, return singleton ir.Values unwrapped. Support for wrapped values
will be removed in a future version of JAX.
## jaxlib 0.4.31

View File

@ -1555,8 +1555,18 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
def write(v: core.Var, node: IrValues):
assert node is not None
w: IrValues
w = node if isinstance(node, ir.Value) else tuple(node)
assert _is_ir_values(w), w
if isinstance(node, ir.Value):
w = node
else:
if len(node) == 1:
warnings.warn(
"JAX lowering rules should not wrap singleton values in tuples. "
"It will be an error to wrap a singleton value in a tuple in a "
"future version of JAX.",
DeprecationWarning, stacklevel=2)
w = node[0]
else:
w = tuple(node)
env[v] = w
def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None:
@ -1623,7 +1633,6 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
platform_rules, default_rule,
eqn.effects,
*in_nodes, **eqn.params)
assert all(_is_ir_values(v) for v in ans), (eqn, ans)
if effects:
# If there were ordered effects in the primitive, there should be output
@ -1646,8 +1655,6 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
raise ValueError("Output of translation rule must be iterable: "
f"{eqn}, got output {ans}") from e
# assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (
# ans, "lowering function returned a bad output", eqn)
assert len(ans) == len(eqn.outvars), (ans, eqn)
map(write, eqn.outvars, out_nodes)
core.clean_up_dead_vars(eqn, env, last_used)