Use MLIR generated convenience functions athing(...) instead of writing AThingOp(...).result.

In most cases these are more succinct.

This change does not update Pallas/Mosaic.

PiperOrigin-RevId: 583448254
This commit is contained in:
Peter Hawkins 2023-11-17 11:46:24 -08:00 committed by jax authors
parent e016ce4639
commit 8e8dc263bc
25 changed files with 289 additions and 306 deletions

View File

@ -91,15 +91,15 @@ def shape_tensor(sizes: Sequence[int | ir.RankedTensorType]
return ir_constant(np.array([d], np.int32))
else:
if d.type != i32_type:
d = hlo.ConvertOp(i32_type, d)
return hlo.ReshapeOp(int1d, d).result
d = hlo.convert(i32_type, d)
return hlo.reshape(int1d, d)
ds = map(lower_dim, sizes)
if not ds:
return ir_constant(np.array([], np.int32))
elif len(ds) == 1:
return ds[0]
else:
return hlo.ConcatenateOp(ds, i64_attr(0)).result
return hlo.concatenate(ds, i64_attr(0))
def delegate_lowering(ctx, lowering_fun, *args, **ctx_override_kwargs):
@ -251,7 +251,7 @@ def _numpy_array_constant(x: np.ndarray) -> Sequence[ir.Value]:
x = np.packbits(x, bitorder='little')
x = np.ascontiguousarray(x)
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
return (hlo.ConstantOp(attr).result,)
return (hlo.constant(attr),)
def _masked_array_constant_handler(*args, **kwargs):
@ -284,11 +284,11 @@ def _ndarray_constant_handler(val: np.ndarray) -> Sequence[ir.Value]:
other_axes, = np.where(np.not_equal(0, val.strides))
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) # type: ignore
for ax in range(val.ndim))] # type: ignore
out = hlo.BroadcastInDimOp(
out = hlo.broadcast_in_dim(
ir.RankedTensorType.get(
val.shape, dtype_to_ir_type(collapsed_val.dtype)),
_numpy_array_constant(collapsed_val)[0],
dense_int_elements(other_axes)).result
dense_int_elements(other_axes))
return (out,)
else:
return _numpy_array_constant(val)
@ -309,7 +309,7 @@ for ptype, dtype in dtypes.python_scalar_dtypes.items():
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
def _token_constant_handler(val):
return [hlo.CreateTokenOp().result]
return [hlo.create_token()]
register_constant_handler(core.Token, _token_constant_handler)
# Source locations
@ -646,7 +646,7 @@ def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
else:
i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
if d.type != i32_type: # type: ignore
return hlo.ConvertOp(i32_type, d).result
return hlo.convert(i32_type, d)
else:
return d
return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape))
@ -662,7 +662,7 @@ def eval_dynamic_shape_as_ivals(
else:
i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
if d.type != i32_type: # type: ignore
return hlo.ConvertOp(i32_type, d).result
return hlo.convert(i32_type, d)
else:
return d
return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape))
@ -912,7 +912,7 @@ def token_type() -> Sequence[ir.Type]:
return [hlo.TokenType.get()]
def create_token() -> Token:
return wrap_singleton_ir_values(hlo.CreateTokenOp().result)
return wrap_singleton_ir_values(hlo.create_token())
class TokenSet:
"""An immutable container of tokens to be used to lower effectful jaxprs. When lowering
@ -1252,7 +1252,7 @@ def lower_jaxpr_to_fun(
args: list[list[ir.Value]] = []
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
if replace_tokens_with_dummy and aval is core.abstract_token:
args.append(hlo.CreateTokenOp().results)
args.append([hlo.create_token()])
else:
args.append(arg)
callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name))
@ -1285,7 +1285,7 @@ def lower_jaxpr_to_fun(
o if mk is None else wrap_with_memory_kind(o, mk, o_aval)
for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)]
func_dialect.ReturnOp(flat_outputs)
func_dialect.return_(flat_outputs)
return func_op
@ -1356,7 +1356,7 @@ def _emit_lowering_rule_as_fun(lowering_rule,
outs = lowering_rule(sub_ctx, *_unwrap_singleton_ir_values(unflattened_args))
if sub_ctx.tokens_out:
outs = [*[sub_ctx.tokens_out.get(eff) for eff in effs], outs]
func_dialect.ReturnOp(util.flatten(map(wrap_singleton_ir_values, outs)))
func_dialect.return_(util.flatten(map(wrap_singleton_ir_values, outs)))
return func_op
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
@ -1500,7 +1500,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
out_nodes = tuple(map(wrap_singleton_ir_values, ans))
except TypeError as e:
raise ValueError("Output of translation rule must be iterable: "
f"{eqn}, got output {ans} from {rule.__name__}") from e
f"{eqn}, got output {ans}") from e
assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn)
assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (
@ -1595,14 +1595,14 @@ def lower_multi_platform(ctx: LoweringRuleContext,
# Compute the rule index based on the current platform
i32_type = aval_to_ir_types(core.ShapedArray((), dtype=np.int32))[0]
if current_platform_idx.type != i32_type:
current_platform_idx = hlo.ConvertOp(i32_type, current_platform_idx)
current_platform_idx = hlo.convert(i32_type, current_platform_idx)
rule_idx_op = hlo.CaseOp([i32_type],
index=current_platform_idx,
num_branches=len(platforms))
for i, p in enumerate(platforms):
branch = rule_idx_op.regions[i].blocks.append()
with ir.InsertionPoint(branch):
hlo.ReturnOp(ir_constants(np.int32(platform_to_kept_rules_idx[p])))
hlo.return_(ir_constants(np.int32(platform_to_kept_rules_idx[p])))
ordered_effects = effects_lib.ordered_effects.filter_in(effects)
rule_out_avals = [core.abstract_token] * len(ordered_effects) + ctx.avals_out
output_types = map(aval_to_ir_types, rule_out_avals)
@ -1623,7 +1623,7 @@ def lower_multi_platform(ctx: LoweringRuleContext,
assert len(ordered_effects) == len(inner_ctx.tokens_out)
out_nodes = [inner_ctx.tokens_out.get(eff)
for eff in ordered_effects] + out_nodes
hlo.ReturnOp(util.flatten(map(wrap_singleton_ir_values, out_nodes)))
hlo.return_(util.flatten(map(wrap_singleton_ir_values, out_nodes)))
results = case_op.results
if ordered_effects:
@ -1772,17 +1772,17 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue,
else:
if not core.is_constant_shape(aval_out.shape): # type: ignore
shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape) # type: ignore
return hlo.DynamicBroadcastInDimOp(
return hlo.dynamic_broadcast_in_dim(
aval_to_ir_type(aval_out), op,
shape,
dense_int_elements(broadcast_dimensions),
).result
)
else:
assert all(d != ir.ShapedType.get_dynamic_size()
for d in aval_out.shape), aval_out # type: ignore
return hlo.BroadcastInDimOp(
return hlo.broadcast_in_dim(
aval_to_ir_type(aval_out), op,
dense_int_elements(broadcast_dimensions)).result
dense_int_elements(broadcast_dimensions))
def multi_broadcast_in_dim(ctx: LoweringRuleContext,
ops: Sequence[ir.Value],
@ -1806,11 +1806,11 @@ def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Va
aval_out = core.physical_aval(aval_out)
if not core.is_constant_shape(aval_out.shape): # type: ignore
shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape) # type: ignore
return hlo.DynamicReshapeOp(
return hlo.dynamic_reshape(
aval_to_ir_type(aval_out), op, shape,
).result
)
else:
return hlo.ReshapeOp(aval_to_ir_type(aval_out), op).result
return hlo.reshape(aval_to_ir_type(aval_out), op)
def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
start_indices, limit_indices, strides) -> ir.Value:
@ -1830,14 +1830,14 @@ def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
start_indices = eval_dynamic_shape_as_tensor(ctx, start_indices)
limit_indices = eval_dynamic_shape_as_tensor(ctx, limit_indices)
strides = eval_dynamic_shape_as_tensor(ctx, strides)
return hlo.RealDynamicSliceOp(
return hlo.real_dynamic_slice(
aval_to_ir_type(aval_out),
x, start_indices, limit_indices, strides).result
x, start_indices, limit_indices, strides)
else:
return hlo.SliceOp(x,
dense_int_elements(start_indices),
dense_int_elements(limit_indices),
dense_int_elements(strides)).result
return hlo.slice(x,
dense_int_elements(start_indices),
dense_int_elements(limit_indices),
dense_int_elements(strides))
def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
start_indices) -> ir.Value:
@ -1859,21 +1859,20 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
# lower to RealDynamicSliceOp, which is a version of SliceOp, and does
# not have the clamping behavior. We clamp start ourselves.
slice_sizes = eval_dynamic_shape_as_tensor(ctx, slice_sizes)
clamped_start = hlo.ClampOp(
clamped_start = hlo.clamp(
shape_tensor([0] * len(start_indices)),
shape_tensor(start_indices),
hlo.SubtractOp(
hlo.subtract(
eval_dynamic_shape_as_tensor(ctx, x_aval.shape), # type: ignore
slice_sizes))
return hlo.RealDynamicSliceOp(
return hlo.real_dynamic_slice(
aval_to_ir_type(aval_out), x,
clamped_start,
hlo.AddOp(clamped_start, slice_sizes).result,
hlo.add(clamped_start, slice_sizes),
shape_tensor([1] * len(start_indices))
).result
)
else:
return hlo.DynamicSliceOp(x, start_indices,
dense_int_elements(slice_sizes)).result
return hlo.dynamic_slice(x, start_indices, dense_int_elements(slice_sizes))
def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
start_indices) -> ir.Value:
@ -1890,36 +1889,35 @@ def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
start_indices=start_indices)
else:
# TODO(necula): handle dynamic shapes
return hlo.DynamicUpdateSliceOp(x, update, start_indices).result
return hlo.dynamic_update_slice(x, update, start_indices)
def pad(ctx: LoweringRuleContext, aval_out,
x, padding_value,
padding_low, padding_high, padding_interior) -> ir.Value:
if all(core.is_constant_shape(s) for s in (padding_low,
padding_high, padding_interior)):
return hlo.PadOp(x, padding_value,
dense_int_elements(padding_low),
dense_int_elements(padding_high),
dense_int_elements(padding_interior)).result
return hlo.pad(x, padding_value,
dense_int_elements(padding_low),
dense_int_elements(padding_high),
dense_int_elements(padding_interior))
else:
padding_low = eval_dynamic_shape_as_tensor(ctx, padding_low)
padding_high = eval_dynamic_shape_as_tensor(ctx, padding_high)
padding_interior = eval_dynamic_shape_as_tensor(ctx, padding_interior)
return hlo.DynamicPadOp(
return hlo.dynamic_pad(
aval_to_ir_type(aval_out),
x, padding_value, padding_low, padding_high, padding_interior).result
x, padding_value, padding_low, padding_high, padding_interior)
def iota(ctx: LoweringRuleContext, aval_out, *, dimension: int):
if not core.is_constant_shape(aval_out.shape):
shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape)
return hlo.DynamicIotaOp(
return hlo.dynamic_iota(
aval_to_ir_type(aval_out),
shape,
i64_attr(dimension),
).result
)
else:
return hlo.IotaOp(aval_to_ir_type(aval_out),
i64_attr(dimension)).result
return hlo.iota(aval_to_ir_type(aval_out), i64_attr(dimension))
def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> ir.Value:
"""Returns an IR constant shaped full of `value` shaped like `aval`."""
@ -1933,7 +1931,7 @@ def zeros_like_lowering(ctx, x):
register_lowering(ad_util.zeros_like_p, zeros_like_lowering)
def add_jaxvals_lowering(ctx, x, y):
return hlo.AddOp(x, y).results
return [hlo.add(x, y)]
register_lowering(ad_util.add_jaxvals_p, add_jaxvals_lowering)
register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x])
@ -1949,7 +1947,7 @@ def compare_hlo(x, y, direction: str, comparison_type: str | None = None):
else:
comparison_type = "FLOAT"
return hlo.CompareOp(
return hlo.compare(
x,
y,
hlo.ComparisonDirectionAttr.get(direction),
@ -1959,20 +1957,18 @@ def _minmax_hlo(op, cmp, x, y):
"""Min/max that compares complex values lexicographically as pairs."""
tensor_type = ir.RankedTensorType(x.type)
if ir.ComplexType.isinstance(tensor_type.element_type):
rx = hlo.RealOp(x).result
ry = hlo.RealOp(y).result
rx = hlo.real(x)
ry = hlo.real(y)
real_eq = compare_hlo(rx, ry, "EQ", "FLOAT")
real_cmp = compare_hlo(rx, ry, cmp, "FLOAT")
imag_cmp = compare_hlo(
hlo.ImagOp(x).result,
hlo.ImagOp(y).result, cmp, "FLOAT")
which = hlo.SelectOp(real_eq, imag_cmp, real_cmp).result
return hlo.SelectOp(which, x, y)
imag_cmp = compare_hlo(hlo.imag(x), hlo.imag(y), cmp, "FLOAT")
which = hlo.select(real_eq, imag_cmp, real_cmp)
return hlo.select(which, x, y)
else:
return op(x, y)
min_hlo = partial(_minmax_hlo, hlo.MinOp, "LT")
max_hlo = partial(_minmax_hlo, hlo.MaxOp, "GT")
min_hlo = partial(_minmax_hlo, hlo.minimum, "LT")
max_hlo = partial(_minmax_hlo, hlo.maximum, "GT")
def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
@ -1988,10 +1984,9 @@ def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
compare_type = "SIGNED"
else:
compare_type = "UNSIGNED"
x = compare_hlo(x, full_like_aval(ctx, 0, aval_in), "NE",
compare_type).result
x = compare_hlo(x, full_like_aval(ctx, 0, aval_in), "NE", compare_type)
# continue, to adjust the shape if needed
return hlo.ConvertOp(aval_to_ir_type(aval_out), x).result
return hlo.convert(aval_to_ir_type(aval_out), x)
def _wrap_with_spmd_op(name: str,
ctx: LoweringRuleContext,
@ -2184,7 +2179,7 @@ def xla_fallback_lowering(prim: core.Primitive):
flatten_lowering_ir_args(args)).result
if not prim.multiple_results:
return [call]
flat_results = [hlo.GetTupleElementOp(call, i32_attr(i)).result
flat_results = [hlo.get_tuple_element(call, i32_attr(i))
for i in range(len(flat_output_types))]
return util.unflatten(flat_results, map(len, output_types))
@ -2267,7 +2262,7 @@ def _emit_tpu_python_callback(
*,
sharding: xc.OpSharding | None = None
) -> tuple[Sequence[ir.Value], Any]:
token = token or hlo.CreateTokenOp().result
token = token or hlo.create_token()
_wrapped_callback = callback
send_channels = []
@ -2445,7 +2440,7 @@ def emit_python_callback(
if sharding is not None:
set_sharding(result, sharding)
results = [
hlo.GetTupleElementOp(result, i32_attr(i)).result
hlo.get_tuple_element(result, i32_attr(i))
for i in range(len(result_types))
]
if token:
@ -2603,9 +2598,8 @@ def reduce_window(
int2d = aval_to_ir_type(core.ShapedArray((1, 2), np.int32))
def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]):
pads = eval_dynamic_shape_as_tensor(ctx, pad_lo_hi) # i32[2]
return hlo.ReshapeOp(int2d, pads)
d_padding = hlo.ConcatenateOp(list(map(prep_one_pad, padding)),
i64_attr(0)).result
return hlo.reshape(int2d, pads)
d_padding = hlo.concatenate(list(map(prep_one_pad, padding)), i64_attr(0))
# Build the reducer
reducer_type = ir.FunctionType.get(scalar_types + scalar_types,
scalar_types)
@ -2614,8 +2608,7 @@ def reduce_window(
ctx.module_context.symbol_table.insert(reducer)
entry_block = reducer.add_entry_block()
with ir.InsertionPoint(entry_block):
res = reducer_body(entry_block)
hlo.ReturnOp(res)
hlo.return_(reducer_body(entry_block))
rw = custom_call(
"stablehlo.dynamic_reduce_window",
@ -2641,8 +2634,7 @@ def reduce_window(
shape=(len(padding), 2)))
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
with ir.InsertionPoint(reducer):
res = reducer_body(reducer)
hlo.ReturnOp(res)
hlo.return_(reducer_body(reducer))
return rw.results

View File

@ -1269,8 +1269,7 @@ def _unravel_index_hlo(axis_env):
div = mlir.ir_constant(
np.array(axis_env.nreps // math.prod(axis_env.sizes), np.uint32))
mod = mlir.ir_constant(np.array(axis_env.sizes[-1], np.uint32))
return hlo.RemOp(
hlo.DivOp(hlo.ReplicaIdOp().result, div).result, mod).result
return hlo.remainder(hlo.divide(hlo.replica_id(), div), mod)
def _hlo_shard(aval, axis_env, xs, in_axis):
if aval is core.abstract_token:
@ -1283,10 +1282,10 @@ def _hlo_shard(aval, axis_env, xs, in_axis):
idxs.insert(in_axis, _unravel_index_hlo(axis_env))
dims_unsqueezed = dims.copy()
dims_unsqueezed.insert(in_axis, 1)
dynamic_slice_result = hlo.DynamicSliceOp(
x, idxs, mlir.dense_int_elements(dims_unsqueezed)).result
dynamic_slice_result = hlo.dynamic_slice(
x, idxs, mlir.dense_int_elements(dims_unsqueezed))
return [
hlo.ReshapeOp(mlir.aval_to_ir_type(aval), dynamic_slice_result).result
hlo.reshape(mlir.aval_to_ir_type(aval), dynamic_slice_result)
]
else:
raise TypeError(aval)
@ -1335,19 +1334,18 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs):
padded = mlir.full_like_aval(ctx, 0, padded_aval)
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
idxs = [_unravel_index_hlo(axis_env)] + [zero] * len(dims)
broadcast_result = hlo.BroadcastOp(
x, mlir.dense_int_elements([1])).result
padded = hlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result
broadcast_result = hlo.broadcast(x, mlir.dense_int_elements([1]))
padded = hlo.dynamic_update_slice(padded, broadcast_result, idxs)
replica_groups = mlir.dense_int_elements(
axis_groups(axis_env, axis_env.names[-1]))
out = hlo.CrossReplicaSumOp(padded, replica_groups).result
out = hlo.cross_replica_sum(padded, replica_groups)
if out_axis != 0:
# TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead
perm = list(range(1, len(dims)))
perm.insert(out_axis, 0)
transposed_dims = list(dims)
transposed_dims.insert(out_axis, axis_env.sizes[-1])
out = hlo.TransposeOp(out, mlir.dense_int_elements(perm)).result
out = hlo.transpose(out, mlir.dense_int_elements(perm))
return out
else:

View File

@ -295,8 +295,8 @@ def _comparator_builder_mlir(ctx, op_type, is_max_k):
with ir.InsertionPoint(entry_block):
p0, p1, _, _ = entry_block.arguments
direction = hlo.ComparisonDirectionAttr.get('GT' if is_max_k else 'LT')
cmp_result = hlo.CompareOp(p0, p1, comparison_direction=direction)
hlo.ReturnOp(cmp_result)
cmp_result = hlo.compare(p0, p1, comparison_direction=direction)
hlo.return_([cmp_result])
return comparator
@ -321,7 +321,7 @@ def _approx_top_k_lowering(ctx, operand, *, k,
iota = mlir.iota(ctx, core.ShapedArray(ctx.avals_in[0].shape, np.int32),
dimension=reduction_dimension)
init_arg = hlo.ConstantOp(ir.DenseElementsAttr.get(np.int32(-1))).result
init_arg = hlo.constant(ir.DenseElementsAttr.get(np.int32(-1)))
init_val_array = _get_init_val_literal(ctx.avals_in[0].dtype, is_max_k)
init_val = mlir.ir_constant(init_val_array.reshape(()))

View File

@ -860,7 +860,7 @@ def _cond_lowering(ctx, index, *args, branches, linear):
dim_var_values=ctx.dim_var_values)
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
out_vals = [*out_tokens, *out_vals]
hlo.ReturnOp(util.flatten(out_vals))
hlo.return_(util.flatten(out_vals))
tokens_and_outputs = util.unflatten(case_op.results, map(len, output_types))
tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])

View File

@ -1656,7 +1656,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
pred_ctx,
pred,
axes=tuple(range(len(pred_aval.shape))))
hlo.ReturnOp([pred])
hlo.return_([pred])
# Loop body
body_block = while_op.regions[1].blocks.append(*flat_loop_carry_types)
@ -1687,8 +1687,8 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z,
body_jaxpr.out_avals)
hlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x),
*util.flatten(y), *util.flatten(new_z)])
hlo.return_([*util.flatten(out_tokens), *util.flatten(x), *util.flatten(y),
*util.flatten(new_z)])
outputs = util.unflatten(while_op.results, _map(len, loop_carry_types))
tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts])

View File

@ -707,7 +707,7 @@ def _conv_general_dilated_lower(
raise NotImplementedError("Convolutions with non-static strides, dilation, feature_group_count, or batch_group_count")
if all(core.is_constant_shape(p) for p in padding):
return [
hlo.ConvolutionOp(
hlo.convolution(
mlir.aval_to_ir_type(aval_out),
lhs,
rhs,
@ -719,7 +719,7 @@ def _conv_general_dilated_lower(
lhs_dilation=mlir.dense_int_elements(lhs_dilation),
rhs_dilation=mlir.dense_int_elements(rhs_dilation),
window_reversal=window_reversal,
precision_config=lax.precision_attr(precision)).result
precision_config=lax.precision_attr(precision))
]
else:
# d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each
@ -731,7 +731,7 @@ def _conv_general_dilated_lower(
d_padding = hlo.ConcatenateOp(list(map(prep_one_pad, padding)),
mlir.i64_attr(0))
return [
hlo.DynamicConvOp(
hlo.dynamic_conv(
mlir.aval_to_ir_type(aval_out),
lhs,
rhs,
@ -743,7 +743,7 @@ def _conv_general_dilated_lower(
lhs_dilation=mlir.dense_int_elements(lhs_dilation),
rhs_dilation=mlir.dense_int_elements(rhs_dilation),
window_reversal=window_reversal,
precision_config=lax.precision_attr(precision)).result
precision_config=lax.precision_attr(precision))
]
mlir.register_lowering(conv_general_dilated_p, _conv_general_dilated_lower)

View File

@ -1680,19 +1680,19 @@ def broadcast_hlo(
dims = mlir.dense_int_elements(
range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape)))
if any(isinstance(d, ir.Value) for d in aval_out.shape):
arg = hlo.DynamicBroadcastInDimOp(
arg = hlo.dynamic_broadcast_in_dim(
mlir.aval_to_ir_type(aval_out), arg,
mlir.shape_tensor(aval_out.shape), dims).result
mlir.shape_tensor(aval_out.shape), dims)
else:
arg = hlo.BroadcastInDimOp(
arg = hlo.broadcast_in_dim(
mlir.aval_to_ir_type(aval.update(shape=aval_out.shape)), arg,
dims).result
dims)
out.append(arg)
return out
def _nary_lower_hlo(op: Callable, ctx,
*args: Union[ir.Value, Sequence[ir.Value]],
explicit_type=False, **params):
explicit_type=False, **params) -> Sequence[ir.Value]:
"""Lowers an elementwise operator to its MLIR equivalent.
Args:
@ -1704,9 +1704,9 @@ def _nary_lower_hlo(op: Callable, ctx,
ctx, args, avals_in, aval_out.shape)
if explicit_type:
return op(mlir.aval_to_ir_type(aval_out), *broadcasted_args).results
return [op(mlir.aval_to_ir_type(aval_out), *broadcasted_args)]
else:
return op(*broadcasted_args).results
return [op(*broadcasted_args)]
_float = {np.floating}
@ -1723,7 +1723,7 @@ _ordered = _int | _float | _bool
neg_p = standard_unop(_num, 'neg')
ad.deflinear2(neg_p, lambda t, operand: [neg(t)])
mlir.register_lowering(neg_p, partial(_nary_lower_hlo, hlo.NegOp))
mlir.register_lowering(neg_p, partial(_nary_lower_hlo, hlo.negate))
sign_p = standard_unop(_num, 'sign')
ad.defjvp_zero(sign_p)
@ -1731,44 +1731,44 @@ ad.defjvp_zero(sign_p)
def _sign_lower_hlo(ctx, x):
x_aval, = ctx.avals_in
if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger):
return hlo.SelectOp(
return [hlo.select(
mlir.compare_hlo(x, mlir.full_like_aval(ctx, 0, x_aval), 'EQ',
'UNSIGNED').result,
'UNSIGNED'),
mlir.full_like_aval(ctx, 0, x_aval),
mlir.full_like_aval(ctx, 1, x_aval)).results
return hlo.SignOp(x).results
mlir.full_like_aval(ctx, 1, x_aval))]
return [hlo.sign(x)]
mlir.register_lowering(sign_p, _sign_lower_hlo)
nextafter_p = standard_naryop([_float, _float], 'nextafter')
mlir.register_lowering(nextafter_p, partial(_nary_lower_hlo, chlo.NextAfterOp))
mlir.register_lowering(nextafter_p, partial(_nary_lower_hlo, chlo.next_after))
floor_p = standard_unop(_float, 'floor')
ad.defjvp_zero(floor_p)
mlir.register_lowering(floor_p, partial(_nary_lower_hlo, hlo.FloorOp))
mlir.register_lowering(floor_p, partial(_nary_lower_hlo, hlo.floor))
ceil_p = standard_unop(_float, 'ceil')
ad.defjvp_zero(ceil_p)
mlir.register_lowering(ceil_p, partial(_nary_lower_hlo, hlo.CeilOp))
mlir.register_lowering(ceil_p, partial(_nary_lower_hlo, hlo.ceil))
round_p = standard_unop(_float, 'round')
ad.defjvp_zero(round_p)
def _round_lower(ctx, x, *, rounding_method):
if rounding_method is RoundingMethod.AWAY_FROM_ZERO:
return hlo.RoundOp(x).results
return [hlo.round_nearest_afz(x)]
else:
assert rounding_method is RoundingMethod.TO_NEAREST_EVEN
return hlo.RoundNearestEvenOp(x).results
return [hlo.round_nearest_even(x)]
mlir.register_lowering(round_p, _round_lower)
is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite')
ad.defjvp_zero(is_finite_p)
mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.IsFiniteOp))
mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite))
exp_p = standard_unop(_float | _complex, 'exp')
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.ExpOp))
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential))
exp2_p = standard_unop(_float | _complex, 'exp2')
ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans)))
@ -1776,30 +1776,31 @@ def _exp2_lower(ctx, x):
x_aval, = ctx.avals_in
log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype))
log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=())
return hlo.ExpOp(hlo.MulOp(log2, x).result).results
return [hlo.exponential(hlo.multiply(log2, x))]
mlir.register_lowering(exp2_p, _exp2_lower)
log_p = standard_unop(_float | _complex, 'log')
ad.defjvp(log_p, lambda g, x: div(g, x))
mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.LogOp))
mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.log))
expm1_p = standard_unop(_float | _complex, 'expm1')
ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans))))
mlir.register_lowering(expm1_p, partial(_nary_lower_hlo, hlo.Expm1Op))
mlir.register_lowering(expm1_p,
partial(_nary_lower_hlo, hlo.exponential_minus_one))
log1p_p = standard_unop(_float | _complex, 'log1p')
ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x))))
mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.Log1pOp))
mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.log_plus_one))
tanh_p = standard_unop(_float | _complex, 'tanh')
ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)),
sub(_one(x), ans)))
mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.TanhOp))
mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.tanh))
logistic_p = standard_unop(_float | _complex, 'logistic')
ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans))))
# TODO(phawkins): switch to LogisticOp lowering; debug numerical problems.
# mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.LogisticOp))
# mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.logistic))
def logistic_impl(x):
one = _const(x, 1)
@ -1810,11 +1811,11 @@ mlir.register_lowering(logistic_p,
sin_p = standard_unop(_float | _complex, 'sin')
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
mlir.register_lowering(sin_p, partial(_nary_lower_hlo, hlo.SineOp))
mlir.register_lowering(sin_p, partial(_nary_lower_hlo, hlo.sine))
cos_p = standard_unop(_float | _complex, 'cos')
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
mlir.register_lowering(cos_p, partial(_nary_lower_hlo, hlo.CosineOp))
mlir.register_lowering(cos_p, partial(_nary_lower_hlo, hlo.cosine))
@_upcast_fp16_for_computation
def _tan_impl(x):
@ -1822,7 +1823,7 @@ def _tan_impl(x):
tan_p = standard_unop(_float | _complex, 'tan')
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.TanOp))
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan))
def asin_impl(x):
if dtypes.issubdtype(_dtype(x), np.complexfloating):
@ -1833,7 +1834,7 @@ def asin_impl(x):
asin_p = standard_unop(_float | _complex, 'asin')
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x))))
mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.AsinOp))
mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.asin))
def acos_impl(x):
if dtypes.issubdtype(_dtype(x), np.complexfloating):
@ -1862,43 +1863,43 @@ def atan_impl(x):
atan_p = standard_unop(_float | _complex, 'atan')
ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x)))
mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.AtanOp))
mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.atan))
atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2')
ad.defjvp(atan2_p,
lambda g, x, y: g * (y / (square(x) + square(y))),
lambda g, x, y: g * -x / (square(x) + square(y)))
mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.Atan2Op))
mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.atan2))
sinh_p = standard_unop(_float | _complex, 'sinh')
ad.defjvp(sinh_p, lambda g, x: mul(g, cosh(x)))
mlir.register_lowering(sinh_p, partial(_nary_lower_hlo, chlo.SinhOp))
mlir.register_lowering(sinh_p, partial(_nary_lower_hlo, chlo.sinh))
cosh_p = standard_unop(_float | _complex, 'cosh')
ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.CoshOp))
mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.cosh))
asinh_p = standard_unop(_float | _complex, 'asinh')
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x))))
mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.AsinhOp))
mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.asinh))
acosh_p = standard_unop(_float | _complex, 'acosh')
ad.defjvp(acosh_p,
lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x)))))
mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.AcoshOp))
mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.acosh))
atanh_p = standard_unop(_float | _complex, 'atanh')
ad.defjvp(atanh_p,
lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x))))
mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.AtanhOp))
mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.atanh))
real_p = unop(_complex_basetype, _complex, 'real')
ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))])
mlir.register_lowering(real_p, partial(_nary_lower_hlo, hlo.RealOp))
mlir.register_lowering(real_p, partial(_nary_lower_hlo, hlo.real))
imag_p = unop(_complex_basetype, _complex, 'imag')
ad.deflinear2(imag_p, lambda t, _: [complex(np.zeros((), _dtype(t)), neg(t))])
mlir.register_lowering(imag_p, partial(_nary_lower_hlo, hlo.ImagOp))
mlir.register_lowering(imag_p, partial(_nary_lower_hlo, hlo.imag))
def _complex_transpose_rule(t, x, y):
@ -1923,7 +1924,7 @@ _complex_dtype = lambda dtype, *args: (np.zeros((), dtype) + np.zeros((), np.com
complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types],
'complex')
ad.deflinear2(complex_p, _complex_transpose_rule)
mlir.register_lowering(complex_p, partial(_nary_lower_hlo, hlo.ComplexOp))
mlir.register_lowering(complex_p, partial(_nary_lower_hlo, hlo.complex))
conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj')
@ -1950,7 +1951,7 @@ ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p)
ad.primitive_transposes[conj_p] = _conj_transpose_rule
abs_p = unop(_complex_basetype, _signedint | _float | _complex, 'abs')
mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.AbsOp))
mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.abs))
def _abs_jvp_rule(g, ans, x):
if _iscomplex(x):
@ -1964,18 +1965,18 @@ _maybe_real = lambda x: real(x) if _iscomplex(x) else x
sqrt_p = standard_unop(_float | _complex, 'sqrt')
ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans)))
mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.SqrtOp))
mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.sqrt))
rsqrt_p = standard_unop(_float | _complex, 'rsqrt')
ad.defjvp2(rsqrt_p,
lambda g, ans, x:
mul(g, mul(_const(x, -0.5), div(ans, x))))
mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.RsqrtOp))
mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.rsqrt))
cbrt_p = standard_unop(_float, 'cbrt')
ad.defjvp2(cbrt_p,
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))
mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.CbrtOp))
mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt))
def _pow_dtype_rule(x, y):
if (dtypes.issubdtype(x.dtype, np.inexact) and
@ -2020,7 +2021,7 @@ def _pow_lower(ctx, x, y):
[(x_,)] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x)
[(y_,)] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y)
ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_])
return _nary_lower_hlo(hlo.PowOp, ctx_, x_, y_)
return _nary_lower_hlo(hlo.power, ctx_, x_, y_)
mlir.register_lowering(pow_p, _pow_lower)
@ -2075,26 +2076,25 @@ _replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
not_p = standard_unop(_bool_or_int, 'not')
ad.defjvp_zero(not_p)
mlir.register_lowering(not_p, partial(_nary_lower_hlo, hlo.NotOp))
mlir.register_lowering(not_p, partial(_nary_lower_hlo, hlo.not_))
and_p = standard_naryop([_bool_or_int, _bool_or_int], 'and')
ad.defjvp_zero(and_p)
mlir.register_lowering(and_p, partial(_nary_lower_hlo, hlo.AndOp))
mlir.register_lowering(and_p, partial(_nary_lower_hlo, hlo.and_))
or_p = standard_naryop([_bool_or_int, _bool_or_int], 'or')
ad.defjvp_zero(or_p)
mlir.register_lowering(or_p, partial(_nary_lower_hlo, hlo.OrOp))
mlir.register_lowering(or_p, partial(_nary_lower_hlo, hlo.or_))
xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor')
ad.defjvp_zero(xor_p)
mlir.register_lowering(xor_p, partial(_nary_lower_hlo, hlo.XorOp))
mlir.register_lowering(xor_p, partial(_nary_lower_hlo, hlo.xor))
population_count_p = standard_unop(_int, 'population_count')
mlir.register_lowering(population_count_p,
partial(_nary_lower_hlo, hlo.PopulationCountOp))
mlir.register_lowering(population_count_p, partial(_nary_lower_hlo, hlo.popcnt))
clz_p = standard_unop(_int, 'clz')
mlir.register_lowering(clz_p, partial(_nary_lower_hlo, hlo.ClzOp))
mlir.register_lowering(clz_p, partial(_nary_lower_hlo, hlo.count_leading_zeros))
def _add_jvp(primals, tangents):
x, y = primals
@ -2130,7 +2130,7 @@ def _add_inverse(r, x, y):
add_p: Primitive = standard_naryop([_num, _num], 'add')
ad.primitive_jvps[add_p] = _add_jvp
ad.primitive_transposes[add_p] = _add_transpose
mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.AddOp))
mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.add))
def _sub_jvp(primals, tangents):
x, y = primals
@ -2159,7 +2159,7 @@ def _sub_transpose(t, x, y):
sub_p = standard_naryop([_num, _num], 'sub')
ad.primitive_jvps[sub_p] = _sub_jvp
ad.primitive_transposes[sub_p] = _sub_transpose
mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.SubtractOp))
mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.subtract))
def _mul_transpose(ct, x, y):
@ -2185,7 +2185,7 @@ ad.defjvp(mul_p,
lambda xdot, x, y: mul(xdot, y),
lambda ydot, x, y: mul(x, ydot))
ad.primitive_transposes[mul_p] = _mul_transpose
mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.MulOp))
mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.multiply))
def _div_transpose_rule(cotangent, x, y):
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
@ -2198,14 +2198,14 @@ ad.defjvp(div_p,
lambda g, x, y: div(g, y),
lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2)))
ad.primitive_transposes[div_p] = _div_transpose_rule
mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.DivOp))
mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.divide))
rem_p = standard_naryop([_int | _float, _int | _float], 'rem')
ad.defjvp(
rem_p,
lambda g, x, y: _maybe_broadcast(broadcast_shapes(np.shape(x), np.shape(y)), g),
lambda g, x, y: mul(neg(g), mul(sign(div(x, y)), floor(abs(div(x, y))))))
mlir.register_lowering(rem_p, partial(_nary_lower_hlo, hlo.RemOp))
mlir.register_lowering(rem_p, partial(_nary_lower_hlo, hlo.remainder))
def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x):
result_shape = broadcast_shapes(np.shape(x), np.shape(y))
@ -2231,17 +2231,17 @@ mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo))
shift_left_p = standard_naryop([_int, _int], 'shift_left')
ad.defjvp_zero(shift_left_p)
mlir.register_lowering(shift_left_p, partial(_nary_lower_hlo, hlo.ShiftLeftOp))
mlir.register_lowering(shift_left_p, partial(_nary_lower_hlo, hlo.shift_left))
shift_right_arithmetic_p = standard_naryop([_int, _int], 'shift_right_arithmetic')
ad.defjvp_zero(shift_right_arithmetic_p)
mlir.register_lowering(shift_right_arithmetic_p,
partial(_nary_lower_hlo, hlo.ShiftRightArithmeticOp))
partial(_nary_lower_hlo, hlo.shift_right_arithmetic))
shift_right_logical_p = standard_naryop([_int, _int], 'shift_right_logical')
ad.defjvp_zero(shift_right_logical_p)
mlir.register_lowering(shift_right_logical_p,
partial(_nary_lower_hlo, hlo.ShiftRightLogicalOp))
partial(_nary_lower_hlo, hlo.shift_right_logical))
def _opaque_comparison_hlo(direction, reduction_op, identity, ctx,
avals_in, aval_out, x, y):
@ -2288,7 +2288,7 @@ def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y):
compare_type = "SIGNED"
else:
compare_type = "UNSIGNED"
return mlir.compare_hlo(x, y, direction, compare_type).results
return [mlir.compare_hlo(x, y, direction, compare_type)]
eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq', allow_extended_dtype=True)
ad.defjvp_zero(eq_p)
@ -2426,7 +2426,7 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type):
aval_out, = ctx.avals_out
if (dtypes.issubdtype(aval_in.dtype, np.complexfloating) and
not dtypes.issubdtype(new_dtype, np.complexfloating)):
operand = hlo.RealOp(operand).result
operand = hlo.real(operand)
aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype))
return [mlir.convert_hlo(ctx, operand, aval_in, aval_out)]
@ -2473,7 +2473,7 @@ batching.defvectorized(bitcast_convert_type_p)
def _bitcast_convert_type_lower(ctx, operand, *, new_dtype):
aval_out, = ctx.avals_out
return hlo.BitcastConvertOp(mlir.aval_to_ir_type(aval_out), operand).results
return [hlo.bitcast_convert(mlir.aval_to_ir_type(aval_out), operand)]
mlir.register_lowering(bitcast_convert_type_p, _bitcast_convert_type_lower)
@ -2841,12 +2841,12 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
lhs_contracting_dimensions=list(lhs_contracting),
rhs_contracting_dimensions=list(rhs_contracting))
return [
hlo.DotGeneralOp(
hlo.dot_general(
mlir.aval_to_ir_type(aval_out),
lhs,
rhs,
dot_dnums,
precision_config=precision_attr(precision)).result
precision_config=precision_attr(precision))
]
mlir.register_lowering(dot_general_p, _dot_general_lower)
@ -3143,8 +3143,7 @@ ad.defjvp(clamp_p,
lambda g, min, operand, max:
select(lt(max, operand), g, _zeros(operand)))
batching.primitive_batchers[clamp_p] = _clamp_batch_rule
mlir.register_lowering(
clamp_p, partial(_nary_lower_hlo, hlo.ClampOp))
mlir.register_lowering(clamp_p, partial(_nary_lower_hlo, hlo.clamp))
pe.def_trivial_padding(clamp_p)
def _concatenate_shape_rule(*operands, **kwargs):
@ -3222,7 +3221,7 @@ batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule
pe.padding_rules[concatenate_p] = _concatenate_pad_rule
def _concatenate_lower(ctx, *xs, dimension):
return hlo.ConcatenateOp(xs, mlir.i64_attr(dimension)).results
return [hlo.concatenate(xs, mlir.i64_attr(dimension))]
mlir.register_lowering(concatenate_p, _concatenate_lower)
@ -3424,7 +3423,7 @@ def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions):
def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions):
aval_out, = ctx.avals_out
if dimensions is not None:
x = hlo.TransposeOp(x, mlir.dense_int_elements(dimensions)).result
x = hlo.transpose(x, mlir.dense_int_elements(dimensions))
if dyn_shape:
aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
return [mlir.reshape(ctx, x, aval_out)]
@ -3468,7 +3467,7 @@ ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
batching.primitive_batchers[rev_p] = _rev_batch_rule
def _rev_lower(ctx, x, *, dimensions):
return hlo.ReverseOp(x, mlir.dense_int_elements(dimensions)).results
return [hlo.reverse(x, mlir.dense_int_elements(dimensions))]
mlir.register_lowering(rev_p, _rev_lower)
@ -3500,7 +3499,7 @@ def _transpose_lower(ctx, x, *, permutation):
aval_out.dtype).shape
trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))]
permutation = [*permutation, *trailing_dims]
return hlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results
return [hlo.transpose(x, mlir.dense_int_elements(permutation))]
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
'transpose')
@ -3632,7 +3631,7 @@ def _select_hlo_lowering(ctx, which, *cases):
if which_aval.dtype == np.dtype(np.bool_):
assert len(cases) <= 2
if len(cases) == 1: return cases
return hlo.SelectOp(which, cases[1], cases[0]).results
return [hlo.select(which, cases[1], cases[0])]
if dtypes.issubdtype(which_aval.dtype, np.signedinteger):
compare_type = 'SIGNED'
@ -3648,8 +3647,8 @@ def _select_hlo_lowering(ctx, which, *cases):
pred = mlir.compare_hlo(which,
mlir.full_like_aval(ctx, offset + mid, which_aval),
lt, compare_type)
return hlo.SelectOp(pred, _select(offset, cases[:mid]),
_select(offset + mid, cases[mid:])).result
return hlo.select(pred, _select(offset, cases[:mid]),
_select(offset + mid, cases[mid:]))
return [_select(0, cases)]
@ -3791,7 +3790,7 @@ def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions):
out_nodes, _ = mlir.jaxpr_subcomp(reducer_ctx, jaxpr, mlir.TokenSet(), consts,
*([a] for a in reducer.arguments),
dim_var_values=ctx.dim_var_values)
hlo.ReturnOp(util.flatten(out_nodes))
hlo.return_(util.flatten(out_nodes))
return op.results
mlir.register_lowering(reduce_p, _reduce_lower)
@ -3982,8 +3981,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype))
reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_region):
add = reducer(*reducer_region.arguments)
hlo.ReturnOp(add.results)
hlo.return_([reducer(*reducer_region.arguments)])
return op.results
mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp,
@ -4021,8 +4019,8 @@ batching.defvectorized(reduce_precision_p)
def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
aval_out, = ctx.avals_out
return hlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits),
mlir.i32_attr(mantissa_bits)).results
return [hlo.reduce_precision(operand, mlir.i32_attr(exponent_bits),
mlir.i32_attr(mantissa_bits))]
mlir.register_lowering(reduce_precision_p, _reduce_precision_lower)
@ -4163,7 +4161,7 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
out = lower_comparator(sub_ctx, *[[a] for a in comparator.arguments],
num_keys=num_keys)
hlo.ReturnOp(util.flatten(out))
hlo.return_(util.flatten(out))
return sort.results
mlir.register_lowering(sort_p, _sort_lower)
@ -4273,7 +4271,7 @@ create_token_p.def_abstract_eval(lambda *_: abstract_token)
def _create_token_lowering(ctx, *operands):
aval_out, = ctx.avals_out
return hlo.CreateTokenOp().results
return [hlo.create_token()]
mlir.register_lowering(create_token_p, _create_token_lowering)
@ -4295,7 +4293,7 @@ after_all_p.def_abstract_eval(_after_all_abstract_eval)
def _after_all_lowering(ctx, *operands):
aval_out, = ctx.avals_out
return hlo.AfterAllOp(operands).results
return [hlo.after_all(operands)]
mlir.register_lowering(after_all_p, _after_all_lowering)
@ -4434,8 +4432,7 @@ rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
def _rng_uniform_lowering(ctx, a, b, *, shape):
aval_out, = ctx.avals_out
shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64))
return hlo.RngOp(a, b, shape,
hlo.RngDistributionAttr.get('UNIFORM')).results
return [hlo.rng(a, b, shape, hlo.RngDistributionAttr.get('UNIFORM'))]
mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering)
@ -4491,9 +4488,9 @@ def _rng_bit_generator_lowering(
rbg_etype = u32_type
rbg_dtype = np.uint32
if key_etype == u32_type:
key = hlo.BitcastConvertOp(
key = hlo.bitcast_convert(
ir.RankedTensorType.get([2], u64_type),
hlo.ReshapeOp(ir.RankedTensorType.get([2, 2], u32_type), key)).result
hlo.reshape(ir.RankedTensorType.get([2, 2], u32_type), key))
algorithm_attr = _rng_algorithm(algorithm)
_, out_vals_aval = ctx.avals_out
if any(not core.is_constant_shape(a.shape) for a in ctx.avals_out):
@ -4511,14 +4508,14 @@ def _rng_bit_generator_lowering(
ir.RankedTensorType.get(shape, rbg_etype),
algorithm_attr, key).results
if key_etype == u32_type:
out_key = hlo.ReshapeOp(
out_key = hlo.reshape(
ir.RankedTensorType.get([4], u32_type),
hlo.BitcastConvertOp(
ir.RankedTensorType.get([2, 2], u32_type), out_key)).result
hlo.bitcast_convert(
ir.RankedTensorType.get([2, 2], u32_type), out_key))
if rbg_etype != etype:
out_vals = hlo.ConvertOp(
out_vals = hlo.convert(
ir.RankedTensorType.get(ir.RankedTensorType(out_vals.type).shape, etype),
out_vals).result
out_vals)
return [out_key, out_vals]

View File

@ -435,7 +435,7 @@ ad.primitive_jvps[cholesky_p] = _cholesky_jvp_rule
batching.primitive_batchers[cholesky_p] = _cholesky_batching_rule
def _cholesky_lowering(ctx, x):
return hlo.CholeskyOp(x, lower=ir.BoolAttr.get(True)).results
return [hlo.cholesky(x, lower=ir.BoolAttr.get(True))]
mlir.register_lowering(cholesky_p, _cholesky_lowering)
@ -621,7 +621,7 @@ def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues):
if operand_aval.shape[-1] == 0:
reshape_aval = operand_aval.update(shape=operand_aval.shape[:-1])
return [
hlo.RealOp(mlir.reshape(ctx, operand, reshape_aval)).result,
hlo.real(mlir.reshape(ctx, operand, reshape_aval)),
operand,
]
@ -985,10 +985,10 @@ def _triangular_solve_lowering(
transpose = "NO_TRANSPOSE"
else:
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
return hlo.TriangularSolveOp(
return [hlo.triangular_solve(
a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
hlo.TransposeAttr.get(transpose)).results
hlo.TransposeAttr.get(transpose))]
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
@ -998,7 +998,7 @@ def _triangular_solve_cpu_lower(
a_aval, b_aval = ctx.avals_in
if conjugate_a and not transpose_a:
a = chlo.ConjOp(a).result
a = chlo.conj(a)
conjugate_a = False
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
@ -1014,10 +1014,10 @@ def _triangular_solve_cpu_lower(
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
else:
transpose = "NO_TRANSPOSE"
return hlo.TriangularSolveOp(a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower),
ir.BoolAttr.get(unit_diagonal),
hlo.TransposeAttr.get(transpose)).results
return [hlo.triangular_solve(a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower),
ir.BoolAttr.get(unit_diagonal),
hlo.TransposeAttr.get(transpose))]
mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
platform='cpu')
@ -1297,7 +1297,7 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand, *,
lu, pivot, info = getrf_impl(
operand_aval.dtype, operand, a_shape_vals=op_shape_vals)
# Subtract 1 from the pivot to get 0-based indices.
pivot = hlo.SubtractOp(pivot, mlir.full_like_aval(ctx, 1, pivot_aval)).result
pivot = hlo.subtract(pivot, mlir.full_like_aval(ctx, 1, pivot_aval))
ok = mlir.compare_hlo(
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
"GE", "SIGNED")
@ -2380,4 +2380,4 @@ def _broadcasting_select_hlo(ctx, which, which_aval, x, x_aval, y, y_aval) -> ir
which, x, y = mlir.multi_broadcast_in_dim(ctx, (which, x, y),
(which_aval, x_aval, y_aval),
out_shapes)
return hlo.SelectOp(which, x, y).result
return hlo.select(which, x, y)

