diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c5cd34d1..b10c408fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,9 @@ Remember to align the itemized text with the first line of an item within a list * Removed a number of previously-deprecated internal APIs related to polymorphic shapes. From {mod}`jax.core`: removed `canonicalize_shape`, `dimension_as_value`, `definitely_equal`, and `symbolic_equal_dim`. + * HLO lowering rules should no longer wrap singleton ir.Values in tuples. + Instead, return singleton ir.Values unwrapped. Support for wrapped values + will be removed in a future version of JAX. ## jaxlib 0.4.31 diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index f3e5a3f2a..73f81a472 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1555,8 +1555,18 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, def write(v: core.Var, node: IrValues): assert node is not None w: IrValues - w = node if isinstance(node, ir.Value) else tuple(node) - assert _is_ir_values(w), w + if isinstance(node, ir.Value): + w = node + else: + if len(node) == 1: + warnings.warn( + "JAX lowering rules should not wrap singleton values in tuples. " + "It will be an error to wrap a singleton value in a tuple in a " + "future version of JAX.", + DeprecationWarning, stacklevel=2) + w = node[0] + else: + w = tuple(node) env[v] = w def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None: @@ -1623,7 +1633,6 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, platform_rules, default_rule, eqn.effects, *in_nodes, **eqn.params) - assert all(_is_ir_values(v) for v in ans), (eqn, ans) if effects: # If there were ordered effects in the primitive, there should be output @@ -1646,8 +1655,6 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, raise ValueError("Output of translation rule must be iterable: " f"{eqn}, got output {ans}") from e - # assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), ( - # ans, "lowering function returned a bad output", eqn) assert len(ans) == len(eqn.outvars), (ans, eqn) map(write, eqn.outvars, out_nodes) core.clean_up_dead_vars(eqn, env, last_used)