(NFC) Prepare for migration from producing MHLO to producing StableHLO

This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
This commit is contained in:
Eugene Burmako 2022-12-15 20:59:34 -08:00 committed by jax authors
parent 523c6f7a53
commit b8ae8e3fa1
49 changed files with 991 additions and 882 deletions

View File

@ -169,7 +169,7 @@ def prepare_wheel(sources_path):
copy_to_jaxlib("__main__/jaxlib/init.py", dst_filename="__init__.py")
copy_to_jaxlib(f"__main__/jaxlib/cpu_feature_guard.{pyext}")
copy_to_jaxlib("__main__/jaxlib/lapack.py")
copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py")
copy_to_jaxlib("__main__/jaxlib/hlo_helpers.py")
copy_to_jaxlib("__main__/jaxlib/ducc_fft.py")
copy_to_jaxlib("__main__/jaxlib/gpu_prng.py")
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")

View File

@ -34,7 +34,7 @@ from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.lax import lax as lax_internal
from jax._src.lax import convolution as lax_convolution
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.traceback_util import api_boundary
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
safe_zip, merge_lists, weakref_lru_cache)
@ -623,9 +623,9 @@ def _optimization_barrier_lowering_rule(ctx, *args):
flat_barrier_types = util.flatten(barrier_types)
flat_args = mlir.flatten_lowering_ir_args(args)
if xc.mlir_api_version < 40:
barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
barrier_op = hlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
else:
barrier_op = mhlo.OptimizationBarrierOp(flat_args)
barrier_op = hlo.OptimizationBarrierOp(flat_args)
return util.unflatten(barrier_op.results, map(len, barrier_types))
def _optimization_barrier(arg):

View File

@ -1205,8 +1205,8 @@ unreachable_p.def_impl(unreachable_impl)
# Translation raises an exception
# TODO(frostig,mattjj): We have no good way to translate a function
# that errs. Since MHLO lowering over-approximates concrete evaluation,
# we err on MHLO lowering for the time being.
# that errs. Since MLIR lowering over-approximates concrete evaluation,
# we err on MLIR lowering for the time being.
mlir.register_lowering(unreachable_p, unreachable_impl)
# Abstract evaluation proceeds without issue, to allow for staging

View File

@ -38,7 +38,7 @@ from jax._src import util
from jax._src.lax import control_flow as lcf
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
import jax.numpy as jnp
import numpy as np
@ -335,7 +335,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
# partitioner runs so we keep it alive by attaching it to the executable.
ctx.module_context.add_keepalive(sharding_callback_info)
mhlo.CustomCallOp([value.type], [value],
hlo.CustomCallOp([value.type], [value],
call_target_name=ir.StringAttr.get(
_INSPECT_SHARDING_CALL_NAME),
has_side_effect=ir.BoolAttr.get(True),

View File

@ -67,7 +67,7 @@ FLAGS = flags.FLAGS
flags.DEFINE_string(
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
help="Path to which HLO/MHLO IR that is emitted by JAX as input to the "
help="Path to which the IR that is emitted by JAX as input to the "
"compiler should be dumped as text files. Optional. If omitted, JAX "
"will not dump IR.")

View File

@ -41,7 +41,7 @@ from jax._src.traceback_util import api_boundary
from jax._src.util import (safe_map, extend_name_stack, split_list,
partition_list)
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
import numpy as np
from jax._src.lax.control_flow.common import (
@ -806,10 +806,10 @@ def _cond_lowering(ctx, index, *args, branches, linear):
*output_token_types, *map(mlir.aval_to_ir_types, ctx.avals_out)]
flat_output_types = util.flatten(output_types)
# mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
# CaseOp takes a single argument 'index' and the corresponding blocks
# have no arguments; the computation within the block uses implicit
# captures.
case_op = mhlo.CaseOp(flat_output_types, index=index,
case_op = hlo.CaseOp(flat_output_types, index=index,
num_branches=len(branches))
name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
for i, jaxpr in enumerate(branches):
@ -824,7 +824,7 @@ def _cond_lowering(ctx, index, *args, branches, linear):
dim_var_values=ctx.dim_var_values)
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
out_vals = [*out_tokens, *out_vals]
mhlo.ReturnOp(util.flatten(out_vals))
hlo.ReturnOp(util.flatten(out_vals))
tokens_and_outputs = util.unflatten(case_op.results, map(len, output_types))
tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])

View File