View File

@ -775,7 +775,7 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
avals_in=[scalar_aval] * 2, avals_out=[scalar_aval])
out_nodes = lower_reducer(
reducer_ctx, *([a] for a in reducer_block.arguments))
hlo.ReturnOp(util.flatten(out_nodes))
hlo.return_(util.flatten(out_nodes))
return op.result
return [all_reduce(aval, x) for aval, x in zip(ctx.avals_in, args)]
@ -1191,7 +1191,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
new_shape = list(x_aval.shape)
new_shape.insert(all_gather_dimension, 1)
broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension]
x = hlo.BroadcastInDimOp(
x = hlo.broadcast_in_dim(
mlir.aval_to_ir_type(x_aval.update(shape=new_shape)), x,
mlir.dense_int_elements(broadcast_dimensions))
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
@ -1359,12 +1359,12 @@ def _reduce_scatter_lowering(
avals_out=[scalar_aval])
out_nodes = lower_reducer(
reducer_ctx, *([a] for a in reducer_block.arguments))
hlo.ReturnOp(util.flatten(out_nodes))
hlo.return_(util.flatten(out_nodes))
if tiled:
return op.results
else:
return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), op.result).results
return [hlo.reshape(mlir.aval_to_ir_type(aval_out), op.result)]
def _reduce_scatter_lowering_via_reducer(
prim, reducer, ctx, x,
@ -1582,13 +1582,13 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
)
if is_spmd:
device_id = hlo.PartitionIdOp()
device_id = hlo.partition_id()
else:
device_id = hlo.ReplicaIdOp()
unsigned_index = hlo.RemOp(hlo.DivOp(device_id, div), mod)
return hlo.ConvertOp(
device_id = hlo.replica_id()
unsigned_index = hlo.remainder(hlo.divide(device_id, div), mod)
return hlo.convert(
ir.RankedTensorType.get([], ir.IntegerType.get_signless(32)),
unsigned_index).result
unsigned_index)
def _axis_index_lowering(ctx, *, axis_name):
return [

View File

@ -1839,12 +1839,12 @@ def _gather_lower(ctx, operand, indices, *,
return hlo.DynamicGatherOp.build_generic(
results=results, operands=operands, attributes=attributes).results
else:
return hlo.GatherOp(
return [hlo.gather(
operand,
indices,
dnums,
mlir.dense_int_elements(slice_sizes),
indices_are_sorted=ir.BoolAttr.get(indices_are_sorted)).results
indices_are_sorted=ir.BoolAttr.get(indices_are_sorted))]
mlir.register_lowering(gather_p, _gather_lower)
@ -2499,7 +2499,7 @@ def _scatter_lower(ctx, operand, indices, updates, *,
update_ctx, update_jaxpr, mlir.TokenSet(), update_consts,
(update.arguments[0],), (update.arguments[1],),
dim_var_values=ctx.dim_var_values)
hlo.ReturnOp(util.flatten(out_nodes))
hlo.return_(util.flatten(out_nodes))
return op.results
mlir.register_lowering(scatter_p, _scatter_lower)
@ -2555,12 +2555,12 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
reducer = scatter.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer):
add = hlo.AddOp(*reducer.arguments).result
hlo.ReturnOp([add])
hlo.return_([add])
return scatter.result
real = _scatter(hlo.RealOp(operand).result, hlo.RealOp(updates).result)
imag = _scatter(hlo.ImagOp(operand).result, hlo.ImagOp(updates).result)
return hlo.ComplexOp(real, imag).results
real = _scatter(hlo.real(operand), hlo.real(updates))
imag = _scatter(hlo.imag(operand), hlo.imag(updates))
return [hlo.complex(real, imag)]
mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu")

View File

@ -625,14 +625,14 @@ ad.defjvp(regularized_incomplete_beta_p,
lgamma_p = standard_unop(_float, 'lgamma')
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.LgammaOp))
mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.lgamma))
digamma_p = standard_unop(_float, 'digamma')
mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.DigammaOp))
mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.digamma))
ad.defjvp(digamma_p, lambda g, x: mul(g, polygamma(_const(x, 1), x)))
polygamma_p = standard_naryop([_float, _float], 'polygamma')
mlir.register_lowering(polygamma_p, partial(_nary_lower_hlo, chlo.PolygammaOp))
mlir.register_lowering(polygamma_p, partial(_nary_lower_hlo, chlo.polygamma))
ad.defjvp(polygamma_p, polygamma_gradm, polygamma_gradx)
igamma_p = standard_naryop([_float, _float], 'igamma')
@ -659,7 +659,7 @@ mlir.register_lowering(random_gamma_grad_p,
multiple_results=False))
zeta_p = standard_naryop([_float, _float], 'zeta')
mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.ZetaOp))
mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.zeta))
bessel_i0e_p = standard_unop(_float, 'bessel_i0e')
mlir.register_lowering(bessel_i0e_p,
@ -669,7 +669,7 @@ ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y))
bessel_i1e_p = standard_unop(_float, 'bessel_i1e')
mlir.register_lowering(bessel_i1e_p,
partial(_nary_lower_hlo, chlo.BesselI1eOp))
partial(_nary_lower_hlo, chlo.bessel_i1e))
def _bessel_i1e_jvp(g, y, x):
eps = dtypes.finfo(_dtype(x)).eps
@ -683,14 +683,14 @@ ad.defjvp2(bessel_i1e_p, _bessel_i1e_jvp)
erf_p = standard_unop(_float, 'erf')
ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
mul(g, exp(neg(square(x))))))
mlir.register_lowering(erf_p, partial(_nary_lower_hlo, chlo.ErfOp))
mlir.register_lowering(erf_p, partial(_nary_lower_hlo, chlo.erf))
erfc_p = standard_unop(_float, 'erfc')
ad.defjvp(erfc_p, lambda g, x: mul(_const(x, -2. / np.sqrt(np.pi)),
mul(g, exp(neg(square(x))))))
mlir.register_lowering(erfc_p, partial(_nary_lower_hlo, chlo.ErfcOp))
mlir.register_lowering(erfc_p, partial(_nary_lower_hlo, chlo.erfc))
erf_inv_p = standard_unop(_float, 'erf_inv')
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),
mul(g, exp(square(ans)))))
mlir.register_lowering(erf_inv_p, partial(_nary_lower_hlo, chlo.ErfInvOp))
mlir.register_lowering(erf_inv_p, partial(_nary_lower_hlo, chlo.erf_inv))

