diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 71a9831e8..7e88fe03b 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -717,16 +717,23 @@ def _complex_mul(mul, x, y): k3 = mul(x_im, lax.add(y_re, y_im)) return lax.complex(lax.sub(k1, k3), lax.add(k1, k2)) + +_real_dtype = lambda dtype: np.finfo(dtype).dtype + def _conv_general_dilated_lower( ctx, avals_in, avals_out, lhs, rhs, *, window_strides, padding, lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count, - batch_group_count, precision, expand_complex_convolutions=False, - **unused_kwargs): + batch_group_count, precision, preferred_element_type, + expand_complex_convolutions=False, **unused_kwargs): lhs_aval, rhs_aval = avals_in aval_out, = avals_out assert isinstance(dimension_numbers, ConvDimensionNumbers) dtype = lhs_aval.dtype if expand_complex_convolutions and np.issubdtype(dtype, np.complexfloating): + if preferred_element_type is not None: + # Convert complex dtype to types used for real and imaginary parts + assert np.issubdtype(preferred_element_type, np.complexfloating) + preferred_element_type = _real_dtype(preferred_element_type) complex_conv = mlir.lower_fun( partial( _complex_mul, @@ -734,7 +741,8 @@ def _conv_general_dilated_lower( padding=padding, lhs_dilation=lhs_dilation, rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers, feature_group_count=feature_group_count, - batch_group_count=batch_group_count, precision=precision)), + batch_group_count=batch_group_count, precision=precision, + preferred_element_type=preferred_element_type)), multiple_results=False) return complex_conv(ctx, avals_in, avals_out, lhs, rhs) diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 1a3a32d3b..8960f79fb 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1784,7 +1784,7 @@ mlir.register_lowering(complex_p, partial(_nary_lower_mhlo, mhlo.ComplexOp)) conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj') -def _conj_impl(x, *, input_dtype): +def _conj_impl(x, **kw): if dtypes.issubdtype(x.dtype, np.complexfloating): return complex(real(x), -imag(x)) else: @@ -2075,30 +2075,6 @@ def _minmax_translation_rule(ctx, avals_in, avals_out, x, y, *, op_minmax=None, else: return [op_minmax(x, y)] -def _minmax_mhlo(op, cmp, x, y): - """Min/max that compares complex values lexicographically as pairs.""" - tensor_type = ir.RankedTensorType(x.type) - if ir.ComplexType.isinstance(tensor_type.element_type): - rx = mhlo.RealOp(x).result - ry = mhlo.RealOp(y).result - dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)] - bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1)) - real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"), - ir.StringAttr.get("FLOAT")) - real_cmp = mhlo.CompareOp(bool_shape, rx, ry, - ir.StringAttr.get(cmp), - ir.StringAttr.get("FLOAT")) - imag_cmp = mhlo.CompareOp(bool_shape, mhlo.ImagOp(x).result, - mhlo.ImagOp(y).result, - ir.StringAttr.get(cmp), - ir.StringAttr.get("FLOAT")) - which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result - return mhlo.SelectOp(which, x, y) - else: - return op(x, y) - -_min_mhlo = partial(_minmax_mhlo, mhlo.MinOp, "LT") -_max_mhlo = partial(_minmax_mhlo, mhlo.MaxOp, "GT") max_p: core.Primitive = standard_naryop( [_any, _any], 'max', translation_rule=partial( @@ -2106,7 +2082,7 @@ max_p: core.Primitive = standard_naryop( ad.defjvp2(max_p, lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) -mlir.register_lowering(max_p, partial(_nary_lower_mhlo, _max_mhlo)) +mlir.register_lowering(max_p, partial(_nary_lower_mhlo, mlir.max_mhlo)) min_p: core.Primitive = standard_naryop( [_any, _any], 'min', translation_rule=partial( @@ -2114,7 +2090,7 @@ min_p: core.Primitive = standard_naryop( ad.defjvp2(min_p, lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) -mlir.register_lowering(min_p, partial(_nary_lower_mhlo, _min_mhlo)) +mlir.register_lowering(min_p, partial(_nary_lower_mhlo, mlir.min_mhlo)) shift_left_p = standard_naryop([_int, _int], 'shift_left') ad.defjvp_zero(shift_left_p) @@ -3590,9 +3566,9 @@ mlir.register_lowering(reduce_or_p, partial(_unary_reduce_lower, mhlo.OrOp, lambda dtype: np.array(False, dtype))) mlir.register_lowering(reduce_and_p, partial(_unary_reduce_lower, mhlo.AndOp, lambda dtype: np.array(True, dtype))) -mlir.register_lowering(reduce_min_p, partial(_unary_reduce_lower, _min_mhlo, +mlir.register_lowering(reduce_min_p, partial(_unary_reduce_lower, mlir.min_mhlo, _get_min_identity)) -mlir.register_lowering(reduce_max_p, partial(_unary_reduce_lower, _max_mhlo, +mlir.register_lowering(reduce_max_p, partial(_unary_reduce_lower, mlir.max_mhlo, _get_max_identity)) @@ -3960,7 +3936,7 @@ xla.register_translation(infeed_p, _infeed_translation_rule) def _infeed_lowering(ctx, avals_in, avals_out, token, *, shapes, partitions): assert partitions is None, partitions # TODO(phawkins): implement me. - output_types = map(mlir.aval_to_ir_types, avals_out[:-1]) + output_types = safe_map(mlir.aval_to_ir_types, avals_out[:-1]) flat_output_types = util.flatten(output_types) output_tuple_type = ir.TupleType.get_tuple(flat_output_types) # TODO(phawkins): verify `shapes` have a major-to-minor layout. diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 488d2c6ac..c8023522e 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -1984,7 +1984,8 @@ def _real_dtype(dtype): return np.finfo(dtype).dtype def _scatter_add_lower_gpu(ctx, avals_in, avals_out, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): - if operand.dtype != np.complex128: + operand_aval_in, _, updates_aval_in = avals_in + if operand_aval_in.dtype != np.complex128: return _scatter_lower(ctx, avals_in, avals_out, operand, indices, updates, update_jaxpr=update_jaxpr, update_consts=update_consts, @@ -1992,7 +1993,6 @@ def _scatter_add_lower_gpu(ctx, avals_in, avals_out, operand, indices, updates, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) assert mode == GatherScatterMode.PROMISE_IN_BOUNDS, mode - _, _, updates_aval_in = avals_in aval_out, = avals_out dnums = dimension_numbers scatter_dnums = mhlo.ScatterDimensionNumbers.get( diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index bbb290eec..79440d944 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -510,9 +510,9 @@ def _reduce_window_lower( mlir.register_lowering(reduce_window_sum_p, partial( _reduce_window_lower, mhlo.AddOp, lambda _: 0)) mlir.register_lowering(reduce_window_min_p, partial( - _reduce_window_lower, mhlo.MinOp, lax._get_min_identity)) + _reduce_window_lower, mlir.min_mhlo, lax._get_min_identity)) mlir.register_lowering(reduce_window_max_p, partial( - _reduce_window_lower, mhlo.MaxOp, lax._get_max_identity)) + _reduce_window_lower, mlir.max_mhlo, lax._get_max_identity)) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 9779e6bb8..de500024e 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -39,6 +39,7 @@ from jax._src.util import prod, unzip2 from jax.tree_util import tree_multimap, tree_all, tree_map, tree_reduce from jax._src.lib import xla_bridge from jax._src import dispatch +from jax.interpreters import mlir from jax.interpreters import xla from jax.experimental.maps import mesh @@ -368,18 +369,25 @@ def count_jit_and_pmap_compiles(): # No need to clear any caches since we generally jit and pmap fresh callables # in tests. - jaxpr_subcomp = xla.jaxpr_subcomp + xla_jaxpr_subcomp = xla.jaxpr_subcomp + mlir_jaxpr_subcomp = mlir.jaxpr_subcomp count = [0] - def jaxpr_subcomp_and_count(*args, **kwargs): + def xla_jaxpr_subcomp_and_count(*args, **kwargs): count[0] += 1 - return jaxpr_subcomp(*args, **kwargs) + return xla_jaxpr_subcomp(*args, **kwargs) - xla.jaxpr_subcomp = jaxpr_subcomp_and_count + def mlir_jaxpr_subcomp_and_count(*args, **kwargs): + count[0] += 1 + return mlir_jaxpr_subcomp(*args, **kwargs) + + xla.jaxpr_subcomp = xla_jaxpr_subcomp_and_count + mlir.jaxpr_subcomp = mlir_jaxpr_subcomp_and_count try: yield count finally: - xla.jaxpr_subcomp = jaxpr_subcomp + xla.jaxpr_subcomp = xla_jaxpr_subcomp + mlir.jaxpr_subcomp = mlir_jaxpr_subcomp @contextmanager def assert_num_jit_and_pmap_compilations(times): diff --git a/jax/experimental/ann.py b/jax/experimental/ann.py index f4d8c12e4..419c0eadb 100644 --- a/jax/experimental/ann.py +++ b/jax/experimental/ann.py @@ -178,9 +178,10 @@ def _comparator_builder(operand, op_type, is_max_k): return c.build(cmp_result) -def _approx_top_k_tpu_translation(c, operand, k, reduction_dimension, - recall_target, is_max_k, +def _approx_top_k_tpu_translation(ctx, avals_in, avals_out, operand, *, k, + reduction_dimension, recall_target, is_max_k, reduction_input_size_override): + c = ctx.builder op_shape = c.get_shape(operand) if not op_shape.is_array(): raise ValueError('operand must be an array, but was {}'.format(op_shape)) @@ -203,14 +204,17 @@ def _approx_top_k_tpu_translation(c, operand, k, reduction_dimension, reduction_dimension) init_val = xc.ops.Constant(c, init_literal) init_arg = xc.ops.Constant(c, np.int32(-1)) - return xc.ops.ApproxTopK(c, [operand, iota], [init_val, init_arg], k, - reduction_dimension, comparator, recall_target, True, - reduction_input_size_override) + out = xc.ops.ApproxTopK(c, [operand, iota], [init_val, init_arg], k, + reduction_dimension, comparator, recall_target, True, + reduction_input_size_override) + return xla.xla_destructure(c, out) -def _approx_top_k_fallback_translation(c, operand, k, reduction_dimension, +def _approx_top_k_fallback_translation(ctx, avals_in, avals_out, operand, *, k, + reduction_dimension, recall_target, is_max_k, reduction_input_size_override): + c = ctx.builder op_shape = c.get_shape(operand) if not op_shape.is_array(): raise ValueError('operand must be an array, but was {}'.format(op_shape)) @@ -226,7 +230,7 @@ def _approx_top_k_fallback_translation(c, operand, k, reduction_dimension, args = xc.ops.GetTupleElement(val_arg, 1) sliced_vals = xc.ops.SliceInDim(vals, 0, k, 1, reduction_dimension) sliced_args = xc.ops.SliceInDim(args, 0, k, 1, reduction_dimension) - return xc.ops.Tuple(c, [sliced_vals, sliced_args]) + return sliced_vals, sliced_args def _approx_top_k_batch_rule(batched_args, batch_dims, *, k, @@ -292,11 +296,8 @@ approx_top_k_p = core.Primitive('approx_top_k') approx_top_k_p.multiple_results = True approx_top_k_p.def_impl(partial(xla.apply_primitive, approx_top_k_p)) approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval) -xla.backend_specific_translations['tpu'][ - approx_top_k_p] = _approx_top_k_tpu_translation -xla.backend_specific_translations['cpu'][ - approx_top_k_p] = _approx_top_k_fallback_translation -xla.backend_specific_translations['gpu'][ - approx_top_k_p] = _approx_top_k_fallback_translation +xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation) +xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation, + platform='tpu') batching.primitive_batchers[approx_top_k_p] = _approx_top_k_batch_rule ad.primitive_jvps[approx_top_k_p] = _approx_top_k_jvp diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 01868955a..6bbf69171 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -69,7 +69,7 @@ def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i) # IR Types # Non-canonicalized dtype to IR type mapping. -_dtype_to_ir_type : Dict[np.dtype, Callable[[], ir.Type]] = { +dtype_to_ir_type : Dict[np.dtype, Callable[[], ir.Type]] = { np.dtype(dtypes.float0): partial(ir.IntegerType.get_signless, 1), np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1), np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8), @@ -91,13 +91,13 @@ _dtype_to_ir_type : Dict[np.dtype, Callable[[], ir.Type]] = { def _array_ir_types(aval: core.ShapedArray) -> ir.Type: try: - ir_type_factory = _dtype_to_ir_type[aval.dtype] + ir_type_factory = dtype_to_ir_type[aval.dtype] except KeyError as err: raise TypeError( f"No dtype_to_ir_type handler for dtype: {aval.dtype}") from err return (ir.RankedTensorType.get(aval.shape, ir_type_factory()),) -_ir_type_handlers: Dict[Type[core.AbstractValue], +ir_type_handlers: Dict[Type[core.AbstractValue], Callable[[Any], Sequence[ir.Type]]] = {} def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]: @@ -106,14 +106,14 @@ def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]: In general, a JAX value may be represented by multiple IR values, so this function returns multiple types.""" try: - return _ir_type_handlers[type(aval)](aval) + return ir_type_handlers[type(aval)](aval) except KeyError as err: raise TypeError(f"No ir_type_handler for aval type: {type(aval)}") from err -_ir_type_handlers[core.AbstractUnit] = lambda _: () -_ir_type_handlers[core.ShapedArray] = _array_ir_types -_ir_type_handlers[core.ConcreteArray] = _array_ir_types -_ir_type_handlers[core.AbstractToken] = lambda _: [mhlo.TokenType.get()] +ir_type_handlers[core.AbstractUnit] = lambda _: () +ir_type_handlers[core.ShapedArray] = _array_ir_types +ir_type_handlers[core.ConcreteArray] = _array_ir_types +ir_type_handlers[core.AbstractToken] = lambda _: [mhlo.TokenType.get()] def aval_to_ir_type(aval: core.AbstractValue) -> ir.Type: """Convenience wrapper around aval_to_ir_types for single types. @@ -310,6 +310,7 @@ def register_lowering(prim: core.Primitive, rule: LoweringRule, def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x def wrap_singleton_ir_values(x: Union[ir.Value, Sequence[ir.Value]] ) -> Sequence[ir.Value]: + """Adds a consistent tuples to a mixture of tupled and untuple values.""" return (x,) if isinstance(x, ir.Value) else tuple(x) def flatten_lowering_ir_args( @@ -568,6 +569,32 @@ register_lowering(ad_util.stop_gradient_p, lambda ctx, avals_in, avals_out, x: [x]) +def _minmax_mhlo(op, cmp, x, y): + """Min/max that compares complex values lexicographically as pairs.""" + tensor_type = ir.RankedTensorType(x.type) + if ir.ComplexType.isinstance(tensor_type.element_type): + rx = mhlo.RealOp(x).result + ry = mhlo.RealOp(y).result + dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)] + bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1)) + real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"), + ir.StringAttr.get("FLOAT")) + real_cmp = mhlo.CompareOp(bool_shape, rx, ry, + ir.StringAttr.get(cmp), + ir.StringAttr.get("FLOAT")) + imag_cmp = mhlo.CompareOp(bool_shape, mhlo.ImagOp(x).result, + mhlo.ImagOp(y).result, + ir.StringAttr.get(cmp), + ir.StringAttr.get("FLOAT")) + which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result + return mhlo.SelectOp(which, x, y) + else: + return op(x, y) + +min_mhlo = partial(_minmax_mhlo, mhlo.MinOp, "LT") +max_mhlo = partial(_minmax_mhlo, mhlo.MaxOp, "GT") + + # MLIR lowerings for lax primitives def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext, diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index b71a7b15b..1df788d9f 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -1637,7 +1637,8 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform): x = mhlo.ConvertOp(mlir.aval_to_ir_type(aval), x).result dims = list(aval.shape) - padded = mlir.full_like_aval(0, aval.update(shape=[axis_env.sizes[-1]] + dims)) + padded = mlir.full_like_aval( + 0, aval.update(shape=[axis_env.sizes[-1]] + dims)) zero = mlir.ir_constant(np.zeros((), dtype=np.uint32)) idxs = [_unravel_index_mhlo(axis_env)] + [zero] * len(dims) padded = mhlo.DynamicUpdateSliceOp( @@ -1695,7 +1696,8 @@ def _pmap_lowering(ctx, avals_in, avals_out, *in_nodes, axis_name, axis_env = new_env, name_stack=xla.extend_name_stack(ctx.name_stack, util.wrap_name(name, 'pmap'))) - sharded_outs = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *in_nodes_sharded) + sharded_outs = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, (), + *in_nodes_sharded) out_avals = [v.aval for v in call_jaxpr.outvars] outs = [_mhlo_unshard(aval, new_env, out_axis, shard, platform=ctx.platform) for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)] diff --git a/tests/api_test.py b/tests/api_test.py index 77d676400..eab57b4d1 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -45,6 +45,7 @@ from jax._src import api, dtypes from jax.core import Primitive from jax.errors import UnexpectedTracerError from jax.interpreters import ad +from jax.interpreters import mlir from jax.interpreters import xla from jax.interpreters import pxla from jax.interpreters.sharded_jit import PartitionSpec as P @@ -243,6 +244,8 @@ class CPPJitTest(jtu.BufferDonationTestCase): # Jit and Donate arguments def test_jit_donate_argnums_warning_raised(self): + if jax.config.jax_enable_mlir: + raise unittest.SkipTest("Buffer donation not yet implemented via MLIR") x = jnp.array([1.0, 2.0], jnp.float32) y = jnp.array([1, 2], jnp.int32) f = self.jit(lambda x, y: x.sum() + y.sum(), donate_argnums=(0, 1)) @@ -984,8 +987,8 @@ class APITest(jtu.JaxTestCase): foo_p.def_abstract_eval(lambda x: x) - jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError, - "XLA translation rule for primitive 'foo' not found") + jtu.check_raises_regexp(lambda: jit(foo)(1.0), NotImplementedError, + ".* rule for primitive 'foo' not found.*") foo_p.def_impl(lambda x: x) ad.defjvp(foo_p, lambda g, x: foo(g)) @@ -2406,18 +2409,24 @@ class APITest(jtu.JaxTestCase): def g(x): return f(2, x) - jaxpr_subcomp = xla.jaxpr_subcomp + xla_jaxpr_subcomp = xla.jaxpr_subcomp + mlir_jaxpr_subcomp = mlir.jaxpr_subcomp jaxprs = [] - def jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs): + def xla_jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs): jaxprs.append(jaxpr) - return jaxpr_subcomp(c, jaxpr, *args, **kwargs) + return xla_jaxpr_subcomp(c, jaxpr, *args, **kwargs) + def mlir_jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs): + jaxprs.append(jaxpr) + return mlir_jaxpr_subcomp(c, jaxpr, *args, **kwargs) try: - xla.jaxpr_subcomp = jaxpr_subcomp_and_collect + xla.jaxpr_subcomp = xla_jaxpr_subcomp_and_collect + mlir.jaxpr_subcomp = mlir_jaxpr_subcomp_and_collect ans = g(3) finally: - xla.jaxpr_subcomp = jaxpr_subcomp + xla.jaxpr_subcomp = xla_jaxpr_subcomp + mlir.jaxpr_subcomp = mlir_jaxpr_subcomp self.assertEqual(ans, (7, 3)) self.assertLen(jaxprs, 2) diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py index db8808b57..31b54545c 100644 --- a/tests/custom_object_test.py +++ b/tests/custom_object_test.py @@ -21,7 +21,9 @@ import jax.numpy as jnp from jax import core, jit, lax, make_jaxpr from jax._src import device_array from jax._src import dispatch +from jax.interpreters import mlir from jax.interpreters import xla +from jax._src.lib.mlir import ir from jax._src.lib import xla_bridge, xla_client xops = xla_client.ops xb = xla_bridge @@ -137,6 +139,15 @@ dispatch.num_buffers_handlers[AbstractSparseArray] = lambda _: 2 xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler xla.register_constant_handler(SparseArray, sparse_array_constant_handler) +def sparse_array_mlir_type_handler(a): + return ( + ir.RankedTensorType.get( + a.data_aval.shape, mlir.dtype_to_ir_type[a.data_aval.dtype]()), + ir.RankedTensorType.get( + a.indices_aval.shape, mlir.dtype_to_ir_type[a.indices_aval.dtype]()), + ) + +mlir.ir_type_handlers[AbstractSparseArray] = sparse_array_mlir_type_handler sp_indices_p = core.Primitive('sp_indices') @@ -155,6 +166,11 @@ def _sp_indices_translation_rule(ctx, avals_in, avals_out, data, indices): # because it leads to infinite recursion. xla.register_translation(sp_indices_p, _sp_indices_translation_rule) +def _sp_indices_mhlo_lowering(ctx, avals_in, avals_out, data_and_indices): + return [data_and_indices[1]] + +mlir.register_lowering(sp_indices_p, _sp_indices_mhlo_lowering) + sp_data_p = core.Primitive('sp_data') @sp_data_p.def_impl @@ -172,6 +188,11 @@ def _sp_data_translation_rule(ctx, avals_in, avals_out, data, indices): # because it leads to infinite recursion. xla.register_translation(sp_data_p, _sp_data_translation_rule) +def _sp_data_mhlo_lowering(ctx, avals_in, avals_out, data_and_indices): + return [data_and_indices[0]] + +mlir.register_lowering(sp_data_p, _sp_data_mhlo_lowering) + def identity(x): return identity_p.bind(x) @@ -189,6 +210,10 @@ xla.register_translation( identity_p, xla.lower_fun(_identity_impl, multiple_results=False, new_style=True)) + +mlir.register_lowering( + identity_p, mlir.lower_fun(_identity_impl, multiple_results=False)) + def split(x): return split_p.bind(x)