@ -43,7 +43,7 @@ from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
from jax._src.traceback_util import api_boundary
from jax._src.util import (
@ -146,7 +146,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
output arrays. (None is actually an empty pytree.)
Also unlike that Python version, :func:`~scan` is a JAX primitive and is
lowered to a single XLA While HLO. That makes it useful for reducing
lowered to a single WhileOp. That makes it useful for reducing
compilation times for JIT-compiled functions, since native Python
loop constructs in an :func:`~jax.jit` function are unrolled, leading to large
XLA computations.
@ -1041,7 +1041,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
return val
Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
to a single XLA While HLO. That makes it useful for reducing compilation times
to a single WhileOp. That makes it useful for reducing compilation times
for jit-compiled functions, since native Python loop constructs in an ``@jit``
function are unrolled, leading to large XLA computations.
@ -1420,7 +1420,7 @@ def _while_transpose_error(*_, **kwargs):
# break
# token, x = body(token, x)
# ```
# Unfortunately, with an MHLO while we can't (1) return multiple values
# Unfortunately, with a WhileOp we can't (1) return multiple values
# from a `cond` and (2) can't break a while loop. We thus adopt the
# following rewrite strategy:
# ```
@ -1471,7 +1471,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
args = [*tokens, *args]
flat_args = mlir.flatten_lowering_ir_args(args)
while_op = mhlo.WhileOp(flat_loop_carry_types, flat_args)
while_op = hlo.WhileOp(flat_loop_carry_types, flat_args)
# Loop condition
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
@ -1498,12 +1498,12 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
tokens_in=mlir.TokenSet(),
tokens_out=None)
pred, = lax._unary_reduce_lower(
mhlo.OrOp,
hlo.OrOp,
lambda dtype: np.array(False, dtype),
pred_ctx,
pred,
axes=tuple(range(len(pred_aval.shape))))
mhlo.ReturnOp([pred])
hlo.ReturnOp([pred])
# Loop body
body_block = while_op.regions[1].blocks.append(*flat_loop_carry_types)
@ -1531,10 +1531,10 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
_map(mlir.ir_constants, cond_jaxpr.consts),
*(x + z), dim_var_values=ctx.dim_var_values)
new_z = _map(
partial(_pred_bcast_select_mhlo, ctx, pred_aval, body_pred), new_z, z,
partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z,
body_jaxpr.out_avals)
mhlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x),
hlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x),
*util.flatten(y), *util.flatten(new_z)])
outputs = util.unflatten(while_op.results, _map(len, loop_carry_types))
@ -1566,16 +1566,16 @@ mlir.register_lowering(while_p, _while_lowering)
core.custom_typechecks[while_p] = _while_typecheck
def _pred_bcast_select_mhlo(ctx,
def _pred_bcast_select_hlo(ctx,
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
if x_y_aval is core.abstract_token:
x, = xs
y, = ys
if xc.mlir_api_version < 40:
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
return [hlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
else:
return [mhlo.AfterAllOp([x, y]).result]
return [hlo.AfterAllOp([x, y]).result]
else:
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
x, = xs
@ -1589,7 +1589,7 @@ def _pred_bcast_select_mhlo(ctx,
x_y_shape = x_y_aval.shape
bcast_pred = mlir.broadcast_in_dim(ctx, pred, core.DShapedArray(x_y_shape, np.dtype(np.bool_)),
broadcast_dimensions=list(range(len(pred_aval.shape))))
return mhlo.SelectOp(bcast_pred, x, y).results
return hlo.SelectOp(bcast_pred, x, y).results
### fori_loop

View File

@ -26,7 +26,7 @@ from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src import util
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client
_max = builtins.max
@ -688,7 +688,7 @@ def _conv_general_dilated_lower(
return complex_conv(ctx, lhs, rhs)
lhs_spec, rhs_spec, out_spec = dimension_numbers
dnums = mhlo.ConvDimensionNumbers.get(
dnums = hlo.ConvDimensionNumbers.get(
input_batch_dimension=lhs_spec[0],
input_feature_dimension=lhs_spec[1],
input_spatial_dimensions=list(lhs_spec[2:]),
@ -703,7 +703,7 @@ def _conv_general_dilated_lower(
padding = np.zeros((0, 2), dtype=np.int64)
window_reversal = mlir.dense_bool_elements([False] * num_spatial_dims)
return [
mhlo.ConvolutionOp(
hlo.ConvolutionOp(
mlir.aval_to_ir_type(aval_out),
lhs,
rhs,

View File

@ -26,7 +26,7 @@ from jax._src.util import prod
from jax import lax
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client
from jax._src.lib import ducc_fft
from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact
@ -104,7 +104,7 @@ def fft_abstract_eval(x, fft_type, fft_lengths):
def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
return [
mhlo.FftOp(x, mhlo.FftTypeAttr.get(fft_type.name),
hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name),
mlir.dense_int_elements(fft_lengths)).result
]
@ -113,8 +113,12 @@ def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (fft); b/261671778")
x_aval, = ctx.avals_in
if xla_client.mlir_api_version < 41:
return [ducc_fft.ducc_fft_mhlo(x, x_aval.dtype, fft_type=fft_type,
fft_lengths=fft_lengths)]
else:
return [ducc_fft.ducc_fft_hlo(x, x_aval.dtype, fft_type=fft_type,
fft_lengths=fft_lengths)]
def _naive_rfft(x, fft_lengths):
y = fft(x, xla_client.FftType.FFT, fft_lengths)

View File

@ -58,7 +58,7 @@ from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.lax.utils import (
_input_dtype,
standard_abstract_eval,
@ -1635,10 +1635,10 @@ def _maybe_broadcast(target_shape, x):
squeeze_shape = [x_shape[i] for i in dims]
return broadcast_in_dim(reshape(x, squeeze_shape), target_shape, dims)
def broadcast_mhlo(
def broadcast_hlo(
aval_out: core.ShapedArray, avals: Sequence[core.ShapedArray],
args: Sequence[ir.Value]) -> Sequence[ir.Value]:
"""Broadcasts MHLO values with broadcast-compatible shapes to the same shape.
"""Broadcasts HLO values with broadcast-compatible shapes to the same shape.
"""
out = []
for aval, arg in zip(avals, args):
@ -1647,24 +1647,23 @@ def broadcast_mhlo(
dims = mlir.dense_int_elements(
range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape)))
if any(isinstance(d, ir.Value) for d in aval_out.shape):
arg = mhlo.DynamicBroadcastInDimOp(
arg = hlo.DynamicBroadcastInDimOp(
mlir.aval_to_ir_type(aval_out), arg,
mlir.shape_tensor(aval_out.shape), dims).result
else:
arg = mhlo.BroadcastInDimOp(
arg = hlo.BroadcastInDimOp(
mlir.aval_to_ir_type(aval.update(shape=aval_out.shape)), arg,
dims).result
out.append(arg)
return out
def _nary_lower_mhlo(op: Callable, ctx,
def _nary_lower_hlo(op: Callable, ctx,
*args: Union[ir.Value, Sequence[ir.Value]],
explicit_type=False, **params):
"""Lowers an elementwise operator to its MHLO/CHLO equivalent.
"""Lowers an elementwise operator to its MLIR equivalent.
Args:
explicit_type: does the MHLO/CHLO operator require its output type to be
provided?
explicit_type: does the MLIR op require its output type to be provided?
"""
del params
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
@ -1696,76 +1695,76 @@ _ordered = _int | _float | _bool
neg_p = standard_unop(_num, 'neg')
ad.deflinear2(neg_p, lambda t, operand: [neg(t)])
mlir.register_lowering(neg_p, partial(_nary_lower_mhlo, mhlo.NegOp))
mlir.register_lowering(neg_p, partial(_nary_lower_hlo, hlo.NegOp))
sign_p = standard_unop(_num, 'sign')
ad.defjvp_zero(sign_p)
def _sign_lower_mhlo(ctx, x):
def _sign_lower_hlo(ctx, x):
x_aval, = ctx.avals_in
if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger):
return mhlo.SelectOp(
mlir.compare_mhlo(x, mlir.full_like_aval(ctx, 0, x_aval), 'EQ',
return hlo.SelectOp(
mlir.compare_hlo(x, mlir.full_like_aval(ctx, 0, x_aval), 'EQ',
'UNSIGNED').result,
mlir.full_like_aval(ctx, 0, x_aval),
mlir.full_like_aval(ctx, 1, x_aval)).results
return mhlo.SignOp(x).results
return hlo.SignOp(x).results
mlir.register_lowering(sign_p, _sign_lower_mhlo)
mlir.register_lowering(sign_p, _sign_lower_hlo)
nextafter_p = standard_naryop([_float, _float], 'nextafter')
mlir.register_lowering(nextafter_p, partial(_nary_lower_mhlo, chlo.NextAfterOp))
mlir.register_lowering(nextafter_p, partial(_nary_lower_hlo, chlo.NextAfterOp))
floor_p = standard_unop(_float, 'floor')
ad.defjvp_zero(floor_p)
mlir.register_lowering(floor_p, partial(_nary_lower_mhlo, mhlo.FloorOp))
mlir.register_lowering(floor_p, partial(_nary_lower_hlo, hlo.FloorOp))
ceil_p = standard_unop(_float, 'ceil')
ad.defjvp_zero(ceil_p)
mlir.register_lowering(ceil_p, partial(_nary_lower_mhlo, mhlo.CeilOp))
mlir.register_lowering(ceil_p, partial(_nary_lower_hlo, hlo.CeilOp))
round_p = standard_unop(_float, 'round')
ad.defjvp_zero(round_p)
def _round_lower(ctx, x, *, rounding_method):
if rounding_method is RoundingMethod.AWAY_FROM_ZERO:
return mhlo.RoundOp(x).results
return hlo.RoundOp(x).results
else:
assert rounding_method is RoundingMethod.TO_NEAREST_EVEN
return mhlo.RoundNearestEvenOp(x).results
return hlo.RoundNearestEvenOp(x).results
mlir.register_lowering(round_p, _round_lower)
is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite')
ad.defjvp_zero(is_finite_p)
mlir.register_lowering(is_finite_p, partial(_nary_lower_mhlo, mhlo.IsFiniteOp))
mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.IsFiniteOp))
exp_p = standard_unop(_float | _complex, 'exp')
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
# For exp_p it is more efficient to use the reconstructed output for the vjp
# rule instead of computing it again from the input.
mlir.register_lowering(exp_p, partial(_nary_lower_mhlo, mhlo.ExpOp))
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.ExpOp))
log_p = standard_unop(_float | _complex, 'log')
ad.defjvp(log_p, lambda g, x: div(g, x))
mlir.register_lowering(log_p, partial(_nary_lower_mhlo, mhlo.LogOp))
mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.LogOp))
expm1_p = standard_unop(_float | _complex, 'expm1')
ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans))))
mlir.register_lowering(expm1_p, partial(_nary_lower_mhlo, mhlo.Expm1Op))
mlir.register_lowering(expm1_p, partial(_nary_lower_hlo, hlo.Expm1Op))
log1p_p = standard_unop(_float | _complex, 'log1p')
ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x))))
mlir.register_lowering(log1p_p, partial(_nary_lower_mhlo, mhlo.Log1pOp))
mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.Log1pOp))
tanh_p = standard_unop(_float | _complex, 'tanh')
ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)),
sub(_one(x), ans)))
mlir.register_lowering(tanh_p, partial(_nary_lower_mhlo, mhlo.TanhOp))
mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.TanhOp))
logistic_p = standard_unop(_float | _complex, 'logistic')
ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans))))
# TODO(phawkins): switch to mhlo.logistic lowering; debug numerical problems.
# mlir.register_lowering(logistic_p, partial(_nary_lower_mhlo, mhlo.LogisticOp))
# TODO(phawkins): switch to LogisticOp lowering; debug numerical problems.
# mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.LogisticOp))
def logistic_impl(x):
one = _const(x, 1)
@ -1776,11 +1775,11 @@ mlir.register_lowering(logistic_p,
sin_p = standard_unop(_float | _complex, 'sin')
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
mlir.register_lowering(sin_p, partial(_nary_lower_mhlo, mhlo.SineOp))
mlir.register_lowering(sin_p, partial(_nary_lower_hlo, hlo.SineOp))
cos_p = standard_unop(_float | _complex, 'cos')
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
mlir.register_lowering(cos_p, partial(_nary_lower_mhlo, mhlo.CosineOp))
mlir.register_lowering(cos_p, partial(_nary_lower_hlo, hlo.CosineOp))
@_upcast_fp16_for_computation
def _tan_impl(x):
@ -1788,7 +1787,7 @@ def _tan_impl(x):
tan_p = standard_unop(_float | _complex, 'tan')
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
mlir.register_lowering(tan_p, partial(_nary_lower_mhlo, chlo.TanOp))
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.TanOp))
def asin_impl(x):
if dtypes.issubdtype(_dtype(x), np.complexfloating):
@ -1799,7 +1798,7 @@ def asin_impl(x):
asin_p = standard_unop(_float | _complex, 'asin')
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x))))
mlir.register_lowering(asin_p, partial(_nary_lower_mhlo, chlo.AsinOp))
mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.AsinOp))
def acos_impl(x):
if dtypes.issubdtype(_dtype(x), np.complexfloating):
@ -1828,35 +1827,35 @@ def atan_impl(x):
atan_p = standard_unop(_float | _complex, 'atan')
ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x)))
mlir.register_lowering(atan_p, partial(_nary_lower_mhlo, chlo.AtanOp))
mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.AtanOp))
atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2')
ad.defjvp(atan2_p,
lambda g, x, y: g * (y / (square(x) + square(y))),
lambda g, x, y: g * -x / (square(x) + square(y)))
mlir.register_lowering(atan2_p, partial(_nary_lower_mhlo, mhlo.Atan2Op))
mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.Atan2Op))
sinh_p = standard_unop(_float | _complex, 'sinh')
ad.defjvp(sinh_p, lambda g, x: mul(g, cosh(x)))
mlir.register_lowering(sinh_p, partial(_nary_lower_mhlo, chlo.SinhOp))
mlir.register_lowering(sinh_p, partial(_nary_lower_hlo, chlo.SinhOp))
cosh_p = standard_unop(_float | _complex, 'cosh')
ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
mlir.register_lowering(cosh_p, partial(_nary_lower_mhlo, chlo.CoshOp))
mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.CoshOp))
asinh_p = standard_unop(_float | _complex, 'asinh')
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x))))
mlir.register_lowering(asinh_p, partial(_nary_lower_mhlo, chlo.AsinhOp))
mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.AsinhOp))
acosh_p = standard_unop(_float | _complex, 'acosh')
ad.defjvp(acosh_p,
lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x)))))
mlir.register_lowering(acosh_p, partial(_nary_lower_mhlo, chlo.AcoshOp))
mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.AcoshOp))
atanh_p = standard_unop(_float | _complex, 'atanh')
ad.defjvp(atanh_p,
lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x))))
mlir.register_lowering(atanh_p, partial(_nary_lower_mhlo, chlo.AtanhOp))
mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.AtanhOp))
regularized_incomplete_beta_p = standard_naryop(
[_float, _float, _float], 'regularized_incomplete_beta')
@ -1880,10 +1879,10 @@ ad.defjvp(regularized_incomplete_beta_p,
lgamma_p = standard_unop(_float, 'lgamma')
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
mlir.register_lowering(lgamma_p, partial(_nary_lower_mhlo, chlo.LgammaOp))
mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.LgammaOp))
digamma_p = standard_unop(_float, 'digamma')
mlir.register_lowering(digamma_p, partial(_nary_lower_mhlo, chlo.DigammaOp))
mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.DigammaOp))
igamma_p = standard_naryop([_float, _float], 'igamma')
xla.register_translation(igamma_p, partial(_broadcast_translate, xops.Igamma))
@ -1919,7 +1918,7 @@ ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y))
bessel_i1e_p = standard_unop(_float, 'bessel_i1e')
mlir.register_lowering(bessel_i1e_p,
partial(_nary_lower_mhlo, chlo.BesselI1eOp))
partial(_nary_lower_hlo, chlo.BesselI1eOp))
def _bessel_i1e_jvp(g, y, x):
eps = dtypes.finfo(_dtype(x)).eps
@ -1933,12 +1932,12 @@ ad.defjvp2(bessel_i1e_p, _bessel_i1e_jvp)
erf_p = standard_unop(_float, 'erf')
ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
mul(g, exp(neg(square(x))))))
mlir.register_lowering(erf_p, partial(_nary_lower_mhlo, chlo.ErfOp))
mlir.register_lowering(erf_p, partial(_nary_lower_hlo, chlo.ErfOp))
erfc_p = standard_unop(_float, 'erfc')
ad.defjvp(erfc_p, lambda g, x: mul(_const(x, -2. / np.sqrt(np.pi)),
mul(g, exp(neg(square(x))))))
mlir.register_lowering(erfc_p, partial(_nary_lower_mhlo, chlo.ErfcOp))
mlir.register_lowering(erfc_p, partial(_nary_lower_hlo, chlo.ErfcOp))
erf_inv_p = standard_unop(_float, 'erf_inv')
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),
@ -1947,11 +1946,11 @@ xla.register_translation(erf_inv_p, standard_translate(erf_inv_p))
real_p = unop(_complex_basetype, _complex, 'real')
ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))])
mlir.register_lowering(real_p, partial(_nary_lower_mhlo, mhlo.RealOp))
mlir.register_lowering(real_p, partial(_nary_lower_hlo, hlo.RealOp))
imag_p = unop(_complex_basetype, _complex, 'imag')
ad.deflinear2(imag_p, lambda t, _: [complex(np.zeros((), _dtype(t)), neg(t))])
mlir.register_lowering(imag_p, partial(_nary_lower_mhlo, mhlo.ImagOp))
mlir.register_lowering(imag_p, partial(_nary_lower_hlo, hlo.ImagOp))
def _complex_transpose_rule(t, x, y):
@ -1976,7 +1975,7 @@ _complex_dtype = lambda dtype, *args: (np.zeros((), dtype) + np.zeros((), np.com
complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types],
'complex')
ad.deflinear2(complex_p, _complex_transpose_rule)
mlir.register_lowering(complex_p, partial(_nary_lower_mhlo, mhlo.ComplexOp))
mlir.register_lowering(complex_p, partial(_nary_lower_hlo, hlo.ComplexOp))
conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj')
@ -2001,7 +2000,7 @@ ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p)
ad.primitive_transposes[conj_p] = _conj_transpose_rule
abs_p = unop(_complex_basetype, _num, 'abs')
mlir.register_lowering(abs_p, partial(_nary_lower_mhlo, mhlo.AbsOp))
mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.AbsOp))
def _abs_jvp_rule(g, ans, x):
if _iscomplex(x):
@ -2015,18 +2014,18 @@ _maybe_real = lambda x: real(x) if _iscomplex(x) else x
sqrt_p = standard_unop(_float | _complex, 'sqrt')
ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans)))
mlir.register_lowering(sqrt_p, partial(_nary_lower_mhlo, mhlo.SqrtOp))
mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.SqrtOp))
rsqrt_p = standard_unop(_float | _complex, 'rsqrt')
ad.defjvp2(rsqrt_p,
lambda g, ans, x:
mul(g, mul(_const(x, -0.5), div(ans, x))))
mlir.register_lowering(rsqrt_p, partial(_nary_lower_mhlo, mhlo.RsqrtOp))
mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.RsqrtOp))
cbrt_p = standard_unop(_float, 'cbrt')
ad.defjvp2(cbrt_p,
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))
mlir.register_lowering(cbrt_p, partial(_nary_lower_mhlo, mhlo.CbrtOp))
mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.CbrtOp))
pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow')
@ -2037,7 +2036,7 @@ def _pow_jvp_rhs(g, ans, x, y):
return mul(g, mul(log(_replace_zero(x)), ans))
ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
mlir.register_lowering(pow_p, partial(_nary_lower_mhlo, mhlo.PowOp))
mlir.register_lowering(pow_p, partial(_nary_lower_hlo, hlo.PowOp))
def _integer_pow_dtype_rule(x, *, y):
@ -2077,8 +2076,8 @@ def _integer_pow(x, *, y):
def _integer_pow_lowering(ctx, x, *, y):
lowering = mlir.lower_fun(_integer_pow, multiple_results=False)
# TODO(b/217551391): emitting an out-of-line call leads to a large
# expansion when the MHLO is lowered to HLO, because the HLO lowering
# clones the callee. Consider unconditionally caching when the MHLO->HLO
# expansion when the MLIR is lowered to HLO, because the HLO lowering
# clones the callee. Consider unconditionally caching when the MLIR->HLO
# lowering doesn't expand the program.
if y >= 4:
lowering = mlir.cache_lowering(lowering)
@ -2090,26 +2089,26 @@ _replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
not_p = standard_unop(_bool_or_int, 'not')
ad.defjvp_zero(not_p)
mlir.register_lowering(not_p, partial(_nary_lower_mhlo, mhlo.NotOp))
mlir.register_lowering(not_p, partial(_nary_lower_hlo, hlo.NotOp))
and_p = standard_naryop([_bool_or_int, _bool_or_int], 'and')
ad.defjvp_zero(and_p)
mlir.register_lowering(and_p, partial(_nary_lower_mhlo, mhlo.AndOp))
mlir.register_lowering(and_p, partial(_nary_lower_hlo, hlo.AndOp))
or_p = standard_naryop([_bool_or_int, _bool_or_int], 'or')
ad.defjvp_zero(or_p)
mlir.register_lowering(or_p, partial(_nary_lower_mhlo, mhlo.OrOp))
mlir.register_lowering(or_p, partial(_nary_lower_hlo, hlo.OrOp))
xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor')
ad.defjvp_zero(xor_p)
mlir.register_lowering(xor_p, partial(_nary_lower_mhlo, mhlo.XorOp))
mlir.register_lowering(xor_p, partial(_nary_lower_hlo, hlo.XorOp))
population_count_p = standard_unop(_int, 'population_count')
mlir.register_lowering(population_count_p,
partial(_nary_lower_mhlo, mhlo.PopulationCountOp))
partial(_nary_lower_hlo, hlo.PopulationCountOp))
clz_p = standard_unop(_int, 'clz')
mlir.register_lowering(clz_p, partial(_nary_lower_mhlo, mhlo.ClzOp))
mlir.register_lowering(clz_p, partial(_nary_lower_hlo, hlo.ClzOp))
def _add_jvp(primals, tangents):
x, y = primals
@ -2145,7 +2144,7 @@ def _add_inverse(r, x, y):
add_p: Primitive = standard_naryop([_num, _num], 'add')
ad.primitive_jvps[add_p] = _add_jvp
ad.primitive_transposes[add_p] = _add_transpose
mlir.register_lowering(add_p, partial(_nary_lower_mhlo, mhlo.AddOp))
mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.AddOp))
def _sub_jvp(primals, tangents):
x, y = primals
@ -2174,7 +2173,7 @@ def _sub_transpose(t, x, y):
sub_p = standard_naryop([_num, _num], 'sub')
ad.primitive_jvps[sub_p] = _sub_jvp
ad.primitive_transposes[sub_p] = _sub_transpose
mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubtractOp))
mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.SubtractOp))
def _mul_transpose(ct, x, y):
@ -2200,7 +2199,7 @@ ad.defjvp(mul_p,
lambda xdot, x, y: mul(xdot, y),
lambda ydot, x, y: mul(x, ydot))
ad.primitive_transposes[mul_p] = _mul_transpose
mlir.register_lowering(mul_p, partial(_nary_lower_mhlo, mhlo.MulOp))
mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.MulOp))
def _div_transpose_rule(cotangent, x, y):
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
@ -2213,14 +2212,14 @@ ad.defjvp(div_p,
lambda g, x, y: div(g, y),
lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2)))
ad.primitive_transposes[div_p] = _div_transpose_rule
mlir.register_lowering(div_p, partial(_nary_lower_mhlo, mhlo.DivOp))
mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.DivOp))
rem_p = standard_naryop([_int | _float, _int | _float], 'rem')
ad.defjvp(
rem_p,
lambda g, x, y: _maybe_broadcast(broadcast_shapes(np.shape(x), np.shape(y)), g),
lambda g, x, y: mul(neg(g), mul(sign(div(x, y)), floor(abs(div(x, y))))))
mlir.register_lowering(rem_p, partial(_nary_lower_mhlo, mhlo.RemOp))
mlir.register_lowering(rem_p, partial(_nary_lower_hlo, hlo.RemOp))
def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x):
result_shape = broadcast_shapes(np.shape(x), np.shape(y))
@ -2236,29 +2235,29 @@ max_p: core.Primitive = standard_naryop([_any, _any], 'max')
ad.defjvp2(max_p,
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
mlir.register_lowering(max_p, partial(_nary_lower_mhlo, mlir.max_mhlo))
mlir.register_lowering(max_p, partial(_nary_lower_hlo, mlir.max_hlo))
min_p: core.Primitive = standard_naryop([_any, _any], 'min')
ad.defjvp2(min_p,
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
mlir.register_lowering(min_p, partial(_nary_lower_mhlo, mlir.min_mhlo))
mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo))
shift_left_p = standard_naryop([_int, _int], 'shift_left')
ad.defjvp_zero(shift_left_p)
mlir.register_lowering(shift_left_p, partial(_nary_lower_mhlo, mhlo.ShiftLeftOp))
mlir.register_lowering(shift_left_p, partial(_nary_lower_hlo, hlo.ShiftLeftOp))
shift_right_arithmetic_p = standard_naryop([_int, _int], 'shift_right_arithmetic')
ad.defjvp_zero(shift_right_arithmetic_p)
mlir.register_lowering(shift_right_arithmetic_p,
partial(_nary_lower_mhlo, mhlo.ShiftRightArithmeticOp))
partial(_nary_lower_hlo, hlo.ShiftRightArithmeticOp))
shift_right_logical_p = standard_naryop([_int, _int], 'shift_right_logical')
ad.defjvp_zero(shift_right_logical_p)
mlir.register_lowering(shift_right_logical_p,
partial(_nary_lower_mhlo, mhlo.ShiftRightLogicalOp))
partial(_nary_lower_hlo, hlo.ShiftRightLogicalOp))
def _compare_lower_mhlo(direction: str, ctx, x, y):
def _compare_lower_hlo(direction: str, ctx, x, y):
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
x_dtype = avals_in[0].dtype
x, y = mlir.multi_broadcast_in_dim(ctx, (x, y), avals_in, aval_out.shape)
@ -2269,31 +2268,31 @@ def _compare_lower_mhlo(direction: str, ctx, x, y):
compare_type = "SIGNED"
else:
compare_type = "UNSIGNED"
return mlir.compare_mhlo(x, y, direction, compare_type).results
return mlir.compare_hlo(x, y, direction, compare_type).results
eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq')
ad.defjvp_zero(eq_p)
mlir.register_lowering(eq_p, partial(_compare_lower_mhlo, "EQ"))
mlir.register_lowering(eq_p, partial(_compare_lower_hlo, "EQ"))
ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne')
ad.defjvp_zero(ne_p)
mlir.register_lowering(ne_p, partial(_compare_lower_mhlo, "NE"))
mlir.register_lowering(ne_p, partial(_compare_lower_hlo, "NE"))
ge_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'ge')
ad.defjvp_zero(ge_p)
mlir.register_lowering(ge_p, partial(_compare_lower_mhlo, "GE"))
mlir.register_lowering(ge_p, partial(_compare_lower_hlo, "GE"))
gt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'gt')
ad.defjvp_zero(gt_p)
mlir.register_lowering(gt_p, partial(_compare_lower_mhlo, "GT"))
mlir.register_lowering(gt_p, partial(_compare_lower_hlo, "GT"))
le_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'le')
ad.defjvp_zero(le_p)
mlir.register_lowering(le_p, partial(_compare_lower_mhlo, "LE"))
mlir.register_lowering(le_p, partial(_compare_lower_hlo, "LE"))
lt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt')
ad.defjvp_zero(lt_p)
mlir.register_lowering(lt_p, partial(_compare_lower_mhlo, "LT"))
mlir.register_lowering(lt_p, partial(_compare_lower_hlo, "LT"))
def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type):
@ -2392,9 +2391,9 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type):
aval_out, = ctx.avals_out
if (dtypes.issubdtype(aval_in.dtype, np.complexfloating) and
not dtypes.issubdtype(new_dtype, np.complexfloating)):
operand = mhlo.RealOp(operand).result
operand = hlo.RealOp(operand).result
aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype))
return [mlir.convert_mhlo(ctx, operand, aval_in, aval_out)]
return [mlir.convert_hlo(ctx, operand, aval_in, aval_out)]
mlir.register_lowering(convert_element_type_p, _convert_element_type_lower)
@ -2419,7 +2418,7 @@ batching.defvectorized(bitcast_convert_type_p)
def _bitcast_convert_type_lower(ctx, operand, *, new_dtype):
aval_out, = ctx.avals_out
return mhlo.BitcastConvertOp(mlir.aval_to_ir_type(aval_out), operand).results
return hlo.BitcastConvertOp(mlir.aval_to_ir_type(aval_out), operand).results
mlir.register_lowering(bitcast_convert_type_p, _bitcast_convert_type_lower)
@ -2708,7 +2707,7 @@ def precision_attr(precision: PrecisionType) -> ir.ArrayAttr:
else:
full_precision = precision
return ir.ArrayAttr.get(
[mhlo.PrecisionAttr.get(str(p)) for p in full_precision])
[hlo.PrecisionAttr.get(str(p)) for p in full_precision])
@ -2723,19 +2722,19 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
if ctx.module_context.platform == "cpu":
if lhs_aval.dtype == np.float16:
f32 = mlir.dtype_to_ir_type(np.dtype(np.float32))
lhs = mhlo.ConvertOp(ir.RankedTensorType.get(lhs_aval.shape, f32),
lhs = hlo.ConvertOp(ir.RankedTensorType.get(lhs_aval.shape, f32),
lhs).result
if rhs_aval.dtype == np.float16:
f32 = mlir.dtype_to_ir_type(np.dtype(np.float32))
rhs = mhlo.ConvertOp(ir.RankedTensorType.get(rhs_aval.shape, f32),
rhs = hlo.ConvertOp(ir.RankedTensorType.get(rhs_aval.shape, f32),
rhs).result
dot_dnums = mhlo.DotDimensionNumbers.get(
dot_dnums = hlo.DotDimensionNumbers.get(
lhs_batching_dimensions=list(lhs_batch),
rhs_batching_dimensions=list(rhs_batch),
lhs_contracting_dimensions=list(lhs_contracting),
rhs_contracting_dimensions=list(rhs_contracting))
return [
mhlo.DotGeneralOp(
hlo.DotGeneralOp(
mlir.aval_to_ir_type(aval_out),
lhs,
rhs,
@ -3005,7 +3004,7 @@ ad.defjvp(clamp_p,
select(lt(max, operand), g, _zeros(operand)))
batching.primitive_batchers[clamp_p] = _clamp_batch_rule
mlir.register_lowering(
clamp_p, partial(_nary_lower_mhlo, mhlo.ClampOp))
clamp_p, partial(_nary_lower_hlo, hlo.ClampOp))
pe.def_trivial_padding(clamp_p)
def _concatenate_shape_rule(*operands, **kwargs):
@ -3083,7 +3082,7 @@ batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule
pe.padding_rules[concatenate_p] = _concatenate_pad_rule
def _concatenate_lower(ctx, *xs, dimension):
return mhlo.ConcatenateOp(xs, mlir.i64_attr(dimension)).results
return hlo.ConcatenateOp(xs, mlir.i64_attr(dimension)).results
mlir.register_lowering(concatenate_p, _concatenate_lower)
@ -3158,7 +3157,7 @@ batching.primitive_batchers[pad_p] = _pad_batch_rule
def _pad_lower(ctx, x, padding_value, *, padding_config):
low, high, interior = util.unzip3(padding_config)
return mhlo.PadOp(x, padding_value,
return hlo.PadOp(x, padding_value,
mlir.dense_int_elements(low),
mlir.dense_int_elements(high),
mlir.dense_int_elements(interior)).results
@ -3302,7 +3301,7 @@ def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions):
def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions):
aval_out, = ctx.avals_out
if dimensions is not None:
x = mhlo.TransposeOp(x, mlir.dense_int_elements(dimensions)).result
x = hlo.TransposeOp(x, mlir.dense_int_elements(dimensions)).result
if dyn_shape:
aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
return [mlir.reshape(ctx, x, aval_out)]
@ -3346,7 +3345,7 @@ ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
batching.primitive_batchers[rev_p] = _rev_batch_rule
def _rev_lower(ctx, x, *, dimensions):
return mhlo.ReverseOp(x, mlir.dense_int_elements(dimensions)).results
return hlo.ReverseOp(x, mlir.dense_int_elements(dimensions)).results
mlir.register_lowering(rev_p, _rev_lower)
@ -3370,7 +3369,7 @@ def _transpose_lower(ctx, x, *, permutation):
aval_out, = ctx.avals_out
if core.is_opaque_dtype(aval_out.dtype):
return [aval_out.dtype._rules.transpose_mlir(ctx, aval_out, x, permutation=permutation)]
return mhlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results
return hlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
'transpose')
@ -3470,12 +3469,12 @@ def _select_jvp(primals, tangents):
out_dot = select_n(which, *case_tangents)
return out, out_dot
def _select_mhlo_lowering(ctx, which, *cases):
def _select_hlo_lowering(ctx, which, *cases):
which_aval = ctx.avals_in[0]
if which_aval.dtype == np.dtype(np.bool_):
assert len(cases) <= 2
if len(cases) == 1: return cases
return mhlo.SelectOp(which, cases[1], cases[0]).results
return hlo.SelectOp(which, cases[1], cases[0]).results
if dtypes.issubdtype(which_aval.dtype, np.signedinteger):
compare_type = 'SIGNED'
@ -3488,10 +3487,10 @@ def _select_mhlo_lowering(ctx, which, *cases):
if len(cases) == 1:
return cases[0]
mid = len(cases) // 2
pred = mlir.compare_mhlo(which,
pred = mlir.compare_hlo(which,
mlir.full_like_aval(ctx, offset + mid, which_aval),
lt, compare_type)
return mhlo.SelectOp(pred, _select(offset, cases[:mid]),
return hlo.SelectOp(pred, _select(offset, cases[:mid]),
_select(offset + mid, cases[mid:])).result
return [_select(0, cases)]
@ -3502,7 +3501,7 @@ select_n_p = standard_primitive(
ad.primitive_jvps[select_n_p] = _select_jvp
ad.primitive_transposes[select_n_p] = _select_transpose_rule
batching.primitive_batchers[select_n_p] = _select_batch_rule
mlir.register_lowering(select_n_p, _select_mhlo_lowering)
mlir.register_lowering(select_n_p, _select_hlo_lowering)
pe.def_trivial_padding(select_n_p)
@ -3622,7 +3621,7 @@ def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions):
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
operands, init_values = util.split_list(values, [len(values) // 2])
init_value_avals = ctx.avals_in[len(values) // 2:]
op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
operands, init_values, mlir.dense_int_elements(dimensions))
ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
reducer = op.regions[0].blocks.append(*(ir_types + ir_types))
@ -3633,7 +3632,7 @@ def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions):
out_nodes, _ = mlir.jaxpr_subcomp(reducer_ctx, jaxpr, mlir.TokenSet(), consts,
*([a] for a in reducer.arguments),
dim_var_values=ctx.dim_var_values)
mhlo.ReturnOp(util.flatten(out_nodes))
hlo.ReturnOp(util.flatten(out_nodes))
return op.results
mlir.register_lowering(reduce_p, _reduce_lower)
@ -3818,29 +3817,29 @@ batching.defreducer(reduce_xor_p)
def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
aval_out, = ctx.avals_out
dtype = aval_out.dtype
op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x],
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x],
mlir.ir_constants(unit_factory(aval_out.dtype)),
mlir.dense_int_elements(axes))
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype))
reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer_region):
add = reducer(*reducer_region.arguments)
mhlo.ReturnOp(add.results)
hlo.ReturnOp(add.results)
return op.results
mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, mhlo.AddOp,
mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp,
_get_sum_identity))
mlir.register_lowering(reduce_prod_p, partial(_unary_reduce_lower, mhlo.MulOp,
mlir.register_lowering(reduce_prod_p, partial(_unary_reduce_lower, hlo.MulOp,
_get_prod_identity))
mlir.register_lowering(reduce_or_p, partial(_unary_reduce_lower, mhlo.OrOp,
mlir.register_lowering(reduce_or_p, partial(_unary_reduce_lower, hlo.OrOp,
_get_bitwise_or_identity))
mlir.register_lowering(reduce_and_p, partial(_unary_reduce_lower, mhlo.AndOp,
mlir.register_lowering(reduce_and_p, partial(_unary_reduce_lower, hlo.AndOp,
_get_bitwise_and_identity))
mlir.register_lowering(reduce_xor_p, partial(_unary_reduce_lower, mhlo.XorOp,
mlir.register_lowering(reduce_xor_p, partial(_unary_reduce_lower, hlo.XorOp,
_get_bitwise_or_identity))
mlir.register_lowering(reduce_min_p, partial(_unary_reduce_lower, mlir.min_mhlo,
mlir.register_lowering(reduce_min_p, partial(_unary_reduce_lower, mlir.min_hlo,
_get_min_identity))
mlir.register_lowering(reduce_max_p, partial(_unary_reduce_lower, mlir.max_mhlo,
mlir.register_lowering(reduce_max_p, partial(_unary_reduce_lower, mlir.max_hlo,
_get_max_identity))
@ -3864,7 +3863,7 @@ batching.defvectorized(reduce_precision_p)
def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
aval_out, = ctx.avals_out
return mhlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits),
return hlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits),
mlir.i32_attr(mantissa_bits)).results
mlir.register_lowering(reduce_precision_p, _reduce_precision_lower)
@ -4014,7 +4013,7 @@ batching.primitive_batchers[sort_p] = _sort_batch_rule
def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
sort = mhlo.SortOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
sort = hlo.SortOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
mlir.flatten_lowering_ir_args(operands),
dimension=mlir.i64_attr(dimension),
is_stable=ir.BoolAttr.get(is_stable))
@ -4031,7 +4030,7 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
out = lower_comparator(sub_ctx, *[[a] for a in comparator.arguments],
num_keys=num_keys)
mhlo.ReturnOp(util.flatten(out))
hlo.ReturnOp(util.flatten(out))
return sort.results
mlir.register_lowering(sort_p, _sort_lower)
@ -4134,10 +4133,9 @@ create_token_p.def_abstract_eval(lambda *_: abstract_token)
def _create_token_lowering(ctx, *operands):
aval_out, = ctx.avals_out
if xc.mlir_api_version < 40:
return mhlo.CreateTokenOp(mlir.aval_to_ir_type(aval_out)).results
return hlo.CreateTokenOp(mlir.aval_to_ir_type(aval_out)).results
else:
return mhlo.CreateTokenOp().results
return hlo.CreateTokenOp().results
mlir.register_lowering(create_token_p, _create_token_lowering)
@ -4160,10 +4158,9 @@ after_all_p.def_abstract_eval(_after_all_abstract_eval)
def _after_all_lowering(ctx, *operands):
aval_out, = ctx.avals_out
if xc.mlir_api_version < 40:
return mhlo.AfterAllOp(mlir.aval_to_ir_type(aval_out), operands).results
return hlo.AfterAllOp(mlir.aval_to_ir_type(aval_out), operands).results
else:
return mhlo.AfterAllOp(operands).results
return hlo.AfterAllOp(operands).results
mlir.register_lowering(after_all_p, _after_all_lowering)
@ -4215,8 +4212,8 @@ def _infeed_lowering(ctx, token, *, shapes, partitions):
for i in range(len(aval.shape) - 1, -1, -1)])
for aval in shapes
])
infeed = mhlo.InfeedOp(
flat_output_types + [mhlo.TokenType.get()],
infeed = hlo.InfeedOp(
flat_output_types + [hlo.TokenType.get()],
token,
infeed_config=ir.StringAttr.get(''),
layout=layouts)
@ -4259,13 +4256,13 @@ mlir.lowerable_effects.add(InOutFeedEffect.Outfeed)
def _outfeed_lowering(ctx, token, *xs, partitions):
token_aval = ctx.avals_in[0]
if xc.mlir_api_version < 40:
outfeed = mhlo.OutfeedOp(
outfeed = hlo.OutfeedOp(
mlir.aval_to_ir_type(token_aval),
mlir.flatten_lowering_ir_args(xs),
token,
outfeed_config=ir.StringAttr.get(''))
else:
outfeed = mhlo.OutfeedOp(
outfeed = hlo.OutfeedOp(
mlir.flatten_lowering_ir_args(xs),
token,
outfeed_config=ir.StringAttr.get(''))
@ -4308,8 +4305,8 @@ def _rng_uniform_lowering(ctx, a, b, *, shape):
aval_out, = ctx.avals_out
shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64),
canonicalize_types=False)
return mhlo.RngOp(a, b, shape,
mhlo.RngDistributionAttr.get('UNIFORM')).results
return hlo.RngOp(a, b, shape,
hlo.RngDistributionAttr.get('UNIFORM')).results
mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering)
@ -4331,11 +4328,11 @@ RandomAlgorithm.__str__ = lambda algorithm: algorithm.name # type: ignore[assig
def _rng_algorithm(algorithm: RandomAlgorithm):
if algorithm == RandomAlgorithm.RNG_THREE_FRY:
return mhlo.RngAlgorithmAttr.get("THREE_FRY")
return hlo.RngAlgorithmAttr.get("THREE_FRY")
elif algorithm == RandomAlgorithm.RNG_PHILOX:
return mhlo.RngAlgorithmAttr.get("PHILOX")
return hlo.RngAlgorithmAttr.get("PHILOX")
elif algorithm == RandomAlgorithm.RNG_DEFAULT:
return mhlo.RngAlgorithmAttr.get("DEFAULT")
return hlo.RngAlgorithmAttr.get("DEFAULT")
else:
assert False
@ -4362,21 +4359,21 @@ def _rng_bit_generator_lowering(
else:
rbg_etype = u32_type
if key_etype == u32_type:
key = mhlo.BitcastConvertOp(
key = hlo.BitcastConvertOp(
ir.RankedTensorType.get([2], u64_type),
mhlo.ReshapeOp(ir.RankedTensorType.get([2, 2], u32_type), key)).result
hlo.ReshapeOp(ir.RankedTensorType.get([2, 2], u32_type), key)).result
algorithm_attr = _rng_algorithm(algorithm)
out_key, out_vals = mhlo.RngBitGeneratorOp(
out_key, out_vals = hlo.RngBitGeneratorOp(
key.type,
ir.RankedTensorType.get(shape, rbg_etype),
algorithm_attr, key).results
if key_etype == u32_type:
out_key = mhlo.ReshapeOp(
out_key = hlo.ReshapeOp(
ir.RankedTensorType.get([4], u32_type),
mhlo.BitcastConvertOp(
hlo.BitcastConvertOp(
ir.RankedTensorType.get([2, 2], u32_type), out_key)).result
if rbg_etype != etype:
out_vals = mhlo.ConvertOp(
out_vals = hlo.ConvertOp(
ir.RankedTensorType.get(ir.RankedTensorType(out_vals.type).shape, etype),
out_vals).result
return [out_key, out_vals]
@ -4523,13 +4520,13 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension):
aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
if not core.is_constant_shape(aval_out.shape):
shape = mlir.eval_dynamic_shape(ctx, aval_out.shape)
return mhlo.DynamicIotaOp(
return hlo.DynamicIotaOp(
mlir.aval_to_ir_type(aval_out),
mlir.shape_tensor(shape),
mlir.i64_attr(dimension),
).results
else:
return mhlo.IotaOp(mlir.aval_to_ir_type(aval_out),
return hlo.IotaOp(mlir.aval_to_ir_type(aval_out),
mlir.i64_attr(dimension)).results
mlir.register_lowering(iota_p, _iota_lower)

View File

@ -52,7 +52,7 @@ from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.typing import Array, ArrayLike
xops = xla_client.ops
@ -418,7 +418,7 @@ ad.primitive_jvps[cholesky_p] = _cholesky_jvp_rule
batching.primitive_batchers[cholesky_p] = _cholesky_batching_rule
def _cholesky_lowering(ctx, x):
return mhlo.CholeskyOp(x, lower=ir.BoolAttr.get(True)).results
return hlo.CholeskyOp(x, lower=ir.BoolAttr.get(True)).results
mlir.register_lowering(cholesky_p, _cholesky_lowering)
@ -429,22 +429,28 @@ def _cholesky_cpu_gpu_lowering(potrf_impl, ctx, operand):
out_aval, = ctx.avals_out
batch_dims = operand_aval.shape[:-2]
result, info = potrf_impl(operand_aval.dtype, operand, lower=True)
ok = mlir.compare_mhlo(
ok = mlir.compare_hlo(
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
"EQ", "SIGNED")
select_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
return [_broadcasting_select_mhlo(
return [_broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok,
select_aval,
broadcast_dimensions=range(len(batch_dims))),
select_aval,
result, out_aval, _nan_like_mhlo(ctx, out_aval), out_aval)]
result, out_aval, _nan_like_hlo(ctx, out_aval), out_aval)]
if xla_client.mlir_api_version < 41:
mlir.register_lowering(
cholesky_p,
partial(_cholesky_cpu_gpu_lowering, lapack.potrf_mhlo),
platform='cpu')
else:
mlir.register_lowering(
cholesky_p,
partial(_cholesky_cpu_gpu_lowering, lapack.potrf_hlo),
platform='cpu')
# Asymmetric eigendecomposition
@ -491,42 +497,47 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
out_aval = ctx.avals_out[0]
batch_dims = operand_aval.shape[:-2]
if xla_client.mlir_api_version < 41:
w, vl, vr, info = lapack.geev_mhlo(operand_aval.dtype, operand,
jobvl=compute_left_eigenvectors,
jobvr=compute_right_eigenvectors)
else:
w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand,
jobvl=compute_left_eigenvectors,
jobvr=compute_right_eigenvectors)
ok = mlir.compare_mhlo(
ok = mlir.compare_hlo(
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
"EQ", "SIGNED")
select_w_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
w = _broadcasting_select_mhlo(
w = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_w_aval,
broadcast_dimensions=range(len(batch_dims))),
select_w_aval,
w, out_aval, _nan_like_mhlo(ctx, out_aval), out_aval)
w, out_aval, _nan_like_hlo(ctx, out_aval), out_aval)
output = [w]
if compute_left_eigenvectors:
aval = ctx.avals_out[len(output)]
select_vl_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
vl = _broadcasting_select_mhlo(
vl = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_vl_aval,
broadcast_dimensions=range(len(batch_dims))),
select_vl_aval,
vl, aval, _nan_like_mhlo(ctx, aval), aval)
vl, aval, _nan_like_hlo(ctx, aval), aval)
output.append(vl)
if compute_right_eigenvectors:
aval = ctx.avals_out[len(output)]
select_vr_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
vr = _broadcasting_select_mhlo(
vr = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_vr_aval,
broadcast_dimensions=range(len(batch_dims))),
select_vr_aval,
vr, aval, _nan_like_mhlo(ctx, aval), aval)
vr, aval, _nan_like_hlo(ctx, aval), aval)
output.append(vr)
return output
@ -645,21 +656,21 @@ def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower,
batch_dims = operand_aval.shape[:-2]
v, w, info = syevd_impl(operand_aval.dtype, operand, lower=lower)
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_mhlo(info, zeros, "EQ", "SIGNED")
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
select_v_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
v = _broadcasting_select_mhlo(
v = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_v_aval,
broadcast_dimensions=range(len(batch_dims))),
select_v_aval,
v, v_aval, _nan_like_mhlo(ctx, v_aval), v_aval)
v, v_aval, _nan_like_hlo(ctx, v_aval), v_aval)
select_w_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
w = _broadcasting_select_mhlo(
w = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_w_aval,
broadcast_dimensions=range(len(batch_dims))),
select_w_aval,
w, w_aval, _nan_like_mhlo(ctx, w_aval), w_aval)
w, w_aval, _nan_like_hlo(ctx, w_aval), w_aval)
return [v, w]
def _eigh_tpu_impl(x, *, lower, sort_eigenvalues):
@ -742,9 +753,14 @@ eigh_p.def_abstract_eval(_eigh_abstract_eval)
ad.primitive_jvps[eigh_p] = _eigh_jvp_rule
batching.primitive_batchers[eigh_p] = _eigh_batching_rule
if xla_client.mlir_api_version < 41:
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_mhlo),
platform='cpu')
else:
mlir.register_lowering(
eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_hlo),
platform='cpu')
if gpu_solver is not None:
mlir.register_lowering(
@ -882,15 +898,15 @@ def _triangular_solve_lowering(
else:
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
if mlir_api_version < 36:
return mhlo.TriangularSolveOp(
return hlo.TriangularSolveOp(
mlir.aval_to_ir_type(out_aval), a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
mhlo.TransposeAttr.get(transpose)).results
hlo.TransposeAttr.get(transpose)).results
else:
return mhlo.TriangularSolveOp(
return hlo.TriangularSolveOp(
a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
mhlo.TransposeAttr.get(transpose)).results
hlo.TransposeAttr.get(transpose)).results
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
@ -904,9 +920,14 @@ def _triangular_solve_cpu_lower(
conjugate_a = False
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
if xla_client.mlir_api_version < 41:
return [lapack.trsm_mhlo(
a_aval.dtype, alpha,
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)]
else:
return [lapack.trsm_hlo(
a_aval.dtype, alpha,
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)]
else:
# Fall back to the HLO implementation for unsupported types or batching.
# TODO: Consider swapping XLA for LAPACK in batched case
@ -915,15 +936,15 @@ def _triangular_solve_cpu_lower(
else:
transpose = "NO_TRANSPOSE"
if mlir_api_version < 36:
return mhlo.TriangularSolveOp(b.type, a, b, ir.BoolAttr.get(left_side),
return hlo.TriangularSolveOp(b.type, a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower),
ir.BoolAttr.get(unit_diagonal),
mhlo.TransposeAttr.get(transpose)).results
hlo.TransposeAttr.get(transpose)).results
else:
return mhlo.TriangularSolveOp(a, b, ir.BoolAttr.get(left_side),
return hlo.TriangularSolveOp(a, b, ir.BoolAttr.get(left_side),
ir.BoolAttr.get(lower),
ir.BoolAttr.get(unit_diagonal),
mhlo.TransposeAttr.get(transpose)).results
hlo.TransposeAttr.get(transpose)).results
mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
platform='cpu')
@ -1186,17 +1207,17 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand):
m = operand_aval.shape[-2]
lu, pivot, info = getrf_impl(operand_aval.dtype, operand)
# Subtract 1 from the pivot to get 0-based indices.
pivot = mhlo.SubtractOp(pivot, mlir.full_like_aval(ctx, 1, pivot_aval)).result
ok = mlir.compare_mhlo(
pivot = hlo.SubtractOp(pivot, mlir.full_like_aval(ctx, 1, pivot_aval)).result
ok = mlir.compare_hlo(
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
"GE", "SIGNED")
select_lu_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
lu = _broadcasting_select_mhlo(
lu = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_lu_aval,
broadcast_dimensions=range(len(batch_dims))),
select_lu_aval,
lu, out_aval, _nan_like_mhlo(ctx, out_aval), out_aval)
lu, out_aval, _nan_like_hlo(ctx, out_aval), out_aval)
sub_ctx = ctx.replace(primitive=None, avals_in=[pivot_aval], avals_out=[perm_aval])
perm_fn = mlir.lower_fun(lambda x: lu_pivots_to_permutation(x, m),
multiple_results=False)
@ -1216,9 +1237,14 @@ mlir.register_lowering(lu_p, mlir.lower_fun(_lu_python, multiple_results=True))
ad.primitive_jvps[lu_p] = _lu_jvp_rule
batching.primitive_batchers[lu_p] = _lu_batching_rule
if xla_client.mlir_api_version < 41:
mlir.register_lowering(lu_p,
partial(_lu_cpu_gpu_lowering, lapack.getrf_mhlo),
platform='cpu')
else:
mlir.register_lowering(lu_p,
partial(_lu_cpu_gpu_lowering, lapack.getrf_hlo),
platform='cpu')
mlir.register_lowering(
lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.cuda_getrf),
@ -1339,15 +1365,15 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a):
else:
a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_mhlo(info_geqrf, zeros, "EQ", "SIGNED")
ok = mlir.compare_hlo(info_geqrf, zeros, "EQ", "SIGNED")
select_ok_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_))
ok_a = mlir.broadcast_in_dim(ctx, ok, select_ok_a_aval,
broadcast_dimensions=range(len(batch_dims)))
a_out = _broadcasting_select_mhlo(ctx, ok_a, select_ok_a_aval, a_out, a_aval, _nan_like_mhlo(ctx, a_aval), a_aval)
a_out = _broadcasting_select_hlo(ctx, ok_a, select_ok_a_aval, a_out, a_aval, _nan_like_hlo(ctx, a_aval), a_aval)
select_ok_taus_aval = ShapedArray(batch_dims + [1], np.dtype(np.bool_))
ok_taus = mlir.broadcast_in_dim(ctx, ok, select_ok_taus_aval,
broadcast_dimensions=range(len(batch_dims)))
taus = _broadcasting_select_mhlo(ctx, ok_taus, select_ok_taus_aval, taus, taus_aval, _nan_like_mhlo(ctx, taus_aval), taus_aval)
taus = _broadcasting_select_hlo(ctx, ok_taus, select_ok_taus_aval, taus, taus_aval, _nan_like_hlo(ctx, taus_aval), taus_aval)
return a_out, taus
geqrf_p = Primitive('geqrf')
@ -1357,9 +1383,14 @@ geqrf_p.def_abstract_eval(_geqrf_abstract_eval)
batching.primitive_batchers[geqrf_p] = _geqrf_batching_rule
xla.register_translation(geqrf_p, _geqrf_translation_rule)
if xla_client.mlir_api_version < 41:
mlir.register_lowering(
geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_mhlo, None),
platform='cpu')
else:
mlir.register_lowering(
geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_hlo, None),
platform='cpu')
mlir.register_lowering(
geqrf_p,
partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf,
@ -1425,11 +1456,11 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus)
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_mhlo(info_orgqr, zeros, "EQ", "SIGNED")
ok = mlir.compare_hlo(info_orgqr, zeros, "EQ", "SIGNED")
select_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_))
ok = mlir.broadcast_in_dim(ctx, ok, select_a_aval,
broadcast_dimensions=range(len(batch_dims)))
a = _broadcasting_select_mhlo(ctx, ok, select_a_aval, a, a_aval, _nan_like_mhlo(ctx, a_aval), a_aval)
a = _broadcasting_select_hlo(ctx, ok, select_a_aval, a, a_aval, _nan_like_hlo(ctx, a_aval), a_aval)
return [a]
@ -1439,10 +1470,16 @@ householder_product_p.def_abstract_eval(_householder_product_abstract_eval)
batching.primitive_batchers[householder_product_p] = _householder_product_batching_rule
xla.register_translation(householder_product_p, _householder_product_translation_rule)
if xla_client.mlir_api_version < 41:
mlir.register_lowering(
householder_product_p,
partial(_householder_product_cpu_gpu_lowering, lapack.orgqr_mhlo),
platform='cpu')
else:
mlir.register_lowering(
householder_product_p,
partial(_householder_product_cpu_gpu_lowering, lapack.orgqr_hlo),
platform='cpu')
mlir.register_lowering(
householder_product_p,
partial(_householder_product_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
@ -1642,32 +1679,32 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
full_matrices=full_matrices,
compute_uv=compute_uv)
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
ok = mlir.compare_mhlo(info, zeros, "EQ", "SIGNED")
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
select_s_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
s = _broadcasting_select_mhlo(
s = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_s_aval,
broadcast_dimensions=range(len(batch_dims))),
select_s_aval,
s, s_aval, _nan_like_mhlo(ctx, s_aval), s_aval)
s, s_aval, _nan_like_hlo(ctx, s_aval), s_aval)
result = [s]
if compute_uv:
u_aval, vt_aval = ctx.avals_out[1:]
select_u_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
u = _broadcasting_select_mhlo(
u = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_u_aval,
broadcast_dimensions=range(len(batch_dims))),
select_u_aval,
u, u_aval, _nan_like_mhlo(ctx, u_aval), u_aval)
u, u_aval, _nan_like_hlo(ctx, u_aval), u_aval)
select_v_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
vt = _broadcasting_select_mhlo(
vt = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_v_aval,
broadcast_dimensions=range(len(batch_dims))),
select_v_aval,
vt, vt_aval, _nan_like_mhlo(ctx, vt_aval), vt_aval)
vt, vt_aval, _nan_like_hlo(ctx, vt_aval), vt_aval)
result += [u, vt]
return result
@ -1715,9 +1752,14 @@ svd_p.def_abstract_eval(_svd_abstract_eval)
ad.primitive_jvps[svd_p] = _svd_jvp_rule
batching.primitive_batchers[svd_p] = _svd_batching_rule
if xla_client.mlir_api_version < 41:
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo),
platform='cpu')
else:
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_hlo),
platform='cpu')
mlir.register_lowering(
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.cuda_gesvd),
platform='cuda')
@ -1877,33 +1919,40 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
operand_aval, = ctx.avals_in
batch_dims = operand_aval.shape[:-2]
if xla_client.mlir_api_version < 41:
gees_result = lapack.gees_mhlo(operand_aval.dtype, operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
else:
gees_result = lapack.gees_hlo(operand_aval.dtype, operand,
jobvs=compute_schur_vectors,
sort=sort_eig_vals,
select=select_callable)
# Number of return values depends on value of sort_eig_vals.
T, vs, *_, info = gees_result
ok = mlir.compare_mhlo(
ok = mlir.compare_hlo(
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
"EQ", "SIGNED")
select_T_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
T = _broadcasting_select_mhlo(
T = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_T_aval,
broadcast_dimensions=range(len(batch_dims))),
select_T_aval,
T, ctx.avals_out[0],_nan_like_mhlo(ctx, ctx.avals_out[0]), ctx.avals_out[0])
T, ctx.avals_out[0],_nan_like_hlo(ctx, ctx.avals_out[0]), ctx.avals_out[0])
output = [T]
if compute_schur_vectors:
select_vs_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
vs = _broadcasting_select_mhlo(
vs = _broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_vs_aval,
broadcast_dimensions=range(len(batch_dims))),
select_vs_aval,
vs, ctx.avals_out[1], _nan_like_mhlo(ctx, ctx.avals_out[1]), ctx.avals_out[1])
vs, ctx.avals_out[1], _nan_like_hlo(ctx, ctx.avals_out[1]), ctx.avals_out[1])
output.append(vs)
@ -1983,35 +2032,38 @@ def _hessenberg_batching_rule(batched_args, batch_dims):
batching.primitive_batchers[hessenberg_p] = _hessenberg_batching_rule
def _hessenberg_cpu_mhlo(ctx, a):
def _hessenberg_cpu_hlo(ctx, a):
# TODO(phawkins): remove this test after jaxlib 0.3.25 is the minimum.
if not hasattr(lapack, "gehrd_mhlo"):
if not hasattr(lapack, "gehrd_mhlo") and not hasattr(lapack, "gehrd_hlo"):
raise RuntimeError("Hessenberg reduction on CPU requires jaxlib 0.3.25 or "
"newer")
a_aval, = ctx.avals_in
batch_dims = a_aval.shape[:-2]
if xla_client.mlir_api_version < 41:
a, taus, info = lapack.gehrd_mhlo(a_aval.dtype, a)
ok = mlir.compare_mhlo(
else:
a, taus, info = lapack.gehrd_hlo(a_aval.dtype, a)
ok = mlir.compare_hlo(
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
"EQ", "SIGNED")
select_a_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
select_taus_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
return [
_broadcasting_select_mhlo(
_broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_a_aval,
broadcast_dimensions=range(len(batch_dims))),
select_a_aval,
a, ctx.avals_out[0], _nan_like_mhlo(ctx, ctx.avals_out[0]), ctx.avals_out[0]),
_broadcasting_select_mhlo(
a, ctx.avals_out[0], _nan_like_hlo(ctx, ctx.avals_out[0]), ctx.avals_out[0]),
_broadcasting_select_hlo(
ctx,
mlir.broadcast_in_dim(ctx, ok, select_taus_aval,
broadcast_dimensions=range(len(batch_dims))),
select_taus_aval,
taus, ctx.avals_out[1], _nan_like_mhlo(ctx, ctx.avals_out[1]), ctx.avals_out[1]),
taus, ctx.avals_out[1], _nan_like_hlo(ctx, ctx.avals_out[1]), ctx.avals_out[1]),
]
mlir.register_lowering(hessenberg_p, _hessenberg_cpu_mhlo, platform='cpu')
mlir.register_lowering(hessenberg_p, _hessenberg_cpu_hlo, platform='cpu')
# tridiagonal: Upper Hessenberg reduction
@ -2085,35 +2137,40 @@ def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule
def _tridiagonal_cpu_gpu_mhlo(sytrd_impl, ctx, a, *, lower):
def _tridiagonal_cpu_gpu_hlo(sytrd_impl, ctx, a, *, lower):
a_aval, = ctx.avals_in
a, d, e, taus, info = sytrd_impl(a_aval.dtype, a, lower=lower)
return a, d, e, taus, info
if jaxlib_version >= (0, 3, 25):
if xla_client.mlir_api_version < 41:
mlir.register_lowering(
tridiagonal_p, partial(_tridiagonal_cpu_gpu_mhlo, lapack.sytrd_mhlo),
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_mhlo),
platform='cpu')
else:
mlir.register_lowering(
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_hlo),
platform='cpu')
mlir.register_lowering(
tridiagonal_p, partial(_tridiagonal_cpu_gpu_mhlo, gpu_solver.cuda_sytrd),
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.cuda_sytrd),
platform='cuda')
mlir.register_lowering(
tridiagonal_p, partial(_tridiagonal_cpu_gpu_mhlo, gpu_solver.rocm_sytrd),
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.rocm_sytrd),
platform='rocm')
# Utilities
def _nan_like_mhlo(ctx: mlir.LoweringRuleContext, aval) -> ir.Value:
def _nan_like_hlo(ctx: mlir.LoweringRuleContext, aval) -> ir.Value:
if jnp.issubdtype(aval.dtype, np.complexfloating):
return mlir.full_like_aval(ctx, np.nan + np.nan * 1j, aval)
else:
return mlir.full_like_aval(ctx, np.nan, aval)
def _broadcasting_select_mhlo(ctx, which, which_aval, x, x_aval, y, y_aval) -> ir.Value:
def _broadcasting_select_hlo(ctx, which, which_aval, x, x_aval, y, y_aval) -> ir.Value:
"""Wrapper around XLA `Select` that broadcasts its arguments."""
out_shapes = list(lax_internal.broadcast_shapes(
tuple(which_aval.shape), tuple(x_aval.shape), tuple(y_aval.shape)))
which, x, y = mlir.multi_broadcast_in_dim(ctx, (which, x, y),
(which_aval, x_aval, y_aval),
out_shapes)
return mhlo.SelectOp(which, x, y).result
return hlo.SelectOp(which, x, y).result