View File

@ -470,7 +470,7 @@ def _reduce_window_lower(
return mlir.reduce_window(ctx,
reducer_name=f"reduce_window_{scalar_aval.dtype}_reducer",
reducer_body=lambda reducer: reduce_op(*reducer.arguments),
reducer_body=lambda reducer: [reduce_op(*reducer.arguments)],
operands=[operand],
init_values=[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype),
scalar_aval)],
@ -482,7 +482,7 @@ def _reduce_window_lower(
mlir.register_lowering(reduce_window_sum_p, partial(
_reduce_window_lower, hlo.AddOp, lambda _: 0))
_reduce_window_lower, hlo.add, lambda _: 0))
mlir.register_lowering(reduce_window_min_p, partial(
_reduce_window_lower, mlir.min_hlo, lax._get_min_identity))
mlir.register_lowering(reduce_window_max_p, partial(
@ -530,7 +530,7 @@ def _select_and_scatter_lower(
mlir.TokenSet(), select_consts,
*([a] for a in select.arguments),
dim_var_values=ctx.dim_var_values)
hlo.ReturnOp(util.flatten(out_nodes))
hlo.return_(util.flatten(out_nodes))
scatter = op.scatter.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(scatter):
if scatter_jaxpr.effects:
@ -539,7 +539,7 @@ def _select_and_scatter_lower(
mlir.TokenSet(), scatter_consts,
*([a] for a in scatter.arguments),
dim_var_values=ctx.dim_var_values)
hlo.ReturnOp(util.flatten(out_nodes))
hlo.return_(util.flatten(out_nodes))
return op.results
mlir.register_lowering(select_and_scatter_p, _select_and_scatter_lower)
@ -685,27 +685,27 @@ def _select_and_gather_add_lowering(
def pack(a, b, ab_aval):
word_type_ab_aval = ab_aval.update(dtype=word_dtype)
double_word_type_ab_aval = ab_aval.update(dtype=double_word_dtype)
a = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), a)
b = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), b)
a = hlo.ConvertOp(mlir.aval_to_ir_type(double_word_type_ab_aval), a)
b = hlo.ConvertOp(mlir.aval_to_ir_type(double_word_type_ab_aval), b)
a = hlo.ShiftLeftOp(a,
_broadcast_scalar_const(nbits, double_word_type_ab_aval))
return hlo.OrOp(a, b)
a = hlo.bitcast_convert(mlir.aval_to_ir_type(word_type_ab_aval), a)
b = hlo.bitcast_convert(mlir.aval_to_ir_type(word_type_ab_aval), b)
a = hlo.convert(mlir.aval_to_ir_type(double_word_type_ab_aval), a)
b = hlo.convert(mlir.aval_to_ir_type(double_word_type_ab_aval), b)
a = hlo.shift_left(
a, _broadcast_scalar_const(nbits, double_word_type_ab_aval))
return hlo.or_(a, b)
# Unpacks the first element of a double_word_type.
def fst(t):
assert not ir.RankedTensorType(t.type).shape
st = hlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits))
return hlo.BitcastConvertOp(
st = hlo.shift_right_logical(t, const(double_word_dtype, nbits))
return hlo.bitcast_convert(
ir.RankedTensorType.get([], etype),
hlo.ConvertOp(ir.RankedTensorType.get([], word_type), st)).result
hlo.convert(ir.RankedTensorType.get([], word_type), st))
# Unpacks the second element of a double_word_type.
def snd(t, t_aval):
return hlo.BitcastConvertOp(
return hlo.bitcast_convert(
mlir.aval_to_ir_type(t_aval.update(dtype=dtype)),
hlo.ConvertOp(mlir.aval_to_ir_type(t_aval.update(dtype=word_dtype)), t)).result
hlo.convert(mlir.aval_to_ir_type(t_aval.update(dtype=word_dtype)), t))
else:
# The double-word trick above only works if we have a sufficiently large
@ -726,29 +726,27 @@ def _select_and_gather_add_lowering(
# Packs two values into a double_word_type.
def pack(a, b, ab_aval):
word_type_ab_aval = ab_aval.update(dtype=word_dtype)
a = hlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp),
a = hlo.reduce_precision(a, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
b = hlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp),
b = hlo.reduce_precision(b, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
a = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), a)
b = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), b)
b = hlo.ShiftRightLogicalOp(
a = hlo.bitcast_convert(mlir.aval_to_ir_type(word_type_ab_aval), a)
b = hlo.bitcast_convert(mlir.aval_to_ir_type(word_type_ab_aval), b)
b = hlo.shift_right_logical(
b, _broadcast_scalar_const(r_nbits, word_type_ab_aval))
return hlo.OrOp(a, b)
return hlo.or_(a, b)
# Unpacks the first element of a double_word_type.
def fst(t):
assert not ir.RankedTensorType(t.type).shape
st = hlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits))
return hlo.BitcastConvertOp(ir.RankedTensorType.get([], etype),
st).result
st = hlo.and_(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits))
return hlo.bitcast_convert(ir.RankedTensorType.get([], etype), st)
# Unpacks the second element of a double_word_type.
def snd(t, t_aval):
return hlo.BitcastConvertOp(
return hlo.bitcast_convert(
mlir.aval_to_ir_type(t_aval.update(dtype=dtype)),
hlo.ShiftLeftOp(t, _broadcast_scalar_const(r_nbits, t_aval.update(dtype=word_dtype)))
).result
hlo.shift_left(t, _broadcast_scalar_const(r_nbits, t_aval.update(dtype=word_dtype))))
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
init = -np.inf if select_prim is lax.ge_p else np.inf

