Don't wrap singleton ir.Values with tuples during HLO lowering.

In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.

To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
This commit is contained in:
Peter Hawkins 2024-07-01 08:42:48 -04:00
parent 9def0f1c00
commit 8ab0c07edc
23 changed files with 205 additions and 196 deletions

View File

@ -40,7 +40,6 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.lax import lax as lax_internal
from jax._src.lax import convolution as lax_convolution
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
@ -769,7 +768,7 @@ def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool,
differentiated: bool, policy, is_gpu_platform=False):
jaxpr_args: Sequence[Sequence[ir.Value]]
jaxpr_args: Sequence[mlir.IrValues]
if differentiated and prevent_cse:
# If we're using the loop or cond lowerings, use the slower lower_fun
# based path.
@ -780,11 +779,11 @@ def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool,
is_gpu_platform=is_gpu_platform)
arg_types = map(mlir.aval_to_ir_types, ctx.avals_in)
flat_args = mlir.flatten_lowering_ir_args(args)
flat_args = mlir.flatten_ir_values(args)
barrier_op = hlo.OptimizationBarrierOp(flat_args)
jaxpr_args = util.unflatten(barrier_op.results, map(len, arg_types))
jaxpr_args = mlir.unflatten_ir_values(barrier_op.results, map(len, arg_types))
else:
jaxpr_args = map(mlir.wrap_singleton_ir_values, args)
jaxpr_args = args
outs, tokens_out = mlir.jaxpr_subcomp(
ctx.module_context, jaxpr, ctx.name_stack.extend('checkpoint'),
ctx.tokens_in, (), *jaxpr_args, dim_var_values=ctx.dim_var_values)
@ -800,7 +799,7 @@ def _optimization_barrier_abstract_eval(*args):
def _optimization_barrier_lowering_rule(ctx, *args):
barrier_types = map(mlir.aval_to_ir_types, ctx.avals_in)
flat_args = mlir.flatten_lowering_ir_args(args)
flat_args = mlir.flatten_ir_values(args)
barrier_op = hlo.OptimizationBarrierOp(flat_args)
return util.unflatten(barrier_op.results, map(len, barrier_types))

View File

@ -1051,7 +1051,7 @@ basearray.Array.register(ArrayImpl)
def _array_mlir_constant_handler(val):
try:
return mlir.ir_constants(val._value)
return mlir.ir_constant(val._value)
except RuntimeError as e:
# TODO(yashkatariya): Ideally we would catch a custom exception from
# `_value` function in ArrayImpl instead of checking the error string.

View File

@ -441,7 +441,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding)
if ordered:
token = ctx.tokens_in.get(_OrderedIOEffect)[0]
token = ctx.tokens_in.get(_OrderedIOEffect)
result, token, _ = mlir.emit_python_callback(
ctx,
_callback,
@ -452,7 +452,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
has_side_effect=True,
sharding=op_sharding,
)
ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: (token,)}))
ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: token}))
else:
result, token, _ = mlir.emit_python_callback(
ctx,

View File

@ -436,11 +436,10 @@ core.custom_typechecks[custom_jvp_call_p] = _custom_jvp_call_typecheck
def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_thunk,
num_consts, symbolic_zeros):
del jvp_jaxpr_thunk, num_consts, symbolic_zeros
args_ = map(mlir.wrap_singleton_ir_values, args)
consts = mlir._ir_consts(call_jaxpr.consts)
out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr,
ctx.name_stack, ctx.tokens_in, consts,
*args_, dim_var_values=ctx.dim_var_values)
*args, dim_var_values=ctx.dim_var_values)
ctx.set_tokens_out(tokens)
return out
mlir.register_lowering(custom_jvp_call_p, _custom_jvp_call_mlir_translation)

View File