View File

@ -39,7 +39,7 @@ import jax._src.util as util
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis
from jax._src.lib.mlir import ir
from jax._src.lib import mlir_api_version
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
unsafe_map, map = map, safe_map # type: ignore
@ -219,7 +219,7 @@ def ppermute(x, axis_name, perm):
If ``x`` is a pytree then the result is equivalent to mapping this function to
each leaf in the tree.
This function is an analog of the CollectivePermute XLA HLO.
This function is an analog of the CollectivePermute HLO.
Args:
x: array(s) with a mapped axis named ``axis_name``.
@ -661,7 +661,7 @@ def _replica_groups(axis_env, axis_name, axis_index_groups):
for axis_index_group in axis_index_groups]
return replica_groups
def _replica_groups_mhlo(replica_groups: Sequence[Sequence[int]]
def _replica_groups_hlo(replica_groups: Sequence[Sequence[int]]
) -> ir.DenseIntElementsAttr:
# Uneven replica groups are padded with -1.
groups = np.array(list(itertools.zip_longest(*replica_groups, fillvalue=-1)),
@ -711,7 +711,7 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
if not named_axes:
return args
replica_groups = _replica_groups_mhlo(
replica_groups = _replica_groups_hlo(
_replica_groups(ctx.module_context.axis_env, named_axes,
axis_index_groups))
axis_context = ctx.module_context.axis_context
@ -722,12 +722,12 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
if is_spmd:
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=mhlo.ChannelHandle.get(
channel_handle=hlo.ChannelHandle.get(
channel, mlir.DEVICE_TO_DEVICE_TYPE),
use_global_device_ids=ir.BoolAttr.get(True))
else:
other_args = {}
op = mhlo.AllReduceOp(
op = hlo.AllReduceOp(
x.type, x, replica_groups=replica_groups, **other_args)
scalar_aval = core.ShapedArray((), aval.dtype)
scalar_type = mlir.aval_to_ir_type(scalar_aval)
@ -738,7 +738,7 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
avals_in=[scalar_aval] * 2, avals_out=[scalar_aval])
out_nodes = lower_reducer(
reducer_ctx, *([a] for a in reducer_block.arguments))
mhlo.ReturnOp(util.flatten(out_nodes))
hlo.ReturnOp(util.flatten(out_nodes))
return op.result
return [all_reduce(aval, x) for aval, x in zip(ctx.avals_in, args)]
@ -849,11 +849,11 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm):
if is_manual and mlir_api_version >= 35:
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=mhlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE))
channel_handle=hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE))
else:
other_args = {}
return mhlo.CollectivePermuteOp(
return hlo.CollectivePermuteOp(
x, mlir.dense_int_elements(full_perm), **other_args).results
def _ppermute_transpose_rule(t, x, perm, axis_name):
@ -952,16 +952,16 @@ def _all_to_all_lowering(ctx, x, *,
# of partitions - and XLA is configured with only a single replica.
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=mhlo.ChannelHandle.get(channel,
channel_handle=hlo.ChannelHandle.get(channel,
mlir.DEVICE_TO_DEVICE_TYPE))
else:
other_args = {}
return mhlo.AllToAllOp(
return hlo.AllToAllOp(
operand,
split_dimension=mlir.i64_attr(split_axis),
concat_dimension=mlir.i64_attr(concat_axis),
split_count=mlir.i64_attr(split_count),
replica_groups=_replica_groups_mhlo(replica_groups),
replica_groups=_replica_groups_hlo(replica_groups),
**other_args).results
else:
warnings.warn(
@ -1184,7 +1184,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
new_shape = list(x_aval.shape)
new_shape.insert(all_gather_dimension, 1)
broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension]
x = mhlo.BroadcastInDimOp(
x = hlo.BroadcastInDimOp(
mlir.aval_to_ir_type(x_aval.update(shape=new_shape)), x,
mlir.dense_int_elements(broadcast_dimensions))
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
@ -1195,15 +1195,15 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
# of partitions - and XLA is configured with only a single replica.
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=mhlo.ChannelHandle.get(
channel_handle=hlo.ChannelHandle.get(
channel, mlir.DEVICE_TO_DEVICE_TYPE),
use_global_device_ids=ir.BoolAttr.get(True))
else:
other_args = {}
return mhlo.AllGatherOp(
return hlo.AllGatherOp(
mlir.aval_to_ir_type(out_aval),
x, all_gather_dim=mlir.i64_attr(all_gather_dimension),
replica_groups=_replica_groups_mhlo(replica_groups),
replica_groups=_replica_groups_hlo(replica_groups),
**other_args).results
else:
lowering = mlir.lower_fun(_all_gather_via_psum, multiple_results=False)
@ -1328,16 +1328,16 @@ def _reduce_scatter_lowering(prim, reducer, ctx, x,
# of partitions - and XLA is configured with only a single replica.
channel = ctx.module_context.new_channel()
other_args = dict(
channel_handle=mhlo.ChannelHandle.get(
channel_handle=hlo.ChannelHandle.get(
channel, mlir.DEVICE_TO_DEVICE_TYPE),
use_global_device_ids=ir.BoolAttr.get(True))
else:
other_args = {}
op = mhlo.ReduceScatterOp(
op = hlo.ReduceScatterOp(
mlir.aval_to_ir_type(x_aval.update(shape=scatter_out_shape)),
x,
scatter_dimension=mlir.i64_attr(scatter_dimension),
replica_groups=_replica_groups_mhlo(replica_groups),
replica_groups=_replica_groups_hlo(replica_groups),
**other_args)
scalar_type = mlir.aval_to_ir_type(scalar_aval)
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
@ -1348,12 +1348,12 @@ def _reduce_scatter_lowering(prim, reducer, ctx, x,
avals_out=[scalar_aval])
out_nodes = lower_reducer(
reducer_ctx, *([a] for a in reducer_block.arguments))
mhlo.ReturnOp(util.flatten(out_nodes))
hlo.ReturnOp(util.flatten(out_nodes))
if tiled:
return op.results
else:
return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), op.result).results
return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), op.result).results
else:
return mlir.lower_fun(_reduce_scatter_via_reducer, multiple_results=False)(
ctx, x,
@ -1522,7 +1522,7 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, t
return tree_util.tree_map(bind, x)
def _build_axis_index_lowering_mhlo(ctx, axis_name, axis_env):
def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
if isinstance(axis_name, tuple):
assert axis_name, 'empty axis name'
if len(axis_name) > 1:
@ -1539,20 +1539,20 @@ def _build_axis_index_lowering_mhlo(ctx, axis_name, axis_env):
(mlir.SPMDAxisContext, mlir.ShardingContext))
if is_spmd:
if mlir_api_version >= 39:
device_id = mhlo.PartitionIdOp()
device_id = hlo.PartitionIdOp()
else:
device_id = mhlo.PartitionIdOp(
device_id = hlo.PartitionIdOp(
ir.RankedTensorType.get([], ir.IntegerType.get_unsigned(32)))
else:
device_id = mhlo.ReplicaIdOp()
unsigned_index = mhlo.RemOp(mhlo.DivOp(device_id, div), mod)
return mhlo.ConvertOp(
device_id = hlo.ReplicaIdOp()
unsigned_index = hlo.RemOp(hlo.DivOp(device_id, div), mod)
return hlo.ConvertOp(
ir.RankedTensorType.get([], ir.IntegerType.get_signless(32)),
unsigned_index).result
def _axis_index_lowering(ctx, *, axis_name):
return [
_build_axis_index_lowering_mhlo(ctx, axis_name,
_build_axis_index_lowering_hlo(ctx, axis_name,
ctx.module_context.axis_env)
]

View File

@ -36,7 +36,7 @@ from jax._src.lax import lax
from jax._src import util
from jax._src.util import safe_map, safe_zip
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.typing import Array, ArrayLike, Shape
@ -956,7 +956,7 @@ mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower)
# def _getslice_lower(ctx, x, lo, hi):
# aval_out, = ctx.avals_out
# return mhlo.RealDynamicSliceOp(
# return hlo.RealDynamicSliceOp(
# mlir.aval_to_ir_type(aval_out), x,
# mlir.shape_tensor([lo]), mlir.shape_tensor([hi]), mlir.shape_tensor([1])
# ).results
@ -1393,7 +1393,7 @@ def _gather_lower(ctx, operand, indices, *,
assert mode in (GatherScatterMode.PROMISE_IN_BOUNDS,
GatherScatterMode.CLIP), mode
dnums = mhlo.GatherDimensionNumbers.get(
dnums = hlo.GatherDimensionNumbers.get(
collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims),
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
offset_dims=list(dimension_numbers.offset_dims),
@ -1402,7 +1402,7 @@ def _gather_lower(ctx, operand, indices, *,
slice_sizes = mlir.eval_dynamic_shape(ctx, slice_sizes)
# TODO(burmako): Fix overly conservative type inference of DynamicGatherOp.
# For now use the build_generic so that we can specify the result type.
# return mhlo.DynamicGatherOp(
# return hlo.DynamicGatherOp(
# operand, indices, mlir.shape_tensor(slice_sizes),
# dnums, indices_are_sorted=ir.BoolAttr.get(indices_are_sorted)).results
results = [mlir.aval_to_ir_type(aval_out)]
@ -1411,10 +1411,10 @@ def _gather_lower(ctx, operand, indices, *,
"dimension_numbers": dnums,
"indices_are_sorted": ir.BoolAttr.get(indices_are_sorted)
}
return mhlo.DynamicGatherOp.build_generic(
return hlo.DynamicGatherOp.build_generic(
results=results, operands=operands, attributes=attributes).results
else:
return mhlo.GatherOp(
return hlo.GatherOp(
operand,
indices,
dnums,
@ -2019,7 +2019,7 @@ def _scatter_lower(ctx, operand, indices, updates, *,
aval_out, = ctx.avals_out
dnums = dimension_numbers
scatter_dnums = mhlo.ScatterDimensionNumbers.get(
scatter_dnums = hlo.ScatterDimensionNumbers.get(
update_window_dims=list(dnums.update_window_dims),
inserted_window_dims=list(dnums.inserted_window_dims),
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
@ -2027,7 +2027,7 @@ def _scatter_lower(ctx, operand, indices, updates, *,
result = mlir.aval_to_ir_types(aval_out)
operand = [operand]
updates = [updates]
op = mhlo.ScatterOp(
op = hlo.ScatterOp(
result,
operand,
indices,
@ -2045,7 +2045,7 @@ def _scatter_lower(ctx, operand, indices, updates, *,
update_ctx, update_jaxpr, mlir.TokenSet(), update_consts,
(update.arguments[0],), (update.arguments[1],),
dim_var_values=ctx.dim_var_values)
mhlo.ReturnOp(util.flatten(out_nodes))
hlo.ReturnOp(util.flatten(out_nodes))
return op.results
mlir.register_lowering(scatter_p, _scatter_lower)
@ -2076,7 +2076,7 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
aval_out, = ctx.avals_out
dnums = dimension_numbers
scatter_dnums = mhlo.ScatterDimensionNumbers.get(
scatter_dnums = hlo.ScatterDimensionNumbers.get(
update_window_dims=list(dnums.update_window_dims),
inserted_window_dims=list(dnums.inserted_window_dims),
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
@ -2089,7 +2089,7 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
operand_part = [operand_part]
updates_part = [updates_part]
scatter = mhlo.ScatterOp(
scatter = hlo.ScatterOp(
operand_type_part,
operand_part,
indices,
@ -2100,13 +2100,13 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), real_dtype))
reducer = scatter.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer):
add = mhlo.AddOp(*reducer.arguments).result
mhlo.ReturnOp([add])
add = hlo.AddOp(*reducer.arguments).result
hlo.ReturnOp([add])
return scatter.result
real = _scatter(mhlo.RealOp(operand).result, mhlo.RealOp(updates).result)
imag = _scatter(mhlo.ImagOp(operand).result, mhlo.ImagOp(updates).result)
return mhlo.ComplexOp(real, imag).results
real = _scatter(hlo.RealOp(operand).result, hlo.RealOp(updates).result)
imag = _scatter(hlo.ImagOp(operand).result, hlo.ImagOp(updates).result)
return hlo.ComplexOp(real, imag).results
mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu")

View File

@ -33,7 +33,7 @@ import jax._src.lax.lax as lax
import jax._src.lax.convolution as convolution
import jax._src.lax.slicing as slicing
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
import jax._src.util as util
@ -316,7 +316,7 @@ def _generic_reduce_window_lower(ctx, *args, jaxpr, consts,
operands, init_values = util.split_list(args, [len(args) // 2])
_, init_value_avals = util.split_list(ctx.avals_in, [len(operands)])
scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
rw = mhlo.ReduceWindowOp(
rw = hlo.ReduceWindowOp(
map(mlir.aval_to_ir_type, ctx.avals_out),
operands,
init_values,
@ -333,7 +333,7 @@ def _generic_reduce_window_lower(ctx, *args, jaxpr, consts,
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr,
mlir.TokenSet(), consts, *([a] for a in reducer.arguments),
dim_var_values=ctx.dim_var_values)
mhlo.ReturnOp(util.flatten(out_nodes))
hlo.ReturnOp(util.flatten(out_nodes))
return rw.results
mlir.register_lowering(reduce_window_p, _generic_reduce_window_lower)
@ -468,7 +468,7 @@ def _reduce_window_lower(
operand_aval, = ctx.avals_in
scalar_aval = operand_aval.update(shape=())
scalar_type = mlir.aval_to_ir_type(scalar_aval)
rw = mhlo.ReduceWindowOp(
rw = hlo.ReduceWindowOp(
mlir.aval_to_ir_types(aval_out), [operand],
[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype), scalar_aval)],
mlir.dense_int_elements(window_dimensions),
@ -479,15 +479,15 @@ def _reduce_window_lower(
shape=(len(padding), 2)))
reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(reducer):
mhlo.ReturnOp(reduce_op(*reducer.arguments))
hlo.ReturnOp(reduce_op(*reducer.arguments))
return rw.results
mlir.register_lowering(reduce_window_sum_p, partial(
_reduce_window_lower, mhlo.AddOp, lambda _: 0))
_reduce_window_lower, hlo.AddOp, lambda _: 0))
mlir.register_lowering(reduce_window_min_p, partial(
_reduce_window_lower, mlir.min_mhlo, lax._get_min_identity))
_reduce_window_lower, mlir.min_hlo, lax._get_min_identity))
mlir.register_lowering(reduce_window_max_p, partial(
_reduce_window_lower, mlir.max_mhlo, lax._get_max_identity))
_reduce_window_lower, mlir.max_hlo, lax._get_max_identity))
@ -514,7 +514,7 @@ def _select_and_scatter_lower(
aval_out, = ctx.avals_out
scalar_aval = operand_aval.update(shape=())
scalar_type = mlir.aval_to_ir_type(scalar_aval)
op = mhlo.SelectAndScatterOp(
op = hlo.SelectAndScatterOp(
mlir.aval_to_ir_type(aval_out),
operand,
source,
@ -531,7 +531,7 @@ def _select_and_scatter_lower(
mlir.TokenSet(), select_consts,
*([a] for a in select.arguments),
dim_var_values=ctx.dim_var_values)
mhlo.ReturnOp(util.flatten(out_nodes))
hlo.ReturnOp(util.flatten(out_nodes))
scatter = op.scatter.blocks.append(scalar_type, scalar_type)
with ir.InsertionPoint(scatter):
if scatter_jaxpr.effects:
@ -540,7 +540,7 @@ def _select_and_scatter_lower(
mlir.TokenSet(), scatter_consts,
*([a] for a in scatter.arguments),
dim_var_values=ctx.dim_var_values)
mhlo.ReturnOp(util.flatten(out_nodes))
hlo.ReturnOp(util.flatten(out_nodes))
return op.results
mlir.register_lowering(select_and_scatter_p, _select_and_scatter_lower)
@ -670,7 +670,7 @@ def _select_and_gather_add_lowering(
canonicalize_types=False)
def _broadcast(x, dims):
return mhlo.BroadcastOp(x, mlir.dense_int_elements(dims))
return hlo.BroadcastOp(x, mlir.dense_int_elements(dims))
if double_word_reduction:
# TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so
@ -685,28 +685,28 @@ def _select_and_gather_add_lowering(
def pack(a, b):
a_dims = ir.RankedTensorType(a.type).shape
b_dims = ir.RankedTensorType(b.type).shape
a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
a = mhlo.ConvertOp(ir.RankedTensorType.get(a_dims, double_word_type), a)
b = mhlo.ConvertOp(ir.RankedTensorType.get(b_dims, double_word_type), b)
a = mhlo.ShiftLeftOp(a,
a = hlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
b = hlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
a = hlo.ConvertOp(ir.RankedTensorType.get(a_dims, double_word_type), a)
b = hlo.ConvertOp(ir.RankedTensorType.get(b_dims, double_word_type), b)
a = hlo.ShiftLeftOp(a,
_broadcast(const(double_word_dtype, nbits), a_dims))
return mhlo.OrOp(a, b)
return hlo.OrOp(a, b)
# Unpacks the first element of a tuple.
def fst(t):
dims = ir.RankedTensorType(t.type).shape
st = mhlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits))
return mhlo.BitcastConvertOp(
st = hlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits))
return hlo.BitcastConvertOp(
ir.RankedTensorType.get(dims, etype),
mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), st)).result
hlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), st)).result
# Unpacks the second element of a tuple.
def snd(t):
dims = ir.RankedTensorType(t.type).shape
return mhlo.BitcastConvertOp(
return hlo.BitcastConvertOp(
ir.RankedTensorType.get(dims, etype),
mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), t)).result
hlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), t)).result
else:
# The double-word trick above only works if we have a sufficiently large
@ -729,33 +729,33 @@ def _select_and_gather_add_lowering(
def pack(a, b):
a_dims = ir.RankedTensorType(a.type).shape
b_dims = ir.RankedTensorType(b.type).shape
a = mhlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp),
a = hlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
b = mhlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp),
b = hlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp),
mantissa_bits=mlir.i32_attr(nmant))
a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
b = mhlo.ShiftRightLogicalOp(
a = hlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
b = hlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
b = hlo.ShiftRightLogicalOp(
b, _broadcast(const(word_dtype, r_nbits), b_dims))
return mhlo.OrOp(a, b)
return hlo.OrOp(a, b)
# Unpacks the first element of a tuple.
def fst(t):
st = mhlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits))
return mhlo.BitcastConvertOp(ir.RankedTensorType.get([], etype),
st = hlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits))
return hlo.BitcastConvertOp(ir.RankedTensorType.get([], etype),
st).result
# Unpacks the second element of a tuple.
def snd(t):
dims = ir.RankedTensorType(t.type).shape
return mhlo.BitcastConvertOp(
return hlo.BitcastConvertOp(
ir.RankedTensorType.get(dims, etype),
mhlo.ShiftLeftOp(t, _broadcast(const(word_dtype, r_nbits), dims))
hlo.ShiftLeftOp(t, _broadcast(const(word_dtype, r_nbits), dims))
).result
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
init = -np.inf if select_prim is lax.ge_p else np.inf
rw = mhlo.ReduceWindowOp(
rw = hlo.ReduceWindowOp(
[ir.RankedTensorType.get(out_aval.shape, double_word_type)],
pack(operand, tangents),
pack(const(dtype, init), const(dtype, 0)),
@ -771,8 +771,8 @@ def _select_and_gather_add_lowering(
x, y = reducer.arguments
assert select_prim is lax.ge_p or select_prim is lax.le_p
which = "GE" if select_prim is lax.ge_p else "LE"
out = mhlo.SelectOp(mlir.compare_mhlo(fst(x), fst(y), which), x, y)
mhlo.ReturnOp(out)
out = hlo.SelectOp(mlir.compare_hlo(fst(x), fst(y), which), x, y)
hlo.ReturnOp(out)
return [snd(rw.result)]
# TODO(phawkins): use this translation rule on all platforms.

View File

@ -23,3 +23,8 @@ import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
from jax.lib import xla_client
if xla_client.mlir_api_version >= 37:
import jaxlib.mlir.dialects.stablehlo as stablehlo
# Alias that is set up to abstract away the transition from MHLO to StableHLO.
# At the moment, it points to MHLO, but in the future it will start to
# conditionally and then unconditionally point to StableHLO.
import jaxlib.mlir.dialects.mhlo as hlo

View File

@ -280,7 +280,7 @@ def canonicalize_platform(platform: str) -> str:
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
hardware is actually present. We want to distinguish "cuda" and "rocm" for
purposes such as MHLO lowering rules, but in many cases we don't want to
purposes such as MLIR lowering rules, but in many cases we don't want to
force users to care.
"""
platforms = _alias_to_platforms.get(platform, None)

View File

@ -40,7 +40,7 @@ from jax._src import dtypes
from jax._src.api import jit, vmap
from jax._src.lax import lax as lax_internal
from jax._src.lax import utils as lax_utils
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy import lax_numpy
import jax._src.pretty_printer as pp
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
@ -443,7 +443,7 @@ class KeyTyRules:
key_shape = aval_out.dtype.impl.key_shape
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
perm = [*permutation, *trailing_dims]
return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).result
return hlo.TransposeOp(x, mlir.dense_int_elements(perm)).result
@staticmethod
def gather_mlir(ctx, avals_in, aval_out, x, indices, *,
@ -1041,27 +1041,27 @@ def bcast_iotas_to_reshaped_iota(add, mul, shape, iotas):
def iota_2x32_shape_lowering(ctx, *, shape):
def _add(x, y):
return mlir.mhlo.AddOp(x, y).result
return mlir.hlo.AddOp(x, y).result
def _mul(x, y):
x_const = mlir.ir_constant(np.array(x, np.dtype('uint64')),
canonicalize_types=False)
x_bcast = mlir.mhlo.BroadcastOp(x_const, mlir.dense_int_elements(shape))
return mlir.mhlo.MulOp(x_bcast, y).result
x_bcast = mlir.hlo.BroadcastOp(x_const, mlir.dense_int_elements(shape))
return mlir.hlo.MulOp(x_bcast, y).result
assert len(shape) > 0
aval_out, _ = ctx.avals_out
aval_u64 = core.ShapedArray(shape, np.dtype('uint64'))
iotas = [mlir.mhlo.IotaOp(mlir.aval_to_ir_type(aval_u64),
iotas = [mlir.hlo.IotaOp(mlir.aval_to_ir_type(aval_u64),
mlir.i64_attr(dimension)).result
for dimension in range(len(shape))]
counts = bcast_iotas_to_reshaped_iota(_add, _mul, shape, iotas)
shift = mlir.ir_constant(np.array(32, np.dtype('uint64')),
canonicalize_types=False)
shift = mlir.mhlo.BroadcastOp(shift, mlir.dense_int_elements(shape)).result
counts_shifted = mlir.mhlo.ShiftRightLogicalOp(counts, shift).result
counts_lo = mlir.mhlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts).result
counts_hi = mlir.mhlo.ConvertOp(mlir.aval_to_ir_type(aval_out),
shift = mlir.hlo.BroadcastOp(shift, mlir.dense_int_elements(shape)).result
counts_shifted = mlir.hlo.ShiftRightLogicalOp(counts, shift).result
counts_lo = mlir.hlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts).result
counts_hi = mlir.hlo.ConvertOp(mlir.aval_to_ir_type(aval_out),
counts_shifted).result
return counts_hi, counts_lo
mlir.register_lowering(iota_2x32_shape_p, iota_2x32_shape_lowering)

View File

@ -20,7 +20,7 @@ from jax import tree_util
from jax import linear_util as lu
from jax.experimental import pjit
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir import ir
import jax.interpreters.pxla as pxla
from jax.interpreters import mlir
@ -245,7 +245,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
else:
out_type = [ir.TupleType.get_tuple(mlir_shapes)]
out = mhlo.CustomCallOp(
out = hlo.CustomCallOp(
out_type,
list(values),
call_target_name=ir.StringAttr.get(_CUSTOM_PARTITIONING_CALL_NAME),
@ -259,7 +259,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
return [out.result]
else:
return [
mhlo.GetTupleElementOp(out, mlir.i32_attr(i)).result
hlo.GetTupleElementOp(out, mlir.i32_attr(i)).result
for i in range(len(mlir_shapes))
]

View File

@ -524,7 +524,7 @@ from jax._src.lib import pytree
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
import numpy as np
@ -1143,14 +1143,14 @@ def _outside_call_lowering(
assert has_token
current_token = args[-2]
current_itoken = args[-1]
assert current_token.type == mhlo.TokenType.get(), "The last two arguments must be tokens"
assert current_itoken.type == mhlo.TokenType.get(), "The last two arguments must be tokens"
assert current_token.type == hlo.TokenType.get(), "The last two arguments must be tokens"
assert current_itoken.type == hlo.TokenType.get(), "The last two arguments must be tokens"
args_to_outfeed = args[:-2]
# TODO(necula): this is a weak attempt to get the device. This works
# inside pmap, but does not work when we just execute on a single device,
# because in such executions we always get replica_id == 0.
replica_id = mhlo.ReplicaIdOp()
replica_id = hlo.ReplicaIdOp()
callback_operands = [replica_id, *args_to_outfeed]
callback_operand_avals = [
core.ShapedArray((), np.uint32), *ctx.avals_in[:-2]]

View File

@ -39,7 +39,7 @@ from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client
from jax.experimental.jax2tf import jax2tf as jax2tf_internal
@ -392,16 +392,16 @@ def _code_generator_and_avals(
captured_ops = tuple(mlir.ir_constant(np.asarray(inp),
canonicalize_types=False)
for inp in captured_inputs)
submodule = mlir.xla_computation_to_mhlo_module(xla_comp)
submodule = mlir.xla_computation_to_mlir_module(xla_comp)
symtab = ir.SymbolTable(submodule.operation)
callee_result_types = symtab["main"].type.results
fn = mlir.merge_mhlo_modules(ctx.module, f"call_tf_{function_flat_tf.name}",
fn = mlir.merge_mlir_modules(ctx.module, f"call_tf_{function_flat_tf.name}",
submodule)
call = func_dialect.CallOp(callee_result_types,
ir.FlatSymbolRefAttr.get(fn),
tuple(args_op) + captured_ops)
if result_shape.is_tuple():
flat_results = [mhlo.GetTupleElementOp(call, mlir.i32_attr(i)).result
flat_results = [hlo.GetTupleElementOp(call, mlir.i32_attr(i)).result
for i in range(len(result_shapes))]
else:
flat_results = call.results
@ -410,7 +410,7 @@ def _code_generator_and_avals(
for op, res_aval, res_shape in zip(flat_results, result_avals,
result_shapes):
if res_aval.dtype != res_shape.numpy_dtype():
op = mhlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result
op = hlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result
outputs.append(op)
return outputs

View File

@ -589,7 +589,7 @@ def _lower_native_and_run(fun_jax: Callable,
Work-in-progress.
Uses JAX native lowering to MHLO, and then wraps the result in a
Uses JAX native lowering to MLIR, and then wraps the result in a
XlaCallModule TF op. This op does not have backward-compatibility yet.
Special care must be taken in presence of shape polymorphism.
@ -634,13 +634,13 @@ def _lower_native_and_run(fun_jax: Callable,
fun_jax_lower = fun_jax.lower
lowered = fun_jax_lower(*arg_specs_jax)._lowering
if config.jax2tf_use_stablehlo:
mhlo_module = lowered.stablehlo()
mlir_module = lowered.stablehlo()
xla_call_module_version = 2
else:
mhlo_module = lowered.mhlo()
mlir_module = lowered.mhlo()
xla_call_module_version = 1
mhlo_serialized_module = mlir.module_to_bytecode(mhlo_module)
mlir_serialized_module = mlir.module_to_bytecode(mlir_module)
# Figure out the result types and shapes
if "global_out_avals" in lowered.compile_args:
# This is currently the case for pjit
@ -719,7 +719,7 @@ def _lower_native_and_run(fun_jax: Callable,
args_tf = [atf for i, atf in enumerate(args_tf) if i in module_kept_var_idx]
# Apply the shardings on arguments and results for pjit. This is redundant
# because the mhlo_module_text will already contain the shardings, but it
# because the mlir_module_text will already contain the shardings, but it
# makes it easier for tools like the TPU inference converter to see the
# sharding without digging into the `module` attribute of the `XlaCallModule`
# op, in the same way as it is done for the legacy jax2tf conversion.
@ -728,14 +728,14 @@ def _lower_native_and_run(fun_jax: Callable,
map(_shard_value, args_tf, args_avals, lowered.compile_args["in_shardings"]))
if logging.vlog_is_on(3):
mhlo_module_text = mlir.module_to_string(mhlo_module)
mlir_module_text = mlir.module_to_string(mlir_module)
logging.vlog(3, "XlaCallModule (version=%d, dim_args_spec=%s)\n%s",
xla_call_module_version, ", ".join(dim_args_spec),
mhlo_module_text)
mlir_module_text)
res = tfxla.call_module(
args_tf,
version=xla_call_module_version,
module=mhlo_serialized_module,
module=mlir_serialized_module,
Tout=out_types,
Sout=out_shapes,
dim_args_spec=dim_args_spec)

View File

@ -1479,7 +1479,7 @@ class XlaCallModuleTest(tf_test_util.JaxToTfTestCase):
self.assertAllClose(tf.nest.map_structure(lambda t: t.numpy(), res),
jax_res)
@unittest.skip("TODO(necula): 'mhlo.dynamic_iota' op can't be translated to XLA HLO")
@unittest.skip("TODO(necula): 'dynamic_iota' op can't be translated to XLA HLO")
def test_shape_poly_arange(self):
if not config.jax_dynamic_shapes:
raise unittest.SkipTest("jax_dynamic_shapes must be enabled")

View File

@ -1176,7 +1176,7 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
wrap_name(name, "pjit")))
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
# inputs or outputs because they are lost during MHLO->HLO conversion.
# inputs or outputs because they are lost during MLIR->HLO conversion.
# using_sharding_annotation=False means we add an identity operation instead.
func = mlir.lower_jaxpr_to_fun(sub_ctx, f"pjit_{name}", jaxpr, (),
arg_shardings=arg_shardings,
@ -1544,7 +1544,7 @@ sharding_constraint_p.def_abstract_eval(lambda x, **_: x)
ad.deflinear2(sharding_constraint_p,
lambda ct, _, **params: (sharding_constraint_p.bind(ct, **params),))
def _sharding_constraint_mhlo_lowering(ctx, x_node, *, sharding,
def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
resource_env, unconstrained_dims):
aval, = ctx.avals_in
axis_ctx = ctx.module_context.axis_context
@ -1564,7 +1564,7 @@ def _sharding_constraint_mhlo_lowering(ctx, x_node, *, sharding,
unspecified_dims=unconstrained_dims)
]
mlir.register_lowering(sharding_constraint_p,
_sharding_constraint_mhlo_lowering)
_sharding_constraint_hlo_lowering)
def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size,

View File

@ -48,7 +48,7 @@ from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode
from jax._src.lib.mlir import ir
from jax._src.lib import xla_bridge
from jax._src.lib import gpu_sparse
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.setops import _unique
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.util import canonicalize_axis
@ -728,13 +728,13 @@ def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_num
_bcoo_dot_general_default_lowering = mlir.lower_fun(
_bcoo_dot_general_impl, multiple_results=False)
def _collapse_mhlo(x, start, end):
def _collapse_hlo(x, start, end):
x_type = ir.RankedTensorType(x.type)
shape = x_type.shape
shape = (shape[:start]
+ [functools.reduce(operator.mul, shape[start:end + 1])]
+ shape[end + 1:])
return mhlo.ReshapeOp(
return hlo.ReshapeOp(
ir.RankedTensorType.get(shape, x_type.element_type), x).result
def _bcoo_dot_general_cuda_lowering(
@ -766,7 +766,7 @@ def _bcoo_dot_general_cuda_lowering(
elif rhs_ndim == 2:
bcoo_dot_general_fn = coo_matmat_lowering
if rhs_contract[0] == 1:
rhs = mhlo.TransposeOp(
rhs = hlo.TransposeOp(
rhs, permutation=mlir.dense_int_elements([1, 0])).result
else:
raise ValueError(f"rhs has to be 1d or 2d; get {rhs_ndim}d.")
@ -776,7 +776,7 @@ def _bcoo_dot_general_cuda_lowering(
lhs_transpose = False
if props.n_sparse == 1:
# Converts lhs to a row vector.
col = _collapse_mhlo(lhs_indices, start=0, end=1)
col = _collapse_hlo(lhs_indices, start=0, end=1)
row = mlir.full_like_aval(
ctx, 0, core.ShapedArray(ir.RankedTensorType(col.type).shape,
np.dtype(np.int32)))
@ -788,23 +788,23 @@ def _bcoo_dot_general_cuda_lowering(
if rhs_ndim == 1:
# Transforms a single-element array to a scalar.
return [mhlo.ReshapeOp(
return [hlo.ReshapeOp(
ir.RankedTensorType.get(
[], ir.RankedTensorType(dot_product.type).element_type),
dot_product).result]
else:
return [_collapse_mhlo(dot_product, start=0, end=1)]
return [_collapse_hlo(dot_product, start=0, end=1)]
elif props.n_sparse == 2:
lhs_indices_shape = ir.RankedTensorType(lhs_indices.type).shape
row = _collapse_mhlo(
mhlo.SliceOp(
row = _collapse_hlo(
hlo.SliceOp(
lhs_indices,
start_indices=mlir.dense_int_elements([0, 0]),
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 1]),
strides=mlir.dense_int_elements([1, 1])).result,
start=0, end=1)
col = _collapse_mhlo(
mhlo.SliceOp(
col = _collapse_hlo(
hlo.SliceOp(
lhs_indices,
start_indices=mlir.dense_int_elements([0, 1]),
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 2]),
@ -833,28 +833,28 @@ def _bcoo_dot_general_cuda_lowering(
lhs_indices_shape[-1])
lhs_data_1d_shape = (np.prod(np.array(lhs_data_shape)), )
lhs_indices_2d = mhlo.ReshapeOp(
lhs_indices_2d = hlo.ReshapeOp(
ir.RankedTensorType.get(
lhs_indices_2d_shape,
ir.RankedTensorType(lhs_indices.type).element_type),
lhs_indices).result
lhs_data_1d = mhlo.ReshapeOp(
lhs_data_1d = hlo.ReshapeOp(
ir.RankedTensorType.get(
lhs_data_1d_shape,
ir.RankedTensorType(lhs_data.type).element_type),
lhs_data).result
row = _collapse_mhlo(
mhlo.SliceOp(
row = _collapse_hlo(
hlo.SliceOp(
lhs_indices_2d,
start_indices=mlir.dense_int_elements([0, 0]),
limit_indices=mlir.dense_int_elements([lhs_indices_2d_shape[0], 1]),
strides=mlir.dense_int_elements([1, 1])).result,
start=0, end=1)
col = _collapse_mhlo(
mhlo.SliceOp(
col = _collapse_hlo(
hlo.SliceOp(
lhs_indices_2d,
start_indices=mlir.dense_int_elements([0, 1]),
limit_indices=mlir.dense_int_elements([lhs_indices_2d_shape[0], 2]),
@ -867,13 +867,13 @@ def _bcoo_dot_general_cuda_lowering(
# The issue (https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643)
# in cusparse library does not allow batch_stride = 0 for a non-batched rhs.
batched_rhs_shape = (batch_count,) + tuple(rhs_shape)
batched_rhs = mhlo.BroadcastInDimOp(
batched_rhs = hlo.BroadcastInDimOp(
ir.RankedTensorType.get(batched_rhs_shape,
ir.RankedTensorType(rhs.type).element_type),
rhs,
broadcast_dimensions=mlir.dense_int_elements([1, 2])).result
batched_rhs_2d_shape = (np.prod(np.array(batched_rhs_shape)[:-1]), batched_rhs_shape[-1])
batched_rhs_2d = mhlo.ReshapeOp(
batched_rhs_2d = hlo.ReshapeOp(
ir.RankedTensorType.get(
batched_rhs_2d_shape,
ir.RankedTensorType(batched_rhs.type).element_type),
@ -1404,12 +1404,12 @@ def _bcoo_sort_indices_jvp(primals, tangents, *, spinfo):
data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot, perm)
return (data_out, indices_out), (data_dot_out, indices_dot_out)
_bcoo_sort_indices_mhlo = mlir.lower_fun(
_bcoo_sort_indices_hlo = mlir.lower_fun(
_bcoo_sort_indices_impl, multiple_results=True)
ad.primitive_jvps[bcoo_sort_indices_p] = _bcoo_sort_indices_jvp
batching.primitive_batchers[bcoo_sort_indices_p] = _bcoo_sort_indices_batching_rule
mlir.register_lowering(bcoo_sort_indices_p, _bcoo_sort_indices_mhlo)
mlir.register_lowering(bcoo_sort_indices_p, _bcoo_sort_indices_hlo)
#----------------------------------------------------------------------
@ -1560,12 +1560,12 @@ def _bcoo_sum_duplicates_jvp(primals, tangents, *, spinfo, nse):
data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot_out, mapping, data_dot)
return (data_out, indices_out), (data_dot_out, indices_dot_out)
_bcoo_sum_duplicates_mhlo = mlir.lower_fun(
_bcoo_sum_duplicates_hlo = mlir.lower_fun(
_bcoo_sum_duplicates_impl, multiple_results=True)
ad.primitive_jvps[bcoo_sum_duplicates_p] = _bcoo_sum_duplicates_jvp
batching.primitive_batchers[bcoo_sum_duplicates_p] = _bcoo_sum_duplicates_batching_rule
mlir.register_lowering(bcoo_sum_duplicates_p, _bcoo_sum_duplicates_mhlo)
mlir.register_lowering(bcoo_sum_duplicates_p, _bcoo_sum_duplicates_hlo)
#----------------------------------------------------------------------
# BCOO functions that maybe should be primitives?

View File

@ -30,7 +30,7 @@ from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning
from jax import tree_util
from jax._src.lax.lax import _const
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import gpu_sparse
from jax._src.numpy.lax_numpy import _promote_dtypes
from jax._src.typing import Array, ArrayLike, DTypeLike
@ -199,7 +199,7 @@ def _coo_todense_abstract_eval(data, row, col, *, spinfo):
_coo_todense_lowering = mlir.lower_fun(
_coo_todense_impl, multiple_results=False)
def _coo_todense_gpu_lowering(coo_todense_mhlo, ctx, data, row, col, *, spinfo):
def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo):
data_aval, row_aval, _ = ctx.avals_in
dtype = data_aval.dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
@ -220,10 +220,10 @@ def _coo_todense_gpu_lowering(coo_todense_mhlo, ctx, data, row, col, *, spinfo):
"back to the default implementation.", CuSparseEfficiencyWarning)
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)
result = coo_todense_mhlo(
result = coo_todense_hlo(
data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype)
return (
[mhlo.TransposeOp(result, mlir.dense_int_elements([1, 0])).result]
[hlo.TransposeOp(result, mlir.dense_int_elements([1, 0])).result]
if transpose else [result])
@ -318,14 +318,14 @@ def _coo_fromdense_abstract_eval(mat, *, nse, index_dtype):
_coo_fromdense_lowering = mlir.lower_fun(
_coo_fromdense_impl, multiple_results=True)
def _coo_fromdense_gpu_lowering(coo_fromdense_mhlo, ctx, mat, *, nse,
def _coo_fromdense_gpu_lowering(coo_fromdense_hlo, ctx, mat, *, nse,
index_dtype):
dtype = ctx.avals_in[0].dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
warnings.warn(f"coo_fromdense cusparse/hipsparse lowering not available for {dtype=}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _coo_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype)
data, row, col = coo_fromdense_mhlo(
data, row, col = coo_fromdense_hlo(
mat, nnz=nse,
data_dtype=dtype,
index_dtype=np.dtype(index_dtype),
@ -438,7 +438,7 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, spinfo, transpose):
_coo_matvec_lowering = mlir.lower_fun(
_coo_matvec_impl, multiple_results=False)
def _coo_matvec_gpu_lowering(coo_matvec_mhlo, ctx, data, row, col, v, *, spinfo,
def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo,
transpose):
data_aval, row_aval, _, x_aval = ctx.avals_in
dtype = data_aval.dtype
@ -461,7 +461,7 @@ def _coo_matvec_gpu_lowering(coo_matvec_mhlo, ctx, data, row, col, v, *, spinfo,
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo,
transpose=transpose)
return [coo_matvec_mhlo(
return [coo_matvec_hlo(
data, row, col, v, shape=shape, transpose=transpose,
index_dtype=row_aval.dtype, data_dtype=dtype, x_dtype=x_aval.dtype)]
@ -561,7 +561,7 @@ def _coo_matmat_abstract_eval(data, row, col, B, *, spinfo, transpose):
_coo_matmat_lowering = mlir.lower_fun(_coo_matmat_impl, multiple_results=False)
def _coo_matmat_gpu_lowering(coo_matmat_mhlo, ctx, data, row, col, B, *, spinfo,
def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo,
transpose):
data_aval, row_aval, _, B_aval = ctx.avals_in
dtype = data_aval.dtype
@ -583,7 +583,7 @@ def _coo_matmat_gpu_lowering(coo_matmat_mhlo, ctx, data, row, col, B, *, spinfo,
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo,
transpose=transpose)
return [coo_matmat_mhlo(data, row, col, B, shape=shape,
return [coo_matmat_hlo(data, row, col, B, shape=shape,
transpose=transpose, x_dtype=B_aval.dtype,
data_dtype=data_aval.dtype,
index_dtype=row_aval.dtype)]

View File

@ -227,7 +227,7 @@ def _csr_todense_abstract_eval(data, indices, indptr, *, shape):
_csr_todense_lowering = mlir.lower_fun(
_csr_todense_impl, multiple_results=False)
def _csr_todense_gpu_lowering(csr_todense_mhlo, ctx, data, indices, indptr, *,
def _csr_todense_gpu_lowering(csr_todense_hlo, ctx, data, indices, indptr, *,
shape):
data_aval, indices_aval, _ = ctx.avals_in
dtype = data_aval.dtype
@ -235,7 +235,7 @@ def _csr_todense_gpu_lowering(csr_todense_mhlo, ctx, data, indices, indptr, *,
warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for {dtype=}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_todense_lowering(ctx, data, indices, indptr, shape=shape)
return [csr_todense_mhlo(
return [csr_todense_hlo(
data, indices, indptr, shape=shape, data_dtype=dtype,
index_dtype=indices_aval.dtype)]
@ -319,13 +319,13 @@ def _csr_fromdense_abstract_eval(mat, *, nse, index_dtype):
_csr_fromdense_lowering = mlir.lower_fun(_csr_fromdense_impl,
multiple_results=True)
def _csr_fromdense_gpu_lowering(csr_fromdense_mhlo, ctx, mat, *, nse, index_dtype):
def _csr_fromdense_gpu_lowering(csr_fromdense_hlo, ctx, mat, *, nse, index_dtype):
dtype = ctx.avals_in[0].dtype
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
warnings.warn(f"csr_fromdense cusparse/hipsparse lowering not available for {dtype=}. "
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype)
data, indices, indptr = csr_fromdense_mhlo(
data, indices, indptr = csr_fromdense_hlo(
mat, nnz=nse, index_dtype=np.dtype(index_dtype),
data_dtype=dtype, index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype)))
return [data, indices, indptr]
@ -412,7 +412,7 @@ def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose):
_csr_matvec_lowering = mlir.lower_fun(_csr_matvec_impl, multiple_results=False)
def _csr_matvec_gpu_lowering(csr_matvec_mhlo, ctx, data, indices, indptr, v, *,
def _csr_matvec_gpu_lowering(csr_matvec_hlo, ctx, data, indices, indptr, v, *,
shape, transpose):
data_aval, indices_aval, _, v_aval = ctx.avals_in
dtype = data_aval.dtype
@ -421,7 +421,7 @@ def _csr_matvec_gpu_lowering(csr_matvec_mhlo, ctx, data, indices, indptr, v, *,
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_matvec_lowering(ctx, data, indices, indptr, v, shape=shape,
transpose=transpose)
return [csr_matvec_mhlo(
return [csr_matvec_hlo(
data, indices, indptr, v, shape=shape, transpose=transpose,
data_dtype=dtype, index_dtype=indices_aval.dtype, x_dtype=v_aval.dtype)]
@ -504,7 +504,7 @@ def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose):
_csr_matmat_lowering = mlir.lower_fun(_csr_matmat_impl, multiple_results=False)
def _csr_matmat_gpu_lowering(csr_matmat_mhlo, ctx, data, indices, indptr, B, *,
def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *,
shape, transpose):
data_aval, indices_aval, _, B_aval = ctx.avals_in
dtype = data_aval.dtype
@ -513,7 +513,7 @@ def _csr_matmat_gpu_lowering(csr_matmat_mhlo, ctx, data, indices, indptr, B, *,
"Falling back to default implementation.", CuSparseEfficiencyWarning)
return _csr_matmat_lowering(ctx, data, indices, indptr, B, shape=shape,
transpose=transpose)
return [csr_matmat_mhlo(
return [csr_matmat_hlo(
data, indices, indptr, B, shape=shape, transpose=transpose,
index_dtype=indices_aval.dtype, data_dtype=data_aval.dtype,
B_dtype=B_aval.dtype)]

View File

@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lowering and execution path that converts jaxprs into the MLIR MHLO/CHLO
# dialects.
# Lowering and execution path that converts jaxprs into MLIR.
from __future__ import annotations
import collections
@ -35,8 +34,7 @@ from jax._src import device_array
from jax._src import dtypes
from jax._src.lib import mlir_api_version, xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import can_execute_with_token
from jax._src.lib import xla_bridge as xb
@ -86,12 +84,12 @@ def shape_tensor(sizes: Sequence[Union[int, ir.RankedTensorType]]
if type(d) is int:
return ir_constant(np.array([d], np.int32))
else:
return mhlo.ReshapeOp(int1d, mhlo.ConvertOp(aval_to_ir_type(core.ShapedArray((), np.int32)), d))
return hlo.ReshapeOp(int1d, hlo.ConvertOp(aval_to_ir_type(core.ShapedArray((), np.int32)), d))
d, *ds = map(lower_dim, sizes)
if not ds:
return d
else:
return mhlo.ConcatenateOp([d, *ds], i64_attr(0)).result
return hlo.ConcatenateOp([d, *ds], i64_attr(0)).result
def delegate_lowering(ctx, lowering_fun, *args, **ctx_override_kwargs):
@ -162,7 +160,7 @@ def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]:
ir_type_handlers[core.ShapedArray] = _array_ir_types
ir_type_handlers[core.ConcreteArray] = _array_ir_types
ir_type_handlers[core.AbstractToken] = lambda _: [mhlo.TokenType.get()]
ir_type_handlers[core.AbstractToken] = lambda _: [hlo.TokenType.get()]
ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types
def aval_to_ir_type(aval: core.AbstractValue) -> ir.Type:
@ -239,7 +237,7 @@ def _numpy_array_constant(x: np.ndarray, canonicalize_types
x = x.view(np.uint16)
x = np.ascontiguousarray(x)
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
return (mhlo.ConstantOp(attr).result,)
return (hlo.ConstantOp(attr).result,)
@ -272,7 +270,7 @@ def _ndarray_constant_handler(val: np.ndarray, canonicalize_types
if canonicalize_types:
collapsed_val = np.asarray(
collapsed_val, dtypes.canonicalize_dtype(collapsed_val.dtype))
out = mhlo.BroadcastInDimOp(
out = hlo.BroadcastInDimOp(
ir.RankedTensorType.get(
val.shape, dtype_to_ir_type(collapsed_val.dtype)),
_numpy_array_constant(collapsed_val, canonicalize_types=False)[0],
@ -304,9 +302,9 @@ for t in device_array.device_array_types:
def _token_constant_handler(val, canonicalize_types):
if mlir_api_version < 40:
return [mhlo.CreateTokenOp(mhlo.TokenType.get()).result]
return [hlo.CreateTokenOp(hlo.TokenType.get()).result]
else:
return [mhlo.CreateTokenOp().result]
return [hlo.CreateTokenOp().result]
register_constant_handler(core.Token, _token_constant_handler)
# Source locations
@ -331,12 +329,12 @@ def _source_info_to_location(
# Translation rules
def make_ir_context() -> ir.Context:
"""Creates an MLIR context suitable for JAX IR."""
from jax._src.lib.mlir import dialects
context = ir.Context()
mhlo.register_mhlo_dialect(context)
chlo.register_dialect(context)
dialects.mhlo.register_mhlo_dialect(context)
dialects.chlo.register_dialect(context)
if mlir_api_version >= 37:
from jax._src.lib.mlir.dialects import stablehlo
stablehlo.register_dialect(context)
dialects.stablehlo.register_dialect(context)
return context
@ -581,18 +579,18 @@ class DimPolyEvaluator:
def __add__(self, other: Union[np.int32, DimPolyEvaluator]):
if not isinstance(other, DimPolyEvaluator):
other = DimPolyEvaluator(ir_constant(other))
return DimPolyEvaluator(mhlo.AddOp(self.value, other.value).result)
return DimPolyEvaluator(hlo.AddOp(self.value, other.value).result)
def __radd__(self, other: np.int32):
return DimPolyEvaluator(mhlo.AddOp(ir_constant(other), self.value).result)
return DimPolyEvaluator(hlo.AddOp(ir_constant(other), self.value).result)
def __mul__(self, other: Union[np.int32, DimPolyEvaluator]):
if not isinstance(other, DimPolyEvaluator):
other = DimPolyEvaluator(ir_constant(other))
return DimPolyEvaluator(mhlo.MulOp(self.value, other.value).result)
return DimPolyEvaluator(hlo.MulOp(self.value, other.value).result)
def __rmul__(self, other: np.int32):
return DimPolyEvaluator(mhlo.MulOp(ir_constant(other), self.value).result)
return DimPolyEvaluator(hlo.MulOp(ir_constant(other), self.value).result)
def eval_dynamic_shape(ctx: LoweringRuleContext,
@ -640,7 +638,7 @@ def lower_jaxpr_to_module(
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None
) -> LoweringResult:
"""Lowers a top-level jaxpr to an MHLO module.
"""Lowers a top-level jaxpr to an MLIR module.
Handles the quirks of the argument/return value passing conventions of the
runtime.
@ -678,7 +676,7 @@ def lower_jaxpr_to_module(
msg = f"Donation is not implemented for {platform}.\n{msg}"
warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
# MHLO channels need to start at 1
# HLO channels need to start at 1
channel_iter = itertools.count(1)
# Create a keepalives list that will be mutated during the lowering.
keepalives: List[Any] = []
@ -761,20 +759,20 @@ def _set_up_aliases(avals_in, avals_out, donated_args):
Token = Sequence[ir.Value]
def token_type() -> Sequence[ir.Type]:
return [mhlo.TokenType.get()]
return [hlo.TokenType.get()]
def create_token() -> Token:
if mlir_api_version < 40:
return wrap_singleton_ir_values(
mhlo.CreateTokenOp(mhlo.TokenType.get()).result)
hlo.CreateTokenOp(hlo.TokenType.get()).result)
else:
return wrap_singleton_ir_values(mhlo.CreateTokenOp().result)
return wrap_singleton_ir_values(hlo.CreateTokenOp().result)
class TokenSet:
"""An immutable container of tokens to be used to lower effectful jaxprs. When lowering
effectful jaxprs, we need to thread MHLO tokens to sequence them. Each effect
effectful jaxprs, we need to thread HLO tokens to sequence them. Each effect
will need its own token that will be threaded in and out of the effectful
primitives. A `TokenSet` encapsulates a set of MHLO tokens that will be
primitives. A `TokenSet` encapsulates a set of HLO tokens that will be
used by the lowering rules.
"""
_tokens: typing.OrderedDict[core.Effect, Token]
@ -850,18 +848,18 @@ def lower_jaxpr_to_fun(
jaxpr: the jaxpr to lower.
effects: a sequence of `core.Effect`s corresponding to an ordering of tokens
that will be created in or used by the lowered function.
create_tokens: if true, the MHLO will create tokens and ignore dummy input tokens.
create_tokens: if true, the HLO will create tokens and ignore dummy input tokens.
public: if true, the function's visibility is set to "public".
replace_tokens_with_dummy: if true, token arguments/return values are
replaced with bool arrays of size [0].
replicated_args: if present, annotates arguments as replicated.
arg_shardings: sharding annotations for each argument (optional).
result_shardings: sharding annotations for each argument (optional).
use_sharding_annotations: if True, use mhlo.sharding annotations on
use_sharding_annotations: if True, use "mhlo.sharding" annotations on
parameters and return values to express sharding. If False, use
mhlo.custom_call operators with sharding annotations.
TODO(b/228598865): remove this option when mhlo.sharding annotations are
propagated on non-entry functions during MHLO->HLO conversion.
hlo.custom_call operators with sharding annotations.
TODO(b/228598865): remove this option when "mhlo.sharding" annotations are
propagated on non-entry functions during MLIR->HLO conversion.
input_output_aliases: optional sequence that maps argument numbers to the
corresponding output that should alias them.
Returns the name of the function.
@ -988,9 +986,9 @@ def lower_jaxpr_to_fun(
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
if replace_tokens_with_dummy and aval is core.abstract_token:
if mlir_api_version < 40:
args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
args.append(hlo.CreateTokenOp(hlo.TokenType.get()).results)
else:
args.append(mhlo.CreateTokenOp().results)
args.append(hlo.CreateTokenOp().results)
else:
args.append(arg)
callee_name_stack = xla.extend_name_stack(ctx.name_stack,
@ -1061,7 +1059,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
*args: Sequence[ir.Value],
dim_var_values: Sequence[ir.Value]
) -> Tuple[Sequence[Sequence[ir.Value]], TokenSet]:
"""Lowers a jaxpr into mHLO, inlined into an existing function.
"""Lowers a jaxpr into MLIR, inlined into an existing function.
Assumes that an MLIR context, location, and insertion point are set.
@ -1269,13 +1267,13 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue,
broadcast_dimensions=broadcast_dimensions)
if not core.is_constant_shape(aval_out.shape): # type: ignore
shape = eval_dynamic_shape(ctx, aval_out.shape) # type: ignore
return mhlo.DynamicBroadcastInDimOp(
return hlo.DynamicBroadcastInDimOp(
aval_to_ir_type(aval_out), op,
shape_tensor(shape),
dense_int_elements(broadcast_dimensions),
).result
else:
return mhlo.BroadcastInDimOp(
return hlo.BroadcastInDimOp(
aval_to_ir_type(aval_out), op,
dense_int_elements(broadcast_dimensions)).result
@ -1302,12 +1300,12 @@ def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Va
aval_out, = aval_out.dtype._rules.physical_avals(aval_out) # type: ignore
if not core.is_constant_shape(aval_out.shape): # type: ignore
shape = eval_dynamic_shape(ctx, aval_out.shape) # type: ignore
return mhlo.DynamicReshapeOp(
return hlo.DynamicReshapeOp(
aval_to_ir_type(aval_out), op,
shape_tensor(shape),
).result
else:
return mhlo.ReshapeOp(aval_to_ir_type(aval_out), op).result
return hlo.ReshapeOp(aval_to_ir_type(aval_out), op).result
def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
start_indices, limit_indices, strides) -> ir.Value:
@ -1319,13 +1317,13 @@ def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
start_indices = eval_dynamic_shape(ctx, start_indices)
limit_indices = eval_dynamic_shape(ctx, limit_indices)
strides = eval_dynamic_shape(ctx, strides)
return mhlo.RealDynamicSliceOp(aval_to_ir_type(aval_out),
return hlo.RealDynamicSliceOp(aval_to_ir_type(aval_out),
x,
shape_tensor(start_indices),
shape_tensor(limit_indices),
shape_tensor(strides)).result
else:
return mhlo.SliceOp(x,
return hlo.SliceOp(x,
dense_int_elements(start_indices),
dense_int_elements(limit_indices),
dense_int_elements(strides)).result
@ -1338,14 +1336,14 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
slice_sizes = aval_out.shape
if not core.is_constant_shape(slice_sizes):
slice_sizes = eval_dynamic_shape(ctx, slice_sizes)
return mhlo.RealDynamicSliceOp(
return hlo.RealDynamicSliceOp(
aval_to_ir_type(aval_out), x,
shape_tensor(start_indices),
shape_tensor(slice_sizes),
shape_tensor([1] * len(slice_sizes))
).result
else:
return mhlo.DynamicSliceOp(x, start_indices,
return hlo.DynamicSliceOp(x, start_indices,
dense_int_elements(slice_sizes)).result
def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
@ -1356,10 +1354,10 @@ def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
# TODO(necula): handle dynamic shapes
if mlir_api_version < 40:
return mhlo.DynamicUpdateSliceOp(
return hlo.DynamicUpdateSliceOp(
aval_to_ir_type(aval_out), x, update, start_indices).result
else:
return mhlo.DynamicUpdateSliceOp(x, update, start_indices).result
return hlo.DynamicUpdateSliceOp(x, update, start_indices).result
def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> ir.Value:
"""Returns an IR constant shaped full of `value` shaped like `aval`."""
@ -1373,14 +1371,14 @@ def zeros_like_lowering(ctx, x):
register_lowering(ad_util.zeros_like_p, zeros_like_lowering)
def add_jaxvals_lowering(ctx, x, y):
return mhlo.AddOp(x, y).results
return hlo.AddOp(x, y).results
register_lowering(ad_util.add_jaxvals_p, add_jaxvals_lowering)
register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x])
def compare_mhlo(x, y, direction: str, comparison_type: Optional[str] = None):
"""Creates mhlo.CompareOp."""
def compare_hlo(x, y, direction: str, comparison_type: Optional[str] = None):
"""Creates CompareOp."""
if comparison_type is None:
elem_type = ir.RankedTensorType(x.type).element_type
if ir.IntegerType.isinstance(elem_type):
@ -1389,34 +1387,34 @@ def compare_mhlo(x, y, direction: str, comparison_type: Optional[str] = None):
else:
comparison_type = "FLOAT"
return mhlo.CompareOp(
return hlo.CompareOp(
x,
y,
mhlo.ComparisonDirectionAttr.get(direction),
compare_type=mhlo.ComparisonTypeAttr.get(comparison_type))
hlo.ComparisonDirectionAttr.get(direction),
compare_type=hlo.ComparisonTypeAttr.get(comparison_type))
def _minmax_mhlo(op, cmp, x, y):
def _minmax_hlo(op, cmp, x, y):
"""Min/max that compares complex values lexicographically as pairs."""
tensor_type = ir.RankedTensorType(x.type)
if ir.ComplexType.isinstance(tensor_type.element_type):
rx = mhlo.RealOp(x).result
ry = mhlo.RealOp(y).result
real_eq = compare_mhlo(rx, ry, "EQ", "FLOAT")
real_cmp = compare_mhlo(rx, ry, cmp, "FLOAT")
imag_cmp = compare_mhlo(
mhlo.ImagOp(x).result,
mhlo.ImagOp(y).result, cmp, "FLOAT")
which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
return mhlo.SelectOp(which, x, y)
rx = hlo.RealOp(x).result
ry = hlo.RealOp(y).result
real_eq = compare_hlo(rx, ry, "EQ", "FLOAT")
real_cmp = compare_hlo(rx, ry, cmp, "FLOAT")
imag_cmp = compare_hlo(
hlo.ImagOp(x).result,
hlo.ImagOp(y).result, cmp, "FLOAT")
which = hlo.SelectOp(real_eq, imag_cmp, real_cmp).result
return hlo.SelectOp(which, x, y)
else:
return op(x, y)
min_mhlo = partial(_minmax_mhlo, mhlo.MinOp, "LT")
max_mhlo = partial(_minmax_mhlo, mhlo.MaxOp, "GT")
min_hlo = partial(_minmax_hlo, hlo.MinOp, "LT")
max_hlo = partial(_minmax_hlo, hlo.MaxOp, "GT")
def convert_mhlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
"""Variant of convert that has XLA HLO semantics.
def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
"""Variant of convert that has HLO semantics.
In particular, treat casts to boolean as x != 0, rather than truncating
integer values (b/209440332)."""
@ -1428,9 +1426,9 @@ def convert_mhlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
compare_type = "SIGNED"
else:
compare_type = "UNSIGNED"
return compare_mhlo(x, full_like_aval(ctx, 0, aval_in), "NE",
return compare_hlo(x, full_like_aval(ctx, 0, aval_in), "NE",
compare_type).result
return mhlo.ConvertOp(aval_to_ir_type(aval_out), x).result
return hlo.ConvertOp(aval_to_ir_type(aval_out), x).result
def _wrap_with_spmd_op(name: str,
result_type: ir.Type,
@ -1444,7 +1442,7 @@ def _wrap_with_spmd_op(name: str,
[str(i) for i in sorted(unspecified_dims)]) + "]"
else:
backend_config = ""
op = mhlo.CustomCallOp([result_type], [x],
op = hlo.CustomCallOp([result_type], [x],
call_target_name=ir.StringAttr.get(name),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(backend_config),
@ -1489,7 +1487,7 @@ def cache_lowering(f):
except TypeError:
# If the parameters aren't hashable, give up on caching.
# TODO(phawkins): switch to requiring hashability, when XLA fallback
# computations have been ported to MHLO.
# computations have been ported to MLIR.
return f(ctx, *args, **params)
if func is None:
func = _emit_lowering_rule_as_fun(partial(f, **params), ctx)
@ -1506,12 +1504,12 @@ def cache_lowering(f):
def xla_computation_to_mhlo_module(xla_computation: xc.XlaComputation
def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation
) -> ir.Module:
module_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation)
return ir.Module.parse(module_str)
def merge_mhlo_modules(dst_module: ir.Module,
def merge_mlir_modules(dst_module: ir.Module,
sym_name: str,
src_module: ir.Module) -> str:
"""Returns the name of src_module's main() function, after renaming."""
@ -1563,8 +1561,8 @@ def xla_fallback_lowering(prim: core.Primitive):
xla_computation = xla.primitive_subcomputation(
module_ctx.platform, axis_env, prim, ctx.avals_in,
ctx.avals_out, **params)
xla_module = xla_computation_to_mhlo_module(xla_computation)
callee_name = merge_mhlo_modules(
xla_module = xla_computation_to_mlir_module(xla_computation)
callee_name = merge_mlir_modules(
module_ctx.module, f"xla_fallback_{prim.name}", xla_module)
output_types = map(aval_to_ir_types, ctx.avals_out)
flat_output_types = util.flatten(output_types)
@ -1576,7 +1574,7 @@ def xla_fallback_lowering(prim: core.Primitive):
flatten_lowering_ir_args(args)).result
if not prim.multiple_results:
return [call]
flat_results = [mhlo.GetTupleElementOp(call, i32_attr(i)).result
flat_results = [hlo.GetTupleElementOp(call, i32_attr(i)).result
for i in range(len(flat_output_types))]
return util.unflatten(flat_results, map(len, output_types))
@ -1611,15 +1609,15 @@ def _dtype_to_xla_type_string(dtype: np.dtype) -> str:
raise NotImplementedError(dtype)
return _dtype_to_xla_type_string_map[dtype]
def send_to_host(channel: int, token: mhlo.TokenType, operand: Any,
def send_to_host(channel: int, token: hlo.TokenType, operand: Any,
aval: core.ShapedArray, name: str, *,
sharding: Optional[xc.OpSharding] = None) -> ir.Value:
channel_handle = mhlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE)
channel_handle = hlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE)
if mlir_api_version < 40:
send_op = mhlo.SendOp(mhlo.TokenType.get(), [operand], token, channel_handle,
send_op = hlo.SendOp(hlo.TokenType.get(), [operand], token, channel_handle,
is_host_transfer=ir.BoolAttr.get(True))
else:
send_op = mhlo.SendOp([operand], token, channel_handle,
send_op = hlo.SendOp([operand], token, channel_handle,
is_host_transfer=ir.BoolAttr.get(True))
dtype_str = _dtype_to_xla_type_string(aval.dtype)
if dtype_str in {"f64", "s64", "u64", "c64", "c128"}:
@ -1634,12 +1632,12 @@ def send_to_host(channel: int, token: mhlo.TokenType, operand: Any,
return send_op.result
def receive_from_host(channel: int, token: mhlo.TokenType,
def receive_from_host(channel: int, token: hlo.TokenType,
out_aval: core.ShapedArray, name: str, *,
sharding: Optional[xc.OpSharding] = None) -> ir.Value:
channel_handle = mhlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE)
recv_op = mhlo.RecvOp([aval_to_ir_type(out_aval),
mhlo.TokenType.get()], token, channel_handle,
channel_handle = hlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE)
recv_op = hlo.RecvOp([aval_to_ir_type(out_aval),
hlo.TokenType.get()], token, channel_handle,
is_host_transfer=ir.BoolAttr.get(True))
dtype_str = _dtype_to_xla_type_string(out_aval.dtype)
if dtype_str in {"f64", "s64", "u64", "c64", "c128"}:
@ -1670,9 +1668,9 @@ def _emit_tpu_python_callback(
sharding: Optional[xc.OpSharding] = None
) -> Tuple[List[ir.Value], Any, Any]:
if mlir_api_version < 40:
token = token or mhlo.CreateTokenOp(mhlo.TokenType.get()).result
token = token or hlo.CreateTokenOp(hlo.TokenType.get()).result
else:
token = token or mhlo.CreateTokenOp().result
token = token or hlo.CreateTokenOp().result
_wrapped_callback = callback
send_channels = []
@ -1680,7 +1678,7 @@ def _emit_tpu_python_callback(
# If there are no operands to the callback, we need to insert a dummy send
# op or the callback will never be triggered!
# TODO(sharadmv,chky): Enable this fix in the runtime as opposed to in
# MHLO builder.
# MLIR builder.
callback_without_args = _wrapped_callback
def _wrapped_callback(*args): # pylint: disable=function-redefined
del args
@ -1761,7 +1759,7 @@ def emit_python_callback(
operand_layouts: Optional[Sequence[Optional[Sequence[int]]]] = None,
result_layouts: Optional[Sequence[Optional[Sequence[int]]]] = None,
) -> Tuple[List[ir.Value], Any, Any]:
"""Emits MHLO that calls back to a provided Python function."""
"""Emits MLIR that calls back to a provided Python function."""
platform = ctx.module_context.platform
if platform not in {"cpu", "cuda", "rocm", "tpu"}:
raise ValueError(
@ -1842,7 +1840,7 @@ def emit_python_callback(
result_type = ir.TupleType.get_tuple(result_types)
call_target_name = ("xla_python_gpu_callback"
if platform in {"cuda", "rocm"} else "xla_python_cpu_callback")
result = mhlo.CustomCallOp(
result = hlo.CustomCallOp(
[result_type],
callback_operands,
call_target_name=ir.StringAttr.get(call_target_name),
@ -1859,7 +1857,7 @@ def emit_python_callback(
if sharding is not None:
set_sharding(result, sharding)
results = [
mhlo.GetTupleElementOp(result, i32_attr(i)).result
hlo.GetTupleElementOp(result, i32_attr(i)).result
for i in range(len(result_types))
]
if token:

View File

@ -77,7 +77,7 @@ from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
new_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log,
@ -2211,14 +2211,14 @@ ad.call_transpose_param_updaters[xla_pmap_p] = \
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
def _unravel_index_mhlo(axis_env):
def _unravel_index_hlo(axis_env):
div = mlir.ir_constant(
np.array(axis_env.nreps // util.prod(axis_env.sizes), np.uint32))
mod = mlir.ir_constant(np.array(axis_env.sizes[-1], np.uint32))
return mhlo.RemOp(
mhlo.DivOp(mhlo.ReplicaIdOp().result, div).result, mod).result
return hlo.RemOp(
hlo.DivOp(hlo.ReplicaIdOp().result, div).result, mod).result
def _mhlo_shard(aval, axis_env, xs, in_axis):
def _hlo_shard(aval, axis_env, xs, in_axis):
if aval is core.abstract_token:
return xs
elif isinstance(aval, core.ShapedArray):
@ -2226,20 +2226,20 @@ def _mhlo_shard(aval, axis_env, xs, in_axis):
dims = list(aval.shape)
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
idxs = [zero] * len(dims)
idxs.insert(in_axis, _unravel_index_mhlo(axis_env))
idxs.insert(in_axis, _unravel_index_hlo(axis_env))
dims_unsqueezed = dims.copy()
dims_unsqueezed.insert(in_axis, 1)
dynamic_slice_result = mhlo.DynamicSliceOp(
dynamic_slice_result = hlo.DynamicSliceOp(
x, idxs, mlir.dense_int_elements(dims_unsqueezed)).result
return [
mhlo.ReshapeOp(mlir.aval_to_ir_type(aval), dynamic_slice_result).result
hlo.ReshapeOp(mlir.aval_to_ir_type(aval), dynamic_slice_result).result
]
else:
raise TypeError(aval)
# TODO(b/110096942): more efficient gather
def _mhlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, platform):
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, platform):
if aval is core.abstract_token:
return xs
elif isinstance(aval, core.ShapedArray):
@ -2249,23 +2249,23 @@ def _mhlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, p
and platform in ('cpu', 'gpu'))
if convert_bool:
aval = aval.update(dtype=np.dtype(np.float32))
x = mhlo.ConvertOp(mlir.aval_to_ir_type(aval), x).result
x = hlo.ConvertOp(mlir.aval_to_ir_type(aval), x).result
dims = list(aval.shape)
padded_aval = aval.update(shape=[axis_env.sizes[-1]] + dims)
padded = mlir.full_like_aval(ctx, 0, padded_aval)
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
idxs = [_unravel_index_mhlo(axis_env)] + [zero] * len(dims)
broadcast_result = mhlo.BroadcastOp(
idxs = [_unravel_index_hlo(axis_env)] + [zero] * len(dims)
broadcast_result = hlo.BroadcastOp(
x, mlir.dense_int_elements([1])).result
if xc.mlir_api_version < 40:
padded = mhlo.DynamicUpdateSliceOp(padded.type, padded, broadcast_result,
padded = hlo.DynamicUpdateSliceOp(padded.type, padded, broadcast_result,
idxs).result
else:
padded = mhlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result
padded = hlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result
replica_groups = mlir.dense_int_elements(
xla.axis_groups(axis_env, axis_env.names[-1]))
out = mhlo.CrossReplicaSumOp(padded, replica_groups).result
out = hlo.CrossReplicaSumOp(padded, replica_groups).result
if out_axis != 0:
# TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead
perm = list(range(1, len(dims)))
@ -2273,16 +2273,16 @@ def _mhlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, p
transposed_dims = list(dims)
transposed_dims.insert(out_axis, axis_env.sizes[-1])
aval = aval.update(shape=transposed_dims)
out = mhlo.TransposeOp(out, mlir.dense_int_elements(perm)).result
out = hlo.TransposeOp(out, mlir.dense_int_elements(perm)).result
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
if convert_bool:
float_zero = mlir.full_like_aval(ctx, 0, padded_aval)
out = mhlo.CompareOp(
out = hlo.CompareOp(
out,
float_zero,
mhlo.ComparisonDirectionAttr.get("NE"),
compare_type=mhlo.ComparisonTypeAttr.get("FLOAT")).result
hlo.ComparisonDirectionAttr.get("NE"),
compare_type=hlo.ComparisonTypeAttr.get("FLOAT")).result
return out
else:
raise TypeError(aval)
@ -2305,7 +2305,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
# Shard the in_nodes that are mapped
in_avals = [v.aval for v in call_jaxpr.invars]
in_nodes_sharded = (
_mhlo_shard(aval, new_env, mlir.wrap_singleton_ir_values(in_node), in_axis)
_hlo_shard(aval, new_env, mlir.wrap_singleton_ir_values(in_node), in_axis)
if in_axis is not None else mlir.wrap_singleton_ir_values(in_node)
for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes))
@ -2318,7 +2318,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
*in_nodes_sharded,
dim_var_values=ctx.dim_var_values)
out_avals = [v.aval for v in call_jaxpr.outvars]
outs = [_mhlo_unshard(ctx, aval, new_env, out_axis, shard,
outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard,
platform=ctx.module_context.platform)
for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)]
return outs

View File

@ -586,6 +586,6 @@ def lower_fun(fun: Callable, *, multiple_results: bool, backend=None,
def f(*args, **kw):
raise RuntimeError("XLA translation rules are deprecated and "
"jax.interpreters.xla.lower_fun is no longer supported. "
"Add an MLIR (MHLO) lowering via jax.interpreters.mlir "
"Add an MLIR lowering via jax.interpreters.mlir "
"instead.")
return f

View File

@ -34,9 +34,9 @@ py_library(
"gpu_rnn.py",
"gpu_solver.py",
"gpu_sparse.py",
"hlo_helpers.py",
"init.py",
"lapack.py",
"mhlo_helpers.py",
":version",
":xla_client",
],

View File

@ -15,10 +15,10 @@
from typing import List
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as mhlo
import jaxlib.mlir.dialects.mhlo as hlo
from .mhlo_helpers import custom_call
from .hlo_helpers import custom_call
from .cpu import _ducc_fft
import numpy as np
@ -107,7 +107,11 @@ def _ducc_fft_descriptor(shape: List[int], dtype, fft_type: FftType,
return descriptor, out_dtype, out_shape
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def ducc_fft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
return ducc_fft_hlo(a, dtype, fft_type=fft_type, fft_lengths=fft_lengths)
def ducc_fft_hlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
"""DUCC FFT kernel for CPU."""
a_type = ir.RankedTensorType(a.type)
n = len(a_type.shape)
@ -128,15 +132,15 @@ def ducc_fft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
raise ValueError(f"Unknown output type {out_dtype}")
if 0 in a_type.shape or 0 in out_shape:
zero = mhlo.ConstantOp(
zero = hlo.ConstantOp(
ir.DenseElementsAttr.get(
np.array(0, dtype=out_dtype), type=out_type))
return mhlo.BroadcastOp(
return hlo.BroadcastOp(
zero,
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result
u8_type = ir.IntegerType.get_unsigned(8)
descriptor = mhlo.ConstantOp(
descriptor = hlo.ConstantOp(
ir.DenseElementsAttr.get(
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
layout = tuple(range(n - 1, -1, -1))

View File

@ -18,7 +18,7 @@ import operator
import jaxlib.mlir.ir as ir
from .mhlo_helpers import custom_call
from .hlo_helpers import custom_call
from jaxlib import xla_client
@ -39,7 +39,7 @@ except ImportError:
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
def _lu_pivots_to_permutation_mhlo(platform, gpu_linalg, pivots, *, permutation_size):
def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_size):
"""Kernel for the transformation of pivots to permutations on GPU."""
typ = ir.RankedTensorType(pivots.type)
dims = typ.shape
@ -65,7 +65,7 @@ def _lu_pivots_to_permutation_mhlo(platform, gpu_linalg, pivots, *, permutation_
operand_layouts=[pivots_layout],
result_layouts=[permutations_layout])
cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_mhlo, "cu",
cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_hlo, "cu",
_cuda_linalg)
hip_lu_pivots_to_permutation = partial(
_lu_pivots_to_permutation_mhlo, "hip", _hip_linalg)
_lu_pivots_to_permutation_hlo, "hip", _hip_linalg)

View File

@ -22,7 +22,7 @@ import jaxlib.mlir.ir as ir
from jaxlib import xla_client
from .mhlo_helpers import custom_call
from .hlo_helpers import custom_call
try:
from .cuda import _prng as _cuda_prng

View File

@ -13,7 +13,7 @@
# limitations under the License.
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as mhlo
import jaxlib.mlir.dialects.mhlo as hlo
import numpy as np
@ -61,7 +61,7 @@ def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *,
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
out = hlo.CustomCallOp(
[
ir.TupleType.get_tuple([
output_type, h_0.type, c_0.type, workspace_type,
@ -76,13 +76,13 @@ def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *,
called_computations=ir.ArrayAttr.get([]),
)
return [
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
hlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(5)
]
def _mhlo_zeros_f32(shape):
return mhlo.ConstantOp(
def _hlo_zeros_f32(shape):
return hlo.ConstantOp(
ir.DenseElementsAttr.get(
np.zeros(shape, dtype=np.float32), type=ir.F32Type.get())).result
@ -102,8 +102,8 @@ def cudnn_rnn_bwd_lowering(ctx, dy, dhn, dcn, x, h0, c0, w, y, workspace,
reserve_space_shape[0])
i32_type = ir.IntegerType.get_signless(32)
zeroed_dw = _mhlo_zeros_f32(ctx.avals_out[3].shape)
out = mhlo.CustomCallOp(
zeroed_dw = _hlo_zeros_f32(ctx.avals_out[3].shape)
out = hlo.CustomCallOp(
[ir.TupleType.get_tuple([x.type, h0.type, c0.type, w.type])], [
dy, dhn, dcn, x, h0, c0, w, y, workspace, reserve_space, zeroed_dw,
seq_lengths
@ -114,12 +114,12 @@ def cudnn_rnn_bwd_lowering(ctx, dy, dhn, dcn, x, h0, c0, w, y, workspace,
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
output_operand_aliases=ir.ArrayAttr.get([
mhlo.OutputOperandAlias.get(
hlo.OutputOperandAlias.get(
output_tuple_indices=[3],
operand_index=10,
operand_tuple_indices=[])
]))
return [
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
hlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(4)
]

View File

@ -18,13 +18,13 @@ from functools import partial
import operator
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as mhlo
import jaxlib.mlir.dialects.mhlo as hlo
import numpy as np
from jaxlib import xla_client
from .mhlo_helpers import custom_call
from .hlo_helpers import custom_call
try:
from .cuda import _blas as _cublas
@ -63,7 +63,7 @@ def _real_type(dtype):
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
def _getrf_mhlo(platform, gpu_blas, gpu_solver, dtype, a):
def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a):
"""LU decomposition."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
@ -106,11 +106,11 @@ def _getrf_mhlo(platform, gpu_blas, gpu_solver, dtype, a):
operand_output_aliases={0: 0})
return out[:3]
cuda_getrf = partial(_getrf_mhlo, "cu", _cublas, _cusolver)
rocm_getrf = partial(_getrf_mhlo, "hip", _hipblas, _hipsolver)
cuda_getrf = partial(_getrf_hlo, "cu", _cublas, _cusolver)
rocm_getrf = partial(_getrf_hlo, "hip", _hipblas, _hipsolver)
def _geqrf_mhlo(platform, gpu_solver, dtype, a):
def _geqrf_hlo(platform, gpu_solver, dtype, a):
"""QR decomposition."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
@ -145,10 +145,10 @@ def _geqrf_mhlo(platform, gpu_solver, dtype, a):
operand_output_aliases={0: 0})
return out[:3]
cuda_geqrf = partial(_geqrf_mhlo, "cu", _cusolver)
rocm_geqrf = partial(_geqrf_mhlo, "hip", _hipsolver)
cuda_geqrf = partial(_geqrf_hlo, "cu", _cusolver)
rocm_geqrf = partial(_geqrf_hlo, "hip", _hipsolver)
def _geqrf_batched_mhlo(platform, gpu_blas, dtype, a):
def _geqrf_batched_hlo(platform, gpu_blas, dtype, a):
"""Batched QR decomposition."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
@ -183,11 +183,11 @@ def _geqrf_batched_mhlo(platform, gpu_blas, dtype, a):
)
return out[:2]
cuda_geqrf_batched = partial(_geqrf_batched_mhlo, "cu", _cublas)
rocm_geqrf_batched = partial(_geqrf_batched_mhlo, "hip", _hipblas)
cuda_geqrf_batched = partial(_geqrf_batched_hlo, "cu", _cublas)
rocm_geqrf_batched = partial(_geqrf_batched_hlo, "hip", _hipblas)
def _csrlsvqr_mhlo(platform, gpu_solver, dtype, data,
def _csrlsvqr_hlo(platform, gpu_solver, dtype, data,
indices, indptr, b, tol, reorder):
"""Sparse solver via QR decomposition. CUDA only."""
b_type = ir.RankedTensorType(b.type)
@ -209,10 +209,10 @@ def _csrlsvqr_mhlo(platform, gpu_solver, dtype, data,
)
return [out]
cuda_csrlsvqr = partial(_csrlsvqr_mhlo, "cu", _cusolver)
cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver)
def _orgqr_mhlo(platform, gpu_solver, dtype, a, tau):
def _orgqr_hlo(platform, gpu_solver, dtype, a, tau):
"""Product of elementary Householder reflections."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
@ -252,11 +252,11 @@ def _orgqr_mhlo(platform, gpu_solver, dtype, a, tau):
operand_output_aliases={0: 0})
return out[:2]
cuda_orgqr = partial(_orgqr_mhlo, "cu", _cusolver)
rocm_orgqr = partial(_orgqr_mhlo, "hip", _hipsolver)
cuda_orgqr = partial(_orgqr_hlo, "cu", _cusolver)
rocm_orgqr = partial(_orgqr_hlo, "hip", _hipsolver)
def _syevd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
lower=False):
"""Symmetric (Hermitian) eigendecomposition."""
a_type = ir.RankedTensorType(a.type)
@ -304,11 +304,11 @@ def _syevd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
operand_output_aliases={0: 0})
return out[:3]
cuda_syevd = partial(_syevd_mhlo, "cu", _cusolver, True)
rocm_syevd = partial(_syevd_mhlo, "hip", _hipsolver, True)
cuda_syevd = partial(_syevd_hlo, "cu", _cusolver, True)
rocm_syevd = partial(_syevd_hlo, "hip", _hipsolver, True)
def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
full_matrices=True, compute_uv=True):
"""Singular value decomposition."""
a_type = ir.RankedTensorType(a.type)
@ -358,18 +358,18 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
[0],
],
operand_output_aliases={0: 0})
vt = mhlo.TransposeOp(
vt = hlo.TransposeOp(
v,
ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result
if np.issubdtype(dtype, np.complexfloating):
vt = mhlo.ComplexOp(mhlo.RealOp(vt), mhlo.NegOp(mhlo.ImagOp(vt))).result
vt = hlo.ComplexOp(hlo.RealOp(vt), hlo.NegOp(hlo.ImagOp(vt))).result
if not full_matrices and not econ:
u = mhlo.SliceOp(
u = hlo.SliceOp(
u,
ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)),
ir.DenseIntElementsAttr.get(np.array(batch_dims + (m, min(m, n)))),
ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result
vt = mhlo.SliceOp(
vt = hlo.SliceOp(
vt,
ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)),
ir.DenseIntElementsAttr.get(np.array(batch_dims + (min(m, n), n))),
@ -430,11 +430,11 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
operand_output_aliases={0: 0})
return s, u, vt, info
cuda_gesvd = partial(_gesvd_mhlo, "cu", _cusolver, True)
rocm_gesvd = partial(_gesvd_mhlo, "hip", _hipsolver, False)
cuda_gesvd = partial(_gesvd_hlo, "cu", _cusolver, True)
rocm_gesvd = partial(_gesvd_hlo, "hip", _hipsolver, False)
def _sytrd_mhlo(platform, gpu_solver, dtype, a, *, lower):
def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower):
"""sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form."""
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
@ -490,20 +490,20 @@ def _sytrd_mhlo(platform, gpu_solver, dtype, a, *, lower):
if not lower and platform == "cu" and m > 1:
start = (0,) * len(batch_dims) + (0,)
end = batch_dims + (1,)
s = mhlo.SliceOp(e, intattr(start), intattr(end), intattr([1] * len(start)))
s = hlo.SliceOp(e, intattr(start), intattr(end), intattr([1] * len(start)))
s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type)
s = mhlo.BroadcastInDimOp(s_type, s, intattr(range(len(dims) - 1)))
s = hlo.BroadcastInDimOp(s_type, s, intattr(range(len(dims) - 1)))
# The diagonals are always real; convert to complex if needed.
s = mhlo.ConvertOp(
s = hlo.ConvertOp(
ir.RankedTensorType.get(s_type.shape, a_type.element_type), s)
offsets = tuple(mhlo.ConstantOp(intattr(i))
offsets = tuple(hlo.ConstantOp(intattr(i))
for i in ((0,) * len(batch_dims) + (0, 1)))
if xla_client.mlir_api_version < 40:
a = mhlo.DynamicUpdateSliceOp(a.type, a, s, offsets).result
a = hlo.DynamicUpdateSliceOp(a.type, a, s, offsets).result
else:
a = mhlo.DynamicUpdateSliceOp(a, s, offsets).result
a = hlo.DynamicUpdateSliceOp(a, s, offsets).result
return a, d, e, taus, info
cuda_sytrd = partial(_sytrd_mhlo, "cu", _cusolver)
rocm_sytrd = partial(_sytrd_mhlo, "hip", _hipsolver)
cuda_sytrd = partial(_sytrd_hlo, "cu", _cusolver)
rocm_sytrd = partial(_sytrd_hlo, "hip", _hipsolver)

View File

@ -23,7 +23,7 @@ import numpy as np
from jaxlib import xla_client
from .mhlo_helpers import custom_call
from .hlo_helpers import custom_call
try:
from .cuda import _sparse as _cusparse
@ -46,7 +46,7 @@ cuda_is_supported : bool = _cusparse and _cusparse.sparse_supported
rocm_is_supported : bool = _hipsparse and _hipsparse.sparse_supported
def _validate_csr_mhlo(data, indices, indptr, shape):
def _validate_csr_hlo(data, indices, indptr, shape):
data_type = ir.RankedTensorType(data.type)
indices_type = ir.RankedTensorType(indices.type)
indptr_type = ir.RankedTensorType(indptr.type)
@ -57,7 +57,7 @@ def _validate_csr_mhlo(data, indices, indptr, shape):
assert indptr_type.shape == [shape[0] + 1]
return data_type.element_type, indices_type.element_type, nnz
def _validate_coo_mhlo(data, row, col):
def _validate_coo_hlo(data, row, col):
data_type = ir.RankedTensorType(data.type)
row_type = ir.RankedTensorType(row.type)
col_type = ir.RankedTensorType(col.type)
@ -69,10 +69,10 @@ def _validate_coo_mhlo(data, row, col):
return data_type.element_type, row_type.element_type, nnz
def _csr_todense_mhlo(platform, gpu_sparse, data, indices, indptr, *, shape,
def _csr_todense_hlo(platform, gpu_sparse, data, indices, indptr, *, shape,
data_dtype, index_dtype):
"""CSR to dense matrix."""
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape)
rows, cols = shape
buffer_size, opaque = gpu_sparse.build_csr_todense_descriptor(
@ -91,11 +91,11 @@ def _csr_todense_mhlo(platform, gpu_sparse, data, indices, indptr, *, shape,
result_layouts=[[1, 0], [0]])
return out[0]
cuda_csr_todense = partial(_csr_todense_mhlo, "cu", _cusparse)
rocm_csr_todense = partial(_csr_todense_mhlo, "hip", _hipsparse)
cuda_csr_todense = partial(_csr_todense_hlo, "cu", _cusparse)
rocm_csr_todense = partial(_csr_todense_hlo, "hip", _hipsparse)
def _csr_fromdense_mhlo(platform, gpu_sparse, mat, *, nnz, index_dtype,
def _csr_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, index_dtype,
data_dtype, index_type):
"""CSR from dense matrix."""
mat_type = ir.RankedTensorType(mat.type)
@ -119,15 +119,15 @@ def _csr_fromdense_mhlo(platform, gpu_sparse, mat, *, nnz, index_dtype,
result_layouts=[[0]] * 4)
return out[:3]
cuda_csr_fromdense = partial(_csr_fromdense_mhlo, "cu", _cusparse)
rocm_csr_fromdense = partial(_csr_fromdense_mhlo, "hip", _hipsparse)
cuda_csr_fromdense = partial(_csr_fromdense_hlo, "cu", _cusparse)
rocm_csr_fromdense = partial(_csr_fromdense_hlo, "hip", _hipsparse)
def _csr_matvec_mhlo(platform, gpu_sparse, data, indices, indptr, x, *, shape,
def _csr_matvec_hlo(platform, gpu_sparse, data, indices, indptr, x, *, shape,
transpose=False, compute_dtype=None, compute_type=None,
data_dtype, index_dtype, x_dtype):
"""CSR matrix/vector multiply."""
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape)
rows, cols = shape
if compute_dtype is None:
@ -152,15 +152,15 @@ def _csr_matvec_mhlo(platform, gpu_sparse, data, indices, indptr, x, *, shape,
result_layouts=[[0]] * 2)
return out[0]
cuda_csr_matvec = partial(_csr_matvec_mhlo, "cu", _cusparse)
rocm_csr_matvec = partial(_csr_matvec_mhlo, "hip", _hipsparse)
cuda_csr_matvec = partial(_csr_matvec_hlo, "cu", _cusparse)
rocm_csr_matvec = partial(_csr_matvec_hlo, "hip", _hipsparse)
def _csr_matmat_mhlo(platform, gpu_sparse, data, indices, indptr, B, *, shape,
def _csr_matmat_hlo(platform, gpu_sparse, data, indices, indptr, B, *, shape,
transpose=False, compute_dtype=None, compute_type=None,
index_dtype, data_dtype, B_dtype):
"""CSR from dense matrix."""
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape)
rows, cols = shape
B_shape = ir.RankedTensorType(B.type).shape
_, Ccols = B_shape
@ -187,14 +187,14 @@ def _csr_matmat_mhlo(platform, gpu_sparse, data, indices, indptr, B, *, shape,
result_layouts=[[1, 0], [0]])
return out[0]
cuda_csr_matmat = partial(_csr_matmat_mhlo, "cu", _cusparse)
rocm_csr_matmat = partial(_csr_matmat_mhlo, "hip", _hipsparse)
cuda_csr_matmat = partial(_csr_matmat_hlo, "cu", _cusparse)
rocm_csr_matmat = partial(_csr_matmat_hlo, "hip", _hipsparse)
def _coo_todense_mhlo(platform, gpu_sparse, data, row, col, *, shape,
def _coo_todense_hlo(platform, gpu_sparse, data, row, col, *, shape,
data_dtype, index_dtype):
"""COO to dense matrix."""
data_type, _, nnz = _validate_coo_mhlo(data, row, col)
data_type, _, nnz = _validate_coo_hlo(data, row, col)
rows, cols = shape
buffer_size, opaque = gpu_sparse.build_coo_todense_descriptor(
@ -213,11 +213,11 @@ def _coo_todense_mhlo(platform, gpu_sparse, data, row, col, *, shape,
result_layouts=[[1, 0], [0]])
return out[0]
cuda_coo_todense = partial(_coo_todense_mhlo, "cu", _cusparse)
rocm_coo_todense = partial(_coo_todense_mhlo, "hip", _hipsparse)
cuda_coo_todense = partial(_coo_todense_hlo, "cu", _cusparse)
rocm_coo_todense = partial(_coo_todense_hlo, "hip", _hipsparse)
def _coo_fromdense_mhlo(platform, gpu_sparse, mat, *, nnz, data_dtype,
def _coo_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, data_dtype,
index_dtype, index_type):
"""COO from dense matrix."""
mat_type = ir.RankedTensorType(mat.type)
@ -241,15 +241,15 @@ def _coo_fromdense_mhlo(platform, gpu_sparse, mat, *, nnz, data_dtype,
result_layouts=[[0]] * 4)
return out[:3]
cuda_coo_fromdense = partial(_coo_fromdense_mhlo, "cu", _cusparse)
rocm_coo_fromdense = partial(_coo_fromdense_mhlo, "hip", _hipsparse)
cuda_coo_fromdense = partial(_coo_fromdense_hlo, "cu", _cusparse)
rocm_coo_fromdense = partial(_coo_fromdense_hlo, "hip", _hipsparse)
def _coo_matvec_mhlo(platform, gpu_sparse, data, row, col, x, *, shape,
def _coo_matvec_hlo(platform, gpu_sparse, data, row, col, x, *, shape,
transpose=False, compute_dtype=None, compute_type=None,
index_dtype, data_dtype, x_dtype):
"""COO matrix/vector multiply."""
data_type, _, nnz = _validate_coo_mhlo(data, row, col)
data_type, _, nnz = _validate_coo_hlo(data, row, col)
rows, cols = shape
if compute_dtype is None:
@ -274,15 +274,15 @@ def _coo_matvec_mhlo(platform, gpu_sparse, data, row, col, x, *, shape,
result_layouts=[[0]] * 2)
return out[0]
cuda_coo_matvec = partial(_coo_matvec_mhlo, "cu", _cusparse)
rocm_coo_matvec = partial(_coo_matvec_mhlo, "hip", _hipsparse)
cuda_coo_matvec = partial(_coo_matvec_hlo, "cu", _cusparse)
rocm_coo_matvec = partial(_coo_matvec_hlo, "hip", _hipsparse)
def _coo_matmat_mhlo(platform, gpu_sparse, data, row, col, B, *, shape,
def _coo_matmat_hlo(platform, gpu_sparse, data, row, col, B, *, shape,
transpose=False, compute_dtype=None, compute_type=None,
x_dtype, data_dtype, index_dtype):
"""COO from dense matrix."""
data_type, _, nnz = _validate_coo_mhlo(data, row, col)
data_type, _, nnz = _validate_coo_hlo(data, row, col)
is_batched_matmat = False
batch_count = 1
if len(shape) == 2:
@ -334,11 +334,11 @@ def _coo_matmat_mhlo(platform, gpu_sparse, data, row, col, B, *, shape,
result_layouts=[out_layout, [0]])
return out[0]
cuda_coo_matmat = partial(_coo_matmat_mhlo, "cu", _cusparse)
rocm_coo_matmat = partial(_coo_matmat_mhlo, "hip", _hipsparse)
cuda_coo_matmat = partial(_coo_matmat_hlo, "cu", _cusparse)
rocm_coo_matmat = partial(_coo_matmat_hlo, "hip", _hipsparse)
def _gtsv2_mhlo(platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t):
def _gtsv2_hlo(platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t):
"""Calls `cusparse<t>gtsv2(dl, d, du, B, m, n, ldb)`."""
f32 = (t == np.float32)
if f32:
@ -360,5 +360,5 @@ def _gtsv2_mhlo(platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t):
operand_output_aliases={3: 0})
return out[0]
cuda_gtsv2 = partial(_gtsv2_mhlo, "cu", _cusparse)
rocm_gtsv2 = partial(_gtsv2_mhlo, "hip", _hipsparse)
cuda_gtsv2 = partial(_gtsv2_hlo, "cu", _cusparse)
rocm_gtsv2 = partial(_gtsv2_hlo, "hip", _hipsparse)

View File

@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Helpers for building MHLO operators
# Helpers for building MLIR operators
from typing import Dict, Optional, Sequence, Union
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as mhlo
import jaxlib.mlir.dialects.mhlo as hlo
import numpy as np
@ -31,7 +31,7 @@ def custom_call(
api_version: int = 2,
operand_output_aliases: Dict[int, int] = {},
) -> Union[ir.Value, Sequence[ir.Value]]:
"""Less-verbose helper for building an MHLO custom call op.
"""Less-verbose helper for building a CustomCallOp.
Once https://github.com/llvm/llvm-project/issues/54932 is fixed, this helper
may be able to go away.
@ -42,7 +42,7 @@ def custom_call(
that must alias.
"""
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
out = hlo.CustomCallOp(
(out_types
if len(out_types) == 1 else [ir.TupleType.get_tuple(out_types)]),
operands,
@ -63,7 +63,7 @@ def custom_call(
type=ir.IndexType.get()) for l in result_layouts
]),
output_operand_aliases=ir.ArrayAttr.get([
mhlo.OutputOperandAlias.get(
hlo.OutputOperandAlias.get(
output_tuple_indices=[] if len(out_types) == 1 else [output],
operand_index=input,
operand_tuple_indices=[])
@ -73,6 +73,6 @@ def custom_call(
return out.result
else:
return [
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
hlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(len(out_types))
]

View File

@ -16,12 +16,12 @@
# via CustomCallWithLayout.
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as mhlo
import jaxlib.mlir.dialects.mhlo as hlo
import numpy as np
from jaxlib import xla_client
from .mhlo_helpers import custom_call
from .hlo_helpers import custom_call
from .cpu import _lapack
for _name, _value in _lapack.registrations().items():
@ -32,23 +32,29 @@ for _name, _value in _lapack.registrations().items():
_initialize = _lapack.initialize
def _mhlo_u8(x):
return mhlo.ConstantOp(
def _hlo_u8(x):
return hlo.ConstantOp(
ir.DenseElementsAttr.get(
np.array(x, dtype=np.uint8),
type=ir.IntegerType.get_unsigned(8))).result
def _mhlo_s32(x):
return mhlo.ConstantOp(
def _hlo_s32(x):
return hlo.ConstantOp(
ir.DenseElementsAttr.get(
np.array(x, dtype=np.int32),
type=ir.IntegerType.get_signless(32))).result
# TODO(phawkins): it would be nice to avoid duplicating code for each type.
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
conj_a=False, diag=False):
return trsm_hlo(dtype, alpha, a, b, left_side=left_side, lower=lower,
trans_a=trans_a, conj_a=conj_a, diag=diag)
# ?trsm(left_side, lower, trans_a, diag, m, n, alpha, a, b):
# triangular solve
def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
def trsm_hlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
conj_a=False, diag=False):
_initialize()
a_type = ir.RankedTensorType(a.type)
@ -87,9 +93,9 @@ def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
return custom_call(
fn,
[b.type],
[_mhlo_s32(int(left_side)), _mhlo_s32(int(lower)),
_mhlo_s32((2 if conj_a else 1) if trans_a else 0), _mhlo_s32(int(diag)),
_mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(num_b),
[_hlo_s32(int(left_side)), _hlo_s32(int(lower)),
_hlo_s32((2 if conj_a else 1) if trans_a else 0), _hlo_s32(int(diag)),
_hlo_s32(m), _hlo_s32(n), _hlo_s32(num_b),
alpha, a, b],
operand_layouts=[scalar_layout] * 8 + [layout] * 2,
result_layouts=[layout],
@ -99,7 +105,11 @@ def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
# # ?getrf: LU decomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def getrf_mhlo(dtype, a):
return getrf_hlo(dtype, a)
def getrf_hlo(dtype, a):
_initialize()
dims = ir.RankedTensorType(a.type).shape
assert len(dims) >= 2
@ -131,7 +141,7 @@ def getrf_mhlo(dtype, a):
ir.RankedTensorType.get(batch_dims + (min(m, n),), i32_type),
ir.RankedTensorType.get(batch_dims, i32_type),
],
[_mhlo_s32(int(b)), _mhlo_s32(m), _mhlo_s32(n), a],
[_hlo_s32(int(b)), _hlo_s32(m), _hlo_s32(n), a],
operand_layouts=[scalar_layout] * 3 + [layout],
result_layouts=[
layout,
@ -144,7 +154,11 @@ def getrf_mhlo(dtype, a):
# # ?geqrf: QR decomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def geqrf_mhlo(dtype, a):
return geqrf_hlo(dtype, a)
def geqrf_hlo(dtype, a):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
@ -182,7 +196,7 @@ def geqrf_mhlo(dtype, a):
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
[_mhlo_s32(int(b)), _mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(lwork), a],
[_hlo_s32(int(b)), _hlo_s32(m), _hlo_s32(n), _hlo_s32(lwork), a],
operand_layouts=[scalar_layout] * 4 + [layout],
result_layouts=[
layout,
@ -197,7 +211,11 @@ def geqrf_mhlo(dtype, a):
# # ?orgqr: product of elementary Householder reflectors:
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def orgqr_mhlo(dtype, a, tau):
return orgqr_hlo(dtype, a, tau)
def orgqr_hlo(dtype, a, tau):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
@ -238,8 +256,8 @@ def orgqr_mhlo(dtype, a, tau):
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
[_mhlo_s32(int(b)), _mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(k),
_mhlo_s32(lwork), a, tau],
[_hlo_s32(int(b)), _hlo_s32(m), _hlo_s32(n), _hlo_s32(k),
_hlo_s32(lwork), a, tau],
operand_layouts=[scalar_layout] * 5 + [
layout,
tuple(range(num_bd, -1, -1)),
@ -256,7 +274,11 @@ def orgqr_mhlo(dtype, a, tau):
# ?potrf: Cholesky decomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def potrf_mhlo(dtype, a, lower=False):
return potrf_hlo(dtype, a, lower=lower)
def potrf_hlo(dtype, a, lower=False):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
@ -286,7 +308,7 @@ def potrf_mhlo(dtype, a, lower=False):
fn,
[a.type,
ir.RankedTensorType.get(batch_dims, ir.IntegerType.get_signless(32))],
[_mhlo_s32(int(lower)), _mhlo_s32(b), _mhlo_s32(n), a],
[_hlo_s32(int(lower)), _hlo_s32(b), _hlo_s32(n), a],
operand_layouts=[scalar_layout] * 3 + [layout],
result_layouts=[layout, info_layout],
operand_output_aliases={3: 0},
@ -297,7 +319,11 @@ def potrf_mhlo(dtype, a, lower=False):
# # ?gesdd: Singular value decomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
return gesdd_hlo(dtype, a, full_matrices=full_matrices, compute_uv=compute_uv)
def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
@ -370,8 +396,8 @@ def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
a_type.element_type),
ir.RankedTensorType.get(batch_dims, i32_type),
] + workspace,
[_mhlo_s32(int(full_matrices)), _mhlo_s32(int(compute_uv)), _mhlo_s32(b),
_mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(lwork), a],
[_hlo_s32(int(full_matrices)), _hlo_s32(int(compute_uv)), _hlo_s32(b),
_hlo_s32(m), _hlo_s32(n), _hlo_s32(lwork), a],
operand_layouts=[scalar_layout] * 6 + [layout],
result_layouts=[
layout,
@ -387,7 +413,11 @@ def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
# # syevd: Symmetric eigendecomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def syevd_mhlo(dtype, a, lower=False):
return syevd_hlo(dtype, a, lower=lower)
def syevd_hlo(dtype, a, lower=False):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
@ -452,7 +482,7 @@ def syevd_mhlo(dtype, a, lower=False):
ir.RankedTensorType.get(batch_dims + (n,), eigvals_type),
ir.RankedTensorType.get(batch_dims, i32_type),
] + workspace,
[_mhlo_s32(1 if lower else 0), _mhlo_s32(b), _mhlo_s32(n), a],
[_hlo_s32(1 if lower else 0), _hlo_s32(b), _hlo_s32(n), a],
operand_layouts=[scalar_layout] * 3 + [layout],
result_layouts=[
layout,
@ -466,7 +496,11 @@ def syevd_mhlo(dtype, a, lower=False):
# # geev: Nonsymmetric eigendecomposition
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def geev_mhlo(dtype, a, jobvl=True, jobvr=True):
return geev_hlo(dtype, a, jobvl=jobvl, jobvr=jobvr)
def geev_hlo(dtype, a, jobvl=True, jobvr=True):
_initialize()
dims = ir.RankedTensorType(a.type).shape
assert len(dims) >= 2
@ -539,19 +573,23 @@ def geev_mhlo(dtype, a, jobvl=True, jobvr=True):
ir.RankedTensorType.get(dims, eigvecs_type),
ir.RankedTensorType.get(batch_dims, i32_type),
],
[_mhlo_s32(b), _mhlo_s32(n), _mhlo_u8(jobvl_c), _mhlo_u8(jobvr_c), a],
[_hlo_s32(b), _hlo_s32(n), _hlo_u8(jobvl_c), _hlo_u8(jobvr_c), a],
operand_layouts=[scalar_layout] * 4 + [layout],
result_layouts=(workspace_layouts + eigvals_layouts + [layout] * 2 +
[info_layout])
)
if real:
return (mhlo.ComplexOp(out[3], out[4]).result, out[5], out[6], out[7])
return (hlo.ComplexOp(out[3], out[4]).result, out[5], out[6], out[7])
else:
return out[2:6]
# # gees : Schur factorization
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
return gees_hlo(dtype, a, jobvs=jobvs, sort=sort, select=select)
def gees_hlo(dtype, a, jobvs=True, sort=False, select=None):
_initialize()
a_type = ir.RankedTensorType(a.type)
etype = a_type.element_type
@ -609,10 +647,10 @@ def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
ir.RankedTensorType.get(batch_dims, i32_type),
],
[
_mhlo_s32(b),
_mhlo_s32(n),
_mhlo_u8(np.uint8(jobvs)),
_mhlo_u8(np.uint8(sort)),
_hlo_s32(b),
_hlo_s32(n),
_hlo_u8(np.uint8(jobvs)),
_hlo_u8(np.uint8(sort)),
# TODO: figure out how to put the callable select function here
a
],
@ -630,8 +668,12 @@ def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
return (out[0], out[3], out[5])
# gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form.
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def gehrd_mhlo(dtype, a):
return gehrd_hlo(dtype, a)
# gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form.
def gehrd_hlo(dtype, a):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
@ -669,8 +711,8 @@ def gehrd_mhlo(dtype, a):
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
[_mhlo_s32(n), _mhlo_s32(1), _mhlo_s32(n), _mhlo_s32(n), _mhlo_s32(b),
_mhlo_s32(lwork), a],
[_hlo_s32(n), _hlo_s32(1), _hlo_s32(n), _hlo_s32(n), _hlo_s32(b),
_hlo_s32(lwork), a],
operand_layouts=[[]] * 6 + [layout],
result_layouts=[
layout,
@ -683,8 +725,12 @@ def gehrd_mhlo(dtype, a):
return out[:3]
# sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form.
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
def sytrd_mhlo(dtype, a, *, lower):
return sytrd_hlo(dtype, a, lower=lower)
# sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form.
def sytrd_hlo(dtype, a, *, lower):
_initialize()
a_type = ir.RankedTensorType(a.type)
dims = a_type.shape
@ -728,8 +774,8 @@ def sytrd_mhlo(dtype, a, *, lower):
ir.RankedTensorType.get(batch_dims, i32_type),
ir.RankedTensorType.get([lwork], a_type.element_type),
],
[_mhlo_s32(n), _mhlo_s32(1 if lower else 0), _mhlo_s32(max(1, n)),
_mhlo_s32(b), _mhlo_s32(lwork), a],
[_hlo_s32(n), _hlo_s32(1 if lower else 0), _hlo_s32(max(1, n)),
_hlo_s32(b), _hlo_s32(lwork), a],
operand_layouts=[[]] * 5 + [layout],
result_layouts=[
layout,

View File

@ -155,10 +155,10 @@ def _sp_indices_abstract_eval(mat):
# Note: cannot use lower_fun to define attribute access primitives
# because it leads to infinite recursion.
def _sp_indices_mhlo_lowering(ctx, data_and_indices):
def _sp_indices_hlo_lowering(ctx, data_and_indices):
return [data_and_indices[1]]
mlir.register_lowering(sp_indices_p, _sp_indices_mhlo_lowering)
mlir.register_lowering(sp_indices_p, _sp_indices_hlo_lowering)
sp_data_p = core.Primitive('sp_data')
@ -173,10 +173,10 @@ def _sp_data_abstract_eval(mat):
# Note: cannot use lower_fun to define attribute access primitives
# because it leads to infinite recursion.
def _sp_data_mhlo_lowering(ctx, data_and_indices):
def _sp_data_hlo_lowering(ctx, data_and_indices):
return [data_and_indices[0]]
mlir.register_lowering(sp_data_p, _sp_data_mhlo_lowering)
mlir.register_lowering(sp_data_p, _sp_data_hlo_lowering)
def identity(x):
return identity_p.bind(x)

View File

@ -767,7 +767,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=True)
# TODO(mattjj,dougalm,phawkins): debug iree failure, "failed to legalize
# operation 'mhlo.while' that was explicitly marked illegal"
# operation 'while' that was explicitly marked illegal"
@unittest.skip("revising slicing logic")
def test_scan_basic(self):
def cumsum(x):
@ -1275,8 +1275,8 @@ class DynamicShapeTest(jtu.JaxTestCase):
return x.sum()
f_lowered = f.lower(np.arange(3, dtype='int32'))
mhlo = f_lowered.compiler_ir('mhlo')
self.assertIn('tensor<?xi32>', str(mhlo))
mlir_str = f_lowered.compiler_ir()
self.assertIn('tensor<?xi32>', str(mlir_str))
def test_lower_abstracted_axes_shapedtypestruct(self):
@partial(jax.jit, abstracted_axes=('n',))
@ -1284,8 +1284,8 @@ class DynamicShapeTest(jtu.JaxTestCase):
return x.sum()
f_lowered = f.lower(jax.ShapeDtypeStruct((3,), np.int32))
mhlo = f_lowered.compiler_ir('mhlo')
self.assertIn('tensor<?xi32>', str(mhlo))
mlir_str = f_lowered.compiler_ir()
self.assertIn('tensor<?xi32>', str(mlir_str))
def test_vmap_abstracted_axis(self):
def foo(x, y):

View File

@ -1,6 +1,6 @@
This directory contains LLVM
[FileCheck](https://llvm.org/docs/CommandGuide/FileCheck.html) tests that verify
that JAX primitives can be lowered to MHLO.
that JAX primitives can be lowered to MLIR.
These tests are intended to be a quick and easy-to-understand way to catch
regressions from changes due the MLIR Python bindings and from changes to the

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Tests for lowering of array origami ops into MHLO.
# Tests for lowering of array origami ops into MLIR.
# RUN: %PYTHON %s | FileCheck %s
@ -32,62 +32,62 @@ jax.config.update("jax_enable_x64", True)
def main(_):
# CHECK-LABEL: TEST: concatenate bool[2,7] bool[2,5]
# CHECK: mhlo.concatenate
# CHECK: hlo.concatenate
# CHECK-SAME: tensor<2x12xi1>
print_ir([np.empty([2, 7], np.bool_), np.empty([2, 5], np.bool_)])(
partial(lax.concatenate, dimension=1))
# CHECK-LABEL: TEST: broadcast_in_dim bool[2,7]
# CHECK: mhlo.broadcast_in_dim
# CHECK: hlo.broadcast_in_dim
# CHECK-SAME: tensor<3x2x5x7x2xi1>
print_ir(np.empty([2, 7], np.bool_))(
partial(lax.broadcast_in_dim, shape=(3, 2, 5, 7, 2),
broadcast_dimensions=(1, 3)))
# CHECK-LABEL: TEST: iota
# CHECK: mhlo.iota
# CHECK: hlo.iota
# CHECK-SAME: tensor<10xf32>
print_ir()(partial(lax.iota, dtype=np.float32, size=10))
# CHECK-LABEL: TEST: pad int32[2,7]
# CHECK: mhlo.pad
# CHECK: hlo.pad
# CHECK-SAME: tensor<11x52xi32>
print_ir(np.empty([2, 7], np.int32))(
partial(lax.pad, padding_value=np.int32(7),
padding_config=((2, 3, 4), (4, 5, 6))))
# CHECK-LABEL: TEST: _reduce_sum int32[2,3,7]
# CHECK: mhlo.reduce
# CHECK: mhlo.add
# CHECK: hlo.reduce
# CHECK: hlo.add
# CHECK: tensor<3xi32>
print_ir(np.empty([2, 3, 7], np.int32))(
partial(lax_internal._reduce_sum, axes=(0, 2)))
# CHECK-LABEL: TEST: reshape int32[2,3,7]
# CHECK: mhlo.reshape
# CHECK: hlo.reshape
# CHECK-SAME: tensor<42xi32>
print_ir(np.empty([2, 3, 7], np.int32))(
partial(lax.reshape, new_sizes=(42,)))
# CHECK-LABEL: TEST: rev int32[2,7]
# CHECK: mhlo.rev
# CHECK: hlo.rev
# CHECK-SAME: tensor<2x7xi32>
print_ir(np.empty([2, 7], np.int32))(
partial(lax.rev, dimensions=(0, 1)))
# CHECK-LABEL: TEST: select bool[2,7] int32[2,7] int32[2,7]
# CHECK: mhlo.select
# CHECK: hlo.select
# CHECK-SAME: tensor<2x7xi1>, tensor<2x7xi32>
print_ir(np.empty([2, 7], np.bool_), np.empty([2, 7], np.int32),
np.empty([2, 7], np.int32))(lax.select)
# CHECK-LABEL: TEST: sort int32[2,7]
# CHECK: mhlo.sort
# CHECK: hlo.sort
# CHECK: tensor<2x7xi32>
print_ir(np.empty([2, 7], np.int32))(lax.sort)
# CHECK-LABEL: TEST: squeeze int32[2,1,7]
# CHECK: mhlo.reshape
# CHECK: hlo.reshape
# CHECK-SAME: tensor<2x7xi32>
print_ir(np.empty([2, 1, 7], np.int32))(
partial(lax.squeeze, dimensions=(1,)))
@ -98,7 +98,7 @@ def main(_):
print_ir(np.empty([2, 7], np.int32))(partial(lax.top_k, k=7))
# CHECK-LABEL: TEST: transpose int32[2,7]
# CHECK: mhlo.transpose
# CHECK: hlo.transpose
# CHECK-SAME: tensor<7x2xi32>
print_ir(np.empty([2, 7], np.int32))(
partial(lax.transpose, permutation=(1, 0)))

View File

@ -20,7 +20,7 @@ import numpy as np
def print_ir(*prototypes):
def lower(f):
"""Prints the MHLO IR that results from lowering `f`.
"""Prints the MLIR IR that results from lowering `f`.
The arguments to `f` are taken to be arrays shaped like `prototypes`."""
inputs = tree_util.tree_map(np.array, prototypes)
@ -29,5 +29,5 @@ def print_ir(*prototypes):
for x in flat_inputs])
name = f.func.__name__ if hasattr(f, "func") else f.__name__
print(f"\nTEST: {name} {shape_strs}")
print(jax.jit(f).lower(*inputs).compiler_ir(dialect="mhlo"))
print(jax.jit(f).lower(*inputs).compiler_ir())
return lower

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Tests for lowerings of elementwise ops to MHLO.
# Tests for lowerings of elementwise ops to MLIR.
# RUN: %PYTHON %s | FileCheck %s
@ -31,17 +31,17 @@ jax.config.update("jax_enable_x64", True)
def main(_):
# CHECK-LABEL: TEST: abs int32[]
# CHECK: mhlo.abs
# CHECK: hlo.abs
# CHECK-SAME: tensor<i32>
print_ir(np.int32(0))(lax.abs)
# CHECK-LABEL: TEST: add float32[] float32[]
# CHECK: mhlo.add
# CHECK: hlo.add
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.add)
# CHECK-LABEL: TEST: acos float32[]
# CHECK: mhlo.atan2
# CHECK: hlo.atan2
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1))(lax.acos)
@ -71,7 +71,7 @@ def main(_):
print_ir(np.float32(0))(lax.atanh)
# CHECK-LABEL: TEST: atan2 float64[] float64[]
# CHECK: mhlo.atan2
# CHECK: hlo.atan2
# CHECK-SAME: tensor<f64>
print_ir(np.float64(1), np.float64(2))(lax.atan2)
@ -91,93 +91,93 @@ def main(_):
print_ir(np.float32(0), np.float32(0), np.float32(0))(lax.betainc)
# CHECK-LABEL: TEST: bitcast_convert_type uint32[7]
# CHECK: mhlo.bitcast_convert
# CHECK: hlo.bitcast_convert
# CHECK-SAME: tensor<7xui32>
# CHECK-SAME: tensor<7xf32>
print_ir(np.empty((7,), np.uint32))(
partial(lax.bitcast_convert_type, new_dtype=np.float32))
# CHECK-LABEL: TEST: bitwise_and int32[] int32[]
# CHECK: mhlo.and
# CHECK: hlo.and
# CHECK-SAME: tensor<i32>
print_ir(np.int32(1), np.int32(2))(lax.bitwise_and)
# CHECK-LABEL: TEST: bitwise_and bool[] bool[]
# CHECK: mhlo.and
# CHECK: hlo.and
# CHECK-SAME: tensor<i1>
print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_and)
# CHECK-LABEL: TEST: bitwise_or int32[] int32[]
# CHECK: mhlo.or
# CHECK: hlo.or
# CHECK-SAME: tensor<i32>
print_ir(np.int32(1), np.int32(2))(lax.bitwise_or)
# CHECK-LABEL: TEST: bitwise_or bool[] bool[]
# CHECK: mhlo.or
# CHECK: hlo.or
# CHECK-SAME: tensor<i1>
print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_or)
# CHECK-LABEL: TEST: bitwise_xor int32[] int32[]
# CHECK: mhlo.xor
# CHECK: hlo.xor
# CHECK-SAME: tensor<i32>
print_ir(np.int32(1), np.int32(2))(lax.bitwise_xor)
# CHECK-LABEL: TEST: bitwise_xor bool[] bool[]
# CHECK: mhlo.xor
# CHECK: hlo.xor
# CHECK-SAME: tensor<i1>
print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_xor)
# CHECK-LABEL: TEST: cbrt bfloat16[]
# CHECK: mhlo.cbrt
# CHECK: hlo.cbrt
# CHECK-SAME: tensor<bf16>
print_ir(jnp.bfloat16(0))(lax.cbrt)
# CHECK-LABEL: TEST: clamp bfloat16[] bfloat16[] bfloat16[]
# CHECK: mhlo.clamp
# CHECK: hlo.clamp
# CHECK-SAME: tensor<bf16>
print_ir(jnp.bfloat16(0), jnp.bfloat16(0), jnp.bfloat16(0))(lax.clamp)
# CHECK-LABEL: TEST: ceil float16[7]
# CHECK: mhlo.ceil
# CHECK: hlo.ceil
# CHECK-SAME: tensor<7xf16>
print_ir(np.empty((7,), np.float16))(lax.ceil)
# CHECK-LABEL: TEST: convert_element_type float16[7]
# CHECK: mhlo.convert
# CHECK: hlo.convert
# CHECK-SAME: tensor<7xf16>
# CHECK-SAME: tensor<7xf32>
print_ir(np.empty((7,), np.float16))(
partial(lax.convert_element_type, new_dtype=np.float32))
# CHECK-LABEL: TEST: convert_element_type complex64[7]
# CHECK: mhlo.real
# CHECK: hlo.real
# CHECK-SAME: tensor<7xcomplex<f32>>
# CHECK-SAME: tensor<7xf32>
print_ir(np.empty((7,), np.complex64))(
partial(lax.convert_element_type, new_dtype=np.float32))
# CHECK-LABEL: TEST: convert_element_type float32[7]
# CHECK: mhlo.compare
# CHECK: hlo.compare
# CHECK-SAME: tensor<7xf32>
# CHECK-SAME: tensor<7xi1>
print_ir(np.empty((7,), np.float32))(
partial(lax.convert_element_type, new_dtype=np.bool_))
# CHECK-LABEL: TEST: clz uint32[]
# CHECK: mhlo.count_leading_zeros
# CHECK: hlo.count_leading_zeros
# CHECK-SAME: tensor<ui32>
print_ir(np.uint32(0))(lax.clz)
# CHECK-LABEL: TEST: conj complex64[]
# CHECK-DAG: mhlo.real
# CHECK-DAG: mhlo.imag
# CHECK-DAG: mhlo.neg
# CHECK-DAG: mhlo.complex
# CHECK-DAG: hlo.real
# CHECK-DAG: hlo.imag
# CHECK-DAG: hlo.neg
# CHECK-DAG: hlo.complex
# CHECK-SAME: tensor<complex<f32>>
print_ir(np.complex64(0))(lax.conj)
# CHECK-LABEL: TEST: cos float32[]
# CHECK: mhlo.cos
# CHECK: hlo.cos
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.cos)
@ -192,30 +192,30 @@ def main(_):
print_ir(np.float32(0))(lax.digamma)
# CHECK-LABEL: TEST: div float32[] float32[]
# CHECK: mhlo.div
# CHECK: hlo.div
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.div)
# CHECK-LABEL: TEST: eq float32[] float32[]
# CHECK: mhlo.compare EQ
# CHECK: hlo.compare EQ
# CHECK-SAME: FLOAT
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.eq)
# CHECK-LABEL: TEST: eq complex128[] complex128[]
# CHECK: mhlo.compare EQ
# CHECK: hlo.compare EQ
# CHECK-SAME: FLOAT
# CHECK-SAME: tensor<complex<f64>>
print_ir(np.complex128(1), np.complex128(2))(lax.eq)
# CHECK-LABEL: TEST: eq int64[] int64[]
# CHECK: mhlo.compare EQ
# CHECK: hlo.compare EQ
# CHECK-SAME: SIGNED
# CHECK-SAME: tensor<i64>
print_ir(np.int64(1), np.int64(2))(lax.eq)
# CHECK-LABEL: TEST: eq uint16[] uint16[]
# CHECK: mhlo.compare EQ
# CHECK: hlo.compare EQ
# CHECK-SAME: UNSIGNED
# CHECK-SAME: tensor<ui16>
print_ir(np.uint16(1), np.uint16(2))(lax.eq)
@ -236,28 +236,28 @@ def main(_):
print_ir(np.float32(0))(lax.erf_inv)
# CHECK-LABEL: TEST: exp float16[]
# CHECK: mhlo.exp
# CHECK: hlo.exp
# CHECK-SAME: tensor<f16>
print_ir(np.float16(0))(lax.exp)
# CHECK-LABEL: TEST: expm1 bfloat16[]
# CHECK: mhlo.exponential_minus_one
# CHECK: hlo.exponential_minus_one
# CHECK-SAME: tensor<bf16>
print_ir(jnp.bfloat16(0))(lax.expm1)
# CHECK-LABEL: TEST: floor bfloat16[2,3]
# CHECK: mhlo.floor
# CHECK: hlo.floor
# CHECK-SAME: tensor<2x3xbf16>
print_ir(np.empty((2, 3), jnp.bfloat16))(lax.floor)
# CHECK-LABEL: TEST: ge float32[] float32[]
# CHECK: mhlo.compare GE
# CHECK: hlo.compare GE
# CHECK-SAME: FLOAT
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.ge)
# CHECK-LABEL: TEST: gt float32[] float32[]
# CHECK: mhlo.compare GT
# CHECK: hlo.compare GT
# CHECK-SAME: FLOAT
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.gt)
@ -278,23 +278,23 @@ def main(_):
print_ir(np.float32(0), np.float32(0))(lax.igamma_grad_a)
# CHECK-LABEL: TEST: imag complex64[]
# CHECK: mhlo.imag
# CHECK: hlo.imag
# CHECK-SAME: tensor<complex<f32>>
print_ir(np.complex64(0))(lax.imag)
# CHECK-LABEL: TEST: integer_pow float32[]
# CHECK-DAG: mhlo.mul
# CHECK-DAG: hlo.mul
# CHECK-SAME: tensor<f32>
@print_ir(np.float32(1))
def integer_pow(x): return lax.integer_pow(x, 3)
# CHECK-LABEL: TEST: is_finite float64[]
# CHECK: mhlo.is_finite
# CHECK: hlo.is_finite
# CHECK-SAME: tensor<f64>
print_ir(np.float64(0))(lax.is_finite)
# CHECK-LABEL: TEST: le float32[] float32[]
# CHECK: mhlo.compare LE
# CHECK: hlo.compare LE
# CHECK-SAME: FLOAT
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.le)
@ -305,44 +305,44 @@ def main(_):
print_ir(np.float32(0))(lax.lgamma)
# CHECK-LABEL: TEST: log float32[]
# CHECK: mhlo.log
# CHECK: hlo.log
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.log)
# CHECK-LABEL: TEST: log1p float32[]
# CHECK: mhlo.log_plus_one
# CHECK: hlo.log_plus_one
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.log1p)
# CHECK-LABEL: TEST: lt float32[] float32[]
# CHECK: mhlo.compare LT
# CHECK: hlo.compare LT
# CHECK-SAME: FLOAT
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.lt)
# CHECK-LABEL: TEST: max float32[] float32[]
# CHECK: mhlo.max
# CHECK: hlo.max
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.max)
# CHECK-LABEL: TEST: min float32[] float32[]
# CHECK: mhlo.min
# CHECK: hlo.min
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.min)
# CHECK-LABEL: TEST: mul float32[] float32[]
# CHECK: mhlo.mul
# CHECK: hlo.mul
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.mul)
# CHECK-LABEL: TEST: ne float32[] float32[]
# CHECK: mhlo.compare NE
# CHECK: hlo.compare NE
# CHECK-SAME: FLOAT
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.ne)
# CHECK-LABEL: TEST: neg int64[]
# CHECK: mhlo.negate
# CHECK: hlo.negate
# CHECK-SAME: tensor<i64>
print_ir(np.int64(0))(lax.neg)
@ -352,22 +352,22 @@ def main(_):
print_ir(np.float32(0), np.float32(0))(lax.nextafter)
# CHECK-LABEL: TEST: bitwise_not int64[]
# CHECK: mhlo.not
# CHECK: hlo.not
# CHECK-SAME: tensor<i64>
print_ir(np.int64(0))(lax.bitwise_not)
# CHECK-LABEL: TEST: bitwise_not bool[]
# CHECK: mhlo.not
# CHECK: hlo.not
# CHECK-SAME: tensor<i1>
print_ir(np.bool_(0))(lax.bitwise_not)
# CHECK-LABEL: TEST: population_count uint32[]
# CHECK: mhlo.popcnt
# CHECK: hlo.popcnt
# CHECK-SAME: tensor<ui32>
print_ir(np.uint32(0))(lax.population_count)
# CHECK-LABEL: TEST: pow float32[] float32[]
# CHECK: mhlo.power
# CHECK: hlo.power
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.pow)
@ -377,59 +377,59 @@ def main(_):
print_ir(np.float32(0), np.float32(0))(lax.random_gamma_grad)
# CHECK-LABEL: TEST: real complex128[]
# CHECK: mhlo.real
# CHECK: hlo.real
# CHECK-SAME: tensor<complex<f64>>
print_ir(np.complex128(0))(lax.real)
# CHECK-LABEL: TEST: reduce_precision bfloat16[]
# CHECK: mhlo.reduce_precision
# CHECK: hlo.reduce_precision
# CHECK-SAME: tensor<bf16>
print_ir(jnp.bfloat16(0))(
partial(lax.reduce_precision, exponent_bits=2, mantissa_bits=2))
# CHECK-LABEL: TEST: rem float32[] float32[]
# CHECK: mhlo.rem
# CHECK: hlo.rem
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.rem)
# CHECK-LABEL: TEST: round float64[7,1]
# CHECK: mhlo.round
# CHECK: hlo.round
# CHECK-SAME: tensor<7x1xf64>
print_ir(np.empty((7,1), np.float64))(
partial(lax.round, rounding_method=lax.RoundingMethod.AWAY_FROM_ZERO))
# CHECK-LABEL: TEST: rsqrt complex64[]
# CHECK: mhlo.rsqrt
# CHECK: hlo.rsqrt
# CHECK-SAME: tensor<complex<f32>>
print_ir(jnp.complex64(0))(lax.rsqrt)
# CHECK-LABEL: TEST: shift_left uint32[] uint32[]
# CHECK: mhlo.shift_left
# CHECK: hlo.shift_left
# CHECK-SAME: tensor<ui32>
print_ir(np.uint32(0), np.uint32(0))(lax.shift_left)
# CHECK-LABEL: TEST: shift_right_arithmetic uint8[] uint8[]
# CHECK: mhlo.shift_right_arithmetic
# CHECK: hlo.shift_right_arithmetic
# CHECK-SAME: tensor<ui8>
print_ir(np.uint8(0), np.uint8(0))(lax.shift_right_arithmetic)
# CHECK-LABEL: TEST: shift_right_logical uint16[] uint16[]
# CHECK: mhlo.shift_right_logical
# CHECK: hlo.shift_right_logical
# CHECK-SAME: tensor<ui16>
print_ir(np.uint16(0), np.uint16(0))(lax.shift_right_logical)
# CHECK-LABEL: TEST: sign int64[]
# CHECK: mhlo.sign
# CHECK: hlo.sign
# CHECK-SAME: tensor<i64>
print_ir(np.int64(0))(lax.sign)
# CHECK-LABEL: TEST: sign uint32[]
# CHECK: mhlo.compare
# CHECK: hlo.compare
# CHECK-SAME: tensor<ui32>
print_ir(np.uint32(0))(lax.sign)
# CHECK-LABEL: TEST: sin float32[]
# CHECK: mhlo.sin
# CHECK: hlo.sin
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.sin)
@ -439,12 +439,12 @@ def main(_):
print_ir(np.float32(0))(lax.sinh)
# CHECK-LABEL: TEST: sub float32[] float32[]
# CHECK: mhlo.sub
# CHECK: hlo.sub
# CHECK-SAME: tensor<f32>
print_ir(np.float32(1), np.float32(2))(lax.sub)
# CHECK-LABEL: TEST: sqrt bfloat16[]
# CHECK: mhlo.sqrt
# CHECK: hlo.sqrt
# CHECK-SAME: tensor<bf16>
print_ir(jnp.bfloat16(0))(lax.sqrt)
@ -454,7 +454,7 @@ def main(_):
print_ir(np.float16(0))(lax.tan)
# CHECK-LABEL: TEST: tanh float32[]
# CHECK: mhlo.tanh
# CHECK: hlo.tanh
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.tanh)

View File

@ -30,77 +30,77 @@ jax.config.update("jax_enable_x64", True)
def main(_):
# CHECK-LABEL: TEST: bitwise_not bool[7]
# CHECK: mhlo.not
# CHECK: hlo.not
# CHECK-SAME: tensor<7xi1>
print_ir(np.empty([7], np.bool_))(lax.bitwise_not)
# CHECK-LABEL: TEST: neg int8[]
# CHECK: mhlo.negate
# CHECK: hlo.negate
# CHECK-SAME: tensor<i8>
print_ir(np.int8(0))(lax.neg)
# CHECK-LABEL: TEST: neg int16[0]
# CHECK: mhlo.negate
# CHECK: hlo.negate
# CHECK-SAME: tensor<0xi16>
print_ir(np.empty([0], np.int16))(lax.neg)
# CHECK-LABEL: TEST: neg int32[2,3]
# CHECK: mhlo.negate
# CHECK: hlo.negate
# CHECK-SAME: tensor<2x3xi32>
print_ir(np.empty([2, 3], np.int32))(lax.neg)
# CHECK-LABEL: TEST: neg int64[2,3,4]
# CHECK: mhlo.negate
# CHECK: hlo.negate
# CHECK-SAME: tensor<2x3x4xi64>
print_ir(np.empty([2,3,4], np.int64))(lax.neg)
# CHECK-LABEL: TEST: add uint8[4,0,1] uint8[4,0,1]
# CHECK: mhlo.add
# CHECK: hlo.add
# CHECK-SAME: tensor<4x0x1xui8>
print_ir(np.empty([4,0,1], np.uint8), np.empty([4,0,1], np.uint8))(lax.add)
# CHECK-LABEL: TEST: add uint16[] uint16[]
# CHECK: mhlo.add
# CHECK: hlo.add
# CHECK-SAME: tensor<ui16>
print_ir(np.uint16(0), np.uint16(0))(lax.add)
# CHECK-LABEL: TEST: add uint32[] uint32[]
# CHECK: mhlo.add
# CHECK: hlo.add
# CHECK-SAME: tensor<ui32>
print_ir(np.uint32(0), np.uint32(0))(lax.add)
# CHECK-LABEL: TEST: add uint64[] uint64[]
# CHECK: mhlo.add
# CHECK: hlo.add
# CHECK-SAME: tensor<ui64>
print_ir(np.uint64(0), np.uint64(0))(lax.add)
# CHECK-LABEL: TEST: sin float16[]
# CHECK: mhlo.sine
# CHECK: hlo.sine
# CHECK-SAME: tensor<f16>
print_ir(np.float16(0))(lax.sin)
# CHECK-LABEL: TEST: sin bfloat16[]
# CHECK: mhlo.sine
# CHECK: hlo.sine
# CHECK-SAME: tensor<bf16>
print_ir(jnp.bfloat16(0))(lax.sin)
# CHECK-LABEL: TEST: sin float32[]
# CHECK: mhlo.sine
# CHECK: hlo.sine
# CHECK-SAME: tensor<f32>
print_ir(np.float32(0))(lax.sin)
# CHECK-LABEL: TEST: sin float64[]
# CHECK: mhlo.sine
# CHECK: hlo.sine
# CHECK-SAME: tensor<f64>
print_ir(np.float64(0))(lax.sin)
# CHECK-LABEL: TEST: cos complex64[]
# CHECK: mhlo.cosine
# CHECK: hlo.cosine
# CHECK-SAME: tensor<complex<f32>>
print_ir(np.complex64(0))(lax.cos)
# CHECK-LABEL: TEST: cos complex128[]
# CHECK: mhlo.cosine
# CHECK: hlo.cosine
# CHECK-SAME: tensor<complex<f64>>
print_ir(np.complex128(0))(lax.cos)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Tests for lowering of array origami ops into MHLO.
# Tests for lowering of array origami ops into MLIR.
# RUN: %PYTHON %s | FileCheck %s
@ -65,7 +65,7 @@ def main(_):
with m1.context:
# Reparse m2 in m1's context.
m2_copy = ir.Module.parse(m2)
mlir.merge_mhlo_modules(m1, "m2_main_renamed", m2_copy)
mlir.merge_mlir_modules(m1, "m2_main_renamed", m2_copy)
print("\nTEST: merge_modules")
print(str(m1))

View File

@ -370,43 +370,43 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def f(x):
effect_p.bind(effect='foo')
return x + 1.
mhlo = f.lower(2.).compiler_ir()
main = mhlo.body.operations[0]
module = f.lower(2.).compiler_ir()
main = module.body.operations[0]
first_op = main.body.blocks[0].operations[0]
self.assertEqual(first_op.operation.name, "mhlo.create_token")
self.assertIn('hlo.create_token', first_op.operation.name)
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='foo2')
return x + 1.
mhlo = f.lower(2.).compiler_ir()
main = mhlo.body.operations[0]
module = f.lower(2.).compiler_ir()
main = module.body.operations[0]
first_op = main.body.blocks[0].operations[0]
self.assertEqual(first_op.operation.name, "mhlo.create_token")
self.assertIn('hlo.create_token', first_op.operation.name)
second_op = main.body.blocks[0].operations[1]
self.assertEqual(second_op.operation.name, "mhlo.create_token")
self.assertIn('hlo.create_token', second_op.operation.name)
@jax.jit
def f(x):
effect_p.bind(effect='foo')
return x + 1.
mhlo = f.lower(2.).compiler_ir()
main = mhlo.body.operations[0]
module = f.lower(2.).compiler_ir()
main = module.body.operations[0]
first_op = main.body.blocks[0].operations[0]
self.assertEqual(first_op.operation.name, "mhlo.create_token")
self.assertIn('hlo.create_token', first_op.operation.name)
@jax.jit
def f(x):
effect_p.bind(effect='foo')
effect_p.bind(effect='foo2')
return x + 1.
mhlo = f.lower(2.).compiler_ir()
main = mhlo.body.operations[0]
module = f.lower(2.).compiler_ir()
main = module.body.operations[0]
first_op = main.body.blocks[0].operations[0]
self.assertEqual(first_op.operation.name, "mhlo.create_token")
self.assertIn('hlo.create_token', first_op.operation.name)
second_op = main.body.blocks[0].operations[1]
self.assertEqual(second_op.operation.name, "mhlo.create_token")
self.assertIn('hlo.create_token', second_op.operation.name)
def test_nontrivial_lowering_with_ordered_effect_should_consume_token(self):
@ -416,19 +416,18 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def f(x):
effect_p.bind(effect='foo')
return x + 1.
mhlo = f.lower(2.).compiler_ir()
main = mhlo.body.operations[0]
module = f.lower(2.).compiler_ir()
main = module.body.operations[0]
first_op = main.body.blocks[0].operations[0]
self.assertEqual(first_op.operation.name, "mhlo.create_token")
self.assertIn('hlo.create_token', first_op.operation.name)
second_op = main.body.blocks[0].operations[1]
self.assertEqual(second_op.operation.name, "func.call")
self.assertEqual(str(second_op.attributes["callee"]), "@effect")
self.assertEqual(second_op.operands[0].owner, first_op)
func = mhlo.body.operations[1]
func = module.body.operations[1]
self.assertEqual(func.name.value, "effect")
self.assertEqual(str(func.type.inputs[0]), "!mhlo.token")
self.assertEqual(str(func.type.results[0]), "!mhlo.token")
self.assertIn('hlo.token', str(func.type.inputs[0]))
self.assertIn('hlo.token', str(func.type.results[0]))
def test_nontrivial_lowering_with_unordered_effect_should_consume_token(self):
@ -438,14 +437,13 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def f(x):
effect_p.bind(effect='bar')
return x + 1.
mhlo = f.lower(2.).compiler_ir()
main = mhlo.body.operations[0]
module = f.lower(2.).compiler_ir()
main = module.body.operations[0]
first_op = main.body.blocks[0].operations[0]
self.assertEqual(first_op.operation.name, "func.call")
self.assertEqual(str(first_op.attributes["callee"]), "@effect")
self.assertLen(list(first_op.operands), 0)
func = mhlo.body.operations[1]
func = module.body.operations[1]
self.assertEqual(func.name.value, "effect")
self.assertLen(list(func.type.inputs), 0)
self.assertLen(list(func.type.results), 0)
@ -455,13 +453,13 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def f(x):
effect_p.bind(effect='bar')
return x + 1.
mhlo = f.lower(1.).compiler_ir(dialect='mhlo')
input_types = mhlo.body.operations[0].type.inputs
module = f.lower(1.).compiler_ir()
input_types = module.body.operations[0].type.inputs
self.assertLen(list(input_types), 1)
self.assertEqual(str(input_types[0]), 'tensor<f32>')
# First output should be output token
result_types = mhlo.body.operations[0].type.results
result_types = module.body.operations[0].type.results
if not can_execute_with_token:
self.assertLen(list(result_types), 2)
self.assertEqual(str(result_types[0]), 'tensor<0xi1>')
@ -476,14 +474,14 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
def f(x):
effect_p.bind(effect='foo')
return x + 1.
mhlo = f.lower(1.).compiler_ir(dialect='mhlo')
input_types = mhlo.body.operations[0].type.inputs
module = f.lower(1.).compiler_ir()
input_types = module.body.operations[0].type.inputs
# First argument should be dummy token
self.assertLen(list(input_types), 2)
self.assertEqual(str(input_types[0]), 'tensor<0xi1>')
# First output should be dummy token
result_types = mhlo.body.operations[0].type.results
result_types = module.body.operations[0].type.results
self.assertLen(list(result_types), 2)
self.assertEqual(str(result_types[0]), 'tensor<0xi1>')
@ -493,15 +491,15 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
effect_p.bind(effect='foo')
effect_p.bind(effect='foo2')
return x + 1.
mhlo = f.lower(1.).compiler_ir(dialect='mhlo')
input_types = mhlo.body.operations[0].type.inputs
module = f.lower(1.).compiler_ir()
input_types = module.body.operations[0].type.inputs
# First two arguments should be dummy values
self.assertLen(list(input_types), 3)
self.assertEqual(str(input_types[0]), 'tensor<0xi1>')
self.assertEqual(str(input_types[1]), 'tensor<0xi1>')
# First two outputs should be dummy values
result_types = mhlo.body.operations[0].type.results
result_types = module.body.operations[0].type.results
self.assertLen(list(result_types), 3)
self.assertEqual(str(result_types[0]), 'tensor<0xi1>')
self.assertEqual(str(result_types[1]), 'tensor<0xi1>')

View File

@ -40,7 +40,7 @@ from jax.interpreters import mlir
from jax.interpreters import batching
from jax.interpreters import pxla
from jax._src import array
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src import dispatch
from jax._src import dtypes
from jax._src import test_util as jtu
@ -2867,7 +2867,7 @@ class FooTyRules:
start_indices = (*start_indices, 0)
limit_indices = (*limit_indices, 2)
strides = (*strides, 1)
return mhlo.SliceOp(x,
return hlo.SliceOp(x,
mlir.dense_int_elements(start_indices),
mlir.dense_int_elements(limit_indices),
mlir.dense_int_elements(strides)).result
@ -2877,7 +2877,7 @@ class FooTyRules:
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype)))
slice_sizes_ = mlir.dense_int_elements((*aval_out.shape, 2))
return mhlo.DynamicSliceOp(x, start_indices, slice_sizes_).result
return hlo.DynamicSliceOp(x, start_indices, slice_sizes_).result
@staticmethod
def dynamic_update_slice_mlir(ctx, aval_out, x, update, *start_indices):
@ -2885,22 +2885,22 @@ class FooTyRules:
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype)))
if xc.mlir_api_version < 40:
return mhlo.DynamicUpdateSliceOp(
return hlo.DynamicUpdateSliceOp(
mlir.aval_to_ir_type(aval_out), x, update, start_indices).result
else:
return mhlo.DynamicUpdateSliceOp(x, update, start_indices).result
return hlo.DynamicUpdateSliceOp(x, update, start_indices).result
@staticmethod
def broadcast_in_dim_mlir(ctx, aval_out, x, broadcast_dimensions):
broadcast_dimensions = [*broadcast_dimensions, aval_out.ndim]
return mhlo.BroadcastInDimOp(
return hlo.BroadcastInDimOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.dense_int_elements(broadcast_dimensions)).result
@staticmethod
def transpose_mlir(ctx, aval_out, x, *, permutation):
perm = [*permutation, len(permutation)]
return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).result
return hlo.TransposeOp(x, mlir.dense_int_elements(perm)).result
@staticmethod
def gather_mlir(ctx, avals_in, aval_out, x, indices, *,

View File

@ -501,9 +501,9 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertLen(actual[0]['a'].device_buffers, 4)
mhlo_str = str(f.lower(x).compiler_ir(dialect="mhlo"))
self.assertIn("unspecified_dims=[0]", mhlo_str)
self.assertIn("unspecified_dims=[1]", mhlo_str)
mlir_str = str(f.lower(x).compiler_ir())
self.assertIn("unspecified_dims=[0]", mlir_str)
self.assertIn("unspecified_dims=[1]", mlir_str)
@jtu.with_mesh([('x', 2), ('y', 2)])
def testShardingConstraintPyTreeVmapWithUnconstrainedDims(self):
@ -521,9 +521,9 @@ class PJitTest(jtu.BufferDonationTestCase):
v = np.arange(prod(shape)).reshape(shape)
x = [{'a': v, 'b': v * 2}, v * 3]
mhlo_str = str(f.lower(x).compiler_ir(dialect="mhlo"))
self.assertIn("unspecified_dims=[0,1]", mhlo_str)
self.assertIn("unspecified_dims=[0,2]", mhlo_str)
mlir_str = str(f.lower(x).compiler_ir())
self.assertIn("unspecified_dims=[0,1]", mlir_str)
self.assertIn("unspecified_dims=[0,2]", mlir_str)
def testCaching(self):
def f(x):