View File

@ -1115,9 +1115,9 @@ def _threefry2x32_gpu_lowering(lowering_func, ctx, k1, k2, x1, x2):
out_len = reduce(op.mul, aval_out.shape, 1)
if not core.is_constant_dim(out_len):
length = mlir.eval_dynamic_shape_as_tensor(ctx, [out_len])
length = mlir.hlo.ConvertOp(
length = mlir.hlo.convert(
ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)),
length).result
length)
output_shape = mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape)
else:
length = int(out_len) # will be passed statically
@ -1221,20 +1221,20 @@ def iota_2x32_shape_lowering(ctx, *, shape):
aval_u64 = core.ShapedArray(shape, np.dtype('uint64'))
def _add(x: ir.Value, y: ir.Value) -> ir.Value:
return mlir.hlo.AddOp(x, y).result
return mlir.hlo.add(x, y)
def _mul(x: core.DimSize, y: ir.Value) -> ir.Value:
if core.is_constant_dim(x):
x_const = mlir.ir_constant(np.array(x, np.dtype('uint64')))
else:
x_const, = mlir.eval_dynamic_shape(ctx, (x,))
x_const = hlo.ConvertOp(
x_const = hlo.convert(
ir.RankedTensorType.get(
(),
mlir.dtype_to_ir_type(np.dtype('uint64'))), x_const).result
mlir.dtype_to_ir_type(np.dtype('uint64'))), x_const)
x_bcast = mlir.broadcast_in_dim(ctx, x_const, aval_u64,
broadcast_dimensions=[])
return mlir.hlo.MulOp(x_bcast, y).result
return mlir.hlo.multiply(x_bcast, y)
assert len(shape) > 0
@ -1244,10 +1244,9 @@ def iota_2x32_shape_lowering(ctx, *, shape):
shift = mlir.ir_constant(np.array(32, np.dtype('uint64')))
shift = mlir.broadcast_in_dim(ctx, shift, aval_u64,
broadcast_dimensions=[])
counts_shifted = mlir.hlo.ShiftRightLogicalOp(counts, shift).result
counts_lo = mlir.hlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts).result
counts_hi = mlir.hlo.ConvertOp(mlir.aval_to_ir_type(aval_out),
counts_shifted).result
counts_shifted = mlir.hlo.shift_right_logical(counts, shift)
counts_lo = mlir.hlo.convert(mlir.aval_to_ir_type(aval_out), counts)
counts_hi = mlir.hlo.convert(mlir.aval_to_ir_type(aval_out), counts_shifted)
return counts_hi, counts_lo
mlir.register_lowering(iota_2x32_shape_p, iota_2x32_shape_lowering)