@ -159,11 +159,11 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
*flat_args, effect=effect, callback=callback, **params)
return ()
if effects.ordered_effects.contains(effect):
[token] = ctx.tokens_in.get(effect)
token = ctx.tokens_in.get(effect)
result, token, _ = mlir.emit_python_callback(
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out,
has_side_effect=True)
ctx.set_tokens_out(mlir.TokenSet({effect: (token,)}))
ctx.set_tokens_out(mlir.TokenSet({effect: token}))
else:
result, _, _ = mlir.emit_python_callback(
ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out,

View File

@ -23,7 +23,7 @@ import dataclasses
import functools
import itertools
import re
from typing import Any, Union
from typing import Any, Union, cast
import warnings
from absl import logging
@ -729,7 +729,7 @@ def _wrap_main_func(
context = mlir.make_ir_context()
with context, ir.Location.unknown(context):
# Make a copy, do not mutate because it may be cached
wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) # type: ignore
wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) # type: ignore[arg-type]
symbol_table = ir.SymbolTable(wrapped_module.operation)
orig_main = symbol_table["main"]
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
@ -838,7 +838,7 @@ def _wrap_main_func(
orig_main_args: list[ir.Value] = []
# The platform index and the dimension variables
for arg, arg_type in zip(
list(new_main_op.arguments[0:nr_platform_index_args]) + util.flatten(dim_values),
list(new_main_op.arguments[0:nr_platform_index_args]) + mlir.flatten_ir_values(dim_values),
platform_input_types + dim_var_input_types):
if arg.type != arg_type:
orig_main_args.append(hlo.convert(arg_type, arg))
@ -1327,7 +1327,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
if len(lowering_platforms) > 1:
current_platform_idx = ctx.dim_var_values[0]
else:
current_platform_idx = mlir.ir_constant(np.int32(0))
current_platform_idx = cast(ir.Value, mlir.ir_constant(np.int32(0)))
# Compute the rule index based on the current platform
i32_type = mlir.aval_to_ir_types(core.ShapedArray((), dtype=np.int32))[0]
if current_platform_idx.type != i32_type:
@ -1338,8 +1338,8 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
for i in range(len(lowering_platforms)):
branch = callee_platform_idx.regions[i].blocks.append()
with ir.InsertionPoint(branch):
hlo.return_(mlir.ir_constants(
np.int32(callee_lowering_platform_index[i])))
hlo.return_([mlir.ir_constant(
np.int32(callee_lowering_platform_index[i]))])
if callee_platform_idx.result.type != callee_type.inputs[0]:
callee_platform_idx = hlo.ConvertOp(callee_type.inputs[0],
callee_platform_idx)
@ -1350,7 +1350,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
ordered_effects = exported.ordered_effects
for eff in ordered_effects:
token_in = ctx.tokens_in.get(eff)[0]
token_in = ctx.tokens_in.get(eff)
submodule_args.append(token_in)
kept_args = [
convert_shape(a, a_aval, exported_in_aval)

View File

@ -17,7 +17,7 @@ from __future__ import annotations
import collections
import contextlib
from collections.abc import Callable, Iterator, Sequence
from collections.abc import Callable, Iterable, Iterator, Sequence
import dataclasses
import functools
from functools import partial
@ -87,6 +87,16 @@ lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects
# IR Helpers
IrValues = Union[ir.Value, tuple[ir.Value, ...]]
def _is_ir_values(x: IrValues) -> bool:
"""Returns true if `x` is an ir.Value or tuple of ir.Values"""
if isinstance(x, ir.Value):
return True
return (isinstance(x, tuple) and len(x) != 1
and all(isinstance(v, ir.Value) for v in x))
def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
return type_cast(ir.DenseIntElementsAttr,
ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)))
@ -103,7 +113,7 @@ def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
a, type=ir.IntegerType.get_signless(1), shape=[len(xs)])
def dense_bool_array(xs: Sequence[bool]) -> ir.DenseBoolArrayAttr:
return ir.DenseBoolArrayAttr.get(xs) # type: ignore
return ir.DenseBoolArrayAttr.get(xs) # type: ignore[arg-type]
def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i)
def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i)
@ -224,7 +234,7 @@ def aval_to_ir_type(aval: core.AbstractValue) -> ir.Type:
# Constants
class ConstantHandler(Protocol):
def __call__(self, val: Any) -> Sequence[ir.Value]:
def __call__(self, val: Any) -> IrValues:
"""Builds an IR representation for a constant `val`.
A JAX value is represented by zero or more IR values."""
@ -237,42 +247,33 @@ def register_constant_handler(type_: type, handler_fun: ConstantHandler):
def get_constant_handler(type_: type) -> ConstantHandler:
return _constant_handlers[type_]
def ir_constants(val: Any) -> Sequence[ir.Value]:
def ir_constant(val: Any) -> IrValues:
"""Translate a Python `val` to an IR constant, canonicalizing its dtype.
Args:
val: a Python value to be translated to a constant.
Returns:
A representation of the constant as a list of IR values.
A representation of the constant as an IR value or sequence of IR values.
"""
for t in type(val).__mro__:
handler = _constant_handlers.get(t)
if handler:
out = handler(val)
assert all(isinstance(v, ir.Value) for v in out), (type(val), out)
assert _is_ir_values(out), (type(val), out)
return out
if hasattr(val, '__jax_array__'):
return ir_constants(val.__jax_array__())
return ir_constant(val.__jax_array__())
raise TypeError(f"No constant handler for type: {type(val)}")
def ir_constant(val: Any) -> ir.Value:
"""Convenience wrapper around ir_constants for singleton values."""
values = ir_constants(val)
if len(values) != 1:
raise TypeError(f"ir_constant called on {val} which corresponds to "
f"multiple IR values {values}")
return values[0]
def _numpy_array_constant(x: np.ndarray | np.generic) -> Sequence[ir.Value]:
def _numpy_array_constant(x: np.ndarray | np.generic) -> IrValues:
element_type = dtype_to_ir_type(x.dtype)
shape = x.shape
if x.dtype == np.bool_:
x = np.packbits(x, bitorder='little') # type: ignore
x = np.ascontiguousarray(x)
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore
return (hlo.constant(attr),)
return hlo.constant(attr)
def _masked_array_constant_handler(*args, **kwargs):
@ -281,7 +282,7 @@ def _masked_array_constant_handler(*args, **kwargs):
register_constant_handler(np.ma.MaskedArray, _masked_array_constant_handler)
def _ndarray_constant_handler(val: np.ndarray | np.generic) -> Sequence[ir.Value]:
def _ndarray_constant_handler(val: np.ndarray | np.generic) -> IrValues:
"""Constant handler for ndarray literals, handling zero-size strides.
In most cases this function calls _numpy_array_constant(val) except it has
@ -308,9 +309,9 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic) -> Sequence[ir.Value
out = hlo.broadcast_in_dim(
ir.RankedTensorType.get(
val.shape, dtype_to_ir_type(collapsed_val.dtype)), # type: ignore
_numpy_array_constant(collapsed_val)[0],
_numpy_array_constant(collapsed_val),
dense_int_array(other_axes)) # type: ignore
return (out,)
return out
else:
return _numpy_array_constant(val)
@ -330,7 +331,7 @@ for ptype, dtype in dtypes.python_scalar_dtypes.items():
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
def _token_constant_handler(val):
return [hlo.create_token()]
return hlo.create_token()
register_constant_handler(core.Token, _token_constant_handler)
# Source locations
@ -725,16 +726,33 @@ def register_lowering(prim: core.Primitive, rule: LoweringRule,
return rule
def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x
def wrap_singleton_ir_values(x: ir.Value | Sequence[ir.Value]
) -> Sequence[ir.Value]:
"""Adds a consistent tuples to a mixture of tupled and untuple values."""
return (x,) if isinstance(x, ir.Value) else tuple(x)
def flatten_ir_values(xs: Iterable[IrValues]) -> list[ir.Value]:
"""Concatenates/flattens a list of ir.Values or ir.Value sequences."""
out = []
for x in xs:
if isinstance(x, ir.Value):
out.append(x)
else:
out.extend(x)
return out
_unflatten_done = object()
def unflatten_ir_values(xs: Iterable[ir.Value], ns: Sequence[int]) -> list[IrValues]:
"""Splits `xs` into subsequences of lengths `ns`.
Unlike `split_list`, the `sum(ns)` must be equal to `len(xs)`, and if n == 1
then values are not wrapped in a singleton list."""
xs_iter = iter(xs)
unflattened: list[IrValues]
unflattened = [next(xs_iter) if n == 1 else tuple(next(xs_iter)
for _ in range(n)) for n in ns]
assert next(xs_iter, _unflatten_done) is _unflatten_done
return unflattened
def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x
def flatten_lowering_ir_args(
xs: Sequence[ir.Value | Sequence[ir.Value]]
) -> Sequence[ir.Value]:
return util.flatten(map(wrap_singleton_ir_values, xs))
_module_name_regex = re.compile(r"[^\w.-]")
@ -764,7 +782,7 @@ def eval_dynamic_shape(ctx: LoweringRuleContext,
partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars),
multiple_results=True)(ctx, *ctx.dim_var_values)
return tuple(operator.index(d) if core.is_constant_dim(d) else d_ir
for d, d_ir in zip(shape, util.flatten(res)))
for d, d_ir in zip(shape, flatten_ir_values(res)))
# TODO: replace usage of eval_dynamic_shape_as_vals with eval_dynamic_shape_as_ivals
def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
@ -1036,13 +1054,13 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args,
return input_output_aliases, out_donated_args
Token = Sequence[ir.Value]
Token = ir.Value
def token_type() -> Sequence[ir.Type]:
return [hlo.TokenType.get()]
def create_token() -> Token:
return wrap_singleton_ir_values(hlo.create_token())
return hlo.create_token()
class TokenSet:
"""An immutable container of tokens to be used to lower effectful jaxprs. When lowering
@ -1388,24 +1406,24 @@ def lower_jaxpr_to_fun(
]
_, token_args, unflattened_args = util.split_list(
util.unflatten(flat_args, map(len, input_types)),
unflatten_ir_values(flat_args, map(len, input_types)),
[num_dim_vars, num_tokens])
tokens_in = TokenSet(zip(effects, token_args))
args: list[list[ir.Value]] = unflattened_args
args: list[IrValues] = unflattened_args
if name is not None:
callee_name_stack = name_stack.extend(util.wrap_name(name, api_name))
else:
callee_name_stack = name_stack
consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
consts = [ir_constant(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
out_vals, tokens_out = jaxpr_subcomp(
ctx, jaxpr.jaxpr, callee_name_stack, tokens_in,
consts, *args, dim_var_values=dim_var_values)
outs = []
outs: list[IrValues] = []
for eff in effects:
outs.append(wrap_singleton_ir_values(tokens_out.get(eff)))
outs.append(tokens_out.get(eff))
outs.extend(out_vals)
flat_outputs = util.flatten(outs)
flat_outputs = flatten_ir_values(outs)
if not use_sharding_annotations and ir_result_shardings is not None:
flat_outputs = [
@ -1483,24 +1501,25 @@ def _emit_lowering_rule_as_fun(lowering_rule,
ctx.module_context.symbol_table.insert(func_op)
entry_block = func_op.add_entry_block()
with ir.InsertionPoint(entry_block):
unflattened_args = util.unflatten(entry_block.arguments,
unflattened_args = unflatten_ir_values(entry_block.arguments,
map(len, input_types))
dim_var_values, token_args, unflattened_args = util.split_list(unflattened_args, [num_dim_vars, len(ctx.tokens_in)])
sub_ctx = ctx.replace(tokens_in=TokenSet(zip(effs, token_args)),
dim_var_values=dim_var_values)
outs = lowering_rule(sub_ctx, *_unwrap_singleton_ir_values(unflattened_args))
outs = lowering_rule(sub_ctx, *unflattened_args)
if sub_ctx.tokens_out:
outs = [*[sub_ctx.tokens_out.get(eff) for eff in effs], outs]
func_dialect.return_(util.flatten(map(wrap_singleton_ir_values, outs)))
func_dialect.return_(flatten_ir_values(outs))
return func_op
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
name_stack: source_info_util.NameStack,
tokens: TokenSet,
consts: Sequence[Sequence[ir.Value]],
*args: Sequence[ir.Value],
consts: Sequence[IrValues],
*args: IrValues,
dim_var_values: Sequence[ir.Value]
) -> tuple[Sequence[Sequence[ir.Value]], TokenSet]:
) -> tuple[Sequence[IrValues], TokenSet]:
"""Lowers a jaxpr into MLIR, inlined into an existing function.
Assumes that an MLIR context, location, and insertion point are set.
@ -1509,9 +1528,9 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
IR function, in the order of ctx.shape_poly_state.dim_vars.
"""
assert "gpu" not in ctx.platforms
def read(v: core.Atom) -> Sequence[ir.Value]:
def read(v: core.Atom) -> IrValues:
if type(v) is core.Literal:
return ir_constants(xla.canonicalize_dtype(v.val))
return ir_constant(xla.canonicalize_dtype(v.val))
else:
assert isinstance(v, core.Var)
return env[v]
@ -1522,9 +1541,12 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
else:
return v.aval
def write(v: core.Var, node: Sequence[ir.Value]):
def write(v: core.Var, node: IrValues):
assert node is not None
env[v] = tuple(node)
w: IrValues
w = node if isinstance(node, ir.Value) else tuple(node)
assert _is_ir_values(w), w
env[v] = w
def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None:
if ctx.lowering_parameters.override_lowering_rules is None:
@ -1534,12 +1556,13 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
return rule
return None
env: dict[core.Var, tuple[ir.Value, ...]] = {}
env: dict[core.Var, IrValues] = {}
assert all(_is_ir_values(v) for v in args), args
assert all(_is_ir_values(v) for v in consts), consts
assert isinstance(name_stack, source_info_util.NameStack), type(name_stack)
assert len(args) == len(jaxpr.invars), (jaxpr, args)
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts
assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values)
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
@ -1579,16 +1602,17 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
tokens_out=None, jaxpr_eqn_ctx=eqn.ctx, dim_var_values=dim_var_values)
if config.dynamic_shapes.value:
axis_size_env = {d: read(d)[0]
axis_size_env = {d: read(d)
for a in avals_in if type(a) is core.DShapedArray
for d in a.shape if type(d) is core.Var}
rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
rule_inputs = map(_unwrap_singleton_ir_values, in_nodes)
assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes)
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
platform_rules, default_rule,
eqn.effects,
*rule_inputs, **eqn.params)
*in_nodes, **eqn.params)
assert all(_is_ir_values(v) for v in ans), (eqn, ans)
if effects:
# If there were ordered effects in the primitive, there should be output
@ -1606,18 +1630,17 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
tokens = tokens.update_tokens(tokens_out)
try:
out_nodes = tuple(map(wrap_singleton_ir_values, ans))
out_nodes = tuple(ans)
except TypeError as e:
raise ValueError("Output of translation rule must be iterable: "
f"{eqn}, got output {ans}") from e
assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn)
assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (
ans, "lowering function returned a bad output", eqn)
# assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (
# ans, "lowering function returned a bad output", eqn)
assert len(ans) == len(eqn.outvars), (ans, eqn)
map(write, eqn.outvars, out_nodes)
core.clean_up_dead_vars(eqn, env, last_used)
return map(read, jaxpr.outvars), tokens
return tuple(read(v) for v in jaxpr.outvars), tokens
def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None
@ -1635,7 +1658,7 @@ def lower_per_platform(ctx: LoweringRuleContext,
platform_rules: dict[str, LoweringRule],
default_rule: LoweringRule | None,
effects: effects_lib.Effects,
*rule_args: ir.Value,
*rule_args: ir.Value | tuple[ir.Value, ...],
**rule_kwargs) -> Sequence[ir.Value]:
"""Emits code for a primitive for the current lowering platform(s).
@ -1705,9 +1728,8 @@ def lower_per_platform(ctx: LoweringRuleContext,
# If there is a single rule left just apply the rule, without conditionals.
if len(kept_rules) == 1:
output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
wrapped_out = map(wrap_singleton_ir_values, output)
map(lambda o: wrap_compute_type_in_place(ctx, o.owner),
util.flatten(wrapped_out))
flatten_ir_values(output))
return output
assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules)
@ -1725,7 +1747,7 @@ def lower_per_platform(ctx: LoweringRuleContext,
for i, p in enumerate(platforms):
branch = rule_idx_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
hlo.return_(ir_constants(np.int32(platform_to_kept_rules_idx[p])))
hlo.return_([ir_constant(np.int32(platform_to_kept_rules_idx[p]))])
ordered_effects = effects_lib.ordered_effects.filter_in(effects)
rule_out_avals = [core.abstract_token] * len(ordered_effects) + ctx.avals_out
output_types = map(aval_to_ir_types, rule_out_avals)
@ -1741,32 +1763,31 @@ def lower_per_platform(ctx: LoweringRuleContext,
with ir.InsertionPoint(branch):
output = rule(inner_ctx, *rule_args, **rule_kwargs)
try:
out_nodes = map(wrap_singleton_ir_values, output)
out_nodes = flatten_ir_values(output)
except TypeError as e:
raise ValueError("Output of translation rule must be iterable: "
f"{description}, got output {output}") from e
map(lambda o: wrap_compute_type_in_place(ctx, o.owner),
util.flatten(out_nodes))
map(lambda o: wrap_compute_type_in_place(ctx, o.owner), out_nodes)
if inner_ctx.tokens_out is not None:
assert len(ordered_effects) == len(inner_ctx.tokens_out)
out_nodes = [inner_ctx.tokens_out.get(eff)
for eff in ordered_effects] + out_nodes
hlo.return_(util.flatten(map(wrap_singleton_ir_values, out_nodes)))
hlo.return_(out_nodes)
results = case_op.results
if ordered_effects:
tokens, results = util.split_list(
util.unflatten(results, map(len, output_types)),
unflatten_ir_values(results, map(len, output_types)),
[len(ordered_effects)])
tokens_out = ctx.tokens_in.update_tokens(TokenSet(zip(ordered_effects,
tokens)))
ctx.set_tokens_out(tokens_out)
return results
def _ir_consts(consts):
def _ir_consts(consts) -> list[IrValues]:
unique_consts = {id(const): const for const in consts}
ir_consts = {
id_: ir_constants(xla.canonicalize_dtype(const))
id_: ir_constant(xla.canonicalize_dtype(const))
for id_, const in unique_consts.items()
}
return [ir_consts[id(const)] for const in consts]
@ -1810,7 +1831,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
sub_context = ctx.module_context
out, tokens = jaxpr_subcomp(
sub_context, jaxpr, ctx.name_stack, ctx.tokens_in,
_ir_consts(consts), *map(wrap_singleton_ir_values, args),
_ir_consts(consts), *args,
dim_var_values=ctx.dim_var_values)
ctx.set_tokens_out(tokens)
return out
@ -1854,6 +1875,7 @@ def check_backend_matches(inner_backend: str | None,
f"inner-jit backend specification {inner_backend}.")
def call_lowering(fn_name, name_stack, call_jaxpr, backend,
ctx: ModuleContext, avals_in,
avals_out, tokens_in, *args,
@ -1874,8 +1896,8 @@ def call_lowering(fn_name, name_stack, call_jaxpr, backend,
args = (*dim_var_values, *tokens, *args)
call = func_dialect.CallOp(flat_output_types,
ir.FlatSymbolRefAttr.get(symbol_name),
flatten_lowering_ir_args(args))
out_nodes = util.unflatten(call.results, map(len, output_types))
flatten_ir_values(args))
out_nodes = unflatten_ir_values(call.results, map(len, output_types))
tokens, out_nodes = util.split_list(out_nodes, [len(effects)])
tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens)))
return out_nodes, tokens_out
@ -2181,7 +2203,7 @@ def get_sharding_attr(sharding_proto: xc.OpSharding):
# The MHLO to HLO conversion supports both, and the proto representation is
# more compact.
if len(sharding_proto.tile_assignment_devices) > 100:
return ir.StringAttr.get(sharding_proto.SerializeToString()) # type: ignore
return ir.StringAttr.get(sharding_proto.SerializeToString()) # type: ignore[arg-type]
else:
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto)))
@ -2242,8 +2264,8 @@ def cache_lowering(f):
flat_output_types = util.flatten(output_types)
call = func_dialect.CallOp(flat_output_types,
ir.FlatSymbolRefAttr.get(func.name.value),
flatten_lowering_ir_args(args))
return util.unflatten(call.results, map(len, output_types))
flatten_ir_values(args))
return unflatten_ir_values(call.results, map(len, output_types))
return cached_lowering
@ -2355,13 +2377,13 @@ def xla_fallback_lowering(prim: core.Primitive):
call = func_dialect.CallOp([output_type],
ir.FlatSymbolRefAttr.get(callee_name),
flatten_lowering_ir_args(args)).result
flatten_ir_values(args)).result
if not prim.multiple_results:
return [call]
flat_results = [hlo.get_tuple_element(call, i32_attr(i))
for i in range(len(flat_output_types))]
return util.unflatten(flat_results, map(len, output_types))
return unflatten_ir_values(flat_results, map(len, output_types))
return fallback
@ -2489,7 +2511,7 @@ def emit_python_callback(
sharding: xc.OpSharding | None = None,
operand_layouts: Sequence[Sequence[int] | None] | None = None,
result_layouts: Sequence[Sequence[int] | None] | None = None,
) -> tuple[Sequence[ir.Value], Any, Any]:
) -> tuple[Sequence[IrValues], Any, Any]:
"""Emits MLIR that calls back to a provided Python function."""
if len(ctx.module_context.platforms) > 1:
raise NotImplementedError("multi-platform lowering for python_callback")
@ -2668,7 +2690,7 @@ def custom_call(
if backend_config is None:
backend_config_attr = ir.StringAttr.get("")
elif isinstance(backend_config, (str, bytes)):
backend_config_attr = ir.StringAttr.get(backend_config) # type: ignore
backend_config_attr = ir.StringAttr.get(backend_config) # type: ignore[arg-type]
elif isinstance(backend_config, dict):
# TODO(necula): it seems that the CustomCallOp constructor requires that
# backend_config_attr be a string attribute, even though in some cases we

View File

@ -1387,13 +1387,12 @@ def _unravel_index_hlo(axis_env):
mod = mlir.ir_constant(np.array(axis_env.sizes[-1], np.uint32))
return hlo.remainder(hlo.divide(hlo.replica_id(), div), mod)
def _hlo_shard(aval, axis_env, xs, in_axis):
def _hlo_shard(aval, axis_env, x, in_axis):
if aval is core.abstract_token:
return xs
return x
elif isinstance(aval, core.ShapedArray):
if dtypes.issubdtype(aval.dtype, dtypes.extended):
aval = aval.dtype._rules.physical_element_aval(aval.dtype)
x, = xs
dims = list(aval.shape)
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
idxs = [zero] * len(dims)
@ -1402,9 +1401,7 @@ def _hlo_shard(aval, axis_env, xs, in_axis):
dims_unsqueezed.insert(in_axis, 1)
dynamic_slice_result = hlo.dynamic_slice(
x, idxs, mlir.dense_int_array(dims_unsqueezed))
return [
hlo.reshape(mlir.aval_to_ir_type(aval), dynamic_slice_result)
]
return hlo.reshape(mlir.aval_to_ir_type(aval), dynamic_slice_result)
else:
raise TypeError(aval)
@ -1442,11 +1439,10 @@ def _axis_groups(mesh_spec, mesh_axes):
# TODO(b/110096942): more efficient gather
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs):
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, x):
if aval is core.abstract_token:
return xs
return x
elif isinstance(aval, core.ShapedArray):
x, = xs
dims = list(aval.shape)
padded_aval = aval.update(shape=[axis_env.sizes[-1]] + dims)
padded = mlir.full_like_aval(ctx, 0, padded_aval)
@ -1489,8 +1485,8 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
# Shard the in_nodes that are mapped
in_avals = [v.aval for v in call_jaxpr.invars]
in_nodes_sharded = (
_hlo_shard(aval, new_env, mlir.wrap_singleton_ir_values(in_node), in_axis)
if in_axis is not None else mlir.wrap_singleton_ir_values(in_node)
_hlo_shard(aval, new_env, in_node, in_axis)
if in_axis is not None else in_node
for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes))
with maybe_extend_axis_env(axis_name, global_axis_size, None):

View File

@ -873,16 +873,16 @@ def _cond_lowering(ctx, index, *args, branches, linear):
for i, jaxpr in enumerate(branches):
branch = case_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
consts = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
consts = [mlir.ir_constant(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
out_vals, tokens_out = mlir.jaxpr_subcomp(
ctx.module_context, jaxpr.jaxpr, name_stack.extend(f'branch_{i}_fun'),
tokens_in, consts, *map(mlir.wrap_singleton_ir_values, args),
tokens_in, consts, *args,
dim_var_values=ctx.dim_var_values)
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
out_vals = [*out_tokens, *out_vals]
hlo.return_(util.flatten(out_vals))
hlo.return_(mlir.flatten_ir_values(out_vals))
tokens_and_outputs = util.unflatten(case_op.results, map(len, output_types))
tokens_and_outputs = mlir.unflatten_ir_values(case_op.results, map(len, output_types))
tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens)))
return outputs
@ -1012,7 +1012,9 @@ def _platform_index_lowering(ctx: mlir.LoweringRuleContext,
def lower_constant(
ctx: mlir.LoweringRuleContext, *, i: int
) -> Sequence[ir.Value]:
return mlir.ir_constants(np.int32(i))
v = mlir.ir_constant(np.int32(i))
assert isinstance(v, ir.Value), v
return [v]
platform_rules: dict[str, mlir.LoweringRule] = {}
for i, ps in enumerate(platforms):
rule = partial(lower_constant, i=i)

View File

@ -1722,7 +1722,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
flat_loop_carry_types = util.flatten(loop_carry_types)
args = [*tokens, *args]
flat_args = mlir.flatten_lowering_ir_args(args)
flat_args = mlir.flatten_ir_values(args)
while_op = hlo.WhileOp(flat_loop_carry_types, flat_args)
# Loop condition
@ -1732,15 +1732,15 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
flat_cond_args = [
cond_block.arguments[i] for i in range(len(flat_loop_carry_types))
]
cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types))
cond_args = mlir.unflatten_ir_values(flat_cond_args, _map(len, loop_carry_types))
# Remove tokens from cond args
cond_args = cond_args[num_tokens:]
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
cond_consts = [
mlir.ir_constants(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts
mlir.ir_constant(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts
]
cond_name_stack = name_stack.extend('cond')
((pred,),), _ = mlir.jaxpr_subcomp(
(pred,), _ = mlir.jaxpr_subcomp(
ctx.module_context,
cond_jaxpr.jaxpr,
cond_name_stack,
@ -1772,13 +1772,13 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
flat_body_args = [
body_block.arguments[i] for i in range(len(flat_loop_carry_types))
]
body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types))
body_args = mlir.unflatten_ir_values(flat_body_args, _map(len, loop_carry_types))
# Tokens are at the front of the args list to the while loop
token_args, body_args = util.split_list(body_args, [num_tokens])
tokens_in = mlir.TokenSet(zip(body_effects, token_args))
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
body_name_stack = name_stack.extend('body')
body_consts = [mlir.ir_constants(xla.canonicalize_dtype(x))
body_consts = [mlir.ir_constant(xla.canonicalize_dtype(x))
for x in body_jaxpr.consts]
new_z, tokens_out = mlir.jaxpr_subcomp(
ctx.module_context, body_jaxpr.jaxpr, body_name_stack,
@ -1786,9 +1786,9 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
out_tokens = [tokens_out.get(eff) for eff in body_effects]
if batched:
body_pred_name_stack = name_stack.extend('body_pred')
cond_consts = [mlir.ir_constants(xla.canonicalize_dtype(x))
cond_consts = [mlir.ir_constant(xla.canonicalize_dtype(x))
for x in cond_jaxpr.consts]
((body_pred,),), _ = mlir.jaxpr_subcomp(
(body_pred,), _ = mlir.jaxpr_subcomp(
ctx.module_context, cond_jaxpr.jaxpr, body_pred_name_stack,
mlir.TokenSet(), cond_consts, *(x + z),
dim_var_values=ctx.dim_var_values)
@ -1796,10 +1796,10 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z,
body_jaxpr.out_avals)
hlo.return_([*util.flatten(out_tokens), *util.flatten(x), *util.flatten(y),
*util.flatten(new_z)])
hlo.return_([*mlir.flatten_ir_values(out_tokens), *mlir.flatten_ir_values(x), *mlir.flatten_ir_values(y),
*mlir.flatten_ir_values(new_z)])
outputs = util.unflatten(while_op.results, _map(len, loop_carry_types))
outputs = mlir.unflatten_ir_values(while_op.results, _map(len, loop_carry_types))
tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts])
if tokens:
ctx.set_tokens_out(mlir.TokenSet(zip(body_effects, tokens)))
@ -1909,16 +1909,14 @@ state_discharge.register_discharge_rule(while_p)(_while_discharge_rule)
def _pred_bcast_select_hlo(ctx,
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
pred_aval: core.ShapedArray, pred: ir.Value, x: mlir.IrValues,
y: mlir.IrValues, x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
if x_y_aval is core.abstract_token:
x, = xs
y, = ys
return [hlo.AfterAllOp([x, y]).result]
else:
assert isinstance(x, ir.Value), x
assert isinstance(y, ir.Value), y
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
x, = xs
y, = ys
assert x.type == y.type, (x.type, y.type)
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
pred_aval.shape, x_y_aval)

View File

@ -2156,8 +2156,8 @@ def _pow_lower(ctx, x, y):
partial(convert_element_type, new_dtype=out_aval.dtype), False)
x_aval_ = x_aval.update(dtype=out_aval.dtype)
y_aval_ = y_aval.update(dtype=out_aval.dtype)
[(x_,)] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x)
[(y_,)] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y)
[x_] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x)
[y_] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y)
ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_])
return _nary_lower_hlo(hlo.power, ctx_, x_, y_)
mlir.register_lowering(pow_p, _pow_lower)
@ -3980,9 +3980,9 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions):
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr.jaxpr,
name_stack, mlir.TokenSet(),
jaxpr.consts,
*([a] for a in reducer.arguments),
*reducer.arguments,
dim_var_values=ctx.dim_var_values)
hlo.return_(util.flatten(out_nodes))
hlo.return_(mlir.flatten_ir_values(out_nodes))
return op.results
mlir.register_lowering(reduce_p, _reduce_lower)
@ -4180,7 +4180,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
aval_out, = ctx.avals_out
dtype = aval_out.dtype
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x],
mlir.ir_constants(unit_factory(aval_out.dtype)),
[mlir.ir_constant(unit_factory(aval_out.dtype))],
mlir.dense_int_array(axes))
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype))
reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type)
@ -4350,7 +4350,7 @@ batching.primitive_batchers[sort_p] = _sort_batch_rule
def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
sort = hlo.SortOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
mlir.flatten_lowering_ir_args(operands),
mlir.flatten_ir_values(operands),
dimension=mlir.i64_attr(dimension),
is_stable=ir.BoolAttr.get(is_stable))
scalar_avals = [aval.update(shape=()) for aval in ctx.avals_in]
@ -4364,9 +4364,8 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
avals_in=util.flatten(zip(scalar_avals, scalar_avals)),
avals_out=[core.ShapedArray((), np.bool_)])
out = lower_comparator(sub_ctx, *[[a] for a in comparator.arguments],
num_keys=num_keys)
hlo.return_(util.flatten(out))
out = lower_comparator(sub_ctx, *comparator.arguments, num_keys=num_keys)
hlo.return_(mlir.flatten_ir_values(out))
return sort.results
mlir.register_lowering(sort_p, _sort_lower)
@ -4562,9 +4561,9 @@ def _infeed_lowering(ctx, token, *, shapes, partitions):
mlir.set_sharding(infeed, xla.sharding_to_proto(partitions))
token = infeed.results[-1]
outs = infeed.results[:-1]
return util.unflatten(outs, safe_map(len, output_types)) + [[
return mlir.unflatten_ir_values(outs, safe_map(len, output_types)) + [
token,
]]
]
mlir.register_lowering(infeed_p, _infeed_lowering)
@ -4596,7 +4595,7 @@ mlir.lowerable_effects.add_type(InOutFeedEffect)
def _outfeed_lowering(ctx, token, *xs, partitions):
outfeed = hlo.OutfeedOp(
mlir.flatten_lowering_ir_args(xs),
mlir.flatten_ir_values(xs),
token,
outfeed_config=ir.StringAttr.get(''))
if partitions is not None:
@ -4638,7 +4637,7 @@ rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
def _rng_uniform_lowering(ctx, a, b, *, shape):
aval_out, = ctx.avals_out
shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64))
shape = mlir.ir_constant(np.array(aval_out.shape, np.int64))
return [hlo.rng(a, b, shape, hlo.RngDistributionAttr.get('UNIFORM'))]
mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering)
@ -5173,7 +5172,7 @@ empty_p.def_abstract_eval(lambda *, dtype: core.ShapedArray((), dtype))
def _empty_lower(ctx, *, dtype):
dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype)
phys_aval = core.physical_aval(core.ShapedArray((), dtype))
return mlir.ir_constants(np.zeros(phys_aval.shape, phys_aval.dtype))
return mlir.ir_constant(np.zeros(phys_aval.shape, phys_aval.dtype)),
mlir.register_lowering(empty_p, _empty_lower)

View File

@ -1097,10 +1097,10 @@ def _triangular_solve_cpu_lower(
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape)
return [lapack.trsm_hlo(
return lapack.trsm_hlo(
a_aval.dtype, alpha,
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal,
b_shape_vals=b_shape_vals)]
b_shape_vals=b_shape_vals)
else:
# Fall back to the HLO implementation for unsupported types or batching.
# TODO: Consider swapping XLA for LAPACK in batched case
@ -1189,7 +1189,7 @@ def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *,
def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *,
permutation_size):
return [lowering(pivots, permutation_size=permutation_size)]
return lowering(pivots, permutation_size=permutation_size)
lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation')

View File

@ -27,7 +27,6 @@ from jax import tree_util
from jax._src import core
from jax._src import dtypes
from jax._src import sharding_impls
from jax._src import util
from jax._src.core import AxisName, ShapedArray, raise_to_shaped
from jax._src.interpreters import ad
from jax._src.interpreters import batching
@ -772,7 +771,7 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
shape=np.delete(np.array(aval.shape, dtype=np.int64),
positional_axes))
reducer_ctx = ctx.replace(primitive=None, avals_in=[aval], avals_out=[aval_out])
out, = reducer(reducer_ctx, arg, axes=tuple(positional_axes))[0]
out, = reducer(reducer_ctx, arg, axes=tuple(positional_axes))
return out
args = map(_positional_reduce, ctx.avals_in, args)
if not named_axes:
@ -805,9 +804,8 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
lower_reducer = mlir.lower_fun(prim.bind, multiple_results=False)
reducer_ctx = ctx.replace(primitive=None,
avals_in=[scalar_aval] * 2, avals_out=[scalar_aval])
out_nodes = lower_reducer(
reducer_ctx, *([a] for a in reducer_block.arguments))
hlo.return_(util.flatten(out_nodes))
out_nodes = lower_reducer(reducer_ctx, *reducer_block.arguments)
hlo.return_(mlir.flatten_ir_values(out_nodes))
return op.result
return [all_reduce(aval, x) for aval, x in zip(ctx.avals_in, args)]
@ -1410,9 +1408,8 @@ def _reduce_scatter_lowering(
reducer_ctx = ctx.replace(primitive=None,
avals_in=[scalar_aval] * 2,
avals_out=[scalar_aval])
out_nodes = lower_reducer(
reducer_ctx, *([a] for a in reducer_block.arguments))
hlo.return_(util.flatten(out_nodes))
out_nodes = lower_reducer(reducer_ctx, *reducer_block.arguments)
hlo.return_(mlir.flatten_ir_values(out_nodes))
if tiled:
return op.results

View File

@ -2473,7 +2473,7 @@ def _scatter_lower(ctx, operand, indices, updates, *,
if mode == GatherScatterMode.CLIP:
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
(indices,), = clip_fn(ctx.replace(avals_out=None), operand, indices,
(indices,) = clip_fn(ctx.replace(avals_out=None), operand, indices,
updates, dnums=dimension_numbers)
dnums = dimension_numbers
@ -2504,9 +2504,9 @@ def _scatter_lower(ctx, operand, indices, updates, *,
raise NotImplementedError('Cannot lower effectful `scatter`.')
out_nodes, _ = mlir.jaxpr_subcomp(
ctx.module_context, update_jaxpr, name_stack, mlir.TokenSet(),
update_consts, (update.arguments[0],), (update.arguments[1],),
update_consts, update.arguments[0], update.arguments[1],
dim_var_values=ctx.dim_var_values)
hlo.return_(util.flatten(out_nodes))
hlo.return_(mlir.flatten_ir_values(out_nodes))
return op.results
mlir.register_lowering(scatter_p, _scatter_lower)
@ -2532,7 +2532,7 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
if mode == GatherScatterMode.CLIP:
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
(indices,), = clip_fn(ctx.replace(avals_out=None), operand, indices, updates,
indices, = clip_fn(ctx.replace(avals_out=None), operand, indices, updates,
dnums=dimension_numbers)
aval_out, = ctx.avals_out

View File

@ -441,9 +441,9 @@ def _generic_reduce_window_lower(
if jaxpr.effects:
raise NotImplementedError('Cannot lower effectful `reduce_window`.')
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, ctx.name_stack,
mlir.TokenSet(), consts, *([a] for a in reducer.arguments),
mlir.TokenSet(), consts, *reducer.arguments, # type: ignore[misc]
dim_var_values=ctx.dim_var_values)
return util.flatten(out_nodes)
return mlir.flatten_ir_values(out_nodes)
return mlir.reduce_window(
ctx,
@ -675,9 +675,9 @@ def _select_and_scatter_lower(
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr,
ctx.name_stack,
mlir.TokenSet(), select_consts,
*([a] for a in select.arguments),
*select.arguments,
dim_var_values=ctx.dim_var_values)
hlo.return_(util.flatten(out_nodes))
hlo.return_(mlir.flatten_ir_values(out_nodes))
scatter = op.scatter.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(scatter):
if scatter_jaxpr.effects:
@ -685,9 +685,9 @@ def _select_and_scatter_lower(
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr,
ctx.name_stack,
mlir.TokenSet(), scatter_consts,
*([a] for a in scatter.arguments),
*scatter.arguments,
dim_var_values=ctx.dim_var_values)
hlo.return_(util.flatten(out_nodes))
hlo.return_(mlir.flatten_ir_values(out_nodes))
return op.results
mlir.register_lowering(select_and_scatter_p, _select_and_scatter_lower)

View File

@ -1350,7 +1350,7 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
# them!
vectorized_jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(f, local_avals)
# _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
const_nodes = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in consts]
const_nodes = [mlir.ir_constant(xla.canonicalize_dtype(x)) for x in consts]
local_mesh_shape = mesh.local_mesh.shape
tiled_ins = (
@ -1426,14 +1426,14 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
vectorized_jaxpr, global_out_avals, consts, () = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
sharded_global_in_nodes = [
[mlir.wrap_with_sharding_op(
mlir.wrap_with_sharding_op(
ctx, node, aval,
NamedSharding(mesh, array_mapping_to_axis_resources(aval_axes)
)._to_xla_hlo_sharding(aval.ndim).to_proto())]
if aval_axes else [node]
)._to_xla_hlo_sharding(aval.ndim).to_proto())
if aval_axes else node
for node, aval, aval_axes in zip(global_in_nodes, global_in_avals, mesh_in_axes)
]
const_nodes = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in consts]
const_nodes = [mlir.ir_constant(xla.canonicalize_dtype(x)) for x in consts]
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
@ -1451,7 +1451,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
NamedSharding(mesh, array_mapping_to_axis_resources(aval_axes)
)._to_xla_hlo_sharding(aval.ndim).to_proto())
if aval_axes else node
for (node,), aval, aval_axes in zip(global_out_nodes, global_out_avals, mesh_out_axes)
for node, aval, aval_axes in zip(global_out_nodes, global_out_avals, mesh_out_axes)
]
return sharded_global_out_nodes
@ -1485,7 +1485,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
# them!
global_in_avals = ctx.avals_in
vectorized_jaxpr, global_out_avals, consts, () = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
const_nodes = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in consts]
const_nodes = [mlir.ir_constant(xla.canonicalize_dtype(x)) for x in consts]
# We in-line here rather than generating a Call HLO as in the xla_call
# translation rule just because the extra tuple stuff is a pain.
@ -1498,9 +1498,8 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
in vectorized_jaxpr.effects):
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
global_out_nodes, _ = mlir.jaxpr_subcomp(
sub_ctx, vectorized_jaxpr, name_stack,
mlir.TokenSet(), const_nodes, *([n] for n in global_in_nodes),
dim_var_values=ctx.dim_var_values)
sub_ctx, vectorized_jaxpr, name_stack, mlir.TokenSet(), const_nodes,
*global_in_nodes, dim_var_values=ctx.dim_var_values)
return global_out_nodes

View File

@ -81,7 +81,7 @@ from jax._src.tree_util import (
from jax._src.util import (
HashableFunction, safe_map, safe_zip, wraps,
distributed_debug_log, split_list, weakref_lru_cache,
merge_lists, flatten, unflatten, subs_list, fun_name)
merge_lists, flatten, subs_list, fun_name)
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
@ -1926,9 +1926,9 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
args = (*ctx.dim_var_values, *tokens_in, *args)
call = func_dialect.CallOp(flat_output_types,
ir.FlatSymbolRefAttr.get(func.name.value),
mlir.flatten_lowering_ir_args(args))
mlir.flatten_ir_values(args))
mlir.wrap_compute_type_in_place(ctx, call)
out_nodes = unflatten(call.results, map(len, output_types))
out_nodes = mlir.unflatten_ir_values(call.results, map(len, output_types))
tokens, out_nodes = split_list(out_nodes, [len(effects)])
tokens_out = ctx.tokens_in.update_tokens(mlir.TokenSet(zip(effects, tokens)))
ctx.set_tokens_out(tokens_out)

View File

@ -648,7 +648,7 @@ def emit_tf_embedded_graph_custom_call(
util.flatten([mlir.aval_to_ir_types(aval) for aval in result_avals])
)
if ordered:
operands.insert(0, ctx.tokens_in.get(call_tf_ordered_effect)[0])
operands.insert(0, ctx.tokens_in.get(call_tf_ordered_effect))
result_types.insert(0, mlir.token_type()[0])
custom_call = hlo.CustomCallOp(
@ -668,7 +668,7 @@ def emit_tf_embedded_graph_custom_call(
results = list(custom_call.results)
if ordered:
token = results.pop(0)
ctx.set_tokens_out(mlir.TokenSet({call_tf_ordered_effect: (token,)}))
ctx.set_tokens_out(mlir.TokenSet({call_tf_ordered_effect: token}))
return results

View File

@ -677,8 +677,7 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
return [mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified)]
def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
aval_in, aval_out, xs):
x, = xs
aval_in, aval_out, x):
axes = {name: i for i, ns in names.items() for name in ns}
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):

View File

@ -671,7 +671,7 @@ def _bcsr_dot_general_gpu_lowering(
# Account for a bug in cusparse: it references indices and data beyond
# the extent of indptr.
(lhs_data,), (lhs_indices,) = _bcsr_correct_out_of_bound_indices_lowered(
lhs_data, lhs_indices = _bcsr_correct_out_of_bound_indices_lowered(
ctx, lhs_data, lhs_indices, lhs_indptr, rhs, shape=lhs_spinfo.shape)
if rhs_aval.ndim == 1:

View File

@ -40,14 +40,13 @@ from jax._src.interpreters.mlir import (
dense_int_elements as dense_int_elements,
dtype_to_ir_type as dtype_to_ir_type,
emit_python_callback as emit_python_callback,
flatten_lowering_ir_args as flatten_lowering_ir_args,
flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me
func_dialect as func_dialect,
hlo as hlo,
i32_attr as i32_attr,
i64_attr as i64_attr,
ir as ir,
ir_constant as ir_constant,
ir_constants as ir_constants,
ir_type_handlers as ir_type_handlers,
jaxpr_subcomp as jaxpr_subcomp,
lower_fun as lower_fun,

View File

@ -103,8 +103,8 @@ testing_primitive_with_effect_p.def_effectful_abstract_eval(
def lowering_testing_primitive_with_effect(ctx, a, *, effect_class_name: str):
if "Ordered" in effect_class_name:
token_in = ctx.tokens_in.get(_testing_effects[effect_class_name])[0]
ctx.set_tokens_out(mlir.TokenSet({_testing_effects[effect_class_name]: (token_in,)}))
token_in = ctx.tokens_in.get(_testing_effects[effect_class_name])
ctx.set_tokens_out(mlir.TokenSet({_testing_effects[effect_class_name]: token_in}))
return [mlir.hlo.add(a, a)]
mlir.register_lowering(testing_primitive_with_effect_p,

View File

@ -97,7 +97,7 @@ def function_effect_lowering(ctx, *, effect):
flat_output_types = util.flatten(output_types)
call = mlir.func_dialect.CallOp(flat_output_types,
mlir.ir.FlatSymbolRefAttr.get(func.name.value),
mlir.flatten_lowering_ir_args(in_tokens))
mlir.flatten_ir_values(in_tokens))
tokens, out = util.split_list(call.results, [len(ctx.tokens_in)])
ctx.set_tokens_out(mlir.TokenSet(zip(effs, tokens)))
return out
@ -120,7 +120,7 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out
del out_avals
token_in = None
if effects.ordered_effects.contains(effect):
token_in = ctx.tokens_in.get(effect)[0]
token_in = ctx.tokens_in.get(effect)
out_op, token_out, _ = mlir.emit_python_callback(
ctx, callback, token_in, list(args), list(ctx.avals_in),