From 8e8dc263bc1d6c9c569b2658d5d89ed2abd9ef04 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 17 Nov 2023 11:46:24 -0800 Subject: [PATCH] Use MLIR generated convenience functions athing(...) instead of writing AThingOp(...).result. In most cases these are more succinct. This change does not update Pallas/Mosaic. PiperOrigin-RevId: 583448254 --- jax/_src/interpreters/mlir.py | 136 ++++++++--------- jax/_src/interpreters/pxla.py | 18 +-- jax/_src/lax/ann.py | 6 +- jax/_src/lax/control_flow/conditionals.py | 2 +- jax/_src/lax/control_flow/loops.py | 6 +- jax/_src/lax/convolution.py | 8 +- jax/_src/lax/lax.py | 177 +++++++++++----------- jax/_src/lax/linalg.py | 22 +-- jax/_src/lax/parallel.py | 18 +-- jax/_src/lax/slicing.py | 14 +- jax/_src/lax/special.py | 16 +- jax/_src/lax/windowed_reductions.py | 54 ++++--- jax/_src/prng.py | 19 ++- jax/_src/tpu_custom_call.py | 2 +- jax/experimental/export/export.py | 8 +- jax/experimental/export/shape_poly.py | 10 +- jax/experimental/jax2tf/call_tf.py | 2 +- jax/experimental/sparse/bcsr.py | 3 +- jax/experimental/sparse/coo.py | 2 +- jaxlib/ducc_fft.py | 4 +- jaxlib/gpu_rnn.py | 4 +- jaxlib/gpu_solver.py | 26 ++-- jaxlib/hlo_helpers.py | 16 +- jaxlib/lapack.py | 20 +-- tests/export_test.py | 2 +- 25 files changed, 289 insertions(+), 306 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index ad7f503fb..944b208db 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -91,15 +91,15 @@ def shape_tensor(sizes: Sequence[int | ir.RankedTensorType] return ir_constant(np.array([d], np.int32)) else: if d.type != i32_type: - d = hlo.ConvertOp(i32_type, d) - return hlo.ReshapeOp(int1d, d).result + d = hlo.convert(i32_type, d) + return hlo.reshape(int1d, d) ds = map(lower_dim, sizes) if not ds: return ir_constant(np.array([], np.int32)) elif len(ds) == 1: return ds[0] else: - return hlo.ConcatenateOp(ds, i64_attr(0)).result + return hlo.concatenate(ds, i64_attr(0)) def delegate_lowering(ctx, lowering_fun, *args, **ctx_override_kwargs): @@ -251,7 +251,7 @@ def _numpy_array_constant(x: np.ndarray) -> Sequence[ir.Value]: x = np.packbits(x, bitorder='little') x = np.ascontiguousarray(x) attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) - return (hlo.ConstantOp(attr).result,) + return (hlo.constant(attr),) def _masked_array_constant_handler(*args, **kwargs): @@ -284,11 +284,11 @@ def _ndarray_constant_handler(val: np.ndarray) -> Sequence[ir.Value]: other_axes, = np.where(np.not_equal(0, val.strides)) collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None) # type: ignore for ax in range(val.ndim))] # type: ignore - out = hlo.BroadcastInDimOp( + out = hlo.broadcast_in_dim( ir.RankedTensorType.get( val.shape, dtype_to_ir_type(collapsed_val.dtype)), _numpy_array_constant(collapsed_val)[0], - dense_int_elements(other_axes)).result + dense_int_elements(other_axes)) return (out,) else: return _numpy_array_constant(val) @@ -309,7 +309,7 @@ for ptype, dtype in dtypes.python_scalar_dtypes.items(): register_constant_handler(ptype, partial(_python_scalar_handler, dtype)) def _token_constant_handler(val): - return [hlo.CreateTokenOp().result] + return [hlo.create_token()] register_constant_handler(core.Token, _token_constant_handler) # Source locations @@ -646,7 +646,7 @@ def eval_dynamic_shape_as_vals(ctx: LoweringRuleContext, else: i32_type = aval_to_ir_type(core.ShapedArray((), np.int32)) if d.type != i32_type: # type: ignore - return hlo.ConvertOp(i32_type, d).result + return hlo.convert(i32_type, d) else: return d return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape)) @@ -662,7 +662,7 @@ def eval_dynamic_shape_as_ivals( else: i32_type = aval_to_ir_type(core.ShapedArray((), np.int32)) if d.type != i32_type: # type: ignore - return hlo.ConvertOp(i32_type, d).result + return hlo.convert(i32_type, d) else: return d return tuple(convert_dim(v) for v in eval_dynamic_shape(ctx, shape)) @@ -912,7 +912,7 @@ def token_type() -> Sequence[ir.Type]: return [hlo.TokenType.get()] def create_token() -> Token: - return wrap_singleton_ir_values(hlo.CreateTokenOp().result) + return wrap_singleton_ir_values(hlo.create_token()) class TokenSet: """An immutable container of tokens to be used to lower effectful jaxprs. When lowering @@ -1252,7 +1252,7 @@ def lower_jaxpr_to_fun( args: list[list[ir.Value]] = [] for aval, arg in zip(jaxpr.in_avals, unflattened_args): if replace_tokens_with_dummy and aval is core.abstract_token: - args.append(hlo.CreateTokenOp().results) + args.append([hlo.create_token()]) else: args.append(arg) callee_name_stack = ctx.name_stack.extend(util.wrap_name(name, api_name)) @@ -1285,7 +1285,7 @@ def lower_jaxpr_to_fun( o if mk is None else wrap_with_memory_kind(o, mk, o_aval) for o, mk, o_aval in zip(flat_outputs, ir_result_memory_kinds, output_avals)] - func_dialect.ReturnOp(flat_outputs) + func_dialect.return_(flat_outputs) return func_op @@ -1356,7 +1356,7 @@ def _emit_lowering_rule_as_fun(lowering_rule, outs = lowering_rule(sub_ctx, *_unwrap_singleton_ir_values(unflattened_args)) if sub_ctx.tokens_out: outs = [*[sub_ctx.tokens_out.get(eff) for eff in effs], outs] - func_dialect.ReturnOp(util.flatten(map(wrap_singleton_ir_values, outs))) + func_dialect.return_(util.flatten(map(wrap_singleton_ir_values, outs))) return func_op def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, @@ -1500,7 +1500,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, out_nodes = tuple(map(wrap_singleton_ir_values, ans)) except TypeError as e: raise ValueError("Output of translation rule must be iterable: " - f"{eqn}, got output {ans} from {rule.__name__}") from e + f"{eqn}, got output {ans}") from e assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn) assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), ( @@ -1595,14 +1595,14 @@ def lower_multi_platform(ctx: LoweringRuleContext, # Compute the rule index based on the current platform i32_type = aval_to_ir_types(core.ShapedArray((), dtype=np.int32))[0] if current_platform_idx.type != i32_type: - current_platform_idx = hlo.ConvertOp(i32_type, current_platform_idx) + current_platform_idx = hlo.convert(i32_type, current_platform_idx) rule_idx_op = hlo.CaseOp([i32_type], index=current_platform_idx, num_branches=len(platforms)) for i, p in enumerate(platforms): branch = rule_idx_op.regions[i].blocks.append() with ir.InsertionPoint(branch): - hlo.ReturnOp(ir_constants(np.int32(platform_to_kept_rules_idx[p]))) + hlo.return_(ir_constants(np.int32(platform_to_kept_rules_idx[p]))) ordered_effects = effects_lib.ordered_effects.filter_in(effects) rule_out_avals = [core.abstract_token] * len(ordered_effects) + ctx.avals_out output_types = map(aval_to_ir_types, rule_out_avals) @@ -1623,7 +1623,7 @@ def lower_multi_platform(ctx: LoweringRuleContext, assert len(ordered_effects) == len(inner_ctx.tokens_out) out_nodes = [inner_ctx.tokens_out.get(eff) for eff in ordered_effects] + out_nodes - hlo.ReturnOp(util.flatten(map(wrap_singleton_ir_values, out_nodes))) + hlo.return_(util.flatten(map(wrap_singleton_ir_values, out_nodes))) results = case_op.results if ordered_effects: @@ -1772,17 +1772,17 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, else: if not core.is_constant_shape(aval_out.shape): # type: ignore shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape) # type: ignore - return hlo.DynamicBroadcastInDimOp( + return hlo.dynamic_broadcast_in_dim( aval_to_ir_type(aval_out), op, shape, dense_int_elements(broadcast_dimensions), - ).result + ) else: assert all(d != ir.ShapedType.get_dynamic_size() for d in aval_out.shape), aval_out # type: ignore - return hlo.BroadcastInDimOp( + return hlo.broadcast_in_dim( aval_to_ir_type(aval_out), op, - dense_int_elements(broadcast_dimensions)).result + dense_int_elements(broadcast_dimensions)) def multi_broadcast_in_dim(ctx: LoweringRuleContext, ops: Sequence[ir.Value], @@ -1806,11 +1806,11 @@ def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Va aval_out = core.physical_aval(aval_out) if not core.is_constant_shape(aval_out.shape): # type: ignore shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape) # type: ignore - return hlo.DynamicReshapeOp( + return hlo.dynamic_reshape( aval_to_ir_type(aval_out), op, shape, - ).result + ) else: - return hlo.ReshapeOp(aval_to_ir_type(aval_out), op).result + return hlo.reshape(aval_to_ir_type(aval_out), op) def slice_op(ctx: LoweringRuleContext, x, aval_out, *, start_indices, limit_indices, strides) -> ir.Value: @@ -1830,14 +1830,14 @@ def slice_op(ctx: LoweringRuleContext, x, aval_out, *, start_indices = eval_dynamic_shape_as_tensor(ctx, start_indices) limit_indices = eval_dynamic_shape_as_tensor(ctx, limit_indices) strides = eval_dynamic_shape_as_tensor(ctx, strides) - return hlo.RealDynamicSliceOp( + return hlo.real_dynamic_slice( aval_to_ir_type(aval_out), - x, start_indices, limit_indices, strides).result + x, start_indices, limit_indices, strides) else: - return hlo.SliceOp(x, - dense_int_elements(start_indices), - dense_int_elements(limit_indices), - dense_int_elements(strides)).result + return hlo.slice(x, + dense_int_elements(start_indices), + dense_int_elements(limit_indices), + dense_int_elements(strides)) def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *, start_indices) -> ir.Value: @@ -1859,21 +1859,20 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *, # lower to RealDynamicSliceOp, which is a version of SliceOp, and does # not have the clamping behavior. We clamp start ourselves. slice_sizes = eval_dynamic_shape_as_tensor(ctx, slice_sizes) - clamped_start = hlo.ClampOp( + clamped_start = hlo.clamp( shape_tensor([0] * len(start_indices)), shape_tensor(start_indices), - hlo.SubtractOp( + hlo.subtract( eval_dynamic_shape_as_tensor(ctx, x_aval.shape), # type: ignore slice_sizes)) - return hlo.RealDynamicSliceOp( + return hlo.real_dynamic_slice( aval_to_ir_type(aval_out), x, clamped_start, - hlo.AddOp(clamped_start, slice_sizes).result, + hlo.add(clamped_start, slice_sizes), shape_tensor([1] * len(start_indices)) - ).result + ) else: - return hlo.DynamicSliceOp(x, start_indices, - dense_int_elements(slice_sizes)).result + return hlo.dynamic_slice(x, start_indices, dense_int_elements(slice_sizes)) def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *, start_indices) -> ir.Value: @@ -1890,36 +1889,35 @@ def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *, start_indices=start_indices) else: # TODO(necula): handle dynamic shapes - return hlo.DynamicUpdateSliceOp(x, update, start_indices).result + return hlo.dynamic_update_slice(x, update, start_indices) def pad(ctx: LoweringRuleContext, aval_out, x, padding_value, padding_low, padding_high, padding_interior) -> ir.Value: if all(core.is_constant_shape(s) for s in (padding_low, padding_high, padding_interior)): - return hlo.PadOp(x, padding_value, - dense_int_elements(padding_low), - dense_int_elements(padding_high), - dense_int_elements(padding_interior)).result + return hlo.pad(x, padding_value, + dense_int_elements(padding_low), + dense_int_elements(padding_high), + dense_int_elements(padding_interior)) else: padding_low = eval_dynamic_shape_as_tensor(ctx, padding_low) padding_high = eval_dynamic_shape_as_tensor(ctx, padding_high) padding_interior = eval_dynamic_shape_as_tensor(ctx, padding_interior) - return hlo.DynamicPadOp( + return hlo.dynamic_pad( aval_to_ir_type(aval_out), - x, padding_value, padding_low, padding_high, padding_interior).result + x, padding_value, padding_low, padding_high, padding_interior) def iota(ctx: LoweringRuleContext, aval_out, *, dimension: int): if not core.is_constant_shape(aval_out.shape): shape = eval_dynamic_shape_as_tensor(ctx, aval_out.shape) - return hlo.DynamicIotaOp( + return hlo.dynamic_iota( aval_to_ir_type(aval_out), shape, i64_attr(dimension), - ).result + ) else: - return hlo.IotaOp(aval_to_ir_type(aval_out), - i64_attr(dimension)).result + return hlo.iota(aval_to_ir_type(aval_out), i64_attr(dimension)) def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> ir.Value: """Returns an IR constant shaped full of `value` shaped like `aval`.""" @@ -1933,7 +1931,7 @@ def zeros_like_lowering(ctx, x): register_lowering(ad_util.zeros_like_p, zeros_like_lowering) def add_jaxvals_lowering(ctx, x, y): - return hlo.AddOp(x, y).results + return [hlo.add(x, y)] register_lowering(ad_util.add_jaxvals_p, add_jaxvals_lowering) register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x]) @@ -1949,7 +1947,7 @@ def compare_hlo(x, y, direction: str, comparison_type: str | None = None): else: comparison_type = "FLOAT" - return hlo.CompareOp( + return hlo.compare( x, y, hlo.ComparisonDirectionAttr.get(direction), @@ -1959,20 +1957,18 @@ def _minmax_hlo(op, cmp, x, y): """Min/max that compares complex values lexicographically as pairs.""" tensor_type = ir.RankedTensorType(x.type) if ir.ComplexType.isinstance(tensor_type.element_type): - rx = hlo.RealOp(x).result - ry = hlo.RealOp(y).result + rx = hlo.real(x) + ry = hlo.real(y) real_eq = compare_hlo(rx, ry, "EQ", "FLOAT") real_cmp = compare_hlo(rx, ry, cmp, "FLOAT") - imag_cmp = compare_hlo( - hlo.ImagOp(x).result, - hlo.ImagOp(y).result, cmp, "FLOAT") - which = hlo.SelectOp(real_eq, imag_cmp, real_cmp).result - return hlo.SelectOp(which, x, y) + imag_cmp = compare_hlo(hlo.imag(x), hlo.imag(y), cmp, "FLOAT") + which = hlo.select(real_eq, imag_cmp, real_cmp) + return hlo.select(which, x, y) else: return op(x, y) -min_hlo = partial(_minmax_hlo, hlo.MinOp, "LT") -max_hlo = partial(_minmax_hlo, hlo.MaxOp, "GT") +min_hlo = partial(_minmax_hlo, hlo.minimum, "LT") +max_hlo = partial(_minmax_hlo, hlo.maximum, "GT") def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out): @@ -1988,10 +1984,9 @@ def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out): compare_type = "SIGNED" else: compare_type = "UNSIGNED" - x = compare_hlo(x, full_like_aval(ctx, 0, aval_in), "NE", - compare_type).result + x = compare_hlo(x, full_like_aval(ctx, 0, aval_in), "NE", compare_type) # continue, to adjust the shape if needed - return hlo.ConvertOp(aval_to_ir_type(aval_out), x).result + return hlo.convert(aval_to_ir_type(aval_out), x) def _wrap_with_spmd_op(name: str, ctx: LoweringRuleContext, @@ -2184,7 +2179,7 @@ def xla_fallback_lowering(prim: core.Primitive): flatten_lowering_ir_args(args)).result if not prim.multiple_results: return [call] - flat_results = [hlo.GetTupleElementOp(call, i32_attr(i)).result + flat_results = [hlo.get_tuple_element(call, i32_attr(i)) for i in range(len(flat_output_types))] return util.unflatten(flat_results, map(len, output_types)) @@ -2267,7 +2262,7 @@ def _emit_tpu_python_callback( *, sharding: xc.OpSharding | None = None ) -> tuple[Sequence[ir.Value], Any]: - token = token or hlo.CreateTokenOp().result + token = token or hlo.create_token() _wrapped_callback = callback send_channels = [] @@ -2445,7 +2440,7 @@ def emit_python_callback( if sharding is not None: set_sharding(result, sharding) results = [ - hlo.GetTupleElementOp(result, i32_attr(i)).result + hlo.get_tuple_element(result, i32_attr(i)) for i in range(len(result_types)) ] if token: @@ -2603,9 +2598,8 @@ def reduce_window( int2d = aval_to_ir_type(core.ShapedArray((1, 2), np.int32)) def prep_one_pad(pad_lo_hi: tuple[core.DimSize, core.DimSize]): pads = eval_dynamic_shape_as_tensor(ctx, pad_lo_hi) # i32[2] - return hlo.ReshapeOp(int2d, pads) - d_padding = hlo.ConcatenateOp(list(map(prep_one_pad, padding)), - i64_attr(0)).result + return hlo.reshape(int2d, pads) + d_padding = hlo.concatenate(list(map(prep_one_pad, padding)), i64_attr(0)) # Build the reducer reducer_type = ir.FunctionType.get(scalar_types + scalar_types, scalar_types) @@ -2614,8 +2608,7 @@ def reduce_window( ctx.module_context.symbol_table.insert(reducer) entry_block = reducer.add_entry_block() with ir.InsertionPoint(entry_block): - res = reducer_body(entry_block) - hlo.ReturnOp(res) + hlo.return_(reducer_body(entry_block)) rw = custom_call( "stablehlo.dynamic_reduce_window", @@ -2641,8 +2634,7 @@ def reduce_window( shape=(len(padding), 2))) reducer = rw.regions[0].blocks.append(*(scalar_types + scalar_types)) with ir.InsertionPoint(reducer): - res = reducer_body(reducer) - hlo.ReturnOp(res) + hlo.return_(reducer_body(reducer)) return rw.results diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 6fbd8dcba..234bf5540 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1269,8 +1269,7 @@ def _unravel_index_hlo(axis_env): div = mlir.ir_constant( np.array(axis_env.nreps // math.prod(axis_env.sizes), np.uint32)) mod = mlir.ir_constant(np.array(axis_env.sizes[-1], np.uint32)) - return hlo.RemOp( - hlo.DivOp(hlo.ReplicaIdOp().result, div).result, mod).result + return hlo.remainder(hlo.divide(hlo.replica_id(), div), mod) def _hlo_shard(aval, axis_env, xs, in_axis): if aval is core.abstract_token: @@ -1283,10 +1282,10 @@ def _hlo_shard(aval, axis_env, xs, in_axis): idxs.insert(in_axis, _unravel_index_hlo(axis_env)) dims_unsqueezed = dims.copy() dims_unsqueezed.insert(in_axis, 1) - dynamic_slice_result = hlo.DynamicSliceOp( - x, idxs, mlir.dense_int_elements(dims_unsqueezed)).result + dynamic_slice_result = hlo.dynamic_slice( + x, idxs, mlir.dense_int_elements(dims_unsqueezed)) return [ - hlo.ReshapeOp(mlir.aval_to_ir_type(aval), dynamic_slice_result).result + hlo.reshape(mlir.aval_to_ir_type(aval), dynamic_slice_result) ] else: raise TypeError(aval) @@ -1335,19 +1334,18 @@ def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs): padded = mlir.full_like_aval(ctx, 0, padded_aval) zero = mlir.ir_constant(np.zeros((), dtype=np.uint32)) idxs = [_unravel_index_hlo(axis_env)] + [zero] * len(dims) - broadcast_result = hlo.BroadcastOp( - x, mlir.dense_int_elements([1])).result - padded = hlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result + broadcast_result = hlo.broadcast(x, mlir.dense_int_elements([1])) + padded = hlo.dynamic_update_slice(padded, broadcast_result, idxs) replica_groups = mlir.dense_int_elements( axis_groups(axis_env, axis_env.names[-1])) - out = hlo.CrossReplicaSumOp(padded, replica_groups).result + out = hlo.cross_replica_sum(padded, replica_groups) if out_axis != 0: # TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead perm = list(range(1, len(dims))) perm.insert(out_axis, 0) transposed_dims = list(dims) transposed_dims.insert(out_axis, axis_env.sizes[-1]) - out = hlo.TransposeOp(out, mlir.dense_int_elements(perm)).result + out = hlo.transpose(out, mlir.dense_int_elements(perm)) return out else: diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 22abf2c6b..683c4754f 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -295,8 +295,8 @@ def _comparator_builder_mlir(ctx, op_type, is_max_k): with ir.InsertionPoint(entry_block): p0, p1, _, _ = entry_block.arguments direction = hlo.ComparisonDirectionAttr.get('GT' if is_max_k else 'LT') - cmp_result = hlo.CompareOp(p0, p1, comparison_direction=direction) - hlo.ReturnOp(cmp_result) + cmp_result = hlo.compare(p0, p1, comparison_direction=direction) + hlo.return_([cmp_result]) return comparator @@ -321,7 +321,7 @@ def _approx_top_k_lowering(ctx, operand, *, k, iota = mlir.iota(ctx, core.ShapedArray(ctx.avals_in[0].shape, np.int32), dimension=reduction_dimension) - init_arg = hlo.ConstantOp(ir.DenseElementsAttr.get(np.int32(-1))).result + init_arg = hlo.constant(ir.DenseElementsAttr.get(np.int32(-1))) init_val_array = _get_init_val_literal(ctx.avals_in[0].dtype, is_max_k) init_val = mlir.ir_constant(init_val_array.reshape(())) diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index ab54f96d9..376e35852 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -860,7 +860,7 @@ def _cond_lowering(ctx, index, *args, branches, linear): dim_var_values=ctx.dim_var_values) out_tokens = [tokens_out.get(eff) for eff in ordered_effects] out_vals = [*out_tokens, *out_vals] - hlo.ReturnOp(util.flatten(out_vals)) + hlo.return_(util.flatten(out_vals)) tokens_and_outputs = util.unflatten(case_op.results, map(len, output_types)) tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens]) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 8c581bf27..49da39f8b 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1656,7 +1656,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts, pred_ctx, pred, axes=tuple(range(len(pred_aval.shape)))) - hlo.ReturnOp([pred]) + hlo.return_([pred]) # Loop body body_block = while_op.regions[1].blocks.append(*flat_loop_carry_types) @@ -1687,8 +1687,8 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts, partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z, body_jaxpr.out_avals) - hlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x), - *util.flatten(y), *util.flatten(new_z)]) + hlo.return_([*util.flatten(out_tokens), *util.flatten(x), *util.flatten(y), + *util.flatten(new_z)]) outputs = util.unflatten(while_op.results, _map(len, loop_carry_types)) tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts]) diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 0aed39ce9..10064285a 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -707,7 +707,7 @@ def _conv_general_dilated_lower( raise NotImplementedError("Convolutions with non-static strides, dilation, feature_group_count, or batch_group_count") if all(core.is_constant_shape(p) for p in padding): return [ - hlo.ConvolutionOp( + hlo.convolution( mlir.aval_to_ir_type(aval_out), lhs, rhs, @@ -719,7 +719,7 @@ def _conv_general_dilated_lower( lhs_dilation=mlir.dense_int_elements(lhs_dilation), rhs_dilation=mlir.dense_int_elements(rhs_dilation), window_reversal=window_reversal, - precision_config=lax.precision_attr(precision)).result + precision_config=lax.precision_attr(precision)) ] else: # d_padding will be an array i32[N, 2] with pad_lo and pad_hi for each @@ -731,7 +731,7 @@ def _conv_general_dilated_lower( d_padding = hlo.ConcatenateOp(list(map(prep_one_pad, padding)), mlir.i64_attr(0)) return [ - hlo.DynamicConvOp( + hlo.dynamic_conv( mlir.aval_to_ir_type(aval_out), lhs, rhs, @@ -743,7 +743,7 @@ def _conv_general_dilated_lower( lhs_dilation=mlir.dense_int_elements(lhs_dilation), rhs_dilation=mlir.dense_int_elements(rhs_dilation), window_reversal=window_reversal, - precision_config=lax.precision_attr(precision)).result + precision_config=lax.precision_attr(precision)) ] mlir.register_lowering(conv_general_dilated_p, _conv_general_dilated_lower) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 43f671184..b4dc7201c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1680,19 +1680,19 @@ def broadcast_hlo( dims = mlir.dense_int_elements( range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape))) if any(isinstance(d, ir.Value) for d in aval_out.shape): - arg = hlo.DynamicBroadcastInDimOp( + arg = hlo.dynamic_broadcast_in_dim( mlir.aval_to_ir_type(aval_out), arg, - mlir.shape_tensor(aval_out.shape), dims).result + mlir.shape_tensor(aval_out.shape), dims) else: - arg = hlo.BroadcastInDimOp( + arg = hlo.broadcast_in_dim( mlir.aval_to_ir_type(aval.update(shape=aval_out.shape)), arg, - dims).result + dims) out.append(arg) return out def _nary_lower_hlo(op: Callable, ctx, *args: Union[ir.Value, Sequence[ir.Value]], - explicit_type=False, **params): + explicit_type=False, **params) -> Sequence[ir.Value]: """Lowers an elementwise operator to its MLIR equivalent. Args: @@ -1704,9 +1704,9 @@ def _nary_lower_hlo(op: Callable, ctx, ctx, args, avals_in, aval_out.shape) if explicit_type: - return op(mlir.aval_to_ir_type(aval_out), *broadcasted_args).results + return [op(mlir.aval_to_ir_type(aval_out), *broadcasted_args)] else: - return op(*broadcasted_args).results + return [op(*broadcasted_args)] _float = {np.floating} @@ -1723,7 +1723,7 @@ _ordered = _int | _float | _bool neg_p = standard_unop(_num, 'neg') ad.deflinear2(neg_p, lambda t, operand: [neg(t)]) -mlir.register_lowering(neg_p, partial(_nary_lower_hlo, hlo.NegOp)) +mlir.register_lowering(neg_p, partial(_nary_lower_hlo, hlo.negate)) sign_p = standard_unop(_num, 'sign') ad.defjvp_zero(sign_p) @@ -1731,44 +1731,44 @@ ad.defjvp_zero(sign_p) def _sign_lower_hlo(ctx, x): x_aval, = ctx.avals_in if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger): - return hlo.SelectOp( + return [hlo.select( mlir.compare_hlo(x, mlir.full_like_aval(ctx, 0, x_aval), 'EQ', - 'UNSIGNED').result, + 'UNSIGNED'), mlir.full_like_aval(ctx, 0, x_aval), - mlir.full_like_aval(ctx, 1, x_aval)).results - return hlo.SignOp(x).results + mlir.full_like_aval(ctx, 1, x_aval))] + return [hlo.sign(x)] mlir.register_lowering(sign_p, _sign_lower_hlo) nextafter_p = standard_naryop([_float, _float], 'nextafter') -mlir.register_lowering(nextafter_p, partial(_nary_lower_hlo, chlo.NextAfterOp)) +mlir.register_lowering(nextafter_p, partial(_nary_lower_hlo, chlo.next_after)) floor_p = standard_unop(_float, 'floor') ad.defjvp_zero(floor_p) -mlir.register_lowering(floor_p, partial(_nary_lower_hlo, hlo.FloorOp)) +mlir.register_lowering(floor_p, partial(_nary_lower_hlo, hlo.floor)) ceil_p = standard_unop(_float, 'ceil') ad.defjvp_zero(ceil_p) -mlir.register_lowering(ceil_p, partial(_nary_lower_hlo, hlo.CeilOp)) +mlir.register_lowering(ceil_p, partial(_nary_lower_hlo, hlo.ceil)) round_p = standard_unop(_float, 'round') ad.defjvp_zero(round_p) def _round_lower(ctx, x, *, rounding_method): if rounding_method is RoundingMethod.AWAY_FROM_ZERO: - return hlo.RoundOp(x).results + return [hlo.round_nearest_afz(x)] else: assert rounding_method is RoundingMethod.TO_NEAREST_EVEN - return hlo.RoundNearestEvenOp(x).results + return [hlo.round_nearest_even(x)] mlir.register_lowering(round_p, _round_lower) is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite') ad.defjvp_zero(is_finite_p) -mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.IsFiniteOp)) +mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.is_finite)) exp_p = standard_unop(_float | _complex, 'exp') ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans)) -mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.ExpOp)) +mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.exponential)) exp2_p = standard_unop(_float | _complex, 'exp2') ad.defjvp2(exp2_p, lambda g, ans, x: mul(log(_const(x, 2)), mul(g, ans))) @@ -1776,30 +1776,31 @@ def _exp2_lower(ctx, x): x_aval, = ctx.avals_in log2 = mlir.ir_constant(np.array(np.log(2), x_aval.dtype)) log2 = mlir.broadcast_in_dim(ctx, log2, x_aval, broadcast_dimensions=()) - return hlo.ExpOp(hlo.MulOp(log2, x).result).results + return [hlo.exponential(hlo.multiply(log2, x))] mlir.register_lowering(exp2_p, _exp2_lower) log_p = standard_unop(_float | _complex, 'log') ad.defjvp(log_p, lambda g, x: div(g, x)) -mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.LogOp)) +mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.log)) expm1_p = standard_unop(_float | _complex, 'expm1') ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans)))) -mlir.register_lowering(expm1_p, partial(_nary_lower_hlo, hlo.Expm1Op)) +mlir.register_lowering(expm1_p, + partial(_nary_lower_hlo, hlo.exponential_minus_one)) log1p_p = standard_unop(_float | _complex, 'log1p') ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x)))) -mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.Log1pOp)) +mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.log_plus_one)) tanh_p = standard_unop(_float | _complex, 'tanh') ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)), sub(_one(x), ans))) -mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.TanhOp)) +mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.tanh)) logistic_p = standard_unop(_float | _complex, 'logistic') ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans)))) # TODO(phawkins): switch to LogisticOp lowering; debug numerical problems. -# mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.LogisticOp)) +# mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.logistic)) def logistic_impl(x): one = _const(x, 1) @@ -1810,11 +1811,11 @@ mlir.register_lowering(logistic_p, sin_p = standard_unop(_float | _complex, 'sin') ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) -mlir.register_lowering(sin_p, partial(_nary_lower_hlo, hlo.SineOp)) +mlir.register_lowering(sin_p, partial(_nary_lower_hlo, hlo.sine)) cos_p = standard_unop(_float | _complex, 'cos') ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) -mlir.register_lowering(cos_p, partial(_nary_lower_hlo, hlo.CosineOp)) +mlir.register_lowering(cos_p, partial(_nary_lower_hlo, hlo.cosine)) @_upcast_fp16_for_computation def _tan_impl(x): @@ -1822,7 +1823,7 @@ def _tan_impl(x): tan_p = standard_unop(_float | _complex, 'tan') ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans))) -mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.TanOp)) +mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan)) def asin_impl(x): if dtypes.issubdtype(_dtype(x), np.complexfloating): @@ -1833,7 +1834,7 @@ def asin_impl(x): asin_p = standard_unop(_float | _complex, 'asin') ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x)))) -mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.AsinOp)) +mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.asin)) def acos_impl(x): if dtypes.issubdtype(_dtype(x), np.complexfloating): @@ -1862,43 +1863,43 @@ def atan_impl(x): atan_p = standard_unop(_float | _complex, 'atan') ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x))) -mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.AtanOp)) +mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.atan)) atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2') ad.defjvp(atan2_p, lambda g, x, y: g * (y / (square(x) + square(y))), lambda g, x, y: g * -x / (square(x) + square(y))) -mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.Atan2Op)) +mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.atan2)) sinh_p = standard_unop(_float | _complex, 'sinh') ad.defjvp(sinh_p, lambda g, x: mul(g, cosh(x))) -mlir.register_lowering(sinh_p, partial(_nary_lower_hlo, chlo.SinhOp)) +mlir.register_lowering(sinh_p, partial(_nary_lower_hlo, chlo.sinh)) cosh_p = standard_unop(_float | _complex, 'cosh') ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x))) -mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.CoshOp)) +mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.cosh)) asinh_p = standard_unop(_float | _complex, 'asinh') ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x)))) -mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.AsinhOp)) +mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.asinh)) acosh_p = standard_unop(_float | _complex, 'acosh') ad.defjvp(acosh_p, lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x))))) -mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.AcoshOp)) +mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.acosh)) atanh_p = standard_unop(_float | _complex, 'atanh') ad.defjvp(atanh_p, lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x)))) -mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.AtanhOp)) +mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.atanh)) real_p = unop(_complex_basetype, _complex, 'real') ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))]) -mlir.register_lowering(real_p, partial(_nary_lower_hlo, hlo.RealOp)) +mlir.register_lowering(real_p, partial(_nary_lower_hlo, hlo.real)) imag_p = unop(_complex_basetype, _complex, 'imag') ad.deflinear2(imag_p, lambda t, _: [complex(np.zeros((), _dtype(t)), neg(t))]) -mlir.register_lowering(imag_p, partial(_nary_lower_hlo, hlo.ImagOp)) +mlir.register_lowering(imag_p, partial(_nary_lower_hlo, hlo.imag)) def _complex_transpose_rule(t, x, y): @@ -1923,7 +1924,7 @@ _complex_dtype = lambda dtype, *args: (np.zeros((), dtype) + np.zeros((), np.com complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types], 'complex') ad.deflinear2(complex_p, _complex_transpose_rule) -mlir.register_lowering(complex_p, partial(_nary_lower_hlo, hlo.ComplexOp)) +mlir.register_lowering(complex_p, partial(_nary_lower_hlo, hlo.complex)) conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj') @@ -1950,7 +1951,7 @@ ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p) ad.primitive_transposes[conj_p] = _conj_transpose_rule abs_p = unop(_complex_basetype, _signedint | _float | _complex, 'abs') -mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.AbsOp)) +mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.abs)) def _abs_jvp_rule(g, ans, x): if _iscomplex(x): @@ -1964,18 +1965,18 @@ _maybe_real = lambda x: real(x) if _iscomplex(x) else x sqrt_p = standard_unop(_float | _complex, 'sqrt') ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans))) -mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.SqrtOp)) +mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.sqrt)) rsqrt_p = standard_unop(_float | _complex, 'rsqrt') ad.defjvp2(rsqrt_p, lambda g, ans, x: mul(g, mul(_const(x, -0.5), div(ans, x)))) -mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.RsqrtOp)) +mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.rsqrt)) cbrt_p = standard_unop(_float, 'cbrt') ad.defjvp2(cbrt_p, lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2)))) -mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.CbrtOp)) +mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.cbrt)) def _pow_dtype_rule(x, y): if (dtypes.issubdtype(x.dtype, np.inexact) and @@ -2020,7 +2021,7 @@ def _pow_lower(ctx, x, y): [(x_,)] = convert(ctx.replace(avals_in=[x_aval], avals_out=[x_aval_]), x) [(y_,)] = convert(ctx.replace(avals_in=[y_aval], avals_out=[y_aval_]), y) ctx_ = ctx.replace(avals_in=[x_aval_, y_aval_]) - return _nary_lower_hlo(hlo.PowOp, ctx_, x_, y_) + return _nary_lower_hlo(hlo.power, ctx_, x_, y_) mlir.register_lowering(pow_p, _pow_lower) @@ -2075,26 +2076,25 @@ _replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x) not_p = standard_unop(_bool_or_int, 'not') ad.defjvp_zero(not_p) -mlir.register_lowering(not_p, partial(_nary_lower_hlo, hlo.NotOp)) +mlir.register_lowering(not_p, partial(_nary_lower_hlo, hlo.not_)) and_p = standard_naryop([_bool_or_int, _bool_or_int], 'and') ad.defjvp_zero(and_p) -mlir.register_lowering(and_p, partial(_nary_lower_hlo, hlo.AndOp)) +mlir.register_lowering(and_p, partial(_nary_lower_hlo, hlo.and_)) or_p = standard_naryop([_bool_or_int, _bool_or_int], 'or') ad.defjvp_zero(or_p) -mlir.register_lowering(or_p, partial(_nary_lower_hlo, hlo.OrOp)) +mlir.register_lowering(or_p, partial(_nary_lower_hlo, hlo.or_)) xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor') ad.defjvp_zero(xor_p) -mlir.register_lowering(xor_p, partial(_nary_lower_hlo, hlo.XorOp)) +mlir.register_lowering(xor_p, partial(_nary_lower_hlo, hlo.xor)) population_count_p = standard_unop(_int, 'population_count') -mlir.register_lowering(population_count_p, - partial(_nary_lower_hlo, hlo.PopulationCountOp)) +mlir.register_lowering(population_count_p, partial(_nary_lower_hlo, hlo.popcnt)) clz_p = standard_unop(_int, 'clz') -mlir.register_lowering(clz_p, partial(_nary_lower_hlo, hlo.ClzOp)) +mlir.register_lowering(clz_p, partial(_nary_lower_hlo, hlo.count_leading_zeros)) def _add_jvp(primals, tangents): x, y = primals @@ -2130,7 +2130,7 @@ def _add_inverse(r, x, y): add_p: Primitive = standard_naryop([_num, _num], 'add') ad.primitive_jvps[add_p] = _add_jvp ad.primitive_transposes[add_p] = _add_transpose -mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.AddOp)) +mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.add)) def _sub_jvp(primals, tangents): x, y = primals @@ -2159,7 +2159,7 @@ def _sub_transpose(t, x, y): sub_p = standard_naryop([_num, _num], 'sub') ad.primitive_jvps[sub_p] = _sub_jvp ad.primitive_transposes[sub_p] = _sub_transpose -mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.SubtractOp)) +mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.subtract)) def _mul_transpose(ct, x, y): @@ -2185,7 +2185,7 @@ ad.defjvp(mul_p, lambda xdot, x, y: mul(xdot, y), lambda ydot, x, y: mul(x, ydot)) ad.primitive_transposes[mul_p] = _mul_transpose -mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.MulOp)) +mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.multiply)) def _div_transpose_rule(cotangent, x, y): assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y) @@ -2198,14 +2198,14 @@ ad.defjvp(div_p, lambda g, x, y: div(g, y), lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2))) ad.primitive_transposes[div_p] = _div_transpose_rule -mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.DivOp)) +mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.divide)) rem_p = standard_naryop([_int | _float, _int | _float], 'rem') ad.defjvp( rem_p, lambda g, x, y: _maybe_broadcast(broadcast_shapes(np.shape(x), np.shape(y)), g), lambda g, x, y: mul(neg(g), mul(sign(div(x, y)), floor(abs(div(x, y)))))) -mlir.register_lowering(rem_p, partial(_nary_lower_hlo, hlo.RemOp)) +mlir.register_lowering(rem_p, partial(_nary_lower_hlo, hlo.remainder)) def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x): result_shape = broadcast_shapes(np.shape(x), np.shape(y)) @@ -2231,17 +2231,17 @@ mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo)) shift_left_p = standard_naryop([_int, _int], 'shift_left') ad.defjvp_zero(shift_left_p) -mlir.register_lowering(shift_left_p, partial(_nary_lower_hlo, hlo.ShiftLeftOp)) +mlir.register_lowering(shift_left_p, partial(_nary_lower_hlo, hlo.shift_left)) shift_right_arithmetic_p = standard_naryop([_int, _int], 'shift_right_arithmetic') ad.defjvp_zero(shift_right_arithmetic_p) mlir.register_lowering(shift_right_arithmetic_p, - partial(_nary_lower_hlo, hlo.ShiftRightArithmeticOp)) + partial(_nary_lower_hlo, hlo.shift_right_arithmetic)) shift_right_logical_p = standard_naryop([_int, _int], 'shift_right_logical') ad.defjvp_zero(shift_right_logical_p) mlir.register_lowering(shift_right_logical_p, - partial(_nary_lower_hlo, hlo.ShiftRightLogicalOp)) + partial(_nary_lower_hlo, hlo.shift_right_logical)) def _opaque_comparison_hlo(direction, reduction_op, identity, ctx, avals_in, aval_out, x, y): @@ -2288,7 +2288,7 @@ def _compare_lower_hlo(direction: str, total_order: bool, ctx, x, y): compare_type = "SIGNED" else: compare_type = "UNSIGNED" - return mlir.compare_hlo(x, y, direction, compare_type).results + return [mlir.compare_hlo(x, y, direction, compare_type)] eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq', allow_extended_dtype=True) ad.defjvp_zero(eq_p) @@ -2426,7 +2426,7 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type): aval_out, = ctx.avals_out if (dtypes.issubdtype(aval_in.dtype, np.complexfloating) and not dtypes.issubdtype(new_dtype, np.complexfloating)): - operand = hlo.RealOp(operand).result + operand = hlo.real(operand) aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype)) return [mlir.convert_hlo(ctx, operand, aval_in, aval_out)] @@ -2473,7 +2473,7 @@ batching.defvectorized(bitcast_convert_type_p) def _bitcast_convert_type_lower(ctx, operand, *, new_dtype): aval_out, = ctx.avals_out - return hlo.BitcastConvertOp(mlir.aval_to_ir_type(aval_out), operand).results + return [hlo.bitcast_convert(mlir.aval_to_ir_type(aval_out), operand)] mlir.register_lowering(bitcast_convert_type_p, _bitcast_convert_type_lower) @@ -2841,12 +2841,12 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, lhs_contracting_dimensions=list(lhs_contracting), rhs_contracting_dimensions=list(rhs_contracting)) return [ - hlo.DotGeneralOp( + hlo.dot_general( mlir.aval_to_ir_type(aval_out), lhs, rhs, dot_dnums, - precision_config=precision_attr(precision)).result + precision_config=precision_attr(precision)) ] mlir.register_lowering(dot_general_p, _dot_general_lower) @@ -3143,8 +3143,7 @@ ad.defjvp(clamp_p, lambda g, min, operand, max: select(lt(max, operand), g, _zeros(operand))) batching.primitive_batchers[clamp_p] = _clamp_batch_rule -mlir.register_lowering( - clamp_p, partial(_nary_lower_hlo, hlo.ClampOp)) +mlir.register_lowering(clamp_p, partial(_nary_lower_hlo, hlo.clamp)) pe.def_trivial_padding(clamp_p) def _concatenate_shape_rule(*operands, **kwargs): @@ -3222,7 +3221,7 @@ batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule pe.padding_rules[concatenate_p] = _concatenate_pad_rule def _concatenate_lower(ctx, *xs, dimension): - return hlo.ConcatenateOp(xs, mlir.i64_attr(dimension)).results + return [hlo.concatenate(xs, mlir.i64_attr(dimension))] mlir.register_lowering(concatenate_p, _concatenate_lower) @@ -3424,7 +3423,7 @@ def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions): def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions): aval_out, = ctx.avals_out if dimensions is not None: - x = hlo.TransposeOp(x, mlir.dense_int_elements(dimensions)).result + x = hlo.transpose(x, mlir.dense_int_elements(dimensions)) if dyn_shape: aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape)) return [mlir.reshape(ctx, x, aval_out)] @@ -3468,7 +3467,7 @@ ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)]) batching.primitive_batchers[rev_p] = _rev_batch_rule def _rev_lower(ctx, x, *, dimensions): - return hlo.ReverseOp(x, mlir.dense_int_elements(dimensions)).results + return [hlo.reverse(x, mlir.dense_int_elements(dimensions))] mlir.register_lowering(rev_p, _rev_lower) @@ -3500,7 +3499,7 @@ def _transpose_lower(ctx, x, *, permutation): aval_out.dtype).shape trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))] permutation = [*permutation, *trailing_dims] - return hlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results + return [hlo.transpose(x, mlir.dense_int_elements(permutation))] transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype, 'transpose') @@ -3632,7 +3631,7 @@ def _select_hlo_lowering(ctx, which, *cases): if which_aval.dtype == np.dtype(np.bool_): assert len(cases) <= 2 if len(cases) == 1: return cases - return hlo.SelectOp(which, cases[1], cases[0]).results + return [hlo.select(which, cases[1], cases[0])] if dtypes.issubdtype(which_aval.dtype, np.signedinteger): compare_type = 'SIGNED' @@ -3648,8 +3647,8 @@ def _select_hlo_lowering(ctx, which, *cases): pred = mlir.compare_hlo(which, mlir.full_like_aval(ctx, offset + mid, which_aval), lt, compare_type) - return hlo.SelectOp(pred, _select(offset, cases[:mid]), - _select(offset + mid, cases[mid:])).result + return hlo.select(pred, _select(offset, cases[:mid]), + _select(offset + mid, cases[mid:])) return [_select(0, cases)] @@ -3791,7 +3790,7 @@ def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions): out_nodes, _ = mlir.jaxpr_subcomp(reducer_ctx, jaxpr, mlir.TokenSet(), consts, *([a] for a in reducer.arguments), dim_var_values=ctx.dim_var_values) - hlo.ReturnOp(util.flatten(out_nodes)) + hlo.return_(util.flatten(out_nodes)) return op.results mlir.register_lowering(reduce_p, _reduce_lower) @@ -3982,8 +3981,7 @@ def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes): scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype)) reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer_region): - add = reducer(*reducer_region.arguments) - hlo.ReturnOp(add.results) + hlo.return_([reducer(*reducer_region.arguments)]) return op.results mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp, @@ -4021,8 +4019,8 @@ batching.defvectorized(reduce_precision_p) def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits): aval_out, = ctx.avals_out - return hlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits), - mlir.i32_attr(mantissa_bits)).results + return [hlo.reduce_precision(operand, mlir.i32_attr(exponent_bits), + mlir.i32_attr(mantissa_bits))] mlir.register_lowering(reduce_precision_p, _reduce_precision_lower) @@ -4163,7 +4161,7 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys): out = lower_comparator(sub_ctx, *[[a] for a in comparator.arguments], num_keys=num_keys) - hlo.ReturnOp(util.flatten(out)) + hlo.return_(util.flatten(out)) return sort.results mlir.register_lowering(sort_p, _sort_lower) @@ -4273,7 +4271,7 @@ create_token_p.def_abstract_eval(lambda *_: abstract_token) def _create_token_lowering(ctx, *operands): aval_out, = ctx.avals_out - return hlo.CreateTokenOp().results + return [hlo.create_token()] mlir.register_lowering(create_token_p, _create_token_lowering) @@ -4295,7 +4293,7 @@ after_all_p.def_abstract_eval(_after_all_abstract_eval) def _after_all_lowering(ctx, *operands): aval_out, = ctx.avals_out - return hlo.AfterAllOp(operands).results + return [hlo.after_all(operands)] mlir.register_lowering(after_all_p, _after_all_lowering) @@ -4434,8 +4432,7 @@ rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval) def _rng_uniform_lowering(ctx, a, b, *, shape): aval_out, = ctx.avals_out shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64)) - return hlo.RngOp(a, b, shape, - hlo.RngDistributionAttr.get('UNIFORM')).results + return [hlo.rng(a, b, shape, hlo.RngDistributionAttr.get('UNIFORM'))] mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering) @@ -4491,9 +4488,9 @@ def _rng_bit_generator_lowering( rbg_etype = u32_type rbg_dtype = np.uint32 if key_etype == u32_type: - key = hlo.BitcastConvertOp( + key = hlo.bitcast_convert( ir.RankedTensorType.get([2], u64_type), - hlo.ReshapeOp(ir.RankedTensorType.get([2, 2], u32_type), key)).result + hlo.reshape(ir.RankedTensorType.get([2, 2], u32_type), key)) algorithm_attr = _rng_algorithm(algorithm) _, out_vals_aval = ctx.avals_out if any(not core.is_constant_shape(a.shape) for a in ctx.avals_out): @@ -4511,14 +4508,14 @@ def _rng_bit_generator_lowering( ir.RankedTensorType.get(shape, rbg_etype), algorithm_attr, key).results if key_etype == u32_type: - out_key = hlo.ReshapeOp( + out_key = hlo.reshape( ir.RankedTensorType.get([4], u32_type), - hlo.BitcastConvertOp( - ir.RankedTensorType.get([2, 2], u32_type), out_key)).result + hlo.bitcast_convert( + ir.RankedTensorType.get([2, 2], u32_type), out_key)) if rbg_etype != etype: - out_vals = hlo.ConvertOp( + out_vals = hlo.convert( ir.RankedTensorType.get(ir.RankedTensorType(out_vals.type).shape, etype), - out_vals).result + out_vals) return [out_key, out_vals] diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index d9217e6f4..cfeacd3e8 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -435,7 +435,7 @@ ad.primitive_jvps[cholesky_p] = _cholesky_jvp_rule batching.primitive_batchers[cholesky_p] = _cholesky_batching_rule def _cholesky_lowering(ctx, x): - return hlo.CholeskyOp(x, lower=ir.BoolAttr.get(True)).results + return [hlo.cholesky(x, lower=ir.BoolAttr.get(True))] mlir.register_lowering(cholesky_p, _cholesky_lowering) @@ -621,7 +621,7 @@ def _eigh_jacobi_lowering_rule(ctx, operand, lower, sort_eigenvalues): if operand_aval.shape[-1] == 0: reshape_aval = operand_aval.update(shape=operand_aval.shape[:-1]) return [ - hlo.RealOp(mlir.reshape(ctx, operand, reshape_aval)).result, + hlo.real(mlir.reshape(ctx, operand, reshape_aval)), operand, ] @@ -985,10 +985,10 @@ def _triangular_solve_lowering( transpose = "NO_TRANSPOSE" else: transpose = "ADJOINT" if conjugate_a else "TRANSPOSE" - return hlo.TriangularSolveOp( + return [hlo.triangular_solve( a, b, ir.BoolAttr.get(left_side), ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal), - hlo.TransposeAttr.get(transpose)).results + hlo.TransposeAttr.get(transpose))] mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering) @@ -998,7 +998,7 @@ def _triangular_solve_cpu_lower( a_aval, b_aval = ctx.avals_in if conjugate_a and not transpose_a: - a = chlo.ConjOp(a).result + a = chlo.conj(a) conjugate_a = False if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types: alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)) @@ -1014,10 +1014,10 @@ def _triangular_solve_cpu_lower( transpose = "ADJOINT" if conjugate_a else "TRANSPOSE" else: transpose = "NO_TRANSPOSE" - return hlo.TriangularSolveOp(a, b, ir.BoolAttr.get(left_side), - ir.BoolAttr.get(lower), - ir.BoolAttr.get(unit_diagonal), - hlo.TransposeAttr.get(transpose)).results + return [hlo.triangular_solve(a, b, ir.BoolAttr.get(left_side), + ir.BoolAttr.get(lower), + ir.BoolAttr.get(unit_diagonal), + hlo.TransposeAttr.get(transpose))] mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower, platform='cpu') @@ -1297,7 +1297,7 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand, *, lu, pivot, info = getrf_impl( operand_aval.dtype, operand, a_shape_vals=op_shape_vals) # Subtract 1 from the pivot to get 0-based indices. - pivot = hlo.SubtractOp(pivot, mlir.full_like_aval(ctx, 1, pivot_aval)).result + pivot = hlo.subtract(pivot, mlir.full_like_aval(ctx, 1, pivot_aval)) ok = mlir.compare_hlo( info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))), "GE", "SIGNED") @@ -2380,4 +2380,4 @@ def _broadcasting_select_hlo(ctx, which, which_aval, x, x_aval, y, y_aval) -> ir which, x, y = mlir.multi_broadcast_in_dim(ctx, (which, x, y), (which_aval, x_aval, y_aval), out_shapes) - return hlo.SelectOp(which, x, y).result + return hlo.select(which, x, y) diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 089b30b40..ca42033cc 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -775,7 +775,7 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups): avals_in=[scalar_aval] * 2, avals_out=[scalar_aval]) out_nodes = lower_reducer( reducer_ctx, *([a] for a in reducer_block.arguments)) - hlo.ReturnOp(util.flatten(out_nodes)) + hlo.return_(util.flatten(out_nodes)) return op.result return [all_reduce(aval, x) for aval, x in zip(ctx.avals_in, args)] @@ -1191,7 +1191,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, new_shape = list(x_aval.shape) new_shape.insert(all_gather_dimension, 1) broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension] - x = hlo.BroadcastInDimOp( + x = hlo.broadcast_in_dim( mlir.aval_to_ir_type(x_aval.update(shape=new_shape)), x, mlir.dense_int_elements(broadcast_dimensions)) replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, @@ -1359,12 +1359,12 @@ def _reduce_scatter_lowering( avals_out=[scalar_aval]) out_nodes = lower_reducer( reducer_ctx, *([a] for a in reducer_block.arguments)) - hlo.ReturnOp(util.flatten(out_nodes)) + hlo.return_(util.flatten(out_nodes)) if tiled: return op.results else: - return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), op.result).results + return [hlo.reshape(mlir.aval_to_ir_type(aval_out), op.result)] def _reduce_scatter_lowering_via_reducer( prim, reducer, ctx, x, @@ -1582,13 +1582,13 @@ def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): (sharding_impls.SPMDAxisContext, sharding_impls.ShardingContext), ) if is_spmd: - device_id = hlo.PartitionIdOp() + device_id = hlo.partition_id() else: - device_id = hlo.ReplicaIdOp() - unsigned_index = hlo.RemOp(hlo.DivOp(device_id, div), mod) - return hlo.ConvertOp( + device_id = hlo.replica_id() + unsigned_index = hlo.remainder(hlo.divide(device_id, div), mod) + return hlo.convert( ir.RankedTensorType.get([], ir.IntegerType.get_signless(32)), - unsigned_index).result + unsigned_index) def _axis_index_lowering(ctx, *, axis_name): return [ diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index d7e05ee2a..56240fe0a 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1839,12 +1839,12 @@ def _gather_lower(ctx, operand, indices, *, return hlo.DynamicGatherOp.build_generic( results=results, operands=operands, attributes=attributes).results else: - return hlo.GatherOp( + return [hlo.gather( operand, indices, dnums, mlir.dense_int_elements(slice_sizes), - indices_are_sorted=ir.BoolAttr.get(indices_are_sorted)).results + indices_are_sorted=ir.BoolAttr.get(indices_are_sorted))] mlir.register_lowering(gather_p, _gather_lower) @@ -2499,7 +2499,7 @@ def _scatter_lower(ctx, operand, indices, updates, *, update_ctx, update_jaxpr, mlir.TokenSet(), update_consts, (update.arguments[0],), (update.arguments[1],), dim_var_values=ctx.dim_var_values) - hlo.ReturnOp(util.flatten(out_nodes)) + hlo.return_(util.flatten(out_nodes)) return op.results mlir.register_lowering(scatter_p, _scatter_lower) @@ -2555,12 +2555,12 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates, reducer = scatter.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer): add = hlo.AddOp(*reducer.arguments).result - hlo.ReturnOp([add]) + hlo.return_([add]) return scatter.result - real = _scatter(hlo.RealOp(operand).result, hlo.RealOp(updates).result) - imag = _scatter(hlo.ImagOp(operand).result, hlo.ImagOp(updates).result) - return hlo.ComplexOp(real, imag).results + real = _scatter(hlo.real(operand), hlo.real(updates)) + imag = _scatter(hlo.imag(operand), hlo.imag(updates)) + return [hlo.complex(real, imag)] mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu") diff --git a/jax/_src/lax/special.py b/jax/_src/lax/special.py index e862009f6..f6cd8bc0d 100644 --- a/jax/_src/lax/special.py +++ b/jax/_src/lax/special.py @@ -625,14 +625,14 @@ ad.defjvp(regularized_incomplete_beta_p, lgamma_p = standard_unop(_float, 'lgamma') ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x))) -mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.LgammaOp)) +mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.lgamma)) digamma_p = standard_unop(_float, 'digamma') -mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.DigammaOp)) +mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.digamma)) ad.defjvp(digamma_p, lambda g, x: mul(g, polygamma(_const(x, 1), x))) polygamma_p = standard_naryop([_float, _float], 'polygamma') -mlir.register_lowering(polygamma_p, partial(_nary_lower_hlo, chlo.PolygammaOp)) +mlir.register_lowering(polygamma_p, partial(_nary_lower_hlo, chlo.polygamma)) ad.defjvp(polygamma_p, polygamma_gradm, polygamma_gradx) igamma_p = standard_naryop([_float, _float], 'igamma') @@ -659,7 +659,7 @@ mlir.register_lowering(random_gamma_grad_p, multiple_results=False)) zeta_p = standard_naryop([_float, _float], 'zeta') -mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.ZetaOp)) +mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.zeta)) bessel_i0e_p = standard_unop(_float, 'bessel_i0e') mlir.register_lowering(bessel_i0e_p, @@ -669,7 +669,7 @@ ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y)) bessel_i1e_p = standard_unop(_float, 'bessel_i1e') mlir.register_lowering(bessel_i1e_p, - partial(_nary_lower_hlo, chlo.BesselI1eOp)) + partial(_nary_lower_hlo, chlo.bessel_i1e)) def _bessel_i1e_jvp(g, y, x): eps = dtypes.finfo(_dtype(x)).eps @@ -683,14 +683,14 @@ ad.defjvp2(bessel_i1e_p, _bessel_i1e_jvp) erf_p = standard_unop(_float, 'erf') ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)), mul(g, exp(neg(square(x)))))) -mlir.register_lowering(erf_p, partial(_nary_lower_hlo, chlo.ErfOp)) +mlir.register_lowering(erf_p, partial(_nary_lower_hlo, chlo.erf)) erfc_p = standard_unop(_float, 'erfc') ad.defjvp(erfc_p, lambda g, x: mul(_const(x, -2. / np.sqrt(np.pi)), mul(g, exp(neg(square(x)))))) -mlir.register_lowering(erfc_p, partial(_nary_lower_hlo, chlo.ErfcOp)) +mlir.register_lowering(erfc_p, partial(_nary_lower_hlo, chlo.erfc)) erf_inv_p = standard_unop(_float, 'erf_inv') ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.), mul(g, exp(square(ans))))) -mlir.register_lowering(erf_inv_p, partial(_nary_lower_hlo, chlo.ErfInvOp)) +mlir.register_lowering(erf_inv_p, partial(_nary_lower_hlo, chlo.erf_inv)) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index e7fe2a604..ce9c1c34c 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -470,7 +470,7 @@ def _reduce_window_lower( return mlir.reduce_window(ctx, reducer_name=f"reduce_window_{scalar_aval.dtype}_reducer", - reducer_body=lambda reducer: reduce_op(*reducer.arguments), + reducer_body=lambda reducer: [reduce_op(*reducer.arguments)], operands=[operand], init_values=[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype), scalar_aval)], @@ -482,7 +482,7 @@ def _reduce_window_lower( mlir.register_lowering(reduce_window_sum_p, partial( - _reduce_window_lower, hlo.AddOp, lambda _: 0)) + _reduce_window_lower, hlo.add, lambda _: 0)) mlir.register_lowering(reduce_window_min_p, partial( _reduce_window_lower, mlir.min_hlo, lax._get_min_identity)) mlir.register_lowering(reduce_window_max_p, partial( @@ -530,7 +530,7 @@ def _select_and_scatter_lower( mlir.TokenSet(), select_consts, *([a] for a in select.arguments), dim_var_values=ctx.dim_var_values) - hlo.ReturnOp(util.flatten(out_nodes)) + hlo.return_(util.flatten(out_nodes)) scatter = op.scatter.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(scatter): if scatter_jaxpr.effects: @@ -539,7 +539,7 @@ def _select_and_scatter_lower( mlir.TokenSet(), scatter_consts, *([a] for a in scatter.arguments), dim_var_values=ctx.dim_var_values) - hlo.ReturnOp(util.flatten(out_nodes)) + hlo.return_(util.flatten(out_nodes)) return op.results mlir.register_lowering(select_and_scatter_p, _select_and_scatter_lower) @@ -685,27 +685,27 @@ def _select_and_gather_add_lowering( def pack(a, b, ab_aval): word_type_ab_aval = ab_aval.update(dtype=word_dtype) double_word_type_ab_aval = ab_aval.update(dtype=double_word_dtype) - a = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), a) - b = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), b) - a = hlo.ConvertOp(mlir.aval_to_ir_type(double_word_type_ab_aval), a) - b = hlo.ConvertOp(mlir.aval_to_ir_type(double_word_type_ab_aval), b) - a = hlo.ShiftLeftOp(a, - _broadcast_scalar_const(nbits, double_word_type_ab_aval)) - return hlo.OrOp(a, b) + a = hlo.bitcast_convert(mlir.aval_to_ir_type(word_type_ab_aval), a) + b = hlo.bitcast_convert(mlir.aval_to_ir_type(word_type_ab_aval), b) + a = hlo.convert(mlir.aval_to_ir_type(double_word_type_ab_aval), a) + b = hlo.convert(mlir.aval_to_ir_type(double_word_type_ab_aval), b) + a = hlo.shift_left( + a, _broadcast_scalar_const(nbits, double_word_type_ab_aval)) + return hlo.or_(a, b) # Unpacks the first element of a double_word_type. def fst(t): assert not ir.RankedTensorType(t.type).shape - st = hlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits)) - return hlo.BitcastConvertOp( + st = hlo.shift_right_logical(t, const(double_word_dtype, nbits)) + return hlo.bitcast_convert( ir.RankedTensorType.get([], etype), - hlo.ConvertOp(ir.RankedTensorType.get([], word_type), st)).result + hlo.convert(ir.RankedTensorType.get([], word_type), st)) # Unpacks the second element of a double_word_type. def snd(t, t_aval): - return hlo.BitcastConvertOp( + return hlo.bitcast_convert( mlir.aval_to_ir_type(t_aval.update(dtype=dtype)), - hlo.ConvertOp(mlir.aval_to_ir_type(t_aval.update(dtype=word_dtype)), t)).result + hlo.convert(mlir.aval_to_ir_type(t_aval.update(dtype=word_dtype)), t)) else: # The double-word trick above only works if we have a sufficiently large @@ -726,29 +726,27 @@ def _select_and_gather_add_lowering( # Packs two values into a double_word_type. def pack(a, b, ab_aval): word_type_ab_aval = ab_aval.update(dtype=word_dtype) - a = hlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp), + a = hlo.reduce_precision(a, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) - b = hlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp), + b = hlo.reduce_precision(b, exponent_bits=mlir.i32_attr(nexp), mantissa_bits=mlir.i32_attr(nmant)) - a = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), a) - b = hlo.BitcastConvertOp(mlir.aval_to_ir_type(word_type_ab_aval), b) - b = hlo.ShiftRightLogicalOp( + a = hlo.bitcast_convert(mlir.aval_to_ir_type(word_type_ab_aval), a) + b = hlo.bitcast_convert(mlir.aval_to_ir_type(word_type_ab_aval), b) + b = hlo.shift_right_logical( b, _broadcast_scalar_const(r_nbits, word_type_ab_aval)) - return hlo.OrOp(a, b) + return hlo.or_(a, b) # Unpacks the first element of a double_word_type. def fst(t): assert not ir.RankedTensorType(t.type).shape - st = hlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits)) - return hlo.BitcastConvertOp(ir.RankedTensorType.get([], etype), - st).result + st = hlo.and_(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits)) + return hlo.bitcast_convert(ir.RankedTensorType.get([], etype), st) # Unpacks the second element of a double_word_type. def snd(t, t_aval): - return hlo.BitcastConvertOp( + return hlo.bitcast_convert( mlir.aval_to_ir_type(t_aval.update(dtype=dtype)), - hlo.ShiftLeftOp(t, _broadcast_scalar_const(r_nbits, t_aval.update(dtype=word_dtype))) - ).result + hlo.shift_left(t, _broadcast_scalar_const(r_nbits, t_aval.update(dtype=word_dtype)))) assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim init = -np.inf if select_prim is lax.ge_p else np.inf diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 67670517a..187c7645f 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -1115,9 +1115,9 @@ def _threefry2x32_gpu_lowering(lowering_func, ctx, k1, k2, x1, x2): out_len = reduce(op.mul, aval_out.shape, 1) if not core.is_constant_dim(out_len): length = mlir.eval_dynamic_shape_as_tensor(ctx, [out_len]) - length = mlir.hlo.ConvertOp( + length = mlir.hlo.convert( ir.RankedTensorType.get((1,), ir.IntegerType.get_signless(64)), - length).result + length) output_shape = mlir.eval_dynamic_shape_as_tensor(ctx, aval_out.shape) else: length = int(out_len) # will be passed statically @@ -1221,20 +1221,20 @@ def iota_2x32_shape_lowering(ctx, *, shape): aval_u64 = core.ShapedArray(shape, np.dtype('uint64')) def _add(x: ir.Value, y: ir.Value) -> ir.Value: - return mlir.hlo.AddOp(x, y).result + return mlir.hlo.add(x, y) def _mul(x: core.DimSize, y: ir.Value) -> ir.Value: if core.is_constant_dim(x): x_const = mlir.ir_constant(np.array(x, np.dtype('uint64'))) else: x_const, = mlir.eval_dynamic_shape(ctx, (x,)) - x_const = hlo.ConvertOp( + x_const = hlo.convert( ir.RankedTensorType.get( (), - mlir.dtype_to_ir_type(np.dtype('uint64'))), x_const).result + mlir.dtype_to_ir_type(np.dtype('uint64'))), x_const) x_bcast = mlir.broadcast_in_dim(ctx, x_const, aval_u64, broadcast_dimensions=[]) - return mlir.hlo.MulOp(x_bcast, y).result + return mlir.hlo.multiply(x_bcast, y) assert len(shape) > 0 @@ -1244,10 +1244,9 @@ def iota_2x32_shape_lowering(ctx, *, shape): shift = mlir.ir_constant(np.array(32, np.dtype('uint64'))) shift = mlir.broadcast_in_dim(ctx, shift, aval_u64, broadcast_dimensions=[]) - counts_shifted = mlir.hlo.ShiftRightLogicalOp(counts, shift).result - counts_lo = mlir.hlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts).result - counts_hi = mlir.hlo.ConvertOp(mlir.aval_to_ir_type(aval_out), - counts_shifted).result + counts_shifted = mlir.hlo.shift_right_logical(counts, shift) + counts_lo = mlir.hlo.convert(mlir.aval_to_ir_type(aval_out), counts) + counts_hi = mlir.hlo.convert(mlir.aval_to_ir_type(aval_out), counts_shifted) return counts_hi, counts_lo mlir.register_lowering(iota_2x32_shape_p, iota_2x32_shape_lowering) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index 61d5a2a40..75972e147 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -214,7 +214,7 @@ def _tpu_custom_call_lowering( base64.b64encode(kernel_regeneration_metadata) ) if multiple_results: - results = [stablehlo.GetTupleElementOp(call, mlir.i32_attr(i)).result + results = [stablehlo.get_tuple_element(call, mlir.i32_attr(i)) for i in range(len(out_avals))] else: results = call.results diff --git a/jax/experimental/export/export.py b/jax/experimental/export/export.py index 3a59c69f2..1c7ee4ab5 100644 --- a/jax/experimental/export/export.py +++ b/jax/experimental/export/export.py @@ -720,7 +720,7 @@ def _wrap_main_func( list(new_main_op.arguments[0:nr_platform_index_args]) + util.flatten(dim_values), platform_input_types + dim_var_input_types): if arg.type != arg_type: - orig_main_args.append(hlo.ConvertOp(arg_type, arg).result) + orig_main_args.append(hlo.convert(arg_type, arg)) else: orig_main_args.append(arg) # Then the token arguments @@ -737,7 +737,7 @@ def _wrap_main_func( # output_operand_alias attribute. See b/287386268. for arg, arg_type in zip(new_main_op_array_args, array_input_types): if arg.type != arg_type: - orig_main_args.append(hlo.ConvertOp(arg_type, arg).result) + orig_main_args.append(hlo.convert(arg_type, arg)) else: orig_main_args.append(arg) call = func_dialect.CallOp(orig_output_types, @@ -1166,7 +1166,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.AbstractValue) -> ir.Value: new_ir_type = mlir.aval_to_ir_type(new_aval) if x.type != new_ir_type: - return hlo.ConvertOp(mlir.aval_to_ir_type(new_aval), x).result + return hlo.convert(mlir.aval_to_ir_type(new_aval), x) else: return x @@ -1210,7 +1210,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, for i in range(len(lowering_platforms)): branch = callee_platform_idx.regions[i].blocks.append() with ir.InsertionPoint(branch): - hlo.ReturnOp(mlir.ir_constants( + hlo.return_(mlir.ir_constants( np.int32(callee_lowering_platform_index[i]))) if callee_platform_idx.result.type != callee_type.inputs[0]: callee_platform_idx = hlo.ConvertOp(callee_type.inputs[0], diff --git a/jax/experimental/export/shape_poly.py b/jax/experimental/export/shape_poly.py index e66f0d531..5f573308d 100644 --- a/jax/experimental/export/shape_poly.py +++ b/jax/experimental/export/shape_poly.py @@ -897,7 +897,7 @@ def _dim_as_value_lowering(ctx: mlir.LoweringRuleContext, *, res, = mlir.eval_dynamic_shape(ctx, (dim,)) out_type = mlir.aval_to_ir_type(ctx.avals_out[0]) if out_type != res.type: # type: ignore - return mlir.hlo.ConvertOp(out_type, res).results + return [mlir.hlo.convert(out_type, res)] else: return [res] @@ -1161,11 +1161,11 @@ def _dimension_size_impl(arg, *, dimension): dimension_size_p.def_impl(_dimension_size_impl) def _dimension_size_lowering_rule(ctx, arg, *, dimension): - dim_size = mlir.hlo.GetDimensionSizeOp(arg, dimension) + dim_size = mlir.hlo.get_dimension_size(arg, dimension) dim_type = mlir.aval_to_ir_type(core.dim_value_aval()) - if dim_size.result.type != dim_type: - dim_size = mlir.hlo.ConvertOp(dim_type, dim_size) - return dim_size.results + if dim_size.type != dim_type: + dim_size = mlir.hlo.convert(dim_type, dim_size) + return [dim_size] mlir.register_lowering(dimension_size_p, _dimension_size_lowering_rule) diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index ab9ff7f9a..6da5af5a9 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -570,7 +570,7 @@ def _call_tf_lowering( ir.FlatSymbolRefAttr.get(fn), tuple(args_op) + captured_ops) if result_shape.is_tuple(): - flat_results = [hlo.GetTupleElementOp(call, mlir.i32_attr(i)).result + flat_results = [hlo.get_tuple_element(call, mlir.i32_attr(i)) for i in range(len(result_shapes))] else: flat_results = call.results diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 7943b94c0..c84c28c4f 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -681,8 +681,7 @@ def _bcsr_dot_general_gpu_lowering( dot_general_fn = csr_matmat_lowering x_dtype = 'B_dtype' if rhs_contract[0] == 1: - rhs = hlo.TransposeOp( - rhs, permutation=mlir.dense_int_elements([1, 0])).result + rhs = hlo.transpose(rhs, permutation=mlir.dense_int_elements([1, 0])) else: raise ValueError(f"rhs has to be 1d or 2d; get {rhs_aval.ndim}d.") diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index 269441565..8cf2aa814 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -229,7 +229,7 @@ def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo): result = coo_todense_hlo( data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype) return ( - [hlo.TransposeOp(result, mlir.dense_int_elements([1, 0])).result] + [hlo.transpose(result, mlir.dense_int_elements([1, 0]))] if transpose else [result]) diff --git a/jaxlib/ducc_fft.py b/jaxlib/ducc_fft.py index be28099e0..36d98762a 100644 --- a/jaxlib/ducc_fft.py +++ b/jaxlib/ducc_fft.py @@ -79,9 +79,9 @@ def dynamic_ducc_fft_hlo( assert 0 not in a_type.shape u8_type = ir.IntegerType.get_unsigned(8) - descriptor = hlo.ConstantOp( + descriptor = hlo.constant( ir.DenseElementsAttr.get( - np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)).result + np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)) layout = tuple(range(ndims - 1, -1, -1)) return custom_call( "dynamic_ducc_fft", diff --git a/jaxlib/gpu_rnn.py b/jaxlib/gpu_rnn.py index 4dde76a4e..59192c999 100644 --- a/jaxlib/gpu_rnn.py +++ b/jaxlib/gpu_rnn.py @@ -89,9 +89,9 @@ def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *, def _hlo_zeros_f32(shape): - return hlo.ConstantOp( + return hlo.constant( ir.DenseElementsAttr.get( - np.zeros(shape, dtype=np.float32), type=ir.F32Type.get())).result + np.zeros(shape, dtype=np.float32), type=ir.F32Type.get())) def cudnn_rnn_bwd_lowering(ctx, dy, dhn, dcn, x, h0, c0, w, y, diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index e24d8b89e..782d7eadc 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -320,7 +320,7 @@ def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, *, if dynamic_batch_dims: batch_size_val = hlo_s32(1) for b_v in batch_dims_vals: - batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) operands.append(batch_size_val) operand_layouts.append(()) @@ -406,22 +406,22 @@ def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, [0], ], operand_output_aliases={0: 0}).results - vt = hlo.TransposeOp( + vt = hlo.transpose( v, - ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result + ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))) if np.issubdtype(dtype, np.complexfloating): - vt = hlo.ComplexOp(hlo.RealOp(vt), hlo.NegOp(hlo.ImagOp(vt))).result + vt = hlo.complex(hlo.real(vt), hlo.negate(hlo.imag(vt))) if not full_matrices and not econ: - u = hlo.SliceOp( + u = hlo.slice( u, ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)), ir.DenseIntElementsAttr.get(np.array(batch_dims + (m, min(m, n)))), - ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result - vt = hlo.SliceOp( + ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))) + vt = hlo.slice( vt, ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)), ir.DenseIntElementsAttr.get(np.array(batch_dims + (min(m, n), n))), - ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result + ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))) elif m < n: lwork, opaque = gpu_solver.build_gesvd_descriptor( np.dtype(dtype), b, n, m, compute_uv, full_matrices) @@ -538,15 +538,15 @@ def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower): if not lower and platform == "cu" and m > 1: start = (0,) * len(batch_dims) + (0,) end = batch_dims + (1,) - s = hlo.SliceOp(e, intattr(start), intattr(end), intattr([1] * len(start))) + s = hlo.slice(e, intattr(start), intattr(end), intattr([1] * len(start))) s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type) - s = hlo.BroadcastInDimOp(s_type, s, intattr(range(len(dims) - 1))) + s = hlo.broadcast_in_dim(s_type, s, intattr(range(len(dims) - 1))) # The diagonals are always real; convert to complex if needed. - s = hlo.ConvertOp( + s = hlo.convert( ir.RankedTensorType.get(s_type.shape, a_type.element_type), s) - offsets = tuple(hlo.ConstantOp(intattr(i)) + offsets = tuple(hlo.constant(intattr(i)) for i in ((0,) * len(batch_dims) + (0, 1))) - a = hlo.DynamicUpdateSliceOp(a, s, offsets).result + a = hlo.dynamic_update_slice(a, s, offsets) return a, d, e, taus, info diff --git a/jaxlib/hlo_helpers.py b/jaxlib/hlo_helpers.py index 386c8fc5d..727642f5c 100644 --- a/jaxlib/hlo_helpers.py +++ b/jaxlib/hlo_helpers.py @@ -84,21 +84,21 @@ def shape_tensor(sizes: Sequence[Union[int, ir.Value]]) -> ir.Value: return hlo_const(np.array([d], dtype=np.int32)) else: if d.type != i32_type: - d = hlo.ConvertOp(i32_type, d).result - return hlo.ReshapeOp(int1d, d).result + d = hlo.convert(i32_type, d) + return hlo.reshape(int1d, d) ds = [dim_to_i32x1(sz) for sz in sizes] if not ds: return hlo_const(np.array([], np.int32)) elif len(ds) == 1: return ds[0] else: - return hlo.ConcatenateOp( - ds, ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0)).result + return hlo.concatenate( + ds, ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 0)) def hlo_const(x: np.ndarray) -> ir.Value: assert isinstance(x, np.ndarray) - return hlo.ConstantOp( - ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype))).result + return hlo.constant( + ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype))) def hlo_u8(x: int): return hlo_const(np.array(x, dtype=np.uint8)) @@ -116,7 +116,7 @@ def hlo_min(x: DimensionSize, y: DimensionSize) -> DimensionSize: x = hlo_s32(x) if type(y) is int: y = hlo_s32(y) - return hlo.MinOp(x, y).result + return hlo.minimum(x, y) def hlo_add(x: DimensionSize, y: DimensionSize) -> DimensionSize: @@ -126,7 +126,7 @@ def hlo_add(x: DimensionSize, y: DimensionSize) -> DimensionSize: x = hlo_s32(x) if type(y) is int: y = hlo_s32(y) - return hlo.AddOp(x, y).result + return hlo.add(x, y) # TODO(necula): this is identical with mlir.custom_call, but meant for use diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index 48b7f9378..c30740e00 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -51,7 +51,7 @@ def trsm_hlo(dtype, alpha, a, b, num_bd = len(batch_dims_vals) batch_size_val = hlo_s32(1) for b_v in batch_dims_vals: - batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) if dtype == np.float32: fn = "blas_strsm" @@ -119,7 +119,7 @@ def getrf_hlo(dtype, a: ir.Value, *, batch_size_val = hlo_s32(1) for b_v in batch_dims_vals: - batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) return custom_call( fn, @@ -170,7 +170,7 @@ def geqrf_hlo(dtype, a: ir.Value, *, batch_size_val = hlo_s32(1) for b_v in batch_dims_vals: - batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) shape_type_pairs: Sequence[ShapeTypePair] = [ (a_shape_vals, a_type.element_type), (batch_dims_vals + (min(m, n),), a_type.element_type), @@ -211,7 +211,7 @@ def orgqr_hlo(dtype, a: ir.Value, tau, *, num_bd = len(batch_dims_vals) batch_size_val = hlo_s32(1) for b_v in batch_dims_vals: - batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) k = tau_shape_vals[-1] assert type(k) is int @@ -281,7 +281,7 @@ def potrf_hlo(dtype, a: ir.Value, *, lower=False, num_bd = len(batch_dims_vals) batch_size_val = hlo_s32(1) for b_v in batch_dims_vals: - batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) scalar_layout = [] layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) @@ -318,7 +318,7 @@ def gesdd_hlo(dtype, a: ir.Value, *, full_matrices=True, compute_uv=True, num_bd = len(batch_dims_vals) batch_size_val = hlo_s32(1) for b_v in batch_dims_vals: - batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) i32_type = ir.IntegerType.get_signless(32) workspace: list[ShapeTypePair] @@ -445,7 +445,7 @@ def syevd_hlo(dtype, a: ir.Value, batch_size_val = hlo_s32(1) for b_v in batch_dims_vals: - batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) scalar_layout = [] shape_layout = [0] @@ -540,7 +540,7 @@ def geev_hlo(dtype, input, *, batch_size_val = hlo_s32(1) for b_v in batch_dims_vals: - batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) shape_type_pairs: Sequence[ShapeTypePair] = workspaces + eigvals + [ (input_shape_vals, eigvecs_type), @@ -560,7 +560,7 @@ def geev_hlo(dtype, input, *, result_shapes=result_shapes, ).results if real: - return (hlo.ComplexOp(out[3], out[4]).result, out[5], out[6], out[7]) + return (hlo.complex(out[3], out[4]), out[5], out[6], out[7]) else: return out[2:6] @@ -615,7 +615,7 @@ def gees_hlo(dtype, a, *, jobvs=True, sort=False, select=None, scalar_layout = [] batch_size_val = hlo_s32(1) for b_v in batch_dims_vals: - batch_size_val = hlo.MulOp(batch_size_val, ensure_hlo_s32(b_v)).result + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) shape_type_pairs = workspaces + eigvals + [ (a_shape_vals, etype), (batch_dims_vals, i32_type), diff --git a/tests/export_test.py b/tests/export_test.py index c79e9384e..e6c7f6d9e 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -88,7 +88,7 @@ def lowering_testing_primitive_with_effect(ctx, a, *, effect_class_name: str): if "Ordered" in effect_class_name: token_in = ctx.tokens_in.get(_testing_effects[effect_class_name])[0] ctx.set_tokens_out(mlir.TokenSet({_testing_effects[effect_class_name]: (token_in,)})) - return mlir.hlo.AddOp(a, a).results + return [mlir.hlo.add(a, a)] mlir.register_lowering(testing_primitive_with_effect_p, lowering_testing_primitive_with_effect)