View File

@ -214,7 +214,7 @@ def _tpu_custom_call_lowering(
base64.b64encode(kernel_regeneration_metadata)
)
if multiple_results:
results = [stablehlo.GetTupleElementOp(call, mlir.i32_attr(i)).result
results = [stablehlo.get_tuple_element(call, mlir.i32_attr(i))
for i in range(len(out_avals))]
else:
results = call.results

View File

@ -720,7 +720,7 @@ def _wrap_main_func(
list(new_main_op.arguments[0:nr_platform_index_args]) + util.flatten(dim_values),
platform_input_types + dim_var_input_types):
if arg.type != arg_type:
orig_main_args.append(hlo.ConvertOp(arg_type, arg).result)
orig_main_args.append(hlo.convert(arg_type, arg))
else:
orig_main_args.append(arg)
# Then the token arguments
@ -737,7 +737,7 @@ def _wrap_main_func(
# output_operand_alias attribute. See b/287386268.
for arg, arg_type in zip(new_main_op_array_args, array_input_types):
if arg.type != arg_type:
orig_main_args.append(hlo.ConvertOp(arg_type, arg).result)
orig_main_args.append(hlo.convert(arg_type, arg))
else:
orig_main_args.append(arg)
call = func_dialect.CallOp(orig_output_types,
@ -1166,7 +1166,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value:
new_ir_type = mlir.aval_to_ir_type(new_aval)
if x.type != new_ir_type:
return hlo.ConvertOp(mlir.aval_to_ir_type(new_aval), x).result
return hlo.convert(mlir.aval_to_ir_type(new_aval), x)
else:
return x
@ -1210,7 +1210,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
for i in range(len(lowering_platforms)):
branch = callee_platform_idx.regions[i].blocks.append()
with ir.InsertionPoint(branch):
hlo.ReturnOp(mlir.ir_constants(
hlo.return_(mlir.ir_constants(
np.int32(callee_lowering_platform_index[i])))
if callee_platform_idx.result.type != callee_type.inputs[0]:
callee_platform_idx = hlo.ConvertOp(callee_type.inputs[0],

View File

@ -897,7 +897,7 @@ def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *,
res, = mlir.eval_dynamic_shape(ctx, (dim,))
out_type = mlir.aval_to_ir_type(ctx.avals_out[0])
if out_type != res.type: # type: ignore
return mlir.hlo.ConvertOp(out_type, res).results
return [mlir.hlo.convert(out_type, res)]
else:
return [res]
@ -1161,11 +1161,11 @@ def _dimension_size_impl(arg, *, dimension):
dimension_size_p.def_impl(_dimension_size_impl)
def _dimension_size_lowering_rule(ctx, arg, *, dimension):
dim_size = mlir.hlo.GetDimensionSizeOp(arg, dimension)
dim_size = mlir.hlo.get_dimension_size(arg, dimension)
dim_type = mlir.aval_to_ir_type(core.dim_value_aval())
if dim_size.result.type != dim_type:
dim_size = mlir.hlo.ConvertOp(dim_type, dim_size)
return dim_size.results
if dim_size.type != dim_type:
dim_size = mlir.hlo.convert(dim_type, dim_size)
return [dim_size]
mlir.register_lowering(dimension_size_p, _dimension_size_lowering_rule)

View File

@ -570,7 +570,7 @@ def _call_tf_lowering(
ir.FlatSymbolRefAttr.get(fn),
tuple(args_op) + captured_ops)
if result_shape.is_tuple():
flat_results = [hlo.GetTupleElementOp(call, mlir.i32_attr(i)).result
flat_results = [hlo.get_tuple_element(call, mlir.i32_attr(i))
for i in range(len(result_shapes))]
else:
flat_results = call.results

View File

@ -681,8 +681,7 @@ def _bcsr_dot_general_gpu_lowering(
dot_general_fn = csr_matmat_lowering
x_dtype = 'B_dtype'
if rhs_contract[0] == 1:
rhs = hlo.TransposeOp(
rhs, permutation=mlir.dense_int_elements([1, 0])).result
rhs = hlo.transpose(rhs, permutation=mlir.dense_int_elements([1, 0]))
else:
raise ValueError(f"rhs has to be 1d or 2d; get {rhs_aval.ndim}d.")

View File

@ -229,7 +229,7 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo):
result = coo_todense_hlo(
data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype)
return (
[hlo.TransposeOp(result, mlir.dense_int_elements([1, 0])).result]
[hlo.transpose(result, mlir.dense_int_elements([1, 0]))]
if transpose else [result])

View File

@ -79,9 +79,9 @@ def dynamic_ducc_fft_hlo(
assert 0 not in a_type.shape
u8_type = ir.IntegerType.get_unsigned(8)
descriptor = hlo.ConstantOp(
descriptor = hlo.constant(
ir.DenseElementsAttr.get(
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)).result
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
layout = tuple(range(ndims - 1, -1, -1))
return custom_call(
"dynamic_ducc_fft",

View File

@ -89,9 +89,9 @@ def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *,
def _hlo_zeros_f32(shape):
return hlo.ConstantOp(
return hlo.constant(
ir.DenseElementsAttr.get(
np.zeros(shape, dtype=np.float32), type=ir.F32Type.get())).result
np.zeros(shape, dtype=np.float32), type=ir.F32Type.get()))
def cudnn_rnn_bwd_lowering(ctx, dy, dhn, dcn, x, h0, c0, w, y,

View File

@ -320,7 +320,7 @@ def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, *,
if dynamic_batch_dims:
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
operands.append(batch_size_val)
operand_layouts.append(())
@ -406,22 +406,22 @@ def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
[0],
],
operand_output_aliases={0: 0}).results
vt = hlo.TransposeOp(
vt = hlo.transpose(
v,
ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result
ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd))))
if np.issubdtype(dtype, np.complexfloating):
vt = hlo.ComplexOp(hlo.RealOp(vt), hlo.NegOp(hlo.ImagOp(vt))).result
vt = hlo.complex(hlo.real(vt), hlo.negate(hlo.imag(vt)))
if not full_matrices and not econ:
u = hlo.SliceOp(
u = hlo.slice(
u,
ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)),
ir.DenseIntElementsAttr.get(np.array(batch_dims + (m, min(m, n)))),
ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result
vt = hlo.SliceOp(
ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64)))
vt = hlo.slice(
vt,
ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)),
ir.DenseIntElementsAttr.get(np.array(batch_dims + (min(m, n), n))),
ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result
ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64)))
elif m < n:
lwork, opaque = gpu_solver.build_gesvd_descriptor(
np.dtype(dtype), b, n, m, compute_uv, full_matrices)
@ -538,15 +538,15 @@ def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower):
if not lower and platform == "cu" and m > 1:
start = (0,) * len(batch_dims) + (0,)
end = batch_dims + (1,)
s = hlo.SliceOp(e, intattr(start), intattr(end), intattr([1] * len(start)))
s = hlo.slice(e, intattr(start), intattr(end), intattr([1] * len(start)))
s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type)
s = hlo.BroadcastInDimOp(s_type, s, intattr(range(len(dims) - 1)))
s = hlo.broadcast_in_dim(s_type, s, intattr(range(len(dims) - 1)))
# The diagonals are always real; convert to complex if needed.
s = hlo.ConvertOp(
s = hlo.convert(
ir.RankedTensorType.get(s_type.shape, a_type.element_type), s)
offsets = tuple(hlo.ConstantOp(intattr(i))
offsets = tuple(hlo.constant(intattr(i))
for i in ((0,) * len(batch_dims) + (0, 1)))
a = hlo.DynamicUpdateSliceOp(a, s, offsets).result
a = hlo.dynamic_update_slice(a, s, offsets)
return a, d, e, taus, info

View File

@ -84,21 +84,21 @@ def shape_tensor(sizes: Sequence[Union[int, ir.Value]]) -> ir.Value:
return hlo_const(np.array([d], dtype=np.int32))
else:
if d.type != i32_type:
d = hlo.ConvertOp(i32_type, d).result
return hlo.ReshapeOp(int1d, d).result
d = hlo.convert(i32_type, d)
return hlo.reshape(int1d, d)
ds = [dim_to_i32x1(sz) for sz in sizes]
if not ds:
return hlo_const(np.array([], np.int32))
elif len(ds) == 1:
return ds[0]
else:
return hlo.ConcatenateOp(
ds, ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0)).result
return hlo.concatenate(
ds, ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0))
def hlo_const(x: np.ndarray) -> ir.Value:
assert isinstance(x, np.ndarray)
return hlo.ConstantOp(
ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype))).result
return hlo.constant(
ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype)))
def hlo_u8(x: int):
return hlo_const(np.array(x, dtype=np.uint8))
@ -116,7 +116,7 @@ def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize:
x = hlo_s32(x)
if type(y) is int:
y = hlo_s32(y)
return hlo.MinOp(x, y).result
return hlo.minimum(x, y)
def hlo_add(x: DimensionSize, y: DimensionSize) -> DimensionSize:
@ -126,7 +126,7 @@ def hlo_add(x: DimensionSize, y: DimensionSize) -> DimensionSize:
x = hlo_s32(x)
if type(y) is int:
y = hlo_s32(y)
return hlo.AddOp(x, y).result
return hlo.add(x, y)
# TODO(necula): this is identical with mlir.custom_call, but meant for use

