mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Use MLIR generated convenience functions athing(...) instead of writing AThingOp(...).result.
In most cases these are more succinct. This change does not update Pallas/Mosaic. PiperOrigin-RevId: 583448254
This commit is contained in:
parent
e016ce4639
commit
8e8dc263bc
@ -91,15 +91,15 @@ def shape_tensor(sizes: Sequence[int | ir.RankedTensorType]
|
||||
return ir_constant(np.array([d], np.int32))
|
||||
else:
|
||||
if d.type != i32_type:
|
||||
d = hlo.ConvertOp(i32_type, d)
|
||||
return hlo.ReshapeOp(int1d, d).result
|
||||
d = hlo.convert(i32_type, d)
|
||||
return hlo.reshape(int1d, d)
|
||||
ds = map(lower_dim, sizes)
|
||||
if not ds:
|
||||
return ir_constant(np.array([], np.int32))
|
||||
elif len(ds) == 1:
|
||||
return ds[0]
|
||||
else:
|
||||
return hlo.ConcatenateOp(ds, i64_attr(0)).result
|
||||
return hlo.concatenate(ds, i64_attr(0))
|
||||
|
||||
|
||||
def delegate_lowering(ctx, lowering_fun, *args, **ctx_override_kwargs):
|
||||
@ -251,7 +251,7 @@ def _numpy_array_constant(x: np.ndarray) -> Sequence[ir.Value]:
|
||||
x = np.packbits(x, bitorder='little')
|
||||
x = np.ascontiguousarray(x)
|
||||
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
|
||||
return (hlo.ConstantOp(attr).result,)
|
||||
return (hlo.constant(attr),)
|
||||
|
||||
|
||||
def _masked_array_constant_handler(*args, **kwargs):
|
||||
@ -284,11 +284,11 @@ def _ndarray_constant_handler(val: np.ndarray) -> Sequence[ir.Value]:
|
||||
other_axes, = np.where(np.not_equal(0, val.strides))
|
||||
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) # type: ignore
|
||||
for ax in range(val.ndim))] # type: ignore
|
||||
out = hlo.BroadcastInDimOp(
|
||||
out = hlo.broadcast_in_dim(
|
||||
ir.RankedTensorType.get(
|
||||
val.shape, dtype_to_ir_type(collapsed_val.dtype)),
|
||||
_numpy_array_constant(collapsed_val)[0],
|
||||
dense_int_elements(other_axes)).result
|
||||
dense_int_elements(other_axes))
|
||||
return (out,)
|
||||
else:
|
||||
return _numpy_array_constant(val)
|
||||
@ -309,7 +309,7 @@ for ptype, dtype in dtypes.python_scalar_dtypes.items():
|
||||
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
|
||||
|
||||
def _token_constant_handler(val):
|
||||
return [hlo.CreateTokenOp().result]
|
||||
return [hlo.create_token()]
|
||||
register_constant_handler(core.Token, _token_constant_handler)
|
||||
|
||||
# Source locations
|
||||
@ -646,7 +646,7 @@ def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext,
|
||||
else:
|
||||
i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
|
||||
if d.type != i32_type: # type: ignore
|
||||
return hlo.ConvertOp(i32_type, d).result
|
||||
return hlo.convert(i32_type, d)
|
||||
else:
|
||||
return d
|
||||
return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape))
|
||||
@ -662,7 +662,7 @@ def eval_dynamic_shape_as_ivals(
|
||||
else:
|
||||
i32_type = aval_to_ir_type(core.ShapedArray((), np.int32))
|
||||
if d.type != i32_type: # type: ignore
|
||||
return hlo.ConvertOp(i32_type, d).result
|
||||
return hlo.convert(i32_type, d)
|
||||
else:
|
||||
return d
|
||||
return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape))
|
||||
@ -912,7 +912,7 @@ def token_type() -> Sequence[ir.Type]:
|
||||
return [hlo.TokenType.get()]
|
||||
|
||||
def create_token() -> Token:
|
||||
return wrap_singleton_ir_values(hlo.CreateTokenOp().result)
|
||||
return wrap_singleton_ir_values(hlo.create_token())
|
||||
|
||||
class TokenSet:
|
||||
"""An immutable container of tokens to be used to lower effectful jaxprs. When lowering
|
||||
@ -1252,7 +1252,7 @@ def lower_jaxpr_to_fun(
|
||||
args: list[list[ir.Value]] = []
|
||||
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
|
||||
if replace_tokens_with_dummy and aval is core.abstract_token:
|
||||
args.append(hlo.CreateTokenOp().results)
|
||||
args.append([hlo.create_token()])
|
||||
else:
|
||||
args.append(arg)
|
||||
callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name))
|
||||
@ -1285,7 +1285,7 @@ def lower_jaxpr_to_fun(
|
||||
o if mk is None else wrap_with_memory_kind(o, mk, o_aval)
|
||||
for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)]
|
||||
|
||||
func_dialect.ReturnOp(flat_outputs)
|
||||
func_dialect.return_(flat_outputs)
|
||||
|
||||
return func_op
|
||||
|
||||
@ -1356,7 +1356,7 @@ def _emit_lowering_rule_as_fun(lowering_rule,
|
||||
outs = lowering_rule(sub_ctx, *_unwrap_singleton_ir_values(unflattened_args))
|
||||
if sub_ctx.tokens_out:
|
||||
outs = [*[sub_ctx.tokens_out.get(eff) for eff in effs], outs]
|
||||
func_dialect.ReturnOp(util.flatten(map(wrap_singleton_ir_values, outs)))
|
||||
func_dialect.return_(util.flatten(map(wrap_singleton_ir_values, outs)))
|
||||
return func_op
|
||||
|
||||
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
@ -1500,7 +1500,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
out_nodes = tuple(map(wrap_singleton_ir_values, ans))
|
||||
except TypeError as e:
|
||||
raise ValueError("Output of translation rule must be iterable: "
|
||||
f"{eqn}, got output {ans} from {rule.__name__}") from e
|
||||
f"{eqn}, got output {ans}") from e
|
||||
|
||||
assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn)
|
||||
assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (
|
||||
@ -1595,14 +1595,14 @@ def lower_multi_platform(ctx: LoweringRuleContext,
|
||||
# Compute the rule index based on the current platform
|
||||
i32_type = aval_to_ir_types(core.ShapedArray((), dtype=np.int32))[0]
|
||||
if current_platform_idx.type != i32_type:
|
||||
current_platform_idx = hlo.ConvertOp(i32_type, current_platform_idx)
|
||||
current_platform_idx = hlo.convert(i32_type, current_platform_idx)
|
||||
rule_idx_op = hlo.CaseOp([i32_type],
|
||||
index=current_platform_idx,
|
||||
num_branches=len(platforms))
|
||||
for i, p in enumerate(platforms):
|
||||
branch = rule_idx_op.regions[i].blocks.append()
|
||||
with ir.InsertionPoint(branch):
|
||||
hlo.ReturnOp(ir_constants(np.int32(platform_to_kept_rules_idx[p])))
|
||||
hlo.return_(ir_constants(np.int32(platform_to_kept_rules_idx[p])))
|
||||
ordered_effects = effects_lib.ordered_effects.filter_in(effects)
|
||||
rule_out_avals = [core.abstract_token] * len(ordered_effects) + ctx.avals_out
|
||||
output_types = map(aval_to_ir_types, rule_out_avals)
|
||||
@ -1623,7 +1623,7 @@ def lower_multi_platform(ctx: LoweringRuleContext,
|
||||
assert len(ordered_effects) == len(inner_ctx.tokens_out)
|
||||
out_nodes = [inner_ctx.tokens_out.get(eff)
|
||||
for eff in ordered_effects] + out_nodes
|
||||
hlo.ReturnOp(util.flatten(map(wrap_singleton_ir_values, out_nodes)))
|
||||
hlo.return_(util.flatten(map(wrap_singleton_ir_values, out_nodes)))
|
||||
|
||||
results = case_op.results
|
||||
if ordered_effects:
|
||||
@ -1772,17 +1772,17 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue,
|
||||
else:
|
||||
if not core.is_constant_shape(aval_out.shape): # type: ignore
|
||||
shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape) # type: ignore
|
||||
return hlo.DynamicBroadcastInDimOp(
|
||||
return hlo.dynamic_broadcast_in_dim(
|
||||
aval_to_ir_type(aval_out), op,
|
||||
shape,
|
||||
dense_int_elements(broadcast_dimensions),
|
||||
).result
|
||||
)
|
||||
else:
|
||||
assert all(d != ir.ShapedType.get_dynamic_size()
|
||||
for d in aval_out.shape), aval_out # type: ignore
|
||||
return hlo.BroadcastInDimOp(
|
||||
return hlo.broadcast_in_dim(
|
||||
aval_to_ir_type(aval_out), op,
|
||||
dense_int_elements(broadcast_dimensions)).result
|
||||
dense_int_elements(broadcast_dimensions))
|
||||
|
||||
def multi_broadcast_in_dim(ctx: LoweringRuleContext,
|
||||
ops: Sequence[ir.Value],
|
||||
@ -1806,11 +1806,11 @@ def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Va
|
||||
aval_out = core.physical_aval(aval_out)
|
||||
if not core.is_constant_shape(aval_out.shape): # type: ignore
|
||||
shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape) # type: ignore
|
||||
return hlo.DynamicReshapeOp(
|
||||
return hlo.dynamic_reshape(
|
||||
aval_to_ir_type(aval_out), op, shape,
|
||||
).result
|
||||
)
|
||||
else:
|
||||
return hlo.ReshapeOp(aval_to_ir_type(aval_out), op).result
|
||||
return hlo.reshape(aval_to_ir_type(aval_out), op)
|
||||
|
||||
def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
|
||||
start_indices, limit_indices, strides) -> ir.Value:
|
||||
@ -1830,14 +1830,14 @@ def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
|
||||
start_indices = eval_dynamic_shape_as_tensor(ctx, start_indices)
|
||||
limit_indices = eval_dynamic_shape_as_tensor(ctx, limit_indices)
|
||||
strides = eval_dynamic_shape_as_tensor(ctx, strides)
|
||||
return hlo.RealDynamicSliceOp(
|
||||
return hlo.real_dynamic_slice(
|
||||
aval_to_ir_type(aval_out),
|
||||
x, start_indices, limit_indices, strides).result
|
||||
x, start_indices, limit_indices, strides)
|
||||
else:
|
||||
return hlo.SliceOp(x,
|
||||
dense_int_elements(start_indices),
|
||||
dense_int_elements(limit_indices),
|
||||
dense_int_elements(strides)).result
|
||||
return hlo.slice(x,
|
||||
dense_int_elements(start_indices),
|
||||
dense_int_elements(limit_indices),
|
||||
dense_int_elements(strides))
|
||||
|
||||
def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
|
||||
start_indices) -> ir.Value:
|
||||
@ -1859,21 +1859,20 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
|
||||
# lower to RealDynamicSliceOp, which is a version of SliceOp, and does
|
||||
# not have the clamping behavior. We clamp start ourselves.
|
||||
slice_sizes = eval_dynamic_shape_as_tensor(ctx, slice_sizes)
|
||||
clamped_start = hlo.ClampOp(
|
||||
clamped_start = hlo.clamp(
|
||||
shape_tensor([0] * len(start_indices)),
|
||||
shape_tensor(start_indices),
|
||||
hlo.SubtractOp(
|
||||
hlo.subtract(
|
||||
eval_dynamic_shape_as_tensor(ctx, x_aval.shape), # type: ignore
|
||||
slice_sizes))
|
||||
return hlo.RealDynamicSliceOp(
|
||||
return hlo.real_dynamic_slice(
|
||||
aval_to_ir_type(aval_out), x,
|
||||
clamped_start,
|
||||
hlo.AddOp(clamped_start, slice_sizes).result,
|
||||
hlo.add(clamped_start, slice_sizes),
|
||||
shape_tensor([1] * len(start_indices))
|
||||
).result
|
||||
)
|
||||
else:
|
||||
return hlo.DynamicSliceOp(x, start_indices,
|
||||
dense_int_elements(slice_sizes)).result
|
||||
return hlo.dynamic_slice(x, start_indices, dense_int_elements(slice_sizes))
|
||||
|
||||
def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
|
||||
start_indices) -> ir.Value:
|
||||
@ -1890,36 +1889,35 @@ def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
|
||||
start_indices=start_indices)
|
||||
else:
|
||||
# TODO(necula): handle dynamic shapes
|
||||
return hlo.DynamicUpdateSliceOp(x, update, start_indices).result
|
||||
return hlo.dynamic_update_slice(x, update, start_indices)
|
||||
|
||||
def pad(ctx: LoweringRuleContext, aval_out,
|
||||
x, padding_value,
|
||||
padding_low, padding_high, padding_interior) -> ir.Value:
|
||||
if all(core.is_constant_shape(s) for s in (padding_low,
|
||||
padding_high, padding_interior)):
|
||||
return hlo.PadOp(x, padding_value,
|
||||
dense_int_elements(padding_low),
|
||||
dense_int_elements(padding_high),
|
||||
dense_int_elements(padding_interior)).result
|
||||
return hlo.pad(x, padding_value,
|
||||
dense_int_elements(padding_low),
|
||||
dense_int_elements(padding_high),
|
||||
dense_int_elements(padding_interior))
|
||||
else:
|
||||
padding_low = eval_dynamic_shape_as_tensor(ctx, padding_low)
|
||||
padding_high = eval_dynamic_shape_as_tensor(ctx, padding_high)
|
||||
padding_interior = eval_dynamic_shape_as_tensor(ctx, padding_interior)
|
||||
return hlo.DynamicPadOp(
|
||||
return hlo.dynamic_pad(
|
||||
aval_to_ir_type(aval_out),
|
||||
x, padding_value, padding_low, padding_high, padding_interior).result
|
||||
x, padding_value, padding_low, padding_high, padding_interior)
|
||||
|
||||
def iota(ctx: LoweringRuleContext, aval_out, *, dimension: int):
|
||||
if not core.is_constant_shape(aval_out.shape):
|
||||
shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape)
|
||||
return hlo.DynamicIotaOp(
|
||||
return hlo.dynamic_iota(
|
||||
aval_to_ir_type(aval_out),
|
||||
shape,
|
||||
i64_attr(dimension),
|
||||
).result
|
||||
)
|
||||
else:
|
||||
return hlo.IotaOp(aval_to_ir_type(aval_out),
|
||||
i64_attr(dimension)).result
|
||||
return hlo.iota(aval_to_ir_type(aval_out), i64_attr(dimension))
|
||||
|
||||
def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> ir.Value:
|
||||
"""Returns an IR constant shaped full of `value` shaped like `aval`."""
|
||||
@ -1933,7 +1931,7 @@ def zeros_like_lowering(ctx, x):
|
||||
register_lowering(ad_util.zeros_like_p, zeros_like_lowering)
|
||||
|
||||
def add_jaxvals_lowering(ctx, x, y):
|
||||
return hlo.AddOp(x, y).results
|
||||
return [hlo.add(x, y)]
|
||||
register_lowering(ad_util.add_jaxvals_p, add_jaxvals_lowering)
|
||||
|
||||
register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x])
|
||||
@ -1949,7 +1947,7 @@ def compare_hlo(x, y, direction: str, comparison_type: str | None = None):
|
||||
else:
|
||||
comparison_type = "FLOAT"
|
||||
|
||||
return hlo.CompareOp(
|
||||
return hlo.compare(
|
||||
x,
|
||||
y,
|
||||
hlo.ComparisonDirectionAttr.get(direction),
|
||||
@ -1959,20 +1957,18 @@ def _minmax_hlo(op, cmp, x, y):
|
||||
"""Min/max that compares complex values lexicographically as pairs."""
|
||||
tensor_type = ir.RankedTensorType(x.type)
|
||||
if ir.ComplexType.isinstance(tensor_type.element_type):
|
||||
rx = hlo.RealOp(x).result
|
||||
ry = hlo.RealOp(y).result
|
||||
rx = hlo.real(x)
|
||||
ry = hlo.real(y)
|
||||
real_eq = compare_hlo(rx, ry, "EQ", "FLOAT")
|
||||
real_cmp = compare_hlo(rx, ry, cmp, "FLOAT")
|
||||
imag_cmp = compare_hlo(
|
||||
hlo.ImagOp(x).result,
|
||||
hlo.ImagOp(y).result, cmp, "FLOAT")
|
||||
which = hlo.SelectOp(real_eq, imag_cmp, real_cmp).result
|
||||
return hlo.SelectOp(which, x, y)
|
||||
imag_cmp = compare_hlo(hlo.imag(x), hlo.imag(y), cmp, "FLOAT")
|
||||
which = hlo.select(real_eq, imag_cmp, real_cmp)
|
||||
return hlo.select(which, x, y)
|
||||
else:
|
||||
return op(x, y)
|
||||
|
||||
min_hlo = partial(_minmax_hlo, hlo.MinOp, "LT")
|
||||
max_hlo = partial(_minmax_hlo, hlo.MaxOp, "GT")
|
||||
min_hlo = partial(_minmax_hlo, hlo.minimum, "LT")
|
||||
max_hlo = partial(_minmax_hlo, hlo.maximum, "GT")
|
||||
|
||||
|
||||
def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
|
||||
@ -1988,10 +1984,9 @@ def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
|
||||
compare_type = "SIGNED"
|
||||
else:
|
||||
compare_type = "UNSIGNED"
|
||||
x = compare_hlo(x, full_like_aval(ctx, 0, aval_in), "NE",
|
||||
compare_type).result
|
||||
x = compare_hlo(x, full_like_aval(ctx, 0, aval_in), "NE", compare_type)
|
||||
# continue, to adjust the shape if needed
|
||||
return hlo.ConvertOp(aval_to_ir_type(aval_out), x).result
|
||||
return hlo.convert(aval_to_ir_type(aval_out), x)
|
||||
|
||||
def _wrap_with_spmd_op(name: str,
|
||||
ctx: LoweringRuleContext,
|
||||
@ -2184,7 +2179,7 @@ def xla_fallback_lowering(prim: core.Primitive):
|
||||
flatten_lowering_ir_args(args)).result
|
||||
if not prim.multiple_results:
|
||||
return [call]
|
||||
flat_results = [hlo.GetTupleElementOp(call, i32_attr(i)).result
|
||||
flat_results = [hlo.get_tuple_element(call, i32_attr(i))
|
||||
for i in range(len(flat_output_types))]
|
||||
|
||||
return util.unflatten(flat_results, map(len, output_types))
|
||||
@ -2267,7 +2262,7 @@ def _emit_tpu_python_callback(
|
||||
*,
|
||||
sharding: xc.OpSharding | None = None
|
||||
) -> tuple[Sequence[ir.Value], Any]:
|
||||
token = token or hlo.CreateTokenOp().result
|
||||
token = token or hlo.create_token()
|
||||
_wrapped_callback = callback
|
||||
|
||||
send_channels = []
|
||||
@ -2445,7 +2440,7 @@ def emit_python_callback(
|
||||
if sharding is not None:
|
||||
set_sharding(result, sharding)
|
||||
results = [
|
||||
hlo.GetTupleElementOp(result, i32_attr(i)).result
|
||||
hlo.get_tuple_element(result, i32_attr(i))
|
||||
for i in range(len(result_types))
|
||||
]
|
||||
if token:
|
||||
@ -2603,9 +2598,8 @@ def reduce_window(
|
||||
int2d = aval_to_ir_type(core.ShapedArray((1, 2), np.int32))
|
||||
def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]):
|
||||
pads = eval_dynamic_shape_as_tensor(ctx, pad_lo_hi) # i32[2]
|
||||
return hlo.ReshapeOp(int2d, pads)
|
||||
d_padding = hlo.ConcatenateOp(list(map(prep_one_pad, padding)),
|
||||
i64_attr(0)).result
|
||||
return hlo.reshape(int2d, pads)
|
||||
d_padding = hlo.concatenate(list(map(prep_one_pad, padding)), i64_attr(0))
|
||||
# Build the reducer
|
||||
reducer_type = ir.FunctionType.get(scalar_types + scalar_types,
|
||||
scalar_types)
|
||||
@ -2614,8 +2608,7 @@ def reduce_window(
|
||||
ctx.module_context.symbol_table.insert(reducer)
|
||||
entry_block = reducer.add_entry_block()
|
||||
with ir.InsertionPoint(entry_block):
|
||||
res = reducer_body(entry_block)
|
||||
hlo.ReturnOp(res)
|
||||
hlo.return_(reducer_body(entry_block))
|
||||
|
||||
rw = custom_call(
|
||||
"stablehlo.dynamic_reduce_window",
|
||||
@ -2641,8 +2634,7 @@ def reduce_window(
|
||||
shape=(len(padding), 2)))
|
||||
reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types))
|
||||
with ir.InsertionPoint(reducer):
|
||||
res = reducer_body(reducer)
|
||||
hlo.ReturnOp(res)
|
||||
hlo.return_(reducer_body(reducer))
|
||||
return rw.results
|
||||
|
||||
|
||||
|
@ -1269,8 +1269,7 @@ def _unravel_index_hlo(axis_env):
|
||||
div = mlir.ir_constant(
|
||||
np.array(axis_env.nreps // math.prod(axis_env.sizes), np.uint32))
|
||||
mod = mlir.ir_constant(np.array(axis_env.sizes[-1], np.uint32))
|
||||
return hlo.RemOp(
|
||||
hlo.DivOp(hlo.ReplicaIdOp().result, div).result, mod).result
|
||||
return hlo.remainder(hlo.divide(hlo.replica_id(), div), mod)
|
||||
|
||||
def _hlo_shard(aval, axis_env, xs, in_axis):
|
||||
if aval is core.abstract_token:
|
||||
@ -1283,10 +1282,10 @@ def _hlo_shard(aval, axis_env, xs, in_axis):
|
||||
idxs.insert(in_axis, _unravel_index_hlo(axis_env))
|
||||
dims_unsqueezed = dims.copy()
|
||||
dims_unsqueezed.insert(in_axis, 1)
|
||||
dynamic_slice_result = hlo.DynamicSliceOp(
|
||||
x, idxs, mlir.dense_int_elements(dims_unsqueezed)).result
|
||||
dynamic_slice_result = hlo.dynamic_slice(
|
||||
x, idxs, mlir.dense_int_elements(dims_unsqueezed))
|
||||
return [
|
||||
hlo.ReshapeOp(mlir.aval_to_ir_type(aval), dynamic_slice_result).result
|
||||
hlo.reshape(mlir.aval_to_ir_type(aval), dynamic_slice_result)
|
||||
]
|
||||
else:
|
||||
raise TypeError(aval)
|
||||
@ -1335,19 +1334,18 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs):
|
||||
padded = mlir.full_like_aval(ctx, 0, padded_aval)
|
||||
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
|
||||
idxs = [_unravel_index_hlo(axis_env)] + [zero] * len(dims)
|
||||
broadcast_result = hlo.BroadcastOp(
|
||||
x, mlir.dense_int_elements([1])).result
|
||||
padded = hlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result
|
||||
broadcast_result = hlo.broadcast(x, mlir.dense_int_elements([1]))
|
||||
padded = hlo.dynamic_update_slice(padded, broadcast_result, idxs)
|
||||
replica_groups = mlir.dense_int_elements(
|
||||
axis_groups(axis_env, axis_env.names[-1]))
|
||||
out = hlo.CrossReplicaSumOp(padded, replica_groups).result
|
||||
out = hlo.cross_replica_sum(padded, replica_groups)
|
||||
if out_axis != 0:
|
||||
# TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead
|
||||
perm = list(range(1, len(dims)))
|
||||
perm.insert(out_axis, 0)
|
||||
transposed_dims = list(dims)
|
||||
transposed_dims.insert(out_axis, axis_env.sizes[-1])
|
||||
out = hlo.TransposeOp(out, mlir.dense_int_elements(perm)).result
|
||||
out = hlo.transpose(out, mlir.dense_int_elements(perm))
|
||||
|
||||
return out
|
||||
else:
|
||||
|
@ -295,8 +295,8 @@ def _comparator_builder_mlir(ctx, op_type, is_max_k):
|
||||
with ir.InsertionPoint(entry_block):
|
||||
p0, p1, _, _ = entry_block.arguments
|
||||
direction = hlo.ComparisonDirectionAttr.get('GT' if is_max_k else 'LT')
|
||||
cmp_result = hlo.CompareOp(p0, p1, comparison_direction=direction)
|
||||
hlo.ReturnOp(cmp_result)
|
||||
cmp_result = hlo.compare(p0, p1, comparison_direction=direction)
|
||||
hlo.return_([cmp_result])
|
||||
|
||||
return comparator
|
||||
|
||||
@ -321,7 +321,7 @@ def _approx_top_k_lowering(ctx, operand, *, k,
|
||||
iota = mlir.iota(ctx, core.ShapedArray(ctx.avals_in[0].shape, np.int32),
|
||||
dimension=reduction_dimension)
|
||||
|
||||
init_arg = hlo.ConstantOp(ir.DenseElementsAttr.get(np.int32(-1))).result
|
||||
init_arg = hlo.constant(ir.DenseElementsAttr.get(np.int32(-1)))
|
||||
init_val_array = _get_init_val_literal(ctx.avals_in[0].dtype, is_max_k)
|
||||
init_val = mlir.ir_constant(init_val_array.reshape(()))
|
||||
|
||||
|
@ -860,7 +860,7 @@ def _cond_lowering(ctx, index, *args, branches, linear):
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
|
||||
out_vals = [*out_tokens, *out_vals]
|
||||
hlo.ReturnOp(util.flatten(out_vals))
|
||||
hlo.return_(util.flatten(out_vals))
|
||||
|
||||
tokens_and_outputs = util.unflatten(case_op.results, map(len, output_types))
|
||||
tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
|
||||
|
@ -1656,7 +1656,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
pred_ctx,
|
||||
pred,
|
||||
axes=tuple(range(len(pred_aval.shape))))
|
||||
hlo.ReturnOp([pred])
|
||||
hlo.return_([pred])
|
||||
|
||||
# Loop body
|
||||
body_block = while_op.regions[1].blocks.append(*flat_loop_carry_types)
|
||||
@ -1687,8 +1687,8 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z,
|
||||
body_jaxpr.out_avals)
|
||||
|
||||
hlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x),
|
||||
*util.flatten(y), *util.flatten(new_z)])
|
||||
hlo.return_([*util.flatten(out_tokens), *util.flatten(x), *util.flatten(y),
|
||||
*util.flatten(new_z)])
|
||||
|
||||
outputs = util.unflatten(while_op.results, _map(len, loop_carry_types))
|
||||
tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts])
|
||||
|
@ -707,7 +707,7 @@ def _conv_general_dilated_lower(
|
||||
raise NotImplementedError("Convolutions with non-static strides, dilation, feature_group_count, or batch_group_count")
|
||||
if all(core.is_constant_shape(p) for p in padding):
|
||||
return [
|
||||
hlo.ConvolutionOp(
|
||||
hlo.convolution(
|
||||
mlir.aval_to_ir_type(aval_out),
|
||||
lhs,
|
||||
rhs,
|
||||
@ -719,7 +719,7 @@ def _conv_general_dilated_lower(
|
||||
lhs_dilation=mlir.dense_int_elements(lhs_dilation),
|
||||
rhs_dilation=mlir.dense_int_elements(rhs_dilation),
|
||||
window_reversal=window_reversal,
|
||||
precision_config=lax.precision_attr(precision)).result
|
||||
precision_config=lax.precision_attr(precision))
|
||||
]
|
||||
else:
|
||||
# d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each
|
||||
@ -731,7 +731,7 @@ def _conv_general_dilated_lower(
|
||||
d_padding = hlo.ConcatenateOp(list(map(prep_one_pad, padding)),
|
||||
mlir.i64_attr(0))
|
||||
return [
|
||||
hlo.DynamicConvOp(
|
||||
hlo.dynamic_conv(
|
||||
mlir.aval_to_ir_type(aval_out),
|
||||
lhs,
|
||||
rhs,
|
||||
@ -743,7 +743,7 @@ def _conv_general_dilated_lower(
|
||||
lhs_dilation=mlir.dense_int_elements(lhs_dilation),
|
||||
rhs_dilation=mlir.dense_int_elements(rhs_dilation),
|
||||
window_reversal=window_reversal,
|
||||
precision_config=lax.precision_attr(precision)).result
|
||||
precision_config=lax.precision_attr(precision))
|
||||
]
|
||||
|
||||
mlir.register_lowering(conv_general_dilated_p, _conv_general_dilated_lower)
|
||||
|
@ -1680,19 +1680,19 @@ def broadcast_hlo(
|
||||
dims = mlir.dense_int_elements(
|
||||
range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape)))
|
||||
if any(isinstance(d, ir.Value) for d in aval_out.shape):
|
||||
arg = hlo.DynamicBroadcastInDimOp(
|
||||
arg = hlo.dynamic_broadcast_in_dim(
|
||||
mlir.aval_to_ir_type(aval_out), arg,
|
||||
mlir.shape_tensor(aval_out.shape), dims).result
|
||||
mlir.shape_tensor(aval_out.shape), dims)
|
||||
else:
|
||||
arg = hlo.BroadcastInDimOp(
|
||||
arg = hlo.broadcast_in_dim(
|
||||
mlir.aval_to_ir_type(aval.update(shape=aval_out.shape)), arg,
|
||||
dims).result
|
||||
dims)
|
||||
out.append(arg)
|
||||
return out
|
||||
|
||||
def _nary_lower_hlo(op: Callable, ctx,
|
||||
*args: Union[ir.Value, Sequence[ir.Value]],
|
||||
explicit_type=False, **params):
|
||||
explicit_type=False, **params) -> Sequence[ir.Value]:
|
||||
"""Lowers an elementwise operator to its MLIR equivalent.
|
||||
|
||||
Args:
|
||||
@ -1704,9 +1704,9 @@ def _nary_lower_hlo(op: Callable, ctx,
|
||||
ctx, args, avals_in, aval_out.shape)
|
||||
|
||||
if explicit_type:
|
||||
return op(mlir.aval_to_ir_type(aval_out), *broadcasted_args).results
|
||||
return [op(mlir.aval_to_ir_type(aval_out), *broadcasted_args)]
|
||||
else:
|
||||
return op(*broadcasted_args).results
|
||||
return [op(*broadcasted_args)]
|
||||
|
||||
|
||||
_float = {np.floating}
|
||||
@ -1723,7 +1723,7 @@ _ordered = _int | _float | _bool
|
||||
|
||||
neg_p = standard_unop(_num, 'neg')
|
||||
ad.deflinear2(neg_p, lambda t, operand: [neg(t)])
|
||||
mlir.register_lowering(neg_p, partial(_nary_lower_hlo, hlo.NegOp))
|
||||
mlir.register_lowering(neg_p, partial(_nary_lower_hlo, hlo.negate))
|
||||
|
||||
sign_p = standard_unop(_num, 'sign')
|
||||
ad.defjvp_zero(sign_p)
|
||||
@ -1731,44 +1731,44 @@ ad.defjvp_zero(sign_p)
|
||||
def _sign_lower_hlo(ctx, x):
|
||||
x_aval, = ctx.avals_in
|
||||
if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger):
|
||||
return hlo.SelectOp(
|
||||
return [hlo.select(
|
||||
mlir.compare_hlo(x, mlir.full_like_aval(ctx, 0, x_aval), 'EQ',
|
||||
'UNSIGNED').result,
|
||||
'UNSIGNED'),
|
||||
mlir.full_like_aval(ctx, 0, x_aval),
|
||||
mlir.full_like_aval(ctx, 1, x_aval)).results
|
||||
return hlo.SignOp(x).results
|
||||
mlir.full_like_aval(ctx, 1, x_aval))]
|
||||
return [hlo.sign(x)]
|
||||
|
||||
mlir.register_lowering(sign_p, _sign_lower_hlo)
|
||||
|
||||
nextafter_p = standard_naryop([_float, _float], 'nextafter')
|
||||
mlir.register_lowering(nextafter_p, partial(_nary_lower_hlo, chlo.NextAfterOp))
|
||||
mlir.register_lowering(nextafter_p, partial(_nary_lower_hlo, chlo.next_after))
|
||||
|
||||
floor_p = standard_unop(_float, 'floor')
|
||||
ad.defjvp_zero(floor_p)
|
||||
mlir.register_lowering(floor_p, partial(_nary_lower_hlo, hlo.FloorOp))
|
||||
mlir.register_lowering(floor_p, partial(_nary_lower_hlo, hlo.floor))
|
||||
|
||||
ceil_p = standard_unop(_float, 'ceil')
|
||||
ad.defjvp_zero(ceil_p)
|
||||
mlir.register_lowering(ceil_p, partial(_nary_lower_hlo, hlo.CeilOp))
|
||||
mlir.register_lowering(ceil_p, partial(_nary_lower_hlo, hlo.ceil))
|
||||
|
||||
round_p = standard_unop(_float, 'round')
|
||||
ad.defjvp_zero(round_p)
|
||||
|
||||
def _round_lower(ctx, x, *, rounding_method):
|
||||
if rounding_method is RoundingMethod.AWAY_FROM_ZERO:
|
||||
return hlo.RoundOp(x).results
|
||||
return [hlo.round_nearest_afz(x)]
|
||||
else:
|
||||
assert rounding_method is RoundingMethod.TO_NEAREST_EVEN
|
||||
return hlo.RoundNearestEvenOp(x).results
|
||||
return [hlo.round_nearest_even(x)]
|
||||
mlir.register_lowering(round_p, _round_lower)
|
||||
|
||||
is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite')
|
||||
ad.defjvp_zero(is_finite_p)
|
||||
mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.IsFiniteOp))
|
||||
mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite))
|
||||
|
||||
exp_p = standard_unop(_float | _complex, 'exp')
|
||||
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
|
||||
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.ExpOp))
|
||||
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential))
|
||||
|
||||
exp2_p = standard_unop(_float | _complex, 'exp2')
|
||||
ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans)))
|
||||
@ -1776,30 +1776,31 @@ def _exp2_lower(ctx, x):
|
||||
x_aval, = ctx.avals_in
|
||||
log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype))
|
||||
log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=())
|
||||
return hlo.ExpOp(hlo.MulOp(log2, x).result).results
|
||||
return [hlo.exponential(hlo.multiply(log2, x))]
|
||||
mlir.register_lowering(exp2_p, _exp2_lower)
|
||||
|
||||
log_p = standard_unop(_float | _complex, 'log')
|
||||
ad.defjvp(log_p, lambda g, x: div(g, x))
|
||||
mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.LogOp))
|
||||
mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.log))
|
||||
|
||||
expm1_p = standard_unop(_float | _complex, 'expm1')
|
||||
ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans))))
|
||||
mlir.register_lowering(expm1_p, partial(_nary_lower_hlo, hlo.Expm1Op))
|
||||
mlir.register_lowering(expm1_p,
|
||||
partial(_nary_lower_hlo, hlo.exponential_minus_one))
|
||||
|
||||
log1p_p = standard_unop(_float | _complex, 'log1p')
|
||||
ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x))))
|
||||
mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.Log1pOp))
|
||||
mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.log_plus_one))
|
||||
|
||||
tanh_p = standard_unop(_float | _complex, 'tanh')
|
||||
ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)),
|
||||
sub(_one(x), ans)))
|
||||
mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.TanhOp))
|
||||
mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.tanh))
|
||||
|
||||
logistic_p = standard_unop(_float | _complex, 'logistic')
|
||||
ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans))))
|
||||
# TODO(phawkins): switch to LogisticOp lowering; debug numerical problems.
|
||||
# mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.LogisticOp))
|
||||
# mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.logistic))
|
||||
|
||||
def logistic_impl(x):
|
||||
one = _const(x, 1)
|
||||
@ -1810,11 +1811,11 @@ mlir.register_lowering(logistic_p,
|
||||
|
||||
sin_p = standard_unop(_float | _complex, 'sin')
|
||||
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
|
||||
mlir.register_lowering(sin_p, partial(_nary_lower_hlo, hlo.SineOp))
|
||||
mlir.register_lowering(sin_p, partial(_nary_lower_hlo, hlo.sine))
|
||||
|
||||
cos_p = standard_unop(_float | _complex, 'cos')
|
||||
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
|
||||
mlir.register_lowering(cos_p, partial(_nary_lower_hlo, hlo.CosineOp))
|
||||
mlir.register_lowering(cos_p, partial(_nary_lower_hlo, hlo.cosine))
|
||||
|
||||
@_upcast_fp16_for_computation
|
||||
def _tan_impl(x):
|
||||
@ -1822,7 +1823,7 @@ def _tan_impl(x):
|
||||
|
||||
tan_p = standard_unop(_float | _complex, 'tan')
|
||||
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
|
||||
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.TanOp))
|
||||
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan))
|
||||
|
||||
def asin_impl(x):
|
||||
if dtypes.issubdtype(_dtype(x), np.complexfloating):
|
||||
@ -1833,7 +1834,7 @@ def asin_impl(x):
|
||||
|
||||
asin_p = standard_unop(_float | _complex, 'asin')
|
||||
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x))))
|
||||
mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.AsinOp))
|
||||
mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.asin))
|
||||
|
||||
def acos_impl(x):
|
||||
if dtypes.issubdtype(_dtype(x), np.complexfloating):
|
||||
@ -1862,43 +1863,43 @@ def atan_impl(x):
|
||||
|
||||
atan_p = standard_unop(_float | _complex, 'atan')
|
||||
ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x)))
|
||||
mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.AtanOp))
|
||||
mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.atan))
|
||||
|
||||
atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2')
|
||||
ad.defjvp(atan2_p,
|
||||
lambda g, x, y: g * (y / (square(x) + square(y))),
|
||||
lambda g, x, y: g * -x / (square(x) + square(y)))
|
||||
mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.Atan2Op))
|
||||
mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.atan2))
|
||||
|
||||
sinh_p = standard_unop(_float | _complex, 'sinh')
|
||||
ad.defjvp(sinh_p, lambda g, x: mul(g, cosh(x)))
|
||||
mlir.register_lowering(sinh_p, partial(_nary_lower_hlo, chlo.SinhOp))
|
||||
mlir.register_lowering(sinh_p, partial(_nary_lower_hlo, chlo.sinh))
|
||||
|
||||
cosh_p = standard_unop(_float | _complex, 'cosh')
|
||||
ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
|
||||
mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.CoshOp))
|
||||
mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.cosh))
|
||||
|
||||
asinh_p = standard_unop(_float | _complex, 'asinh')
|
||||
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x))))
|
||||
mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.AsinhOp))
|
||||
mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.asinh))
|
||||
|
||||
acosh_p = standard_unop(_float | _complex, 'acosh')
|
||||
ad.defjvp(acosh_p,
|
||||
lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x)))))
|
||||
mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.AcoshOp))
|
||||
mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.acosh))
|
||||
|
||||
atanh_p = standard_unop(_float | _complex, 'atanh')
|
||||
ad.defjvp(atanh_p,
|
||||
lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x))))
|
||||
mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.AtanhOp))
|
||||
mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.atanh))
|
||||
|
||||
real_p = unop(_complex_basetype, _complex, 'real')
|
||||
ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))])
|
||||
mlir.register_lowering(real_p, partial(_nary_lower_hlo, hlo.RealOp))
|
||||
mlir.register_lowering(real_p, partial(_nary_lower_hlo, hlo.real))
|
||||
|
||||
imag_p = unop(_complex_basetype, _complex, 'imag')
|
||||
ad.deflinear2(imag_p, lambda t, _: [complex(np.zeros((), _dtype(t)), neg(t))])
|
||||
mlir.register_lowering(imag_p, partial(_nary_lower_hlo, hlo.ImagOp))
|
||||
mlir.register_lowering(imag_p, partial(_nary_lower_hlo, hlo.imag))
|
||||
|
||||
|
||||
def _complex_transpose_rule(t, x, y):
|
||||
@ -1923,7 +1924,7 @@ _complex_dtype = lambda dtype, *args: (np.zeros((), dtype) + np.zeros((), np.com
|
||||
complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types],
|
||||
'complex')
|
||||
ad.deflinear2(complex_p, _complex_transpose_rule)
|
||||
mlir.register_lowering(complex_p, partial(_nary_lower_hlo, hlo.ComplexOp))
|
||||
mlir.register_lowering(complex_p, partial(_nary_lower_hlo, hlo.complex))
|
||||
|
||||
conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj')
|
||||
|
||||
@ -1950,7 +1951,7 @@ ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p)
|
||||
ad.primitive_transposes[conj_p] = _conj_transpose_rule
|
||||
|
||||
abs_p = unop(_complex_basetype, _signedint | _float | _complex, 'abs')
|
||||
mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.AbsOp))
|
||||
mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.abs))
|
||||
|
||||
def _abs_jvp_rule(g, ans, x):
|
||||
if _iscomplex(x):
|
||||
@ -1964,18 +1965,18 @@ _maybe_real = lambda x: real(x) if _iscomplex(x) else x
|
||||
|
||||
sqrt_p = standard_unop(_float | _complex, 'sqrt')
|
||||
ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans)))
|
||||
mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.SqrtOp))
|
||||
mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.sqrt))
|
||||
|
||||
rsqrt_p = standard_unop(_float | _complex, 'rsqrt')
|
||||
ad.defjvp2(rsqrt_p,
|
||||
lambda g, ans, x:
|
||||
mul(g, mul(_const(x, -0.5), div(ans, x))))
|
||||
mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.RsqrtOp))
|
||||
mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.rsqrt))
|
||||
|
||||
cbrt_p = standard_unop(_float, 'cbrt')
|
||||
ad.defjvp2(cbrt_p,
|
||||
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))
|
||||
mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.CbrtOp))
|
||||
mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt))
|
||||
|
||||
def _pow_dtype_rule(x, y):
|
||||
if (dtypes.issubdtype(x.dtype, np.inexact) and
|
||||
@ -2020,7 +2021,7 @@ def _pow_lower(ctx, x, y):
|
||||
[(x_,)] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x)
|
||||
[(y_,)] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y)
|
||||
ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_])
|
||||
return _nary_lower_hlo(hlo.PowOp, ctx_, x_, y_)
|
||||
return _nary_lower_hlo(hlo.power, ctx_, x_, y_)
|
||||
mlir.register_lowering(pow_p, _pow_lower)
|
||||
|
||||
|
||||
@ -2075,26 +2076,25 @@ _replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
|
||||
|
||||
not_p = standard_unop(_bool_or_int, 'not')
|
||||
ad.defjvp_zero(not_p)
|
||||
mlir.register_lowering(not_p, partial(_nary_lower_hlo, hlo.NotOp))
|
||||
mlir.register_lowering(not_p, partial(_nary_lower_hlo, hlo.not_))
|
||||
|
||||
and_p = standard_naryop([_bool_or_int, _bool_or_int], 'and')
|
||||
ad.defjvp_zero(and_p)
|
||||
mlir.register_lowering(and_p, partial(_nary_lower_hlo, hlo.AndOp))
|
||||
mlir.register_lowering(and_p, partial(_nary_lower_hlo, hlo.and_))
|
||||
|
||||
or_p = standard_naryop([_bool_or_int, _bool_or_int], 'or')
|
||||
ad.defjvp_zero(or_p)
|
||||
mlir.register_lowering(or_p, partial(_nary_lower_hlo, hlo.OrOp))
|
||||
mlir.register_lowering(or_p, partial(_nary_lower_hlo, hlo.or_))
|
||||
|
||||
xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor')
|
||||
ad.defjvp_zero(xor_p)
|
||||
mlir.register_lowering(xor_p, partial(_nary_lower_hlo, hlo.XorOp))
|
||||
mlir.register_lowering(xor_p, partial(_nary_lower_hlo, hlo.xor))
|
||||
|
||||
population_count_p = standard_unop(_int, 'population_count')
|
||||
mlir.register_lowering(population_count_p,
|
||||
partial(_nary_lower_hlo, hlo.PopulationCountOp))
|
||||
mlir.register_lowering(population_count_p, partial(_nary_lower_hlo, hlo.popcnt))
|
||||
|
||||
clz_p = standard_unop(_int, 'clz')
|
||||
mlir.register_lowering(clz_p, partial(_nary_lower_hlo, hlo.ClzOp))
|
||||
mlir.register_lowering(clz_p, partial(_nary_lower_hlo, hlo.count_leading_zeros))
|
||||
|
||||
def _add_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
@ -2130,7 +2130,7 @@ def _add_inverse(r, x, y):
|
||||
add_p: Primitive = standard_naryop([_num, _num], 'add')
|
||||
ad.primitive_jvps[add_p] = _add_jvp
|
||||
ad.primitive_transposes[add_p] = _add_transpose
|
||||
mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.AddOp))
|
||||
mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.add))
|
||||
|
||||
def _sub_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
@ -2159,7 +2159,7 @@ def _sub_transpose(t, x, y):
|
||||
sub_p = standard_naryop([_num, _num], 'sub')
|
||||
ad.primitive_jvps[sub_p] = _sub_jvp
|
||||
ad.primitive_transposes[sub_p] = _sub_transpose
|
||||
mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.SubtractOp))
|
||||
mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.subtract))
|
||||
|
||||
|
||||
def _mul_transpose(ct, x, y):
|
||||
@ -2185,7 +2185,7 @@ ad.defjvp(mul_p,
|
||||
lambda xdot, x, y: mul(xdot, y),
|
||||
lambda ydot, x, y: mul(x, ydot))
|
||||
ad.primitive_transposes[mul_p] = _mul_transpose
|
||||
mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.MulOp))
|
||||
mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.multiply))
|
||||
|
||||
def _div_transpose_rule(cotangent, x, y):
|
||||
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
|
||||
@ -2198,14 +2198,14 @@ ad.defjvp(div_p,
|
||||
lambda g, x, y: div(g, y),
|
||||
lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2)))
|
||||
ad.primitive_transposes[div_p] = _div_transpose_rule
|
||||
mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.DivOp))
|
||||
mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.divide))
|
||||
|
||||
rem_p = standard_naryop([_int | _float, _int | _float], 'rem')
|
||||
ad.defjvp(
|
||||
rem_p,
|
||||
lambda g, x, y: _maybe_broadcast(broadcast_shapes(np.shape(x), np.shape(y)), g),
|
||||
lambda g, x, y: mul(neg(g), mul(sign(div(x, y)), floor(abs(div(x, y))))))
|
||||
mlir.register_lowering(rem_p, partial(_nary_lower_hlo, hlo.RemOp))
|
||||
mlir.register_lowering(rem_p, partial(_nary_lower_hlo, hlo.remainder))
|
||||
|
||||
def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x):
|
||||
result_shape = broadcast_shapes(np.shape(x), np.shape(y))
|
||||
@ -2231,17 +2231,17 @@ mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo))
|
||||
|
||||
shift_left_p = standard_naryop([_int, _int], 'shift_left')
|
||||
ad.defjvp_zero(shift_left_p)
|
||||
mlir.register_lowering(shift_left_p, partial(_nary_lower_hlo, hlo.ShiftLeftOp))
|
||||
mlir.register_lowering(shift_left_p, partial(_nary_lower_hlo, hlo.shift_left))
|
||||
|
||||
shift_right_arithmetic_p = standard_naryop([_int, _int], 'shift_right_arithmetic')
|
||||
ad.defjvp_zero(shift_right_arithmetic_p)
|
||||
mlir.register_lowering(shift_right_arithmetic_p,
|
||||
partial(_nary_lower_hlo, hlo.ShiftRightArithmeticOp))
|
||||
partial(_nary_lower_hlo, hlo.shift_right_arithmetic))
|
||||
|
||||
shift_right_logical_p = standard_naryop([_int, _int], 'shift_right_logical')
|
||||
ad.defjvp_zero(shift_right_logical_p)
|
||||
mlir.register_lowering(shift_right_logical_p,
|
||||
partial(_nary_lower_hlo, hlo.ShiftRightLogicalOp))
|
||||
partial(_nary_lower_hlo, hlo.shift_right_logical))
|
||||
|
||||
def _opaque_comparison_hlo(direction, reduction_op, identity, ctx,
|
||||
avals_in, aval_out, x, y):
|
||||
@ -2288,7 +2288,7 @@ def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y):
|
||||
compare_type = "SIGNED"
|
||||
else:
|
||||
compare_type = "UNSIGNED"
|
||||
return mlir.compare_hlo(x, y, direction, compare_type).results
|
||||
return [mlir.compare_hlo(x, y, direction, compare_type)]
|
||||
|
||||
eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq', allow_extended_dtype=True)
|
||||
ad.defjvp_zero(eq_p)
|
||||
@ -2426,7 +2426,7 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type):
|
||||
aval_out, = ctx.avals_out
|
||||
if (dtypes.issubdtype(aval_in.dtype, np.complexfloating) and
|
||||
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
||||
operand = hlo.RealOp(operand).result
|
||||
operand = hlo.real(operand)
|
||||
aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype))
|
||||
return [mlir.convert_hlo(ctx, operand, aval_in, aval_out)]
|
||||
|
||||
@ -2473,7 +2473,7 @@ batching.defvectorized(bitcast_convert_type_p)
|
||||
|
||||
def _bitcast_convert_type_lower(ctx, operand, *, new_dtype):
|
||||
aval_out, = ctx.avals_out
|
||||
return hlo.BitcastConvertOp(mlir.aval_to_ir_type(aval_out), operand).results
|
||||
return [hlo.bitcast_convert(mlir.aval_to_ir_type(aval_out), operand)]
|
||||
|
||||
mlir.register_lowering(bitcast_convert_type_p, _bitcast_convert_type_lower)
|
||||
|
||||
@ -2841,12 +2841,12 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
lhs_contracting_dimensions=list(lhs_contracting),
|
||||
rhs_contracting_dimensions=list(rhs_contracting))
|
||||
return [
|
||||
hlo.DotGeneralOp(
|
||||
hlo.dot_general(
|
||||
mlir.aval_to_ir_type(aval_out),
|
||||
lhs,
|
||||
rhs,
|
||||
dot_dnums,
|
||||
precision_config=precision_attr(precision)).result
|
||||
precision_config=precision_attr(precision))
|
||||
]
|
||||
|
||||
mlir.register_lowering(dot_general_p, _dot_general_lower)
|
||||
@ -3143,8 +3143,7 @@ ad.defjvp(clamp_p,
|
||||
lambda g, min, operand, max:
|
||||
select(lt(max, operand), g, _zeros(operand)))
|
||||
batching.primitive_batchers[clamp_p] = _clamp_batch_rule
|
||||
mlir.register_lowering(
|
||||
clamp_p, partial(_nary_lower_hlo, hlo.ClampOp))
|
||||
mlir.register_lowering(clamp_p, partial(_nary_lower_hlo, hlo.clamp))
|
||||
pe.def_trivial_padding(clamp_p)
|
||||
|
||||
def _concatenate_shape_rule(*operands, **kwargs):
|
||||
@ -3222,7 +3221,7 @@ batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule
|
||||
pe.padding_rules[concatenate_p] = _concatenate_pad_rule
|
||||
|
||||
def _concatenate_lower(ctx, *xs, dimension):
|
||||
return hlo.ConcatenateOp(xs, mlir.i64_attr(dimension)).results
|
||||
return [hlo.concatenate(xs, mlir.i64_attr(dimension))]
|
||||
mlir.register_lowering(concatenate_p, _concatenate_lower)
|
||||
|
||||
|
||||
@ -3424,7 +3423,7 @@ def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions):
|
||||
def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions):
|
||||
aval_out, = ctx.avals_out
|
||||
if dimensions is not None:
|
||||
x = hlo.TransposeOp(x, mlir.dense_int_elements(dimensions)).result
|
||||
x = hlo.transpose(x, mlir.dense_int_elements(dimensions))
|
||||
if dyn_shape:
|
||||
aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
|
||||
return [mlir.reshape(ctx, x, aval_out)]
|
||||
@ -3468,7 +3467,7 @@ ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
|
||||
batching.primitive_batchers[rev_p] = _rev_batch_rule
|
||||
|
||||
def _rev_lower(ctx, x, *, dimensions):
|
||||
return hlo.ReverseOp(x, mlir.dense_int_elements(dimensions)).results
|
||||
return [hlo.reverse(x, mlir.dense_int_elements(dimensions))]
|
||||
mlir.register_lowering(rev_p, _rev_lower)
|
||||
|
||||
|
||||
@ -3500,7 +3499,7 @@ def _transpose_lower(ctx, x, *, permutation):
|
||||
aval_out.dtype).shape
|
||||
trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))]
|
||||
permutation = [*permutation, *trailing_dims]
|
||||
return hlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results
|
||||
return [hlo.transpose(x, mlir.dense_int_elements(permutation))]
|
||||
|
||||
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
|
||||
'transpose')
|
||||
@ -3632,7 +3631,7 @@ def _select_hlo_lowering(ctx, which, *cases):
|
||||
if which_aval.dtype == np.dtype(np.bool_):
|
||||
assert len(cases) <= 2
|
||||
if len(cases) == 1: return cases
|
||||
return hlo.SelectOp(which, cases[1], cases[0]).results
|
||||
return [hlo.select(which, cases[1], cases[0])]
|
||||
|
||||
if dtypes.issubdtype(which_aval.dtype, np.signedinteger):
|
||||
compare_type = 'SIGNED'
|
||||
@ -3648,8 +3647,8 @@ def _select_hlo_lowering(ctx, which, *cases):
|
||||
pred = mlir.compare_hlo(which,
|
||||
mlir.full_like_aval(ctx, offset + mid, which_aval),
|
||||
lt, compare_type)
|
||||
return hlo.SelectOp(pred, _select(offset, cases[:mid]),
|
||||
_select(offset + mid, cases[mid:])).result
|
||||
return hlo.select(pred, _select(offset, cases[:mid]),
|
||||
_select(offset + mid, cases[mid:]))
|
||||
|
||||
return [_select(0, cases)]
|
||||
|
||||
@ -3791,7 +3790,7 @@ def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions):
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(reducer_ctx, jaxpr, mlir.TokenSet(), consts,
|
||||
*([a] for a in reducer.arguments),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
hlo.return_(util.flatten(out_nodes))
|
||||
return op.results
|
||||
|
||||
mlir.register_lowering(reduce_p, _reduce_lower)
|
||||
@ -3982,8 +3981,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
|
||||
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype))
|
||||
reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer_region):
|
||||
add = reducer(*reducer_region.arguments)
|
||||
hlo.ReturnOp(add.results)
|
||||
hlo.return_([reducer(*reducer_region.arguments)])
|
||||
return op.results
|
||||
|
||||
mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp,
|
||||
@ -4021,8 +4019,8 @@ batching.defvectorized(reduce_precision_p)
|
||||
|
||||
def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
|
||||
aval_out, = ctx.avals_out
|
||||
return hlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits),
|
||||
mlir.i32_attr(mantissa_bits)).results
|
||||
return [hlo.reduce_precision(operand, mlir.i32_attr(exponent_bits),
|
||||
mlir.i32_attr(mantissa_bits))]
|
||||
|
||||
mlir.register_lowering(reduce_precision_p, _reduce_precision_lower)
|
||||
|
||||
@ -4163,7 +4161,7 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
|
||||
|
||||
out = lower_comparator(sub_ctx, *[[a] for a in comparator.arguments],
|
||||
num_keys=num_keys)
|
||||
hlo.ReturnOp(util.flatten(out))
|
||||
hlo.return_(util.flatten(out))
|
||||
return sort.results
|
||||
|
||||
mlir.register_lowering(sort_p, _sort_lower)
|
||||
@ -4273,7 +4271,7 @@ create_token_p.def_abstract_eval(lambda *_: abstract_token)
|
||||
|
||||
def _create_token_lowering(ctx, *operands):
|
||||
aval_out, = ctx.avals_out
|
||||
return hlo.CreateTokenOp().results
|
||||
return [hlo.create_token()]
|
||||
mlir.register_lowering(create_token_p, _create_token_lowering)
|
||||
|
||||
|
||||
@ -4295,7 +4293,7 @@ after_all_p.def_abstract_eval(_after_all_abstract_eval)
|
||||
|
||||
def _after_all_lowering(ctx, *operands):
|
||||
aval_out, = ctx.avals_out
|
||||
return hlo.AfterAllOp(operands).results
|
||||
return [hlo.after_all(operands)]
|
||||
mlir.register_lowering(after_all_p, _after_all_lowering)
|
||||
|
||||
|
||||
@ -4434,8 +4432,7 @@ rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
|
||||
def _rng_uniform_lowering(ctx, a, b, *, shape):
|
||||
aval_out, = ctx.avals_out
|
||||
shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64))
|
||||
return hlo.RngOp(a, b, shape,
|
||||
hlo.RngDistributionAttr.get('UNIFORM')).results
|
||||
return [hlo.rng(a, b, shape, hlo.RngDistributionAttr.get('UNIFORM'))]
|
||||
|
||||
mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering)
|
||||
|
||||
@ -4491,9 +4488,9 @@ def _rng_bit_generator_lowering(
|
||||
rbg_etype = u32_type
|
||||
rbg_dtype = np.uint32
|
||||
if key_etype == u32_type:
|
||||
key = hlo.BitcastConvertOp(
|
||||
key = hlo.bitcast_convert(
|
||||
ir.RankedTensorType.get([2], u64_type),
|
||||
hlo.ReshapeOp(ir.RankedTensorType.get([2, 2], u32_type), key)).result
|
||||
hlo.reshape(ir.RankedTensorType.get([2, 2], u32_type), key))
|
||||
algorithm_attr = _rng_algorithm(algorithm)
|
||||
_, out_vals_aval = ctx.avals_out
|
||||
if any(not core.is_constant_shape(a.shape) for a in ctx.avals_out):
|
||||
@ -4511,14 +4508,14 @@ def _rng_bit_generator_lowering(
|
||||
ir.RankedTensorType.get(shape, rbg_etype),
|
||||
algorithm_attr, key).results
|
||||
if key_etype == u32_type:
|
||||
out_key = hlo.ReshapeOp(
|
||||
out_key = hlo.reshape(
|
||||
ir.RankedTensorType.get([4], u32_type),
|
||||
hlo.BitcastConvertOp(
|
||||
ir.RankedTensorType.get([2, 2], u32_type), out_key)).result
|
||||
hlo.bitcast_convert(
|
||||
ir.RankedTensorType.get([2, 2], u32_type), out_key))
|
||||
if rbg_etype != etype:
|
||||
out_vals = hlo.ConvertOp(
|
||||
out_vals = hlo.convert(
|
||||
ir.RankedTensorType.get(ir.RankedTensorType(out_vals.type).shape, etype),
|
||||
out_vals).result
|
||||
out_vals)
|
||||
return [out_key, out_vals]
|
||||
|
||||
|
||||
|
@ -435,7 +435,7 @@ ad.primitive_jvps[cholesky_p] = _cholesky_jvp_rule
|
||||
batching.primitive_batchers[cholesky_p] = _cholesky_batching_rule
|
||||
|
||||
def _cholesky_lowering(ctx, x):
|
||||
return hlo.CholeskyOp(x, lower=ir.BoolAttr.get(True)).results
|
||||
return [hlo.cholesky(x, lower=ir.BoolAttr.get(True))]
|
||||
|
||||
mlir.register_lowering(cholesky_p, _cholesky_lowering)
|
||||
|
||||
@ -621,7 +621,7 @@ def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues):
|
||||
if operand_aval.shape[-1] == 0:
|
||||
reshape_aval = operand_aval.update(shape=operand_aval.shape[:-1])
|
||||
return [
|
||||
hlo.RealOp(mlir.reshape(ctx, operand, reshape_aval)).result,
|
||||
hlo.real(mlir.reshape(ctx, operand, reshape_aval)),
|
||||
operand,
|
||||
]
|
||||
|
||||
@ -985,10 +985,10 @@ def _triangular_solve_lowering(
|
||||
transpose = "NO_TRANSPOSE"
|
||||
else:
|
||||
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
|
||||
return hlo.TriangularSolveOp(
|
||||
return [hlo.triangular_solve(
|
||||
a, b, ir.BoolAttr.get(left_side),
|
||||
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
|
||||
hlo.TransposeAttr.get(transpose)).results
|
||||
hlo.TransposeAttr.get(transpose))]
|
||||
|
||||
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
|
||||
|
||||
@ -998,7 +998,7 @@ def _triangular_solve_cpu_lower(
|
||||
a_aval, b_aval = ctx.avals_in
|
||||
|
||||
if conjugate_a and not transpose_a:
|
||||
a = chlo.ConjOp(a).result
|
||||
a = chlo.conj(a)
|
||||
conjugate_a = False
|
||||
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
|
||||
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
|
||||
@ -1014,10 +1014,10 @@ def _triangular_solve_cpu_lower(
|
||||
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
|
||||
else:
|
||||
transpose = "NO_TRANSPOSE"
|
||||
return hlo.TriangularSolveOp(a, b, ir.BoolAttr.get(left_side),
|
||||
ir.BoolAttr.get(lower),
|
||||
ir.BoolAttr.get(unit_diagonal),
|
||||
hlo.TransposeAttr.get(transpose)).results
|
||||
return [hlo.triangular_solve(a, b, ir.BoolAttr.get(left_side),
|
||||
ir.BoolAttr.get(lower),
|
||||
ir.BoolAttr.get(unit_diagonal),
|
||||
hlo.TransposeAttr.get(transpose))]
|
||||
|
||||
mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
|
||||
platform='cpu')
|
||||
@ -1297,7 +1297,7 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand, *,
|
||||
lu, pivot, info = getrf_impl(
|
||||
operand_aval.dtype, operand, a_shape_vals=op_shape_vals)
|
||||
# Subtract 1 from the pivot to get 0-based indices.
|
||||
pivot = hlo.SubtractOp(pivot, mlir.full_like_aval(ctx, 1, pivot_aval)).result
|
||||
pivot = hlo.subtract(pivot, mlir.full_like_aval(ctx, 1, pivot_aval))
|
||||
ok = mlir.compare_hlo(
|
||||
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
||||
"GE", "SIGNED")
|
||||
@ -2380,4 +2380,4 @@ def _broadcasting_select_hlo(ctx, which, which_aval, x, x_aval, y, y_aval) -> ir
|
||||
which, x, y = mlir.multi_broadcast_in_dim(ctx, (which, x, y),
|
||||
(which_aval, x_aval, y_aval),
|
||||
out_shapes)
|
||||
return hlo.SelectOp(which, x, y).result
|
||||
return hlo.select(which, x, y)
|
||||
|
@ -775,7 +775,7 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
avals_in=[scalar_aval] * 2, avals_out=[scalar_aval])
|
||||
out_nodes = lower_reducer(
|
||||
reducer_ctx, *([a] for a in reducer_block.arguments))
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
hlo.return_(util.flatten(out_nodes))
|
||||
return op.result
|
||||
|
||||
return [all_reduce(aval, x) for aval, x in zip(ctx.avals_in, args)]
|
||||
@ -1191,7 +1191,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
|
||||
new_shape = list(x_aval.shape)
|
||||
new_shape.insert(all_gather_dimension, 1)
|
||||
broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension]
|
||||
x = hlo.BroadcastInDimOp(
|
||||
x = hlo.broadcast_in_dim(
|
||||
mlir.aval_to_ir_type(x_aval.update(shape=new_shape)), x,
|
||||
mlir.dense_int_elements(broadcast_dimensions))
|
||||
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
|
||||
@ -1359,12 +1359,12 @@ def _reduce_scatter_lowering(
|
||||
avals_out=[scalar_aval])
|
||||
out_nodes = lower_reducer(
|
||||
reducer_ctx, *([a] for a in reducer_block.arguments))
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
hlo.return_(util.flatten(out_nodes))
|
||||
|
||||
if tiled:
|
||||
return op.results
|
||||
else:
|
||||
return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), op.result).results
|
||||
return [hlo.reshape(mlir.aval_to_ir_type(aval_out), op.result)]
|
||||
|
||||
def _reduce_scatter_lowering_via_reducer(
|
||||
prim, reducer, ctx, x,
|
||||
@ -1582,13 +1582,13 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
|
||||
(sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext),
|
||||
)
|
||||
if is_spmd:
|
||||
device_id = hlo.PartitionIdOp()
|
||||
device_id = hlo.partition_id()
|
||||
else:
|
||||
device_id = hlo.ReplicaIdOp()
|
||||
unsigned_index = hlo.RemOp(hlo.DivOp(device_id, div), mod)
|
||||
return hlo.ConvertOp(
|
||||
device_id = hlo.replica_id()
|
||||
unsigned_index = hlo.remainder(hlo.divide(device_id, div), mod)
|
||||
return hlo.convert(
|
||||
ir.RankedTensorType.get([], ir.IntegerType.get_signless(32)),
|
||||
unsigned_index).result
|
||||
unsigned_index)
|
||||
|
||||
def _axis_index_lowering(ctx, *, axis_name):
|
||||
return [
|
||||
|
@ -1839,12 +1839,12 @@ def _gather_lower(ctx, operand, indices, *,
|
||||
return hlo.DynamicGatherOp.build_generic(
|
||||
results=results, operands=operands, attributes=attributes).results
|
||||
else:
|
||||
return hlo.GatherOp(
|
||||
return [hlo.gather(
|
||||
operand,
|
||||
indices,
|
||||
dnums,
|
||||
mlir.dense_int_elements(slice_sizes),
|
||||
indices_are_sorted=ir.BoolAttr.get(indices_are_sorted)).results
|
||||
indices_are_sorted=ir.BoolAttr.get(indices_are_sorted))]
|
||||
|
||||
mlir.register_lowering(gather_p, _gather_lower)
|
||||
|
||||
@ -2499,7 +2499,7 @@ def _scatter_lower(ctx, operand, indices, updates, *,
|
||||
update_ctx, update_jaxpr, mlir.TokenSet(), update_consts,
|
||||
(update.arguments[0],), (update.arguments[1],),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
hlo.return_(util.flatten(out_nodes))
|
||||
return op.results
|
||||
|
||||
mlir.register_lowering(scatter_p, _scatter_lower)
|
||||
@ -2555,12 +2555,12 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
|
||||
reducer = scatter.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer):
|
||||
add = hlo.AddOp(*reducer.arguments).result
|
||||
hlo.ReturnOp([add])
|
||||
hlo.return_([add])
|
||||
return scatter.result
|
||||
|
||||
real = _scatter(hlo.RealOp(operand).result, hlo.RealOp(updates).result)
|
||||
imag = _scatter(hlo.ImagOp(operand).result, hlo.ImagOp(updates).result)
|
||||
return hlo.ComplexOp(real, imag).results
|
||||
real = _scatter(hlo.real(operand), hlo.real(updates))
|
||||
imag = _scatter(hlo.imag(operand), hlo.imag(updates))
|
||||
return [hlo.complex(real, imag)]
|
||||
|
||||
mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu")
|
||||
|
||||
|
@ -625,14 +625,14 @@ ad.defjvp(regularized_incomplete_beta_p,
|
||||
|
||||
lgamma_p = standard_unop(_float, 'lgamma')
|
||||
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
|
||||
mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.LgammaOp))
|
||||
mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.lgamma))
|
||||
|
||||
digamma_p = standard_unop(_float, 'digamma')
|
||||
mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.DigammaOp))
|
||||
mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.digamma))
|
||||
ad.defjvp(digamma_p, lambda g, x: mul(g, polygamma(_const(x, 1), x)))
|
||||
|
||||
polygamma_p = standard_naryop([_float, _float], 'polygamma')
|
||||
mlir.register_lowering(polygamma_p, partial(_nary_lower_hlo, chlo.PolygammaOp))
|
||||
mlir.register_lowering(polygamma_p, partial(_nary_lower_hlo, chlo.polygamma))
|
||||
ad.defjvp(polygamma_p, polygamma_gradm, polygamma_gradx)
|
||||
|
||||
igamma_p = standard_naryop([_float, _float], 'igamma')
|
||||
@ -659,7 +659,7 @@ mlir.register_lowering(random_gamma_grad_p,
|
||||
multiple_results=False))
|
||||
|
||||
zeta_p = standard_naryop([_float, _float], 'zeta')
|
||||
mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.ZetaOp))
|
||||
mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.zeta))
|
||||
|
||||
bessel_i0e_p = standard_unop(_float, 'bessel_i0e')
|
||||
mlir.register_lowering(bessel_i0e_p,
|
||||
@ -669,7 +669,7 @@ ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y))
|
||||
|
||||
bessel_i1e_p = standard_unop(_float, 'bessel_i1e')
|
||||
mlir.register_lowering(bessel_i1e_p,
|
||||
partial(_nary_lower_hlo, chlo.BesselI1eOp))
|
||||
partial(_nary_lower_hlo, chlo.bessel_i1e))
|
||||
|
||||
def _bessel_i1e_jvp(g, y, x):
|
||||
eps = dtypes.finfo(_dtype(x)).eps
|
||||
@ -683,14 +683,14 @@ ad.defjvp2(bessel_i1e_p, _bessel_i1e_jvp)
|
||||
erf_p = standard_unop(_float, 'erf')
|
||||
ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
|
||||
mul(g, exp(neg(square(x))))))
|
||||
mlir.register_lowering(erf_p, partial(_nary_lower_hlo, chlo.ErfOp))
|
||||
mlir.register_lowering(erf_p, partial(_nary_lower_hlo, chlo.erf))
|
||||
|
||||
erfc_p = standard_unop(_float, 'erfc')
|
||||
ad.defjvp(erfc_p, lambda g, x: mul(_const(x, -2. / np.sqrt(np.pi)),
|
||||
mul(g, exp(neg(square(x))))))
|
||||
mlir.register_lowering(erfc_p, partial(_nary_lower_hlo, chlo.ErfcOp))
|
||||
mlir.register_lowering(erfc_p, partial(_nary_lower_hlo, chlo.erfc))
|
||||
|
||||
erf_inv_p = standard_unop(_float, 'erf_inv')
|
||||
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),
|
||||
mul(g, exp(square(ans)))))
|
||||
mlir.register_lowering(erf_inv_p, partial(_nary_lower_hlo, chlo.ErfInvOp))
|
||||
mlir.register_lowering(erf_inv_p, partial(_nary_lower_hlo, chlo.erf_inv))
|
||||
|
@ -470,7 +470,7 @@ def _reduce_window_lower(
|
||||
|
||||
return mlir.reduce_window(ctx,
|
||||
reducer_name=f"reduce_window_{scalar_aval.dtype}_reducer",
|
||||
reducer_body=lambda reducer: reduce_op(*reducer.arguments),
|
||||
reducer_body=lambda reducer: [reduce_op(*reducer.arguments)],
|
||||
operands=[operand],
|
||||
init_values=[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype),
|
||||
scalar_aval)],
|
||||
@ -482,7 +482,7 @@ def _reduce_window_lower(
|
||||
|
||||
|
||||
mlir.register_lowering(reduce_window_sum_p, partial(
|
||||
_reduce_window_lower, hlo.AddOp, lambda _: 0))
|
||||
_reduce_window_lower, hlo.add, lambda _: 0))
|
||||
mlir.register_lowering(reduce_window_min_p, partial(
|
||||
_reduce_window_lower, mlir.min_hlo, lax._get_min_identity))
|
||||
mlir.register_lowering(reduce_window_max_p, partial(
|
||||
@ -530,7 +530,7 @@ def _select_and_scatter_lower(
|
||||
mlir.TokenSet(), select_consts,
|
||||
*([a] for a in select.arguments),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
hlo.return_(util.flatten(out_nodes))
|
||||
scatter = op.scatter.blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(scatter):
|
||||
if scatter_jaxpr.effects:
|
||||
@ -539,7 +539,7 @@ def _select_and_scatter_lower(
|
||||
mlir.TokenSet(), scatter_consts,
|
||||
*([a] for a in scatter.arguments),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
hlo.return_(util.flatten(out_nodes))
|
||||
return op.results
|
||||
|
||||
mlir.register_lowering(select_and_scatter_p, _select_and_scatter_lower)
|
||||
@ -685,27 +685,27 @@ def _select_and_gather_add_lowering(
|
||||
def pack(a, b, ab_aval):
|
||||
word_type_ab_aval = ab_aval.update(dtype=word_dtype)
|
||||
double_word_type_ab_aval = ab_aval.update(dtype=double_word_dtype)
|
||||
a = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), a)
|
||||
b = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), b)
|
||||
a = hlo.ConvertOp(mlir.aval_to_ir_type(double_word_type_ab_aval), a)
|
||||
b = hlo.ConvertOp(mlir.aval_to_ir_type(double_word_type_ab_aval), b)
|
||||
a = hlo.ShiftLeftOp(a,
|
||||
_broadcast_scalar_const(nbits, double_word_type_ab_aval))
|
||||
return hlo.OrOp(a, b)
|
||||
a = hlo.bitcast_convert(mlir.aval_to_ir_type(word_type_ab_aval), a)
|
||||
b = hlo.bitcast_convert(mlir.aval_to_ir_type(word_type_ab_aval), b)
|
||||
a = hlo.convert(mlir.aval_to_ir_type(double_word_type_ab_aval), a)
|
||||
b = hlo.convert(mlir.aval_to_ir_type(double_word_type_ab_aval), b)
|
||||
a = hlo.shift_left(
|
||||
a, _broadcast_scalar_const(nbits, double_word_type_ab_aval))
|
||||
return hlo.or_(a, b)
|
||||
|
||||
# Unpacks the first element of a double_word_type.
|
||||
def fst(t):
|
||||
assert not ir.RankedTensorType(t.type).shape
|
||||
st = hlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits))
|
||||
return hlo.BitcastConvertOp(
|
||||
st = hlo.shift_right_logical(t, const(double_word_dtype, nbits))
|
||||
return hlo.bitcast_convert(
|
||||
ir.RankedTensorType.get([], etype),
|
||||
hlo.ConvertOp(ir.RankedTensorType.get([], word_type), st)).result
|
||||
hlo.convert(ir.RankedTensorType.get([], word_type), st))
|
||||
|
||||
# Unpacks the second element of a double_word_type.
|
||||
def snd(t, t_aval):
|
||||
return hlo.BitcastConvertOp(
|
||||
return hlo.bitcast_convert(
|
||||
mlir.aval_to_ir_type(t_aval.update(dtype=dtype)),
|
||||
hlo.ConvertOp(mlir.aval_to_ir_type(t_aval.update(dtype=word_dtype)), t)).result
|
||||
hlo.convert(mlir.aval_to_ir_type(t_aval.update(dtype=word_dtype)), t))
|
||||
|
||||
else:
|
||||
# The double-word trick above only works if we have a sufficiently large
|
||||
@ -726,29 +726,27 @@ def _select_and_gather_add_lowering(
|
||||
# Packs two values into a double_word_type.
|
||||
def pack(a, b, ab_aval):
|
||||
word_type_ab_aval = ab_aval.update(dtype=word_dtype)
|
||||
a = hlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp),
|
||||
a = hlo.reduce_precision(a, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
b = hlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp),
|
||||
b = hlo.reduce_precision(b, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
a = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), a)
|
||||
b = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), b)
|
||||
b = hlo.ShiftRightLogicalOp(
|
||||
a = hlo.bitcast_convert(mlir.aval_to_ir_type(word_type_ab_aval), a)
|
||||
b = hlo.bitcast_convert(mlir.aval_to_ir_type(word_type_ab_aval), b)
|
||||
b = hlo.shift_right_logical(
|
||||
b, _broadcast_scalar_const(r_nbits, word_type_ab_aval))
|
||||
return hlo.OrOp(a, b)
|
||||
return hlo.or_(a, b)
|
||||
|
||||
# Unpacks the first element of a double_word_type.
|
||||
def fst(t):
|
||||
assert not ir.RankedTensorType(t.type).shape
|
||||
st = hlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits))
|
||||
return hlo.BitcastConvertOp(ir.RankedTensorType.get([], etype),
|
||||
st).result
|
||||
st = hlo.and_(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits))
|
||||
return hlo.bitcast_convert(ir.RankedTensorType.get([], etype), st)
|
||||
|
||||
# Unpacks the second element of a double_word_type.
|
||||
def snd(t, t_aval):
|
||||
return hlo.BitcastConvertOp(
|
||||
return hlo.bitcast_convert(
|
||||
mlir.aval_to_ir_type(t_aval.update(dtype=dtype)),
|
||||
hlo.ShiftLeftOp(t, _broadcast_scalar_const(r_nbits, t_aval.update(dtype=word_dtype)))
|
||||
).result
|
||||
hlo.shift_left(t, _broadcast_scalar_const(r_nbits, t_aval.update(dtype=word_dtype))))
|
||||
|
||||
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
|
||||
init = -np.inf if select_prim is lax.ge_p else np.inf
|
||||
|
@ -1115,9 +1115,9 @@ def _threefry2x32_gpu_lowering(lowering_func, ctx, k1, k2, x1, x2):
|
||||
out_len = reduce(op.mul, aval_out.shape, 1)
|
||||
if not core.is_constant_dim(out_len):
|
||||
length = mlir.eval_dynamic_shape_as_tensor(ctx, [out_len])
|
||||
length = mlir.hlo.ConvertOp(
|
||||
length = mlir.hlo.convert(
|
||||
ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)),
|
||||
length).result
|
||||
length)
|
||||
output_shape = mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape)
|
||||
else:
|
||||
length = int(out_len) # will be passed statically
|
||||
@ -1221,20 +1221,20 @@ def iota_2x32_shape_lowering(ctx, *, shape):
|
||||
aval_u64 = core.ShapedArray(shape, np.dtype('uint64'))
|
||||
|
||||
def _add(x: ir.Value, y: ir.Value) -> ir.Value:
|
||||
return mlir.hlo.AddOp(x, y).result
|
||||
return mlir.hlo.add(x, y)
|
||||
|
||||
def _mul(x: core.DimSize, y: ir.Value) -> ir.Value:
|
||||
if core.is_constant_dim(x):
|
||||
x_const = mlir.ir_constant(np.array(x, np.dtype('uint64')))
|
||||
else:
|
||||
x_const, = mlir.eval_dynamic_shape(ctx, (x,))
|
||||
x_const = hlo.ConvertOp(
|
||||
x_const = hlo.convert(
|
||||
ir.RankedTensorType.get(
|
||||
(),
|
||||
mlir.dtype_to_ir_type(np.dtype('uint64'))), x_const).result
|
||||
mlir.dtype_to_ir_type(np.dtype('uint64'))), x_const)
|
||||
x_bcast = mlir.broadcast_in_dim(ctx, x_const, aval_u64,
|
||||
broadcast_dimensions=[])
|
||||
return mlir.hlo.MulOp(x_bcast, y).result
|
||||
return mlir.hlo.multiply(x_bcast, y)
|
||||
|
||||
assert len(shape) > 0
|
||||
|
||||
@ -1244,10 +1244,9 @@ def iota_2x32_shape_lowering(ctx, *, shape):
|
||||
shift = mlir.ir_constant(np.array(32, np.dtype('uint64')))
|
||||
shift = mlir.broadcast_in_dim(ctx, shift, aval_u64,
|
||||
broadcast_dimensions=[])
|
||||
counts_shifted = mlir.hlo.ShiftRightLogicalOp(counts, shift).result
|
||||
counts_lo = mlir.hlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts).result
|
||||
counts_hi = mlir.hlo.ConvertOp(mlir.aval_to_ir_type(aval_out),
|
||||
counts_shifted).result
|
||||
counts_shifted = mlir.hlo.shift_right_logical(counts, shift)
|
||||
counts_lo = mlir.hlo.convert(mlir.aval_to_ir_type(aval_out), counts)
|
||||
counts_hi = mlir.hlo.convert(mlir.aval_to_ir_type(aval_out), counts_shifted)
|
||||
return counts_hi, counts_lo
|
||||
mlir.register_lowering(iota_2x32_shape_p, iota_2x32_shape_lowering)
|
||||
|
||||
|
@ -214,7 +214,7 @@ def _tpu_custom_call_lowering(
|
||||
base64.b64encode(kernel_regeneration_metadata)
|
||||
)
|
||||
if multiple_results:
|
||||
results = [stablehlo.GetTupleElementOp(call, mlir.i32_attr(i)).result
|
||||
results = [stablehlo.get_tuple_element(call, mlir.i32_attr(i))
|
||||
for i in range(len(out_avals))]
|
||||
else:
|
||||
results = call.results
|
||||
|
@ -720,7 +720,7 @@ def _wrap_main_func(
|
||||
list(new_main_op.arguments[0:nr_platform_index_args]) + util.flatten(dim_values),
|
||||
platform_input_types + dim_var_input_types):
|
||||
if arg.type != arg_type:
|
||||
orig_main_args.append(hlo.ConvertOp(arg_type, arg).result)
|
||||
orig_main_args.append(hlo.convert(arg_type, arg))
|
||||
else:
|
||||
orig_main_args.append(arg)
|
||||
# Then the token arguments
|
||||
@ -737,7 +737,7 @@ def _wrap_main_func(
|
||||
# output_operand_alias attribute. See b/287386268.
|
||||
for arg, arg_type in zip(new_main_op_array_args, array_input_types):
|
||||
if arg.type != arg_type:
|
||||
orig_main_args.append(hlo.ConvertOp(arg_type, arg).result)
|
||||
orig_main_args.append(hlo.convert(arg_type, arg))
|
||||
else:
|
||||
orig_main_args.append(arg)
|
||||
call = func_dialect.CallOp(orig_output_types,
|
||||
@ -1166,7 +1166,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value:
|
||||
new_ir_type = mlir.aval_to_ir_type(new_aval)
|
||||
if x.type != new_ir_type:
|
||||
return hlo.ConvertOp(mlir.aval_to_ir_type(new_aval), x).result
|
||||
return hlo.convert(mlir.aval_to_ir_type(new_aval), x)
|
||||
else:
|
||||
return x
|
||||
|
||||
@ -1210,7 +1210,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
|
||||
for i in range(len(lowering_platforms)):
|
||||
branch = callee_platform_idx.regions[i].blocks.append()
|
||||
with ir.InsertionPoint(branch):
|
||||
hlo.ReturnOp(mlir.ir_constants(
|
||||
hlo.return_(mlir.ir_constants(
|
||||
np.int32(callee_lowering_platform_index[i])))
|
||||
if callee_platform_idx.result.type != callee_type.inputs[0]:
|
||||
callee_platform_idx = hlo.ConvertOp(callee_type.inputs[0],
|
||||
|
@ -897,7 +897,7 @@ def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *,
|
||||
res, = mlir.eval_dynamic_shape(ctx, (dim,))
|
||||
out_type = mlir.aval_to_ir_type(ctx.avals_out[0])
|
||||
if out_type != res.type: # type: ignore
|
||||
return mlir.hlo.ConvertOp(out_type, res).results
|
||||
return [mlir.hlo.convert(out_type, res)]
|
||||
else:
|
||||
return [res]
|
||||
|
||||
@ -1161,11 +1161,11 @@ def _dimension_size_impl(arg, *, dimension):
|
||||
dimension_size_p.def_impl(_dimension_size_impl)
|
||||
|
||||
def _dimension_size_lowering_rule(ctx, arg, *, dimension):
|
||||
dim_size = mlir.hlo.GetDimensionSizeOp(arg, dimension)
|
||||
dim_size = mlir.hlo.get_dimension_size(arg, dimension)
|
||||
dim_type = mlir.aval_to_ir_type(core.dim_value_aval())
|
||||
if dim_size.result.type != dim_type:
|
||||
dim_size = mlir.hlo.ConvertOp(dim_type, dim_size)
|
||||
return dim_size.results
|
||||
if dim_size.type != dim_type:
|
||||
dim_size = mlir.hlo.convert(dim_type, dim_size)
|
||||
return [dim_size]
|
||||
|
||||
mlir.register_lowering(dimension_size_p, _dimension_size_lowering_rule)
|
||||
|
||||
|
@ -570,7 +570,7 @@ def _call_tf_lowering(
|
||||
ir.FlatSymbolRefAttr.get(fn),
|
||||
tuple(args_op) + captured_ops)
|
||||
if result_shape.is_tuple():
|
||||
flat_results = [hlo.GetTupleElementOp(call, mlir.i32_attr(i)).result
|
||||
flat_results = [hlo.get_tuple_element(call, mlir.i32_attr(i))
|
||||
for i in range(len(result_shapes))]
|
||||
else:
|
||||
flat_results = call.results
|
||||
|
@ -681,8 +681,7 @@ def _bcsr_dot_general_gpu_lowering(
|
||||
dot_general_fn = csr_matmat_lowering
|
||||
x_dtype = 'B_dtype'
|
||||
if rhs_contract[0] == 1:
|
||||
rhs = hlo.TransposeOp(
|
||||
rhs, permutation=mlir.dense_int_elements([1, 0])).result
|
||||
rhs = hlo.transpose(rhs, permutation=mlir.dense_int_elements([1, 0]))
|
||||
else:
|
||||
raise ValueError(f"rhs has to be 1d or 2d; get {rhs_aval.ndim}d.")
|
||||
|
||||
|
@ -229,7 +229,7 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo):
|
||||
result = coo_todense_hlo(
|
||||
data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype)
|
||||
return (
|
||||
[hlo.TransposeOp(result, mlir.dense_int_elements([1, 0])).result]
|
||||
[hlo.transpose(result, mlir.dense_int_elements([1, 0]))]
|
||||
if transpose else [result])
|
||||
|
||||
|
||||
|
@ -79,9 +79,9 @@ def dynamic_ducc_fft_hlo(
|
||||
assert 0 not in a_type.shape
|
||||
|
||||
u8_type = ir.IntegerType.get_unsigned(8)
|
||||
descriptor = hlo.ConstantOp(
|
||||
descriptor = hlo.constant(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)).result
|
||||
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
|
||||
layout = tuple(range(ndims - 1, -1, -1))
|
||||
return custom_call(
|
||||
"dynamic_ducc_fft",
|
||||
|
@ -89,9 +89,9 @@ def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *,
|
||||
|
||||
|
||||
def _hlo_zeros_f32(shape):
|
||||
return hlo.ConstantOp(
|
||||
return hlo.constant(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.zeros(shape, dtype=np.float32), type=ir.F32Type.get())).result
|
||||
np.zeros(shape, dtype=np.float32), type=ir.F32Type.get()))
|
||||
|
||||
|
||||
def cudnn_rnn_bwd_lowering(ctx, dy, dhn, dcn, x, h0, c0, w, y,
|
||||
|
@ -320,7 +320,7 @@ def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, *,
|
||||
if dynamic_batch_dims:
|
||||
batch_size_val = hlo_s32(1)
|
||||
for b_v in batch_dims_vals:
|
||||
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
|
||||
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
|
||||
operands.append(batch_size_val)
|
||||
operand_layouts.append(())
|
||||
|
||||
@ -406,22 +406,22 @@ def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={0: 0}).results
|
||||
vt = hlo.TransposeOp(
|
||||
vt = hlo.transpose(
|
||||
v,
|
||||
ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result
|
||||
ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd))))
|
||||
if np.issubdtype(dtype, np.complexfloating):
|
||||
vt = hlo.ComplexOp(hlo.RealOp(vt), hlo.NegOp(hlo.ImagOp(vt))).result
|
||||
vt = hlo.complex(hlo.real(vt), hlo.negate(hlo.imag(vt)))
|
||||
if not full_matrices and not econ:
|
||||
u = hlo.SliceOp(
|
||||
u = hlo.slice(
|
||||
u,
|
||||
ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)),
|
||||
ir.DenseIntElementsAttr.get(np.array(batch_dims + (m, min(m, n)))),
|
||||
ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result
|
||||
vt = hlo.SliceOp(
|
||||
ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64)))
|
||||
vt = hlo.slice(
|
||||
vt,
|
||||
ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)),
|
||||
ir.DenseIntElementsAttr.get(np.array(batch_dims + (min(m, n), n))),
|
||||
ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result
|
||||
ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64)))
|
||||
elif m < n:
|
||||
lwork, opaque = gpu_solver.build_gesvd_descriptor(
|
||||
np.dtype(dtype), b, n, m, compute_uv, full_matrices)
|
||||
@ -538,15 +538,15 @@ def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower):
|
||||
if not lower and platform == "cu" and m > 1:
|
||||
start = (0,) * len(batch_dims) + (0,)
|
||||
end = batch_dims + (1,)
|
||||
s = hlo.SliceOp(e, intattr(start), intattr(end), intattr([1] * len(start)))
|
||||
s = hlo.slice(e, intattr(start), intattr(end), intattr([1] * len(start)))
|
||||
s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type)
|
||||
s = hlo.BroadcastInDimOp(s_type, s, intattr(range(len(dims) - 1)))
|
||||
s = hlo.broadcast_in_dim(s_type, s, intattr(range(len(dims) - 1)))
|
||||
# The diagonals are always real; convert to complex if needed.
|
||||
s = hlo.ConvertOp(
|
||||
s = hlo.convert(
|
||||
ir.RankedTensorType.get(s_type.shape, a_type.element_type), s)
|
||||
offsets = tuple(hlo.ConstantOp(intattr(i))
|
||||
offsets = tuple(hlo.constant(intattr(i))
|
||||
for i in ((0,) * len(batch_dims) + (0, 1)))
|
||||
a = hlo.DynamicUpdateSliceOp(a, s, offsets).result
|
||||
a = hlo.dynamic_update_slice(a, s, offsets)
|
||||
|
||||
return a, d, e, taus, info
|
||||
|
||||
|
@ -84,21 +84,21 @@ def shape_tensor(sizes: Sequence[Union[int, ir.Value]]) -> ir.Value:
|
||||
return hlo_const(np.array([d], dtype=np.int32))
|
||||
else:
|
||||
if d.type != i32_type:
|
||||
d = hlo.ConvertOp(i32_type, d).result
|
||||
return hlo.ReshapeOp(int1d, d).result
|
||||
d = hlo.convert(i32_type, d)
|
||||
return hlo.reshape(int1d, d)
|
||||
ds = [dim_to_i32x1(sz) for sz in sizes]
|
||||
if not ds:
|
||||
return hlo_const(np.array([], np.int32))
|
||||
elif len(ds) == 1:
|
||||
return ds[0]
|
||||
else:
|
||||
return hlo.ConcatenateOp(
|
||||
ds, ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0)).result
|
||||
return hlo.concatenate(
|
||||
ds, ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0))
|
||||
|
||||
def hlo_const(x: np.ndarray) -> ir.Value:
|
||||
assert isinstance(x, np.ndarray)
|
||||
return hlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype))).result
|
||||
return hlo.constant(
|
||||
ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype)))
|
||||
|
||||
def hlo_u8(x: int):
|
||||
return hlo_const(np.array(x, dtype=np.uint8))
|
||||
@ -116,7 +116,7 @@ def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize:
|
||||
x = hlo_s32(x)
|
||||
if type(y) is int:
|
||||
y = hlo_s32(y)
|
||||
return hlo.MinOp(x, y).result
|
||||
return hlo.minimum(x, y)
|
||||
|
||||
|
||||
def hlo_add(x: DimensionSize, y: DimensionSize) -> DimensionSize:
|
||||
@ -126,7 +126,7 @@ def hlo_add(x: DimensionSize, y: DimensionSize) -> DimensionSize:
|
||||
x = hlo_s32(x)
|
||||
if type(y) is int:
|
||||
y = hlo_s32(y)
|
||||
return hlo.AddOp(x, y).result
|
||||
return hlo.add(x, y)
|
||||
|
||||
|
||||
# TODO(necula): this is identical with mlir.custom_call, but meant for use
|
||||
|
@ -51,7 +51,7 @@ def trsm_hlo(dtype, alpha, a, b,
|
||||
num_bd = len(batch_dims_vals)
|
||||
batch_size_val = hlo_s32(1)
|
||||
for b_v in batch_dims_vals:
|
||||
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
|
||||
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
|
||||
|
||||
if dtype == np.float32:
|
||||
fn = "blas_strsm"
|
||||
@ -119,7 +119,7 @@ def getrf_hlo(dtype, a: ir.Value, *,
|
||||
|
||||
batch_size_val = hlo_s32(1)
|
||||
for b_v in batch_dims_vals:
|
||||
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
|
||||
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
|
||||
|
||||
return custom_call(
|
||||
fn,
|
||||
@ -170,7 +170,7 @@ def geqrf_hlo(dtype, a: ir.Value, *,
|
||||
|
||||
batch_size_val = hlo_s32(1)
|
||||
for b_v in batch_dims_vals:
|
||||
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
|
||||
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
|
||||
shape_type_pairs: Sequence[ShapeTypePair] = [
|
||||
(a_shape_vals, a_type.element_type),
|
||||
(batch_dims_vals + (min(m, n),), a_type.element_type),
|
||||
@ -211,7 +211,7 @@ def orgqr_hlo(dtype, a: ir.Value, tau, *,
|
||||
num_bd = len(batch_dims_vals)
|
||||
batch_size_val = hlo_s32(1)
|
||||
for b_v in batch_dims_vals:
|
||||
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
|
||||
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
|
||||
|
||||
k = tau_shape_vals[-1]
|
||||
assert type(k) is int
|
||||
@ -281,7 +281,7 @@ def potrf_hlo(dtype, a: ir.Value, *, lower=False,
|
||||
num_bd = len(batch_dims_vals)
|
||||
batch_size_val = hlo_s32(1)
|
||||
for b_v in batch_dims_vals:
|
||||
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
|
||||
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
|
||||
|
||||
scalar_layout = []
|
||||
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
|
||||
@ -318,7 +318,7 @@ def gesdd_hlo(dtype, a: ir.Value, *, full_matrices=True, compute_uv=True,
|
||||
num_bd = len(batch_dims_vals)
|
||||
batch_size_val = hlo_s32(1)
|
||||
for b_v in batch_dims_vals:
|
||||
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
|
||||
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
workspace: list[ShapeTypePair]
|
||||
@ -445,7 +445,7 @@ def syevd_hlo(dtype, a: ir.Value,
|
||||
|
||||
batch_size_val = hlo_s32(1)
|
||||
for b_v in batch_dims_vals:
|
||||
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
|
||||
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
|
||||
|
||||
scalar_layout = []
|
||||
shape_layout = [0]
|
||||
@ -540,7 +540,7 @@ def geev_hlo(dtype, input, *,
|
||||
|
||||
batch_size_val = hlo_s32(1)
|
||||
for b_v in batch_dims_vals:
|
||||
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
|
||||
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
|
||||
|
||||
shape_type_pairs: Sequence[ShapeTypePair] = workspaces + eigvals + [
|
||||
(input_shape_vals, eigvecs_type),
|
||||
@ -560,7 +560,7 @@ def geev_hlo(dtype, input, *,
|
||||
result_shapes=result_shapes,
|
||||
).results
|
||||
if real:
|
||||
return (hlo.ComplexOp(out[3], out[4]).result, out[5], out[6], out[7])
|
||||
return (hlo.complex(out[3], out[4]), out[5], out[6], out[7])
|
||||
else:
|
||||
return out[2:6]
|
||||
|
||||
@ -615,7 +615,7 @@ def gees_hlo(dtype, a, *, jobvs=True, sort=False, select=None,
|
||||
scalar_layout = []
|
||||
batch_size_val = hlo_s32(1)
|
||||
for b_v in batch_dims_vals:
|
||||
batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result
|
||||
batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v))
|
||||
shape_type_pairs = workspaces + eigvals + [
|
||||
(a_shape_vals, etype),
|
||||
(batch_dims_vals, i32_type),
|
||||
|
@ -88,7 +88,7 @@ def lowering_testing_primitive_with_effect(ctx, a, *, effect_class_name: str):
|
||||
if "Ordered" in effect_class_name:
|
||||
token_in = ctx.tokens_in.get(_testing_effects[effect_class_name])[0]
|
||||
ctx.set_tokens_out(mlir.TokenSet({_testing_effects[effect_class_name]: (token_in,)}))
|
||||
return mlir.hlo.AddOp(a, a).results
|
||||
return [mlir.hlo.add(a, a)]
|
||||
|
||||
mlir.register_lowering(testing_primitive_with_effect_p,
|
||||
lowering_testing_primitive_with_effect)
|
||||
|
Loading…
x
Reference in New Issue
Block a user