mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Don't wrap singleton ir.Values with tuples during HLO lowering.
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons. To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
This commit is contained in:
parent
9def0f1c00
commit
8ab0c07edc
@ -40,7 +40,6 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import convolution as lax_convolution
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
|
||||
@ -769,7 +768,7 @@ def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
|
||||
|
||||
def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool,
|
||||
differentiated: bool, policy, is_gpu_platform=False):
|
||||
jaxpr_args: Sequence[Sequence[ir.Value]]
|
||||
jaxpr_args: Sequence[mlir.IrValues]
|
||||
if differentiated and prevent_cse:
|
||||
# If we're using the loop or cond lowerings, use the slower lower_fun
|
||||
# based path.
|
||||
@ -780,11 +779,11 @@ def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool,
|
||||
is_gpu_platform=is_gpu_platform)
|
||||
|
||||
arg_types = map(mlir.aval_to_ir_types, ctx.avals_in)
|
||||
flat_args = mlir.flatten_lowering_ir_args(args)
|
||||
flat_args = mlir.flatten_ir_values(args)
|
||||
barrier_op = hlo.OptimizationBarrierOp(flat_args)
|
||||
jaxpr_args = util.unflatten(barrier_op.results, map(len, arg_types))
|
||||
jaxpr_args = mlir.unflatten_ir_values(barrier_op.results, map(len, arg_types))
|
||||
else:
|
||||
jaxpr_args = map(mlir.wrap_singleton_ir_values, args)
|
||||
jaxpr_args = args
|
||||
outs, tokens_out = mlir.jaxpr_subcomp(
|
||||
ctx.module_context, jaxpr, ctx.name_stack.extend('checkpoint'),
|
||||
ctx.tokens_in, (), *jaxpr_args, dim_var_values=ctx.dim_var_values)
|
||||
@ -800,7 +799,7 @@ def _optimization_barrier_abstract_eval(*args):
|
||||
|
||||
def _optimization_barrier_lowering_rule(ctx, *args):
|
||||
barrier_types = map(mlir.aval_to_ir_types, ctx.avals_in)
|
||||
flat_args = mlir.flatten_lowering_ir_args(args)
|
||||
flat_args = mlir.flatten_ir_values(args)
|
||||
barrier_op = hlo.OptimizationBarrierOp(flat_args)
|
||||
return util.unflatten(barrier_op.results, map(len, barrier_types))
|
||||
|
||||
|
@ -1051,7 +1051,7 @@ basearray.Array.register(ArrayImpl)
|
||||
|
||||
def _array_mlir_constant_handler(val):
|
||||
try:
|
||||
return mlir.ir_constants(val._value)
|
||||
return mlir.ir_constant(val._value)
|
||||
except RuntimeError as e:
|
||||
# TODO(yashkatariya): Ideally we would catch a custom exception from
|
||||
# `_value` function in ArrayImpl instead of checking the error string.
|
||||
|
@ -441,7 +441,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
|
||||
|
||||
op_sharding = _callback_op_sharding(ctx.module_context.axis_context, sharding)
|
||||
if ordered:
|
||||
token = ctx.tokens_in.get(_OrderedIOEffect)[0]
|
||||
token = ctx.tokens_in.get(_OrderedIOEffect)
|
||||
result, token, _ = mlir.emit_python_callback(
|
||||
ctx,
|
||||
_callback,
|
||||
@ -452,7 +452,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
|
||||
has_side_effect=True,
|
||||
sharding=op_sharding,
|
||||
)
|
||||
ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: (token,)}))
|
||||
ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: token}))
|
||||
else:
|
||||
result, token, _ = mlir.emit_python_callback(
|
||||
ctx,
|
||||
|
@ -436,11 +436,10 @@ core.custom_typechecks[custom_jvp_call_p] = _custom_jvp_call_typecheck
|
||||
def _custom_jvp_call_mlir_translation(ctx, *args, call_jaxpr, jvp_jaxpr_thunk,
|
||||
num_consts, symbolic_zeros):
|
||||
del jvp_jaxpr_thunk, num_consts, symbolic_zeros
|
||||
args_ = map(mlir.wrap_singleton_ir_values, args)
|
||||
consts = mlir._ir_consts(call_jaxpr.consts)
|
||||
out, tokens = mlir.jaxpr_subcomp(ctx.module_context, call_jaxpr.jaxpr,
|
||||
ctx.name_stack, ctx.tokens_in, consts,
|
||||
*args_, dim_var_values=ctx.dim_var_values)
|
||||
*args, dim_var_values=ctx.dim_var_values)
|
||||
ctx.set_tokens_out(tokens)
|
||||
return out
|
||||
mlir.register_lowering(custom_jvp_call_p, _custom_jvp_call_mlir_translation)
|
||||
|
@ -159,11 +159,11 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
|
||||
*flat_args, effect=effect, callback=callback, **params)
|
||||
return ()
|
||||
if effects.ordered_effects.contains(effect):
|
||||
[token] = ctx.tokens_in.get(effect)
|
||||
token = ctx.tokens_in.get(effect)
|
||||
result, token, _ = mlir.emit_python_callback(
|
||||
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out,
|
||||
has_side_effect=True)
|
||||
ctx.set_tokens_out(mlir.TokenSet({effect: (token,)}))
|
||||
ctx.set_tokens_out(mlir.TokenSet({effect: token}))
|
||||
else:
|
||||
result, _, _ = mlir.emit_python_callback(
|
||||
ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out,
|
||||
|
@ -23,7 +23,7 @@ import dataclasses
|
||||
import functools
|
||||
import itertools
|
||||
import re
|
||||
from typing import Any, Union
|
||||
from typing import Any, Union, cast
|
||||
import warnings
|
||||
|
||||
from absl import logging
|
||||
@ -729,7 +729,7 @@ def _wrap_main_func(
|
||||
context = mlir.make_ir_context()
|
||||
with context, ir.Location.unknown(context):
|
||||
# Make a copy, do not mutate because it may be cached
|
||||
wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) # type: ignore
|
||||
wrapped_module = ir.Module.parse(mlir.module_to_bytecode(module)) # type: ignore[arg-type]
|
||||
symbol_table = ir.SymbolTable(wrapped_module.operation)
|
||||
orig_main = symbol_table["main"]
|
||||
orig_main.attributes["sym_visibility"] = ir.StringAttr.get("private")
|
||||
@ -838,7 +838,7 @@ def _wrap_main_func(
|
||||
orig_main_args: list[ir.Value] = []
|
||||
# The platform index and the dimension variables
|
||||
for arg, arg_type in zip(
|
||||
list(new_main_op.arguments[0:nr_platform_index_args]) + util.flatten(dim_values),
|
||||
list(new_main_op.arguments[0:nr_platform_index_args]) + mlir.flatten_ir_values(dim_values),
|
||||
platform_input_types + dim_var_input_types):
|
||||
if arg.type != arg_type:
|
||||
orig_main_args.append(hlo.convert(arg_type, arg))
|
||||
@ -1327,7 +1327,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
if len(lowering_platforms) > 1:
|
||||
current_platform_idx = ctx.dim_var_values[0]
|
||||
else:
|
||||
current_platform_idx = mlir.ir_constant(np.int32(0))
|
||||
current_platform_idx = cast(ir.Value, mlir.ir_constant(np.int32(0)))
|
||||
# Compute the rule index based on the current platform
|
||||
i32_type = mlir.aval_to_ir_types(core.ShapedArray((), dtype=np.int32))[0]
|
||||
if current_platform_idx.type != i32_type:
|
||||
@ -1338,8 +1338,8 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
for i in range(len(lowering_platforms)):
|
||||
branch = callee_platform_idx.regions[i].blocks.append()
|
||||
with ir.InsertionPoint(branch):
|
||||
hlo.return_(mlir.ir_constants(
|
||||
np.int32(callee_lowering_platform_index[i])))
|
||||
hlo.return_([mlir.ir_constant(
|
||||
np.int32(callee_lowering_platform_index[i]))])
|
||||
if callee_platform_idx.result.type != callee_type.inputs[0]:
|
||||
callee_platform_idx = hlo.ConvertOp(callee_type.inputs[0],
|
||||
callee_platform_idx)
|
||||
@ -1350,7 +1350,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
|
||||
ordered_effects = exported.ordered_effects
|
||||
for eff in ordered_effects:
|
||||
token_in = ctx.tokens_in.get(eff)[0]
|
||||
token_in = ctx.tokens_in.get(eff)
|
||||
submodule_args.append(token_in)
|
||||
kept_args = [
|
||||
convert_shape(a, a_aval, exported_in_aval)
|
||||
|
@ -17,7 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
from collections.abc import Callable, Iterable, Iterator, Sequence
|
||||
import dataclasses
|
||||
import functools
|
||||
from functools import partial
|
||||
@ -87,6 +87,16 @@ lowerable_effects: effects_lib.EffectTypeSet = effects_lib.lowerable_effects
|
||||
|
||||
# IR Helpers
|
||||
|
||||
IrValues = Union[ir.Value, tuple[ir.Value, ...]]
|
||||
|
||||
def _is_ir_values(x: IrValues) -> bool:
|
||||
"""Returns true if `x` is an ir.Value or tuple of ir.Values"""
|
||||
if isinstance(x, ir.Value):
|
||||
return True
|
||||
return (isinstance(x, tuple) and len(x) != 1
|
||||
and all(isinstance(v, ir.Value) for v in x))
|
||||
|
||||
|
||||
def dense_int_elements(xs) -> ir.DenseIntElementsAttr:
|
||||
return type_cast(ir.DenseIntElementsAttr,
|
||||
ir.DenseIntElementsAttr.get(np.asarray(xs, np.int64)))
|
||||
@ -103,7 +113,7 @@ def dense_bool_elements(xs: Sequence[bool]) -> ir.DenseElementsAttr:
|
||||
a, type=ir.IntegerType.get_signless(1), shape=[len(xs)])
|
||||
|
||||
def dense_bool_array(xs: Sequence[bool]) -> ir.DenseBoolArrayAttr:
|
||||
return ir.DenseBoolArrayAttr.get(xs) # type: ignore
|
||||
return ir.DenseBoolArrayAttr.get(xs) # type: ignore[arg-type]
|
||||
|
||||
def i32_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(32), i)
|
||||
def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i)
|
||||
@ -224,7 +234,7 @@ def aval_to_ir_type(aval: core.AbstractValue) -> ir.Type:
|
||||
# Constants
|
||||
|
||||
class ConstantHandler(Protocol):
|
||||
def __call__(self, val: Any) -> Sequence[ir.Value]:
|
||||
def __call__(self, val: Any) -> IrValues:
|
||||
"""Builds an IR representation for a constant `val`.
|
||||
|
||||
A JAX value is represented by zero or more IR values."""
|
||||
@ -237,42 +247,33 @@ def register_constant_handler(type_: type, handler_fun: ConstantHandler):
|
||||
def get_constant_handler(type_: type) -> ConstantHandler:
|
||||
return _constant_handlers[type_]
|
||||
|
||||
def ir_constants(val: Any) -> Sequence[ir.Value]:
|
||||
def ir_constant(val: Any) -> IrValues:
|
||||
"""Translate a Python `val` to an IR constant, canonicalizing its dtype.
|
||||
|
||||
Args:
|
||||
val: a Python value to be translated to a constant.
|
||||
|
||||
Returns:
|
||||
A representation of the constant as a list of IR values.
|
||||
A representation of the constant as an IR value or sequence of IR values.
|
||||
"""
|
||||
for t in type(val).__mro__:
|
||||
handler = _constant_handlers.get(t)
|
||||
if handler:
|
||||
out = handler(val)
|
||||
assert all(isinstance(v, ir.Value) for v in out), (type(val), out)
|
||||
assert _is_ir_values(out), (type(val), out)
|
||||
return out
|
||||
if hasattr(val, '__jax_array__'):
|
||||
return ir_constants(val.__jax_array__())
|
||||
return ir_constant(val.__jax_array__())
|
||||
raise TypeError(f"No constant handler for type: {type(val)}")
|
||||
|
||||
def ir_constant(val: Any) -> ir.Value:
|
||||
"""Convenience wrapper around ir_constants for singleton values."""
|
||||
values = ir_constants(val)
|
||||
if len(values) != 1:
|
||||
raise TypeError(f"ir_constant called on {val} which corresponds to "
|
||||
f"multiple IR values {values}")
|
||||
return values[0]
|
||||
|
||||
|
||||
def _numpy_array_constant(x: np.ndarray | np.generic) -> Sequence[ir.Value]:
|
||||
def _numpy_array_constant(x: np.ndarray | np.generic) -> IrValues:
|
||||
element_type = dtype_to_ir_type(x.dtype)
|
||||
shape = x.shape
|
||||
if x.dtype == np.bool_:
|
||||
x = np.packbits(x, bitorder='little') # type: ignore
|
||||
x = np.ascontiguousarray(x)
|
||||
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore
|
||||
return (hlo.constant(attr),)
|
||||
return hlo.constant(attr)
|
||||
|
||||
|
||||
def _masked_array_constant_handler(*args, **kwargs):
|
||||
@ -281,7 +282,7 @@ def _masked_array_constant_handler(*args, **kwargs):
|
||||
|
||||
register_constant_handler(np.ma.MaskedArray, _masked_array_constant_handler)
|
||||
|
||||
def _ndarray_constant_handler(val: np.ndarray | np.generic) -> Sequence[ir.Value]:
|
||||
def _ndarray_constant_handler(val: np.ndarray | np.generic) -> IrValues:
|
||||
"""Constant handler for ndarray literals, handling zero-size strides.
|
||||
|
||||
In most cases this function calls _numpy_array_constant(val) except it has
|
||||
@ -308,9 +309,9 @@ def _ndarray_constant_handler(val: np.ndarray | np.generic) -> Sequence[ir.Value
|
||||
out = hlo.broadcast_in_dim(
|
||||
ir.RankedTensorType.get(
|
||||
val.shape, dtype_to_ir_type(collapsed_val.dtype)), # type: ignore
|
||||
_numpy_array_constant(collapsed_val)[0],
|
||||
_numpy_array_constant(collapsed_val),
|
||||
dense_int_array(other_axes)) # type: ignore
|
||||
return (out,)
|
||||
return out
|
||||
else:
|
||||
return _numpy_array_constant(val)
|
||||
|
||||
@ -330,7 +331,7 @@ for ptype, dtype in dtypes.python_scalar_dtypes.items():
|
||||
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
|
||||
|
||||
def _token_constant_handler(val):
|
||||
return [hlo.create_token()]
|
||||
return hlo.create_token()
|
||||
register_constant_handler(core.Token, _token_constant_handler)
|
||||
|
||||
# Source locations
|
||||
@ -725,16 +726,33 @@ def register_lowering(prim: core.Primitive, rule: LoweringRule,
|
||||
return rule
|
||||
|
||||
|
||||
def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x
|
||||
def wrap_singleton_ir_values(x: ir.Value | Sequence[ir.Value]
|
||||
) -> Sequence[ir.Value]:
|
||||
"""Adds a consistent tuples to a mixture of tupled and untuple values."""
|
||||
return (x,) if isinstance(x, ir.Value) else tuple(x)
|
||||
def flatten_ir_values(xs: Iterable[IrValues]) -> list[ir.Value]:
|
||||
"""Concatenates/flattens a list of ir.Values or ir.Value sequences."""
|
||||
out = []
|
||||
for x in xs:
|
||||
if isinstance(x, ir.Value):
|
||||
out.append(x)
|
||||
else:
|
||||
out.extend(x)
|
||||
return out
|
||||
|
||||
|
||||
_unflatten_done = object()
|
||||
|
||||
def unflatten_ir_values(xs: Iterable[ir.Value], ns: Sequence[int]) -> list[IrValues]:
|
||||
"""Splits `xs` into subsequences of lengths `ns`.
|
||||
|
||||
Unlike `split_list`, the `sum(ns)` must be equal to `len(xs)`, and if n == 1
|
||||
then values are not wrapped in a singleton list."""
|
||||
xs_iter = iter(xs)
|
||||
unflattened: list[IrValues]
|
||||
unflattened = [next(xs_iter) if n == 1 else tuple(next(xs_iter)
|
||||
for _ in range(n)) for n in ns]
|
||||
assert next(xs_iter, _unflatten_done) is _unflatten_done
|
||||
return unflattened
|
||||
|
||||
def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x
|
||||
|
||||
def flatten_lowering_ir_args(
|
||||
xs: Sequence[ir.Value | Sequence[ir.Value]]
|
||||
) -> Sequence[ir.Value]:
|
||||
return util.flatten(map(wrap_singleton_ir_values, xs))
|
||||
|
||||
_module_name_regex = re.compile(r"[^\w.-]")
|
||||
|
||||
@ -764,7 +782,7 @@ def eval_dynamic_shape(ctx: LoweringRuleContext,
|
||||
partial(core.evaluate_shape, shape, ctx.module_context.shape_poly_state.dim_vars),
|
||||
multiple_results=True)(ctx, *ctx.dim_var_values)
|
||||
return tuple(operator.index(d) if core.is_constant_dim(d) else d_ir
|
||||
for d, d_ir in zip(shape, util.flatten(res)))
|
||||
for d, d_ir in zip(shape, flatten_ir_values(res)))
|
||||
|
||||
# TODO: replace usage of eval_dynamic_shape_as_vals with eval_dynamic_shape_as_ivals
|
||||
def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
|
||||
@ -1036,13 +1054,13 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args,
|
||||
|
||||
return input_output_aliases, out_donated_args
|
||||
|
||||
Token = Sequence[ir.Value]
|
||||
Token = ir.Value
|
||||
|
||||
def token_type() -> Sequence[ir.Type]:
|
||||
return [hlo.TokenType.get()]
|
||||
|
||||
def create_token() -> Token:
|
||||
return wrap_singleton_ir_values(hlo.create_token())
|
||||
return hlo.create_token()
|
||||
|
||||
class TokenSet:
|
||||
"""An immutable container of tokens to be used to lower effectful jaxprs. When lowering
|
||||
@ -1388,24 +1406,24 @@ def lower_jaxpr_to_fun(
|
||||
]
|
||||
|
||||
_, token_args, unflattened_args = util.split_list(
|
||||
util.unflatten(flat_args, map(len, input_types)),
|
||||
unflatten_ir_values(flat_args, map(len, input_types)),
|
||||
[num_dim_vars, num_tokens])
|
||||
tokens_in = TokenSet(zip(effects, token_args))
|
||||
args: list[list[ir.Value]] = unflattened_args
|
||||
args: list[IrValues] = unflattened_args
|
||||
if name is not None:
|
||||
callee_name_stack = name_stack.extend(util.wrap_name(name, api_name))
|
||||
else:
|
||||
callee_name_stack = name_stack
|
||||
consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
|
||||
consts = [ir_constant(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
|
||||
out_vals, tokens_out = jaxpr_subcomp(
|
||||
ctx, jaxpr.jaxpr, callee_name_stack, tokens_in,
|
||||
consts, *args, dim_var_values=dim_var_values)
|
||||
outs = []
|
||||
outs: list[IrValues] = []
|
||||
for eff in effects:
|
||||
outs.append(wrap_singleton_ir_values(tokens_out.get(eff)))
|
||||
outs.append(tokens_out.get(eff))
|
||||
outs.extend(out_vals)
|
||||
|
||||
flat_outputs = util.flatten(outs)
|
||||
flat_outputs = flatten_ir_values(outs)
|
||||
|
||||
if not use_sharding_annotations and ir_result_shardings is not None:
|
||||
flat_outputs = [
|
||||
@ -1483,24 +1501,25 @@ def _emit_lowering_rule_as_fun(lowering_rule,
|
||||
ctx.module_context.symbol_table.insert(func_op)
|
||||
entry_block = func_op.add_entry_block()
|
||||
with ir.InsertionPoint(entry_block):
|
||||
unflattened_args = util.unflatten(entry_block.arguments,
|
||||
unflattened_args = unflatten_ir_values(entry_block.arguments,
|
||||
map(len, input_types))
|
||||
dim_var_values, token_args, unflattened_args = util.split_list(unflattened_args, [num_dim_vars, len(ctx.tokens_in)])
|
||||
sub_ctx = ctx.replace(tokens_in=TokenSet(zip(effs, token_args)),
|
||||
dim_var_values=dim_var_values)
|
||||
outs = lowering_rule(sub_ctx, *_unwrap_singleton_ir_values(unflattened_args))
|
||||
outs = lowering_rule(sub_ctx, *unflattened_args)
|
||||
if sub_ctx.tokens_out:
|
||||
outs = [*[sub_ctx.tokens_out.get(eff) for eff in effs], outs]
|
||||
func_dialect.return_(util.flatten(map(wrap_singleton_ir_values, outs)))
|
||||
func_dialect.return_(flatten_ir_values(outs))
|
||||
return func_op
|
||||
|
||||
|
||||
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
name_stack: source_info_util.NameStack,
|
||||
tokens: TokenSet,
|
||||
consts: Sequence[Sequence[ir.Value]],
|
||||
*args: Sequence[ir.Value],
|
||||
consts: Sequence[IrValues],
|
||||
*args: IrValues,
|
||||
dim_var_values: Sequence[ir.Value]
|
||||
) -> tuple[Sequence[Sequence[ir.Value]], TokenSet]:
|
||||
) -> tuple[Sequence[IrValues], TokenSet]:
|
||||
"""Lowers a jaxpr into MLIR, inlined into an existing function.
|
||||
|
||||
Assumes that an MLIR context, location, and insertion point are set.
|
||||
@ -1509,9 +1528,9 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
IR function, in the order of ctx.shape_poly_state.dim_vars.
|
||||
"""
|
||||
assert "gpu" not in ctx.platforms
|
||||
def read(v: core.Atom) -> Sequence[ir.Value]:
|
||||
def read(v: core.Atom) -> IrValues:
|
||||
if type(v) is core.Literal:
|
||||
return ir_constants(xla.canonicalize_dtype(v.val))
|
||||
return ir_constant(xla.canonicalize_dtype(v.val))
|
||||
else:
|
||||
assert isinstance(v, core.Var)
|
||||
return env[v]
|
||||
@ -1522,9 +1541,12 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
else:
|
||||
return v.aval
|
||||
|
||||
def write(v: core.Var, node: Sequence[ir.Value]):
|
||||
def write(v: core.Var, node: IrValues):
|
||||
assert node is not None
|
||||
env[v] = tuple(node)
|
||||
w: IrValues
|
||||
w = node if isinstance(node, ir.Value) else tuple(node)
|
||||
assert _is_ir_values(w), w
|
||||
env[v] = w
|
||||
|
||||
def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None:
|
||||
if ctx.lowering_parameters.override_lowering_rules is None:
|
||||
@ -1534,12 +1556,13 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
return rule
|
||||
return None
|
||||
|
||||
env: dict[core.Var, tuple[ir.Value, ...]] = {}
|
||||
env: dict[core.Var, IrValues] = {}
|
||||
|
||||
assert all(_is_ir_values(v) for v in args), args
|
||||
assert all(_is_ir_values(v) for v in consts), consts
|
||||
assert isinstance(name_stack, source_info_util.NameStack), type(name_stack)
|
||||
assert len(args) == len(jaxpr.invars), (jaxpr, args)
|
||||
assert len(consts) == len(jaxpr.constvars), (jaxpr, consts)
|
||||
assert all(isinstance(v, ir.Value) for vs in consts for v in vs), consts
|
||||
assert len(ctx.shape_poly_state.dim_vars) == len(dim_var_values), (ctx.shape_poly_state.dim_vars, dim_var_values)
|
||||
map(write, jaxpr.constvars, consts)
|
||||
map(write, jaxpr.invars, args)
|
||||
@ -1579,16 +1602,17 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
avals_out=map(aval, eqn.outvars), tokens_in=tokens_in,
|
||||
tokens_out=None, jaxpr_eqn_ctx=eqn.ctx, dim_var_values=dim_var_values)
|
||||
if config.dynamic_shapes.value:
|
||||
axis_size_env = {d: read(d)[0]
|
||||
axis_size_env = {d: read(d)
|
||||
for a in avals_in if type(a) is core.DShapedArray
|
||||
for d in a.shape if type(d) is core.Var}
|
||||
rule_ctx = rule_ctx.replace(axis_size_env=axis_size_env)
|
||||
|
||||
rule_inputs = map(_unwrap_singleton_ir_values, in_nodes)
|
||||
assert all(_is_ir_values(v) for v in in_nodes), (eqn, in_nodes)
|
||||
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
|
||||
platform_rules, default_rule,
|
||||
eqn.effects,
|
||||
*rule_inputs, **eqn.params)
|
||||
*in_nodes, **eqn.params)
|
||||
assert all(_is_ir_values(v) for v in ans), (eqn, ans)
|
||||
|
||||
if effects:
|
||||
# If there were ordered effects in the primitive, there should be output
|
||||
@ -1606,18 +1630,17 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
tokens = tokens.update_tokens(tokens_out)
|
||||
|
||||
try:
|
||||
out_nodes = tuple(map(wrap_singleton_ir_values, ans))
|
||||
out_nodes = tuple(ans)
|
||||
except TypeError as e:
|
||||
raise ValueError("Output of translation rule must be iterable: "
|
||||
f"{eqn}, got output {ans}") from e
|
||||
|
||||
assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn)
|
||||
assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (
|
||||
ans, "lowering function returned a bad output", eqn)
|
||||
# assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (
|
||||
# ans, "lowering function returned a bad output", eqn)
|
||||
assert len(ans) == len(eqn.outvars), (ans, eqn)
|
||||
map(write, eqn.outvars, out_nodes)
|
||||
core.clean_up_dead_vars(eqn, env, last_used)
|
||||
return map(read, jaxpr.outvars), tokens
|
||||
return tuple(read(v) for v in jaxpr.outvars), tokens
|
||||
|
||||
|
||||
def _platforms_for_eqn_ctx(eqn_ctx: core.JaxprEqnContext | None
|
||||
@ -1635,7 +1658,7 @@ def lower_per_platform(ctx: LoweringRuleContext,
|
||||
platform_rules: dict[str, LoweringRule],
|
||||
default_rule: LoweringRule | None,
|
||||
effects: effects_lib.Effects,
|
||||
*rule_args: ir.Value,
|
||||
*rule_args: ir.Value | tuple[ir.Value, ...],
|
||||
**rule_kwargs) -> Sequence[ir.Value]:
|
||||
"""Emits code for a primitive for the current lowering platform(s).
|
||||
|
||||
@ -1705,9 +1728,8 @@ def lower_per_platform(ctx: LoweringRuleContext,
|
||||
# If there is a single rule left just apply the rule, without conditionals.
|
||||
if len(kept_rules) == 1:
|
||||
output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
|
||||
wrapped_out = map(wrap_singleton_ir_values, output)
|
||||
map(lambda o: wrap_compute_type_in_place(ctx, o.owner),
|
||||
util.flatten(wrapped_out))
|
||||
flatten_ir_values(output))
|
||||
return output
|
||||
|
||||
assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules)
|
||||
@ -1725,7 +1747,7 @@ def lower_per_platform(ctx: LoweringRuleContext,
|
||||
for i, p in enumerate(platforms):
|
||||
branch = rule_idx_op.regions[i].blocks.append()
|
||||
with ir.InsertionPoint(branch):
|
||||
hlo.return_(ir_constants(np.int32(platform_to_kept_rules_idx[p])))
|
||||
hlo.return_([ir_constant(np.int32(platform_to_kept_rules_idx[p]))])
|
||||
ordered_effects = effects_lib.ordered_effects.filter_in(effects)
|
||||
rule_out_avals = [core.abstract_token] * len(ordered_effects) + ctx.avals_out
|
||||
output_types = map(aval_to_ir_types, rule_out_avals)
|
||||
@ -1741,32 +1763,31 @@ def lower_per_platform(ctx: LoweringRuleContext,
|
||||
with ir.InsertionPoint(branch):
|
||||
output = rule(inner_ctx, *rule_args, **rule_kwargs)
|
||||
try:
|
||||
out_nodes = map(wrap_singleton_ir_values, output)
|
||||
out_nodes = flatten_ir_values(output)
|
||||
except TypeError as e:
|
||||
raise ValueError("Output of translation rule must be iterable: "
|
||||
f"{description}, got output {output}") from e
|
||||
map(lambda o: wrap_compute_type_in_place(ctx, o.owner),
|
||||
util.flatten(out_nodes))
|
||||
map(lambda o: wrap_compute_type_in_place(ctx, o.owner), out_nodes)
|
||||
if inner_ctx.tokens_out is not None:
|
||||
assert len(ordered_effects) == len(inner_ctx.tokens_out)
|
||||
out_nodes = [inner_ctx.tokens_out.get(eff)
|
||||
for eff in ordered_effects] + out_nodes
|
||||
hlo.return_(util.flatten(map(wrap_singleton_ir_values, out_nodes)))
|
||||
hlo.return_(out_nodes)
|
||||
|
||||
results = case_op.results
|
||||
if ordered_effects:
|
||||
tokens, results = util.split_list(
|
||||
util.unflatten(results, map(len, output_types)),
|
||||
unflatten_ir_values(results, map(len, output_types)),
|
||||
[len(ordered_effects)])
|
||||
tokens_out = ctx.tokens_in.update_tokens(TokenSet(zip(ordered_effects,
|
||||
tokens)))
|
||||
ctx.set_tokens_out(tokens_out)
|
||||
return results
|
||||
|
||||
def _ir_consts(consts):
|
||||
def _ir_consts(consts) -> list[IrValues]:
|
||||
unique_consts = {id(const): const for const in consts}
|
||||
ir_consts = {
|
||||
id_: ir_constants(xla.canonicalize_dtype(const))
|
||||
id_: ir_constant(xla.canonicalize_dtype(const))
|
||||
for id_, const in unique_consts.items()
|
||||
}
|
||||
return [ir_consts[id(const)] for const in consts]
|
||||
@ -1810,7 +1831,7 @@ def lower_fun(fun: Callable, multiple_results: bool = True) -> Callable:
|
||||
sub_context = ctx.module_context
|
||||
out, tokens = jaxpr_subcomp(
|
||||
sub_context, jaxpr, ctx.name_stack, ctx.tokens_in,
|
||||
_ir_consts(consts), *map(wrap_singleton_ir_values, args),
|
||||
_ir_consts(consts), *args,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
ctx.set_tokens_out(tokens)
|
||||
return out
|
||||
@ -1854,6 +1875,7 @@ def check_backend_matches(inner_backend: str | None,
|
||||
f"inner-jit backend specification {inner_backend}.")
|
||||
|
||||
|
||||
|
||||
def call_lowering(fn_name, name_stack, call_jaxpr, backend,
|
||||
ctx: ModuleContext, avals_in,
|
||||
avals_out, tokens_in, *args,
|
||||
@ -1874,8 +1896,8 @@ def call_lowering(fn_name, name_stack, call_jaxpr, backend,
|
||||
args = (*dim_var_values, *tokens, *args)
|
||||
call = func_dialect.CallOp(flat_output_types,
|
||||
ir.FlatSymbolRefAttr.get(symbol_name),
|
||||
flatten_lowering_ir_args(args))
|
||||
out_nodes = util.unflatten(call.results, map(len, output_types))
|
||||
flatten_ir_values(args))
|
||||
out_nodes = unflatten_ir_values(call.results, map(len, output_types))
|
||||
tokens, out_nodes = util.split_list(out_nodes, [len(effects)])
|
||||
tokens_out = tokens_in.update_tokens(TokenSet(zip(effects, tokens)))
|
||||
return out_nodes, tokens_out
|
||||
@ -2181,7 +2203,7 @@ def get_sharding_attr(sharding_proto: xc.OpSharding):
|
||||
# The MHLO to HLO conversion supports both, and the proto representation is
|
||||
# more compact.
|
||||
if len(sharding_proto.tile_assignment_devices) > 100:
|
||||
return ir.StringAttr.get(sharding_proto.SerializeToString()) # type: ignore
|
||||
return ir.StringAttr.get(sharding_proto.SerializeToString()) # type: ignore[arg-type]
|
||||
else:
|
||||
return ir.StringAttr.get(repr(xc.HloSharding.from_proto(sharding_proto)))
|
||||
|
||||
@ -2242,8 +2264,8 @@ def cache_lowering(f):
|
||||
flat_output_types = util.flatten(output_types)
|
||||
call = func_dialect.CallOp(flat_output_types,
|
||||
ir.FlatSymbolRefAttr.get(func.name.value),
|
||||
flatten_lowering_ir_args(args))
|
||||
return util.unflatten(call.results, map(len, output_types))
|
||||
flatten_ir_values(args))
|
||||
return unflatten_ir_values(call.results, map(len, output_types))
|
||||
return cached_lowering
|
||||
|
||||
|
||||
@ -2355,13 +2377,13 @@ def xla_fallback_lowering(prim: core.Primitive):
|
||||
|
||||
call = func_dialect.CallOp([output_type],
|
||||
ir.FlatSymbolRefAttr.get(callee_name),
|
||||
flatten_lowering_ir_args(args)).result
|
||||
flatten_ir_values(args)).result
|
||||
if not prim.multiple_results:
|
||||
return [call]
|
||||
flat_results = [hlo.get_tuple_element(call, i32_attr(i))
|
||||
for i in range(len(flat_output_types))]
|
||||
|
||||
return util.unflatten(flat_results, map(len, output_types))
|
||||
return unflatten_ir_values(flat_results, map(len, output_types))
|
||||
return fallback
|
||||
|
||||
|
||||
@ -2489,7 +2511,7 @@ def emit_python_callback(
|
||||
sharding: xc.OpSharding | None = None,
|
||||
operand_layouts: Sequence[Sequence[int] | None] | None = None,
|
||||
result_layouts: Sequence[Sequence[int] | None] | None = None,
|
||||
) -> tuple[Sequence[ir.Value], Any, Any]:
|
||||
) -> tuple[Sequence[IrValues], Any, Any]:
|
||||
"""Emits MLIR that calls back to a provided Python function."""
|
||||
if len(ctx.module_context.platforms) > 1:
|
||||
raise NotImplementedError("multi-platform lowering for python_callback")
|
||||
@ -2668,7 +2690,7 @@ def custom_call(
|
||||
if backend_config is None:
|
||||
backend_config_attr = ir.StringAttr.get("")
|
||||
elif isinstance(backend_config, (str, bytes)):
|
||||
backend_config_attr = ir.StringAttr.get(backend_config) # type: ignore
|
||||
backend_config_attr = ir.StringAttr.get(backend_config) # type: ignore[arg-type]
|
||||
elif isinstance(backend_config, dict):
|
||||
# TODO(necula): it seems that the CustomCallOp constructor requires that
|
||||
# backend_config_attr be a string attribute, even though in some cases we
|
||||
|
@ -1387,13 +1387,12 @@ def _unravel_index_hlo(axis_env):
|
||||
mod = mlir.ir_constant(np.array(axis_env.sizes[-1], np.uint32))
|
||||
return hlo.remainder(hlo.divide(hlo.replica_id(), div), mod)
|
||||
|
||||
def _hlo_shard(aval, axis_env, xs, in_axis):
|
||||
def _hlo_shard(aval, axis_env, x, in_axis):
|
||||
if aval is core.abstract_token:
|
||||
return xs
|
||||
return x
|
||||
elif isinstance(aval, core.ShapedArray):
|
||||
if dtypes.issubdtype(aval.dtype, dtypes.extended):
|
||||
aval = aval.dtype._rules.physical_element_aval(aval.dtype)
|
||||
x, = xs
|
||||
dims = list(aval.shape)
|
||||
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
|
||||
idxs = [zero] * len(dims)
|
||||
@ -1402,9 +1401,7 @@ def _hlo_shard(aval, axis_env, xs, in_axis):
|
||||
dims_unsqueezed.insert(in_axis, 1)
|
||||
dynamic_slice_result = hlo.dynamic_slice(
|
||||
x, idxs, mlir.dense_int_array(dims_unsqueezed))
|
||||
return [
|
||||
hlo.reshape(mlir.aval_to_ir_type(aval), dynamic_slice_result)
|
||||
]
|
||||
return hlo.reshape(mlir.aval_to_ir_type(aval), dynamic_slice_result)
|
||||
else:
|
||||
raise TypeError(aval)
|
||||
|
||||
@ -1442,11 +1439,10 @@ def _axis_groups(mesh_spec, mesh_axes):
|
||||
|
||||
|
||||
# TODO(b/110096942): more efficient gather
|
||||
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs):
|
||||
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, x):
|
||||
if aval is core.abstract_token:
|
||||
return xs
|
||||
return x
|
||||
elif isinstance(aval, core.ShapedArray):
|
||||
x, = xs
|
||||
dims = list(aval.shape)
|
||||
padded_aval = aval.update(shape=[axis_env.sizes[-1]] + dims)
|
||||
padded = mlir.full_like_aval(ctx, 0, padded_aval)
|
||||
@ -1489,8 +1485,8 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
|
||||
# Shard the in_nodes that are mapped
|
||||
in_avals = [v.aval for v in call_jaxpr.invars]
|
||||
in_nodes_sharded = (
|
||||
_hlo_shard(aval, new_env, mlir.wrap_singleton_ir_values(in_node), in_axis)
|
||||
if in_axis is not None else mlir.wrap_singleton_ir_values(in_node)
|
||||
_hlo_shard(aval, new_env, in_node, in_axis)
|
||||
if in_axis is not None else in_node
|
||||
for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes))
|
||||
|
||||
with maybe_extend_axis_env(axis_name, global_axis_size, None):
|
||||
|
@ -873,16 +873,16 @@ def _cond_lowering(ctx, index, *args, branches, linear):
|
||||
for i, jaxpr in enumerate(branches):
|
||||
branch = case_op.regions[i].blocks.append()
|
||||
with ir.InsertionPoint(branch):
|
||||
consts = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
|
||||
consts = [mlir.ir_constant(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
|
||||
out_vals, tokens_out = mlir.jaxpr_subcomp(
|
||||
ctx.module_context, jaxpr.jaxpr, name_stack.extend(f'branch_{i}_fun'),
|
||||
tokens_in, consts, *map(mlir.wrap_singleton_ir_values, args),
|
||||
tokens_in, consts, *args,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
|
||||
out_vals = [*out_tokens, *out_vals]
|
||||
hlo.return_(util.flatten(out_vals))
|
||||
hlo.return_(mlir.flatten_ir_values(out_vals))
|
||||
|
||||
tokens_and_outputs = util.unflatten(case_op.results, map(len, output_types))
|
||||
tokens_and_outputs = mlir.unflatten_ir_values(case_op.results, map(len, output_types))
|
||||
tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
|
||||
ctx.set_tokens_out(mlir.TokenSet(zip(ordered_effects, tokens)))
|
||||
return outputs
|
||||
@ -1012,7 +1012,9 @@ def _platform_index_lowering(ctx: mlir.LoweringRuleContext,
|
||||
def lower_constant(
|
||||
ctx: mlir.LoweringRuleContext, *, i: int
|
||||
) -> Sequence[ir.Value]:
|
||||
return mlir.ir_constants(np.int32(i))
|
||||
v = mlir.ir_constant(np.int32(i))
|
||||
assert isinstance(v, ir.Value), v
|
||||
return [v]
|
||||
platform_rules: dict[str, mlir.LoweringRule] = {}
|
||||
for i, ps in enumerate(platforms):
|
||||
rule = partial(lower_constant, i=i)
|
||||
|
@ -1722,7 +1722,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
flat_loop_carry_types = util.flatten(loop_carry_types)
|
||||
args = [*tokens, *args]
|
||||
|
||||
flat_args = mlir.flatten_lowering_ir_args(args)
|
||||
flat_args = mlir.flatten_ir_values(args)
|
||||
while_op = hlo.WhileOp(flat_loop_carry_types, flat_args)
|
||||
|
||||
# Loop condition
|
||||
@ -1732,15 +1732,15 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
flat_cond_args = [
|
||||
cond_block.arguments[i] for i in range(len(flat_loop_carry_types))
|
||||
]
|
||||
cond_args = util.unflatten(flat_cond_args, _map(len, loop_carry_types))
|
||||
cond_args = mlir.unflatten_ir_values(flat_cond_args, _map(len, loop_carry_types))
|
||||
# Remove tokens from cond args
|
||||
cond_args = cond_args[num_tokens:]
|
||||
x, _, z = util.split_list(cond_args, [cond_nconsts, body_nconsts])
|
||||
cond_consts = [
|
||||
mlir.ir_constants(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts
|
||||
mlir.ir_constant(xla.canonicalize_dtype(x)) for x in cond_jaxpr.consts
|
||||
]
|
||||
cond_name_stack = name_stack.extend('cond')
|
||||
((pred,),), _ = mlir.jaxpr_subcomp(
|
||||
(pred,), _ = mlir.jaxpr_subcomp(
|
||||
ctx.module_context,
|
||||
cond_jaxpr.jaxpr,
|
||||
cond_name_stack,
|
||||
@ -1772,13 +1772,13 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
flat_body_args = [
|
||||
body_block.arguments[i] for i in range(len(flat_loop_carry_types))
|
||||
]
|
||||
body_args = util.unflatten(flat_body_args, _map(len, loop_carry_types))
|
||||
body_args = mlir.unflatten_ir_values(flat_body_args, _map(len, loop_carry_types))
|
||||
# Tokens are at the front of the args list to the while loop
|
||||
token_args, body_args = util.split_list(body_args, [num_tokens])
|
||||
tokens_in = mlir.TokenSet(zip(body_effects, token_args))
|
||||
x, y, z = util.split_list(body_args, [cond_nconsts, body_nconsts])
|
||||
body_name_stack = name_stack.extend('body')
|
||||
body_consts = [mlir.ir_constants(xla.canonicalize_dtype(x))
|
||||
body_consts = [mlir.ir_constant(xla.canonicalize_dtype(x))
|
||||
for x in body_jaxpr.consts]
|
||||
new_z, tokens_out = mlir.jaxpr_subcomp(
|
||||
ctx.module_context, body_jaxpr.jaxpr, body_name_stack,
|
||||
@ -1786,9 +1786,9 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
out_tokens = [tokens_out.get(eff) for eff in body_effects]
|
||||
if batched:
|
||||
body_pred_name_stack = name_stack.extend('body_pred')
|
||||
cond_consts = [mlir.ir_constants(xla.canonicalize_dtype(x))
|
||||
cond_consts = [mlir.ir_constant(xla.canonicalize_dtype(x))
|
||||
for x in cond_jaxpr.consts]
|
||||
((body_pred,),), _ = mlir.jaxpr_subcomp(
|
||||
(body_pred,), _ = mlir.jaxpr_subcomp(
|
||||
ctx.module_context, cond_jaxpr.jaxpr, body_pred_name_stack,
|
||||
mlir.TokenSet(), cond_consts, *(x + z),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
@ -1796,10 +1796,10 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z,
|
||||
body_jaxpr.out_avals)
|
||||
|
||||
hlo.return_([*util.flatten(out_tokens), *util.flatten(x), *util.flatten(y),
|
||||
*util.flatten(new_z)])
|
||||
hlo.return_([*mlir.flatten_ir_values(out_tokens), *mlir.flatten_ir_values(x), *mlir.flatten_ir_values(y),
|
||||
*mlir.flatten_ir_values(new_z)])
|
||||
|
||||
outputs = util.unflatten(while_op.results, _map(len, loop_carry_types))
|
||||
outputs = mlir.unflatten_ir_values(while_op.results, _map(len, loop_carry_types))
|
||||
tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts])
|
||||
if tokens:
|
||||
ctx.set_tokens_out(mlir.TokenSet(zip(body_effects, tokens)))
|
||||
@ -1909,16 +1909,14 @@ state_discharge.register_discharge_rule(while_p)(_while_discharge_rule)
|
||||
|
||||
|
||||
def _pred_bcast_select_hlo(ctx,
|
||||
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
|
||||
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
|
||||
pred_aval: core.ShapedArray, pred: ir.Value, x: mlir.IrValues,
|
||||
y: mlir.IrValues, x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
|
||||
if x_y_aval is core.abstract_token:
|
||||
x, = xs
|
||||
y, = ys
|
||||
return [hlo.AfterAllOp([x, y]).result]
|
||||
else:
|
||||
assert isinstance(x, ir.Value), x
|
||||
assert isinstance(y, ir.Value), y
|
||||
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
|
||||
x, = xs
|
||||
y, = ys
|
||||
assert x.type == y.type, (x.type, y.type)
|
||||
assert (pred_aval.shape == x_y_aval.shape[:len(pred_aval.shape)]), (
|
||||
pred_aval.shape, x_y_aval)
|
||||
|
@ -2156,8 +2156,8 @@ def _pow_lower(ctx, x, y):
|
||||
partial(convert_element_type, new_dtype=out_aval.dtype), False)
|
||||
x_aval_ = x_aval.update(dtype=out_aval.dtype)
|
||||
y_aval_ = y_aval.update(dtype=out_aval.dtype)
|
||||
[(x_,)] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x)
|
||||
[(y_,)] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y)
|
||||
[x_] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x)
|
||||
[y_] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y)
|
||||
ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_])
|
||||
return _nary_lower_hlo(hlo.power, ctx_, x_, y_)
|
||||
mlir.register_lowering(pow_p, _pow_lower)
|
||||
@ -3980,9 +3980,9 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions):
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr.jaxpr,
|
||||
name_stack, mlir.TokenSet(),
|
||||
jaxpr.consts,
|
||||
*([a] for a in reducer.arguments),
|
||||
*reducer.arguments,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
hlo.return_(util.flatten(out_nodes))
|
||||
hlo.return_(mlir.flatten_ir_values(out_nodes))
|
||||
return op.results
|
||||
|
||||
mlir.register_lowering(reduce_p, _reduce_lower)
|
||||
@ -4180,7 +4180,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
|
||||
aval_out, = ctx.avals_out
|
||||
dtype = aval_out.dtype
|
||||
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x],
|
||||
mlir.ir_constants(unit_factory(aval_out.dtype)),
|
||||
[mlir.ir_constant(unit_factory(aval_out.dtype))],
|
||||
mlir.dense_int_array(axes))
|
||||
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype))
|
||||
reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
@ -4350,7 +4350,7 @@ batching.primitive_batchers[sort_p] = _sort_batch_rule
|
||||
def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
|
||||
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
|
||||
sort = hlo.SortOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
mlir.flatten_lowering_ir_args(operands),
|
||||
mlir.flatten_ir_values(operands),
|
||||
dimension=mlir.i64_attr(dimension),
|
||||
is_stable=ir.BoolAttr.get(is_stable))
|
||||
scalar_avals = [aval.update(shape=()) for aval in ctx.avals_in]
|
||||
@ -4364,9 +4364,8 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
|
||||
avals_in=util.flatten(zip(scalar_avals, scalar_avals)),
|
||||
avals_out=[core.ShapedArray((), np.bool_)])
|
||||
|
||||
out = lower_comparator(sub_ctx, *[[a] for a in comparator.arguments],
|
||||
num_keys=num_keys)
|
||||
hlo.return_(util.flatten(out))
|
||||
out = lower_comparator(sub_ctx, *comparator.arguments, num_keys=num_keys)
|
||||
hlo.return_(mlir.flatten_ir_values(out))
|
||||
return sort.results
|
||||
|
||||
mlir.register_lowering(sort_p, _sort_lower)
|
||||
@ -4562,9 +4561,9 @@ def _infeed_lowering(ctx, token, *, shapes, partitions):
|
||||
mlir.set_sharding(infeed, xla.sharding_to_proto(partitions))
|
||||
token = infeed.results[-1]
|
||||
outs = infeed.results[:-1]
|
||||
return util.unflatten(outs, safe_map(len, output_types)) + [[
|
||||
return mlir.unflatten_ir_values(outs, safe_map(len, output_types)) + [
|
||||
token,
|
||||
]]
|
||||
]
|
||||
|
||||
mlir.register_lowering(infeed_p, _infeed_lowering)
|
||||
|
||||
@ -4596,7 +4595,7 @@ mlir.lowerable_effects.add_type(InOutFeedEffect)
|
||||
|
||||
def _outfeed_lowering(ctx, token, *xs, partitions):
|
||||
outfeed = hlo.OutfeedOp(
|
||||
mlir.flatten_lowering_ir_args(xs),
|
||||
mlir.flatten_ir_values(xs),
|
||||
token,
|
||||
outfeed_config=ir.StringAttr.get(''))
|
||||
if partitions is not None:
|
||||
@ -4638,7 +4637,7 @@ rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
|
||||
|
||||
def _rng_uniform_lowering(ctx, a, b, *, shape):
|
||||
aval_out, = ctx.avals_out
|
||||
shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64))
|
||||
shape = mlir.ir_constant(np.array(aval_out.shape, np.int64))
|
||||
return [hlo.rng(a, b, shape, hlo.RngDistributionAttr.get('UNIFORM'))]
|
||||
|
||||
mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering)
|
||||
@ -5173,7 +5172,7 @@ empty_p.def_abstract_eval(lambda *, dtype: core.ShapedArray((), dtype))
|
||||
def _empty_lower(ctx, *, dtype):
|
||||
dtype = dtype if dtypes.issubdtype(dtype, dtypes.extended) else np.dtype(dtype)
|
||||
phys_aval = core.physical_aval(core.ShapedArray((), dtype))
|
||||
return mlir.ir_constants(np.zeros(phys_aval.shape, phys_aval.dtype))
|
||||
return mlir.ir_constant(np.zeros(phys_aval.shape, phys_aval.dtype)),
|
||||
mlir.register_lowering(empty_p, _empty_lower)
|
||||
|
||||
|
||||
|
@ -1097,10 +1097,10 @@ def _triangular_solve_cpu_lower(
|
||||
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
|
||||
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
|
||||
b_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, b_aval.shape)
|
||||
return [lapack.trsm_hlo(
|
||||
return lapack.trsm_hlo(
|
||||
a_aval.dtype, alpha,
|
||||
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal,
|
||||
b_shape_vals=b_shape_vals)]
|
||||
b_shape_vals=b_shape_vals)
|
||||
else:
|
||||
# Fall back to the HLO implementation for unsupported types or batching.
|
||||
# TODO: Consider swapping XLA for LAPACK in batched case
|
||||
@ -1189,7 +1189,7 @@ def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *,
|
||||
|
||||
def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *,
|
||||
permutation_size):
|
||||
return [lowering(pivots, permutation_size=permutation_size)]
|
||||
return lowering(pivots, permutation_size=permutation_size)
|
||||
|
||||
|
||||
lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation')
|
||||
|
@ -27,7 +27,6 @@ from jax import tree_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import util
|
||||
from jax._src.core import AxisName, ShapedArray, raise_to_shaped
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
@ -772,7 +771,7 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
shape=np.delete(np.array(aval.shape, dtype=np.int64),
|
||||
positional_axes))
|
||||
reducer_ctx = ctx.replace(primitive=None, avals_in=[aval], avals_out=[aval_out])
|
||||
out, = reducer(reducer_ctx, arg, axes=tuple(positional_axes))[0]
|
||||
out, = reducer(reducer_ctx, arg, axes=tuple(positional_axes))
|
||||
return out
|
||||
args = map(_positional_reduce, ctx.avals_in, args)
|
||||
if not named_axes:
|
||||
@ -805,9 +804,8 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
lower_reducer = mlir.lower_fun(prim.bind, multiple_results=False)
|
||||
reducer_ctx = ctx.replace(primitive=None,
|
||||
avals_in=[scalar_aval] * 2, avals_out=[scalar_aval])
|
||||
out_nodes = lower_reducer(
|
||||
reducer_ctx, *([a] for a in reducer_block.arguments))
|
||||
hlo.return_(util.flatten(out_nodes))
|
||||
out_nodes = lower_reducer(reducer_ctx, *reducer_block.arguments)
|
||||
hlo.return_(mlir.flatten_ir_values(out_nodes))
|
||||
return op.result
|
||||
|
||||
return [all_reduce(aval, x) for aval, x in zip(ctx.avals_in, args)]
|
||||
@ -1410,9 +1408,8 @@ def _reduce_scatter_lowering(
|
||||
reducer_ctx = ctx.replace(primitive=None,
|
||||
avals_in=[scalar_aval] * 2,
|
||||
avals_out=[scalar_aval])
|
||||
out_nodes = lower_reducer(
|
||||
reducer_ctx, *([a] for a in reducer_block.arguments))
|
||||
hlo.return_(util.flatten(out_nodes))
|
||||
out_nodes = lower_reducer(reducer_ctx, *reducer_block.arguments)
|
||||
hlo.return_(mlir.flatten_ir_values(out_nodes))
|
||||
|
||||
if tiled:
|
||||
return op.results
|
||||
|
@ -2473,7 +2473,7 @@ def _scatter_lower(ctx, operand, indices, updates, *,
|
||||
|
||||
if mode == GatherScatterMode.CLIP:
|
||||
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
|
||||
(indices,), = clip_fn(ctx.replace(avals_out=None), operand, indices,
|
||||
(indices,) = clip_fn(ctx.replace(avals_out=None), operand, indices,
|
||||
updates, dnums=dimension_numbers)
|
||||
|
||||
dnums = dimension_numbers
|
||||
@ -2504,9 +2504,9 @@ def _scatter_lower(ctx, operand, indices, updates, *,
|
||||
raise NotImplementedError('Cannot lower effectful `scatter`.')
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(
|
||||
ctx.module_context, update_jaxpr, name_stack, mlir.TokenSet(),
|
||||
update_consts, (update.arguments[0],), (update.arguments[1],),
|
||||
update_consts, update.arguments[0], update.arguments[1],
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
hlo.return_(util.flatten(out_nodes))
|
||||
hlo.return_(mlir.flatten_ir_values(out_nodes))
|
||||
return op.results
|
||||
|
||||
mlir.register_lowering(scatter_p, _scatter_lower)
|
||||
@ -2532,7 +2532,7 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
|
||||
|
||||
if mode == GatherScatterMode.CLIP:
|
||||
clip_fn = mlir.lower_fun(_clamp_scatter_indices, multiple_results=False)
|
||||
(indices,), = clip_fn(ctx.replace(avals_out=None), operand, indices, updates,
|
||||
indices, = clip_fn(ctx.replace(avals_out=None), operand, indices, updates,
|
||||
dnums=dimension_numbers)
|
||||
|
||||
aval_out, = ctx.avals_out
|
||||
|
@ -441,9 +441,9 @@ def _generic_reduce_window_lower(
|
||||
if jaxpr.effects:
|
||||
raise NotImplementedError('Cannot lower effectful `reduce_window`.')
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, ctx.name_stack,
|
||||
mlir.TokenSet(), consts, *([a] for a in reducer.arguments),
|
||||
mlir.TokenSet(), consts, *reducer.arguments, # type: ignore[misc]
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
return util.flatten(out_nodes)
|
||||
return mlir.flatten_ir_values(out_nodes)
|
||||
|
||||
return mlir.reduce_window(
|
||||
ctx,
|
||||
@ -675,9 +675,9 @@ def _select_and_scatter_lower(
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, select_jaxpr,
|
||||
ctx.name_stack,
|
||||
mlir.TokenSet(), select_consts,
|
||||
*([a] for a in select.arguments),
|
||||
*select.arguments,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
hlo.return_(util.flatten(out_nodes))
|
||||
hlo.return_(mlir.flatten_ir_values(out_nodes))
|
||||
scatter = op.scatter.blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(scatter):
|
||||
if scatter_jaxpr.effects:
|
||||
@ -685,9 +685,9 @@ def _select_and_scatter_lower(
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, scatter_jaxpr,
|
||||
ctx.name_stack,
|
||||
mlir.TokenSet(), scatter_consts,
|
||||
*([a] for a in scatter.arguments),
|
||||
*scatter.arguments,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
hlo.return_(util.flatten(out_nodes))
|
||||
hlo.return_(mlir.flatten_ir_values(out_nodes))
|
||||
return op.results
|
||||
|
||||
mlir.register_lowering(select_and_scatter_p, _select_and_scatter_lower)
|
||||
|
@ -1350,7 +1350,7 @@ def _xmap_lowering_rule_replica(ctx, *in_nodes,
|
||||
# them!
|
||||
vectorized_jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(f, local_avals)
|
||||
# _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
|
||||
const_nodes = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in consts]
|
||||
const_nodes = [mlir.ir_constant(xla.canonicalize_dtype(x)) for x in consts]
|
||||
|
||||
local_mesh_shape = mesh.local_mesh.shape
|
||||
tiled_ins = (
|
||||
@ -1426,14 +1426,14 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
vectorized_jaxpr, global_out_avals, consts, () = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
|
||||
|
||||
sharded_global_in_nodes = [
|
||||
[mlir.wrap_with_sharding_op(
|
||||
mlir.wrap_with_sharding_op(
|
||||
ctx, node, aval,
|
||||
NamedSharding(mesh, array_mapping_to_axis_resources(aval_axes)
|
||||
)._to_xla_hlo_sharding(aval.ndim).to_proto())]
|
||||
if aval_axes else [node]
|
||||
)._to_xla_hlo_sharding(aval.ndim).to_proto())
|
||||
if aval_axes else node
|
||||
for node, aval, aval_axes in zip(global_in_nodes, global_in_avals, mesh_in_axes)
|
||||
]
|
||||
const_nodes = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in consts]
|
||||
const_nodes = [mlir.ir_constant(xla.canonicalize_dtype(x)) for x in consts]
|
||||
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
@ -1451,7 +1451,7 @@ def _xmap_lowering_rule_spmd(ctx, *global_in_nodes,
|
||||
NamedSharding(mesh, array_mapping_to_axis_resources(aval_axes)
|
||||
)._to_xla_hlo_sharding(aval.ndim).to_proto())
|
||||
if aval_axes else node
|
||||
for (node,), aval, aval_axes in zip(global_out_nodes, global_out_avals, mesh_out_axes)
|
||||
for node, aval, aval_axes in zip(global_out_nodes, global_out_avals, mesh_out_axes)
|
||||
]
|
||||
|
||||
return sharded_global_out_nodes
|
||||
@ -1485,7 +1485,7 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
|
||||
# them!
|
||||
global_in_avals = ctx.avals_in
|
||||
vectorized_jaxpr, global_out_avals, consts, () = pe.trace_to_jaxpr_dynamic(f, global_in_avals)
|
||||
const_nodes = [mlir.ir_constants(xla.canonicalize_dtype(x)) for x in consts]
|
||||
const_nodes = [mlir.ir_constant(xla.canonicalize_dtype(x)) for x in consts]
|
||||
|
||||
# We in-line here rather than generating a Call HLO as in the xla_call
|
||||
# translation rule just because the extra tuple stuff is a pain.
|
||||
@ -1498,9 +1498,8 @@ def _xmap_lowering_rule_spmd_manual(ctx, *global_in_nodes,
|
||||
in vectorized_jaxpr.effects):
|
||||
raise NotImplementedError('Cannot lower `xmap` with ordered effects.')
|
||||
global_out_nodes, _ = mlir.jaxpr_subcomp(
|
||||
sub_ctx, vectorized_jaxpr, name_stack,
|
||||
mlir.TokenSet(), const_nodes, *([n] for n in global_in_nodes),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
sub_ctx, vectorized_jaxpr, name_stack, mlir.TokenSet(), const_nodes,
|
||||
*global_in_nodes, dim_var_values=ctx.dim_var_values)
|
||||
|
||||
return global_out_nodes
|
||||
|
||||
|
@ -81,7 +81,7 @@ from jax._src.tree_util import (
|
||||
from jax._src.util import (
|
||||
HashableFunction, safe_map, safe_zip, wraps,
|
||||
distributed_debug_log, split_list, weakref_lru_cache,
|
||||
merge_lists, flatten, unflatten, subs_list, fun_name)
|
||||
merge_lists, flatten, subs_list, fun_name)
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
@ -1926,9 +1926,9 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
args = (*ctx.dim_var_values, *tokens_in, *args)
|
||||
call = func_dialect.CallOp(flat_output_types,
|
||||
ir.FlatSymbolRefAttr.get(func.name.value),
|
||||
mlir.flatten_lowering_ir_args(args))
|
||||
mlir.flatten_ir_values(args))
|
||||
mlir.wrap_compute_type_in_place(ctx, call)
|
||||
out_nodes = unflatten(call.results, map(len, output_types))
|
||||
out_nodes = mlir.unflatten_ir_values(call.results, map(len, output_types))
|
||||
tokens, out_nodes = split_list(out_nodes, [len(effects)])
|
||||
tokens_out = ctx.tokens_in.update_tokens(mlir.TokenSet(zip(effects, tokens)))
|
||||
ctx.set_tokens_out(tokens_out)
|
||||
|
@ -648,7 +648,7 @@ def emit_tf_embedded_graph_custom_call(
|
||||
util.flatten([mlir.aval_to_ir_types(aval) for aval in result_avals])
|
||||
)
|
||||
if ordered:
|
||||
operands.insert(0, ctx.tokens_in.get(call_tf_ordered_effect)[0])
|
||||
operands.insert(0, ctx.tokens_in.get(call_tf_ordered_effect))
|
||||
result_types.insert(0, mlir.token_type()[0])
|
||||
|
||||
custom_call = hlo.CustomCallOp(
|
||||
@ -668,7 +668,7 @@ def emit_tf_embedded_graph_custom_call(
|
||||
results = list(custom_call.results)
|
||||
if ordered:
|
||||
token = results.pop(0)
|
||||
ctx.set_tokens_out(mlir.TokenSet({call_tf_ordered_effect: (token,)}))
|
||||
ctx.set_tokens_out(mlir.TokenSet({call_tf_ordered_effect: token}))
|
||||
|
||||
return results
|
||||
|
||||
|
@ -677,8 +677,7 @@ def _xla_shard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
|
||||
return [mlir.wrap_with_full_to_shard_op(ctx, sx, aval_out, manual_proto, unspecified)]
|
||||
|
||||
def _xla_unshard(ctx: mlir.LoweringRuleContext, mesh, auto, names,
|
||||
aval_in, aval_out, xs):
|
||||
x, = xs
|
||||
aval_in, aval_out, x):
|
||||
axes = {name: i for i, ns in names.items() for name in ns}
|
||||
ns = _make_scoped_manual_sharding(ctx, mesh, axes)
|
||||
if dtypes.issubdtype(aval_out.dtype, dtypes.extended):
|
||||
|
@ -671,7 +671,7 @@ def _bcsr_dot_general_gpu_lowering(
|
||||
|
||||
# Account for a bug in cusparse: it references indices and data beyond
|
||||
# the extent of indptr.
|
||||
(lhs_data,), (lhs_indices,) = _bcsr_correct_out_of_bound_indices_lowered(
|
||||
lhs_data, lhs_indices = _bcsr_correct_out_of_bound_indices_lowered(
|
||||
ctx, lhs_data, lhs_indices, lhs_indptr, rhs, shape=lhs_spinfo.shape)
|
||||
|
||||
if rhs_aval.ndim == 1:
|
||||
|
@ -40,14 +40,13 @@ from jax._src.interpreters.mlir import (
|
||||
dense_int_elements as dense_int_elements,
|
||||
dtype_to_ir_type as dtype_to_ir_type,
|
||||
emit_python_callback as emit_python_callback,
|
||||
flatten_lowering_ir_args as flatten_lowering_ir_args,
|
||||
flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me
|
||||
func_dialect as func_dialect,
|
||||
hlo as hlo,
|
||||
i32_attr as i32_attr,
|
||||
i64_attr as i64_attr,
|
||||
ir as ir,
|
||||
ir_constant as ir_constant,
|
||||
ir_constants as ir_constants,
|
||||
ir_type_handlers as ir_type_handlers,
|
||||
jaxpr_subcomp as jaxpr_subcomp,
|
||||
lower_fun as lower_fun,
|
||||
|
@ -103,8 +103,8 @@ testing_primitive_with_effect_p.def_effectful_abstract_eval(
|
||||
|
||||
def lowering_testing_primitive_with_effect(ctx, a, *, effect_class_name: str):
|
||||
if "Ordered" in effect_class_name:
|
||||
token_in = ctx.tokens_in.get(_testing_effects[effect_class_name])[0]
|
||||
ctx.set_tokens_out(mlir.TokenSet({_testing_effects[effect_class_name]: (token_in,)}))
|
||||
token_in = ctx.tokens_in.get(_testing_effects[effect_class_name])
|
||||
ctx.set_tokens_out(mlir.TokenSet({_testing_effects[effect_class_name]: token_in}))
|
||||
return [mlir.hlo.add(a, a)]
|
||||
|
||||
mlir.register_lowering(testing_primitive_with_effect_p,
|
||||
|
@ -97,7 +97,7 @@ def function_effect_lowering(ctx, *, effect):
|
||||
flat_output_types = util.flatten(output_types)
|
||||
call = mlir.func_dialect.CallOp(flat_output_types,
|
||||
mlir.ir.FlatSymbolRefAttr.get(func.name.value),
|
||||
mlir.flatten_lowering_ir_args(in_tokens))
|
||||
mlir.flatten_ir_values(in_tokens))
|
||||
tokens, out = util.split_list(call.results, [len(ctx.tokens_in)])
|
||||
ctx.set_tokens_out(mlir.TokenSet(zip(effs, tokens)))
|
||||
return out
|
||||
@ -120,7 +120,7 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out
|
||||
del out_avals
|
||||
token_in = None
|
||||
if effects.ordered_effects.contains(effect):
|
||||
token_in = ctx.tokens_in.get(effect)[0]
|
||||
token_in = ctx.tokens_in.get(effect)
|
||||
|
||||
out_op, token_out, _ = mlir.emit_python_callback(
|
||||
ctx, callback, token_in, list(args), list(ctx.avals_in),
|
||||
|
Loading…
x
Reference in New Issue
Block a user