View File

@ -51,7 +51,7 @@ def trsm_hlo(dtype, alpha, a, b,
num_bd = len(batch_dims_vals)
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
if dtype == np.float32:
fn = "blas_strsm"
@ -119,7 +119,7 @@ def getrf_hlo(dtype, a: ir.Value, *,
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
return custom_call(
fn,
@ -170,7 +170,7 @@ def geqrf_hlo(dtype, a: ir.Value, *,
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
shape_type_pairs: Sequence[ShapeTypePair] = [
(a_shape_vals, a_type.element_type),
(batch_dims_vals + (min(m, n),), a_type.element_type),
@ -211,7 +211,7 @@ def orgqr_hlo(dtype, a: ir.Value, tau, *,
num_bd = len(batch_dims_vals)
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
k = tau_shape_vals[-1]
assert type(k) is int
@ -281,7 +281,7 @@ def potrf_hlo(dtype, a: ir.Value, *, lower=False,
num_bd = len(batch_dims_vals)
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
scalar_layout = []
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
@ -318,7 +318,7 @@ def gesdd_hlo(dtype, a: ir.Value, *, full_matrices=True, compute_uv=True,
num_bd = len(batch_dims_vals)
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
i32_type = ir.IntegerType.get_signless(32)
workspace: list[ShapeTypePair]
@ -445,7 +445,7 @@ def syevd_hlo(dtype, a: ir.Value,
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
scalar_layout = []
shape_layout = [0]
@ -540,7 +540,7 @@ def geev_hlo(dtype, input, *,
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
shape_type_pairs: Sequence[ShapeTypePair] = workspaces + eigvals + [
(input_shape_vals, eigvecs_type),
@ -560,7 +560,7 @@ def geev_hlo(dtype, input, *,
result_shapes=result_shapes,
).results
if real:
return (hlo.ComplexOp(out[3], out[4]).result, out[5], out[6], out[7])
return (hlo.complex(out[3], out[4]), out[5], out[6], out[7])
else:
return out[2:6]
@ -615,7 +615,7 @@ def gees_hlo(dtype, a, *, jobvs=True, sort=False, select=None,
scalar_layout = []
batch_size_val = hlo_s32(1)
for b_v in batch_dims_vals:
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
shape_type_pairs = workspaces + eigvals + [
(a_shape_vals, etype),
(batch_dims_vals, i32_type),

View File

@ -88,7 +88,7 @@ def lowering_testing_primitive_with_effect(ctx, a, *, effect_class_name: str):
if "Ordered" in effect_class_name:
token_in = ctx.tokens_in.get(_testing_effects[effect_class_name])[0]
ctx.set_tokens_out(mlir.TokenSet({_testing_effects[effect_class_name]: (token_in,)}))
return mlir.hlo.AddOp(a, a).results
return [mlir.hlo.add(a, a)]
mlir.register_lowering(testing_primitive_with_effect_p,
lowering_testing_primitive_with_effect)