mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
(NFC) Prepare for migration from producing MHLO to producing StableHLO
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order to prepare for the upcoming migration. Unchanged occurrences: 1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo" argument value in Lowering.as_text and Lowering.compiler_ir. 2) Documentation (changelog, JEPs, IR examples, etc). 3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence, so both are necessary to disambiguate. PiperOrigin-RevId: 495771153
This commit is contained in:
parent
523c6f7a53
commit
b8ae8e3fa1
@ -169,7 +169,7 @@ def prepare_wheel(sources_path):
|
||||
copy_to_jaxlib("__main__/jaxlib/init.py", dst_filename="__init__.py")
|
||||
copy_to_jaxlib(f"__main__/jaxlib/cpu_feature_guard.{pyext}")
|
||||
copy_to_jaxlib("__main__/jaxlib/lapack.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/hlo_helpers.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/ducc_fft.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_prng.py")
|
||||
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")
|
||||
|
@ -34,7 +34,7 @@ from jax._src.api_util import flatten_fun, shaped_abstractify
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import convolution as lax_convolution
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
|
||||
safe_zip, merge_lists, weakref_lru_cache)
|
||||
@ -623,9 +623,9 @@ def _optimization_barrier_lowering_rule(ctx, *args):
|
||||
flat_barrier_types = util.flatten(barrier_types)
|
||||
flat_args = mlir.flatten_lowering_ir_args(args)
|
||||
if xc.mlir_api_version < 40:
|
||||
barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
|
||||
barrier_op = hlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
|
||||
else:
|
||||
barrier_op = mhlo.OptimizationBarrierOp(flat_args)
|
||||
barrier_op = hlo.OptimizationBarrierOp(flat_args)
|
||||
return util.unflatten(barrier_op.results, map(len, barrier_types))
|
||||
|
||||
def _optimization_barrier(arg):
|
||||
|
@ -1205,8 +1205,8 @@ unreachable_p.def_impl(unreachable_impl)
|
||||
|
||||
# Translation raises an exception
|
||||
# TODO(frostig,mattjj): We have no good way to translate a function
|
||||
# that errs. Since MHLO lowering over-approximates concrete evaluation,
|
||||
# we err on MHLO lowering for the time being.
|
||||
# that errs. Since MLIR lowering over-approximates concrete evaluation,
|
||||
# we err on MLIR lowering for the time being.
|
||||
mlir.register_lowering(unreachable_p, unreachable_impl)
|
||||
|
||||
# Abstract evaluation proceeds without issue, to allow for staging
|
||||
|
@ -38,7 +38,7 @@ from jax._src import util
|
||||
from jax._src.lax import control_flow as lcf
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
import jax.numpy as jnp
|
||||
|
||||
import numpy as np
|
||||
@ -335,7 +335,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
|
||||
# partitioner runs so we keep it alive by attaching it to the executable.
|
||||
ctx.module_context.add_keepalive(sharding_callback_info)
|
||||
|
||||
mhlo.CustomCallOp([value.type], [value],
|
||||
hlo.CustomCallOp([value.type], [value],
|
||||
call_target_name=ir.StringAttr.get(
|
||||
_INSPECT_SHARDING_CALL_NAME),
|
||||
has_side_effect=ir.BoolAttr.get(True),
|
||||
|
@ -67,7 +67,7 @@ FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string(
|
||||
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
|
||||
help="Path to which HLO/MHLO IR that is emitted by JAX as input to the "
|
||||
help="Path to which the IR that is emitted by JAX as input to the "
|
||||
"compiler should be dumped as text files. Optional. If omitted, JAX "
|
||||
"will not dump IR.")
|
||||
|
||||
|
@ -41,7 +41,7 @@ from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (safe_map, extend_name_stack, split_list,
|
||||
partition_list)
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
import numpy as np
|
||||
|
||||
from jax._src.lax.control_flow.common import (
|
||||
@ -806,10 +806,10 @@ def _cond_lowering(ctx, index, *args, branches, linear):
|
||||
*output_token_types, *map(mlir.aval_to_ir_types, ctx.avals_out)]
|
||||
flat_output_types = util.flatten(output_types)
|
||||
|
||||
# mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
|
||||
# CaseOp takes a single argument 'index' and the corresponding blocks
|
||||
# have no arguments; the computation within the block uses implicit
|
||||
# captures.
|
||||
case_op = mhlo.CaseOp(flat_output_types, index=index,
|
||||
case_op = hlo.CaseOp(flat_output_types, index=index,
|
||||
num_branches=len(branches))
|
||||
name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
|
||||
for i, jaxpr in enumerate(branches):
|
||||
@ -824,7 +824,7 @@ def _cond_lowering(ctx, index, *args, branches, linear):
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
|
||||
out_vals = [*out_tokens, *out_vals]
|
||||
mhlo.ReturnOp(util.flatten(out_vals))
|
||||
hlo.ReturnOp(util.flatten(out_vals))
|
||||
|
||||
tokens_and_outputs = util.unflatten(case_op.results, map(len, output_types))
|
||||
tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
|
||||
|
@ -43,7 +43,7 @@ from jax._src.lax import slicing
|
||||
from jax._src.lax import windowed_reductions
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy.ufuncs import logaddexp
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.util import (
|
||||
@ -146,7 +146,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
|
||||
output arrays. (None is actually an empty pytree.)
|
||||
|
||||
Also unlike that Python version, :func:`~scan` is a JAX primitive and is
|
||||
lowered to a single XLA While HLO. That makes it useful for reducing
|
||||
lowered to a single WhileOp. That makes it useful for reducing
|
||||
compilation times for JIT-compiled functions, since native Python
|
||||
loop constructs in an :func:`~jax.jit` function are unrolled, leading to large
|
||||
XLA computations.
|
||||
@ -1041,7 +1041,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
|
||||
return val
|
||||
|
||||
Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
|
||||
to a single XLA While HLO. That makes it useful for reducing compilation times
|
||||
to a single WhileOp. That makes it useful for reducing compilation times
|
||||
for jit-compiled functions, since native Python loop constructs in an ``@jit``
|
||||
function are unrolled, leading to large XLA computations.
|
||||
|
||||
@ -1420,7 +1420,7 @@ def _while_transpose_error(*_, **kwargs):
|
||||
# break
|
||||
# token, x = body(token, x)
|
||||
# ```
|
||||
# Unfortunately, with an MHLO while we can't (1) return multiple values
|
||||
# Unfortunately, with a WhileOp we can't (1) return multiple values
|
||||
# from a `cond` and (2) can't break a while loop. We thus adopt the
|
||||
# following rewrite strategy:
|
||||
# ```
|
||||
@ -1471,7 +1471,7 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
args = [*tokens, *args]
|
||||
|
||||
flat_args = mlir.flatten_lowering_ir_args(args)
|
||||
while_op = mhlo.WhileOp(flat_loop_carry_types, flat_args)
|
||||
while_op = hlo.WhileOp(flat_loop_carry_types, flat_args)
|
||||
|
||||
# Loop condition
|
||||
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
|
||||
@ -1498,12 +1498,12 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
tokens_in=mlir.TokenSet(),
|
||||
tokens_out=None)
|
||||
pred, = lax._unary_reduce_lower(
|
||||
mhlo.OrOp,
|
||||
hlo.OrOp,
|
||||
lambda dtype: np.array(False, dtype),
|
||||
pred_ctx,
|
||||
pred,
|
||||
axes=tuple(range(len(pred_aval.shape))))
|
||||
mhlo.ReturnOp([pred])
|
||||
hlo.ReturnOp([pred])
|
||||
|
||||
# Loop body
|
||||
body_block = while_op.regions[1].blocks.append(*flat_loop_carry_types)
|
||||
@ -1531,10 +1531,10 @@ def _while_lowering(ctx, *args, cond_jaxpr, body_jaxpr, cond_nconsts,
|
||||
_map(mlir.ir_constants, cond_jaxpr.consts),
|
||||
*(x + z), dim_var_values=ctx.dim_var_values)
|
||||
new_z = _map(
|
||||
partial(_pred_bcast_select_mhlo, ctx, pred_aval, body_pred), new_z, z,
|
||||
partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z,
|
||||
body_jaxpr.out_avals)
|
||||
|
||||
mhlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x),
|
||||
hlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x),
|
||||
*util.flatten(y), *util.flatten(new_z)])
|
||||
|
||||
outputs = util.unflatten(while_op.results, _map(len, loop_carry_types))
|
||||
@ -1566,16 +1566,16 @@ mlir.register_lowering(while_p, _while_lowering)
|
||||
core.custom_typechecks[while_p] = _while_typecheck
|
||||
|
||||
|
||||
def _pred_bcast_select_mhlo(ctx,
|
||||
def _pred_bcast_select_hlo(ctx,
|
||||
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
|
||||
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
|
||||
if x_y_aval is core.abstract_token:
|
||||
x, = xs
|
||||
y, = ys
|
||||
if xc.mlir_api_version < 40:
|
||||
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
|
||||
return [hlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
|
||||
else:
|
||||
return [mhlo.AfterAllOp([x, y]).result]
|
||||
return [hlo.AfterAllOp([x, y]).result]
|
||||
else:
|
||||
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
|
||||
x, = xs
|
||||
@ -1589,7 +1589,7 @@ def _pred_bcast_select_mhlo(ctx,
|
||||
x_y_shape = x_y_aval.shape
|
||||
bcast_pred = mlir.broadcast_in_dim(ctx, pred, core.DShapedArray(x_y_shape, np.dtype(np.bool_)),
|
||||
broadcast_dimensions=list(range(len(pred_aval.shape))))
|
||||
return mhlo.SelectOp(bcast_pred, x, y).results
|
||||
return hlo.SelectOp(bcast_pred, x, y).results
|
||||
|
||||
### fori_loop
|
||||
|
||||
|
@ -26,7 +26,7 @@ from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src import util
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
_max = builtins.max
|
||||
@ -688,7 +688,7 @@ def _conv_general_dilated_lower(
|
||||
return complex_conv(ctx, lhs, rhs)
|
||||
|
||||
lhs_spec, rhs_spec, out_spec = dimension_numbers
|
||||
dnums = mhlo.ConvDimensionNumbers.get(
|
||||
dnums = hlo.ConvDimensionNumbers.get(
|
||||
input_batch_dimension=lhs_spec[0],
|
||||
input_feature_dimension=lhs_spec[1],
|
||||
input_spatial_dimensions=list(lhs_spec[2:]),
|
||||
@ -703,7 +703,7 @@ def _conv_general_dilated_lower(
|
||||
padding = np.zeros((0, 2), dtype=np.int64)
|
||||
window_reversal = mlir.dense_bool_elements([False] * num_spatial_dims)
|
||||
return [
|
||||
mhlo.ConvolutionOp(
|
||||
hlo.ConvolutionOp(
|
||||
mlir.aval_to_ir_type(aval_out),
|
||||
lhs,
|
||||
rhs,
|
||||
|
@ -26,7 +26,7 @@ from jax._src.util import prod
|
||||
from jax import lax
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import ducc_fft
|
||||
from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact
|
||||
@ -104,7 +104,7 @@ def fft_abstract_eval(x, fft_type, fft_lengths):
|
||||
|
||||
def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
|
||||
return [
|
||||
mhlo.FftOp(x, mhlo.FftTypeAttr.get(fft_type.name),
|
||||
hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name),
|
||||
mlir.dense_int_elements(fft_lengths)).result
|
||||
]
|
||||
|
||||
@ -113,8 +113,12 @@ def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths):
|
||||
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
|
||||
raise NotImplementedError("Shape polymorphism for custom call is not implemented (fft); b/261671778")
|
||||
x_aval, = ctx.avals_in
|
||||
if xla_client.mlir_api_version < 41:
|
||||
return [ducc_fft.ducc_fft_mhlo(x, x_aval.dtype, fft_type=fft_type,
|
||||
fft_lengths=fft_lengths)]
|
||||
else:
|
||||
return [ducc_fft.ducc_fft_hlo(x, x_aval.dtype, fft_type=fft_type,
|
||||
fft_lengths=fft_lengths)]
|
||||
|
||||
def _naive_rfft(x, fft_lengths):
|
||||
y = fft(x, xla_client.FftType.FFT, fft_lengths)
|
||||
|
@ -58,7 +58,7 @@ from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lax.utils import (
|
||||
_input_dtype,
|
||||
standard_abstract_eval,
|
||||
@ -1635,10 +1635,10 @@ def _maybe_broadcast(target_shape, x):
|
||||
squeeze_shape = [x_shape[i] for i in dims]
|
||||
return broadcast_in_dim(reshape(x, squeeze_shape), target_shape, dims)
|
||||
|
||||
def broadcast_mhlo(
|
||||
def broadcast_hlo(
|
||||
aval_out: core.ShapedArray, avals: Sequence[core.ShapedArray],
|
||||
args: Sequence[ir.Value]) -> Sequence[ir.Value]:
|
||||
"""Broadcasts MHLO values with broadcast-compatible shapes to the same shape.
|
||||
"""Broadcasts HLO values with broadcast-compatible shapes to the same shape.
|
||||
"""
|
||||
out = []
|
||||
for aval, arg in zip(avals, args):
|
||||
@ -1647,24 +1647,23 @@ def broadcast_mhlo(
|
||||
dims = mlir.dense_int_elements(
|
||||
range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape)))
|
||||
if any(isinstance(d, ir.Value) for d in aval_out.shape):
|
||||
arg = mhlo.DynamicBroadcastInDimOp(
|
||||
arg = hlo.DynamicBroadcastInDimOp(
|
||||
mlir.aval_to_ir_type(aval_out), arg,
|
||||
mlir.shape_tensor(aval_out.shape), dims).result
|
||||
else:
|
||||
arg = mhlo.BroadcastInDimOp(
|
||||
arg = hlo.BroadcastInDimOp(
|
||||
mlir.aval_to_ir_type(aval.update(shape=aval_out.shape)), arg,
|
||||
dims).result
|
||||
out.append(arg)
|
||||
return out
|
||||
|
||||
def _nary_lower_mhlo(op: Callable, ctx,
|
||||
def _nary_lower_hlo(op: Callable, ctx,
|
||||
*args: Union[ir.Value, Sequence[ir.Value]],
|
||||
explicit_type=False, **params):
|
||||
"""Lowers an elementwise operator to its MHLO/CHLO equivalent.
|
||||
"""Lowers an elementwise operator to its MLIR equivalent.
|
||||
|
||||
Args:
|
||||
explicit_type: does the MHLO/CHLO operator require its output type to be
|
||||
provided?
|
||||
explicit_type: does the MLIR op require its output type to be provided?
|
||||
"""
|
||||
del params
|
||||
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
|
||||
@ -1696,76 +1695,76 @@ _ordered = _int | _float | _bool
|
||||
|
||||
neg_p = standard_unop(_num, 'neg')
|
||||
ad.deflinear2(neg_p, lambda t, operand: [neg(t)])
|
||||
mlir.register_lowering(neg_p, partial(_nary_lower_mhlo, mhlo.NegOp))
|
||||
mlir.register_lowering(neg_p, partial(_nary_lower_hlo, hlo.NegOp))
|
||||
|
||||
sign_p = standard_unop(_num, 'sign')
|
||||
ad.defjvp_zero(sign_p)
|
||||
|
||||
def _sign_lower_mhlo(ctx, x):
|
||||
def _sign_lower_hlo(ctx, x):
|
||||
x_aval, = ctx.avals_in
|
||||
if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger):
|
||||
return mhlo.SelectOp(
|
||||
mlir.compare_mhlo(x, mlir.full_like_aval(ctx, 0, x_aval), 'EQ',
|
||||
return hlo.SelectOp(
|
||||
mlir.compare_hlo(x, mlir.full_like_aval(ctx, 0, x_aval), 'EQ',
|
||||
'UNSIGNED').result,
|
||||
mlir.full_like_aval(ctx, 0, x_aval),
|
||||
mlir.full_like_aval(ctx, 1, x_aval)).results
|
||||
return mhlo.SignOp(x).results
|
||||
return hlo.SignOp(x).results
|
||||
|
||||
mlir.register_lowering(sign_p, _sign_lower_mhlo)
|
||||
mlir.register_lowering(sign_p, _sign_lower_hlo)
|
||||
|
||||
nextafter_p = standard_naryop([_float, _float], 'nextafter')
|
||||
mlir.register_lowering(nextafter_p, partial(_nary_lower_mhlo, chlo.NextAfterOp))
|
||||
mlir.register_lowering(nextafter_p, partial(_nary_lower_hlo, chlo.NextAfterOp))
|
||||
|
||||
floor_p = standard_unop(_float, 'floor')
|
||||
ad.defjvp_zero(floor_p)
|
||||
mlir.register_lowering(floor_p, partial(_nary_lower_mhlo, mhlo.FloorOp))
|
||||
mlir.register_lowering(floor_p, partial(_nary_lower_hlo, hlo.FloorOp))
|
||||
|
||||
ceil_p = standard_unop(_float, 'ceil')
|
||||
ad.defjvp_zero(ceil_p)
|
||||
mlir.register_lowering(ceil_p, partial(_nary_lower_mhlo, mhlo.CeilOp))
|
||||
mlir.register_lowering(ceil_p, partial(_nary_lower_hlo, hlo.CeilOp))
|
||||
|
||||
round_p = standard_unop(_float, 'round')
|
||||
ad.defjvp_zero(round_p)
|
||||
|
||||
def _round_lower(ctx, x, *, rounding_method):
|
||||
if rounding_method is RoundingMethod.AWAY_FROM_ZERO:
|
||||
return mhlo.RoundOp(x).results
|
||||
return hlo.RoundOp(x).results
|
||||
else:
|
||||
assert rounding_method is RoundingMethod.TO_NEAREST_EVEN
|
||||
return mhlo.RoundNearestEvenOp(x).results
|
||||
return hlo.RoundNearestEvenOp(x).results
|
||||
mlir.register_lowering(round_p, _round_lower)
|
||||
|
||||
is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite')
|
||||
ad.defjvp_zero(is_finite_p)
|
||||
mlir.register_lowering(is_finite_p, partial(_nary_lower_mhlo, mhlo.IsFiniteOp))
|
||||
mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.IsFiniteOp))
|
||||
|
||||
exp_p = standard_unop(_float | _complex, 'exp')
|
||||
ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans))
|
||||
# For exp_p it is more efficient to use the reconstructed output for the vjp
|
||||
# rule instead of computing it again from the input.
|
||||
mlir.register_lowering(exp_p, partial(_nary_lower_mhlo, mhlo.ExpOp))
|
||||
mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.ExpOp))
|
||||
|
||||
log_p = standard_unop(_float | _complex, 'log')
|
||||
ad.defjvp(log_p, lambda g, x: div(g, x))
|
||||
mlir.register_lowering(log_p, partial(_nary_lower_mhlo, mhlo.LogOp))
|
||||
mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.LogOp))
|
||||
|
||||
expm1_p = standard_unop(_float | _complex, 'expm1')
|
||||
ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans))))
|
||||
mlir.register_lowering(expm1_p, partial(_nary_lower_mhlo, mhlo.Expm1Op))
|
||||
mlir.register_lowering(expm1_p, partial(_nary_lower_hlo, hlo.Expm1Op))
|
||||
|
||||
log1p_p = standard_unop(_float | _complex, 'log1p')
|
||||
ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x))))
|
||||
mlir.register_lowering(log1p_p, partial(_nary_lower_mhlo, mhlo.Log1pOp))
|
||||
mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.Log1pOp))
|
||||
|
||||
tanh_p = standard_unop(_float | _complex, 'tanh')
|
||||
ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)),
|
||||
sub(_one(x), ans)))
|
||||
mlir.register_lowering(tanh_p, partial(_nary_lower_mhlo, mhlo.TanhOp))
|
||||
mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.TanhOp))
|
||||
|
||||
logistic_p = standard_unop(_float | _complex, 'logistic')
|
||||
ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans))))
|
||||
# TODO(phawkins): switch to mhlo.logistic lowering; debug numerical problems.
|
||||
# mlir.register_lowering(logistic_p, partial(_nary_lower_mhlo, mhlo.LogisticOp))
|
||||
# TODO(phawkins): switch to LogisticOp lowering; debug numerical problems.
|
||||
# mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.LogisticOp))
|
||||
|
||||
def logistic_impl(x):
|
||||
one = _const(x, 1)
|
||||
@ -1776,11 +1775,11 @@ mlir.register_lowering(logistic_p,
|
||||
|
||||
sin_p = standard_unop(_float | _complex, 'sin')
|
||||
ad.defjvp(sin_p, lambda g, x: mul(g, cos(x)))
|
||||
mlir.register_lowering(sin_p, partial(_nary_lower_mhlo, mhlo.SineOp))
|
||||
mlir.register_lowering(sin_p, partial(_nary_lower_hlo, hlo.SineOp))
|
||||
|
||||
cos_p = standard_unop(_float | _complex, 'cos')
|
||||
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
|
||||
mlir.register_lowering(cos_p, partial(_nary_lower_mhlo, mhlo.CosineOp))
|
||||
mlir.register_lowering(cos_p, partial(_nary_lower_hlo, hlo.CosineOp))
|
||||
|
||||
@_upcast_fp16_for_computation
|
||||
def _tan_impl(x):
|
||||
@ -1788,7 +1787,7 @@ def _tan_impl(x):
|
||||
|
||||
tan_p = standard_unop(_float | _complex, 'tan')
|
||||
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
|
||||
mlir.register_lowering(tan_p, partial(_nary_lower_mhlo, chlo.TanOp))
|
||||
mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.TanOp))
|
||||
|
||||
def asin_impl(x):
|
||||
if dtypes.issubdtype(_dtype(x), np.complexfloating):
|
||||
@ -1799,7 +1798,7 @@ def asin_impl(x):
|
||||
|
||||
asin_p = standard_unop(_float | _complex, 'asin')
|
||||
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x))))
|
||||
mlir.register_lowering(asin_p, partial(_nary_lower_mhlo, chlo.AsinOp))
|
||||
mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.AsinOp))
|
||||
|
||||
def acos_impl(x):
|
||||
if dtypes.issubdtype(_dtype(x), np.complexfloating):
|
||||
@ -1828,35 +1827,35 @@ def atan_impl(x):
|
||||
|
||||
atan_p = standard_unop(_float | _complex, 'atan')
|
||||
ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x)))
|
||||
mlir.register_lowering(atan_p, partial(_nary_lower_mhlo, chlo.AtanOp))
|
||||
mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.AtanOp))
|
||||
|
||||
atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2')
|
||||
ad.defjvp(atan2_p,
|
||||
lambda g, x, y: g * (y / (square(x) + square(y))),
|
||||
lambda g, x, y: g * -x / (square(x) + square(y)))
|
||||
mlir.register_lowering(atan2_p, partial(_nary_lower_mhlo, mhlo.Atan2Op))
|
||||
mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.Atan2Op))
|
||||
|
||||
sinh_p = standard_unop(_float | _complex, 'sinh')
|
||||
ad.defjvp(sinh_p, lambda g, x: mul(g, cosh(x)))
|
||||
mlir.register_lowering(sinh_p, partial(_nary_lower_mhlo, chlo.SinhOp))
|
||||
mlir.register_lowering(sinh_p, partial(_nary_lower_hlo, chlo.SinhOp))
|
||||
|
||||
cosh_p = standard_unop(_float | _complex, 'cosh')
|
||||
ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
|
||||
mlir.register_lowering(cosh_p, partial(_nary_lower_mhlo, chlo.CoshOp))
|
||||
mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.CoshOp))
|
||||
|
||||
asinh_p = standard_unop(_float | _complex, 'asinh')
|
||||
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x))))
|
||||
mlir.register_lowering(asinh_p, partial(_nary_lower_mhlo, chlo.AsinhOp))
|
||||
mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.AsinhOp))
|
||||
|
||||
acosh_p = standard_unop(_float | _complex, 'acosh')
|
||||
ad.defjvp(acosh_p,
|
||||
lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x)))))
|
||||
mlir.register_lowering(acosh_p, partial(_nary_lower_mhlo, chlo.AcoshOp))
|
||||
mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.AcoshOp))
|
||||
|
||||
atanh_p = standard_unop(_float | _complex, 'atanh')
|
||||
ad.defjvp(atanh_p,
|
||||
lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x))))
|
||||
mlir.register_lowering(atanh_p, partial(_nary_lower_mhlo, chlo.AtanhOp))
|
||||
mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.AtanhOp))
|
||||
|
||||
regularized_incomplete_beta_p = standard_naryop(
|
||||
[_float, _float, _float], 'regularized_incomplete_beta')
|
||||
@ -1880,10 +1879,10 @@ ad.defjvp(regularized_incomplete_beta_p,
|
||||
|
||||
lgamma_p = standard_unop(_float, 'lgamma')
|
||||
ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x)))
|
||||
mlir.register_lowering(lgamma_p, partial(_nary_lower_mhlo, chlo.LgammaOp))
|
||||
mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.LgammaOp))
|
||||
|
||||
digamma_p = standard_unop(_float, 'digamma')
|
||||
mlir.register_lowering(digamma_p, partial(_nary_lower_mhlo, chlo.DigammaOp))
|
||||
mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.DigammaOp))
|
||||
|
||||
igamma_p = standard_naryop([_float, _float], 'igamma')
|
||||
xla.register_translation(igamma_p, partial(_broadcast_translate, xops.Igamma))
|
||||
@ -1919,7 +1918,7 @@ ad.defjvp2(bessel_i0e_p, lambda g, y, x: g * (bessel_i1e(x) - sign(x) * y))
|
||||
|
||||
bessel_i1e_p = standard_unop(_float, 'bessel_i1e')
|
||||
mlir.register_lowering(bessel_i1e_p,
|
||||
partial(_nary_lower_mhlo, chlo.BesselI1eOp))
|
||||
partial(_nary_lower_hlo, chlo.BesselI1eOp))
|
||||
|
||||
def _bessel_i1e_jvp(g, y, x):
|
||||
eps = dtypes.finfo(_dtype(x)).eps
|
||||
@ -1933,12 +1932,12 @@ ad.defjvp2(bessel_i1e_p, _bessel_i1e_jvp)
|
||||
erf_p = standard_unop(_float, 'erf')
|
||||
ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
|
||||
mul(g, exp(neg(square(x))))))
|
||||
mlir.register_lowering(erf_p, partial(_nary_lower_mhlo, chlo.ErfOp))
|
||||
mlir.register_lowering(erf_p, partial(_nary_lower_hlo, chlo.ErfOp))
|
||||
|
||||
erfc_p = standard_unop(_float, 'erfc')
|
||||
ad.defjvp(erfc_p, lambda g, x: mul(_const(x, -2. / np.sqrt(np.pi)),
|
||||
mul(g, exp(neg(square(x))))))
|
||||
mlir.register_lowering(erfc_p, partial(_nary_lower_mhlo, chlo.ErfcOp))
|
||||
mlir.register_lowering(erfc_p, partial(_nary_lower_hlo, chlo.ErfcOp))
|
||||
|
||||
erf_inv_p = standard_unop(_float, 'erf_inv')
|
||||
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),
|
||||
@ -1947,11 +1946,11 @@ xla.register_translation(erf_inv_p, standard_translate(erf_inv_p))
|
||||
|
||||
real_p = unop(_complex_basetype, _complex, 'real')
|
||||
ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))])
|
||||
mlir.register_lowering(real_p, partial(_nary_lower_mhlo, mhlo.RealOp))
|
||||
mlir.register_lowering(real_p, partial(_nary_lower_hlo, hlo.RealOp))
|
||||
|
||||
imag_p = unop(_complex_basetype, _complex, 'imag')
|
||||
ad.deflinear2(imag_p, lambda t, _: [complex(np.zeros((), _dtype(t)), neg(t))])
|
||||
mlir.register_lowering(imag_p, partial(_nary_lower_mhlo, mhlo.ImagOp))
|
||||
mlir.register_lowering(imag_p, partial(_nary_lower_hlo, hlo.ImagOp))
|
||||
|
||||
|
||||
def _complex_transpose_rule(t, x, y):
|
||||
@ -1976,7 +1975,7 @@ _complex_dtype = lambda dtype, *args: (np.zeros((), dtype) + np.zeros((), np.com
|
||||
complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types],
|
||||
'complex')
|
||||
ad.deflinear2(complex_p, _complex_transpose_rule)
|
||||
mlir.register_lowering(complex_p, partial(_nary_lower_mhlo, mhlo.ComplexOp))
|
||||
mlir.register_lowering(complex_p, partial(_nary_lower_hlo, hlo.ComplexOp))
|
||||
|
||||
conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj')
|
||||
|
||||
@ -2001,7 +2000,7 @@ ad.primitive_jvps[conj_p] = partial(ad.linear_jvp, conj_p)
|
||||
ad.primitive_transposes[conj_p] = _conj_transpose_rule
|
||||
|
||||
abs_p = unop(_complex_basetype, _num, 'abs')
|
||||
mlir.register_lowering(abs_p, partial(_nary_lower_mhlo, mhlo.AbsOp))
|
||||
mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.AbsOp))
|
||||
|
||||
def _abs_jvp_rule(g, ans, x):
|
||||
if _iscomplex(x):
|
||||
@ -2015,18 +2014,18 @@ _maybe_real = lambda x: real(x) if _iscomplex(x) else x
|
||||
|
||||
sqrt_p = standard_unop(_float | _complex, 'sqrt')
|
||||
ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans)))
|
||||
mlir.register_lowering(sqrt_p, partial(_nary_lower_mhlo, mhlo.SqrtOp))
|
||||
mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.SqrtOp))
|
||||
|
||||
rsqrt_p = standard_unop(_float | _complex, 'rsqrt')
|
||||
ad.defjvp2(rsqrt_p,
|
||||
lambda g, ans, x:
|
||||
mul(g, mul(_const(x, -0.5), div(ans, x))))
|
||||
mlir.register_lowering(rsqrt_p, partial(_nary_lower_mhlo, mhlo.RsqrtOp))
|
||||
mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.RsqrtOp))
|
||||
|
||||
cbrt_p = standard_unop(_float, 'cbrt')
|
||||
ad.defjvp2(cbrt_p,
|
||||
lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2))))
|
||||
mlir.register_lowering(cbrt_p, partial(_nary_lower_mhlo, mhlo.CbrtOp))
|
||||
mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.CbrtOp))
|
||||
|
||||
pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow')
|
||||
|
||||
@ -2037,7 +2036,7 @@ def _pow_jvp_rhs(g, ans, x, y):
|
||||
return mul(g, mul(log(_replace_zero(x)), ans))
|
||||
|
||||
ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs)
|
||||
mlir.register_lowering(pow_p, partial(_nary_lower_mhlo, mhlo.PowOp))
|
||||
mlir.register_lowering(pow_p, partial(_nary_lower_hlo, hlo.PowOp))
|
||||
|
||||
|
||||
def _integer_pow_dtype_rule(x, *, y):
|
||||
@ -2077,8 +2076,8 @@ def _integer_pow(x, *, y):
|
||||
def _integer_pow_lowering(ctx, x, *, y):
|
||||
lowering = mlir.lower_fun(_integer_pow, multiple_results=False)
|
||||
# TODO(b/217551391): emitting an out-of-line call leads to a large
|
||||
# expansion when the MHLO is lowered to HLO, because the HLO lowering
|
||||
# clones the callee. Consider unconditionally caching when the MHLO->HLO
|
||||
# expansion when the MLIR is lowered to HLO, because the HLO lowering
|
||||
# clones the callee. Consider unconditionally caching when the MLIR->HLO
|
||||
# lowering doesn't expand the program.
|
||||
if y >= 4:
|
||||
lowering = mlir.cache_lowering(lowering)
|
||||
@ -2090,26 +2089,26 @@ _replace_zero = lambda x: select(eq(x, _const(x, 0)), _ones(x), x)
|
||||
|
||||
not_p = standard_unop(_bool_or_int, 'not')
|
||||
ad.defjvp_zero(not_p)
|
||||
mlir.register_lowering(not_p, partial(_nary_lower_mhlo, mhlo.NotOp))
|
||||
mlir.register_lowering(not_p, partial(_nary_lower_hlo, hlo.NotOp))
|
||||
|
||||
and_p = standard_naryop([_bool_or_int, _bool_or_int], 'and')
|
||||
ad.defjvp_zero(and_p)
|
||||
mlir.register_lowering(and_p, partial(_nary_lower_mhlo, mhlo.AndOp))
|
||||
mlir.register_lowering(and_p, partial(_nary_lower_hlo, hlo.AndOp))
|
||||
|
||||
or_p = standard_naryop([_bool_or_int, _bool_or_int], 'or')
|
||||
ad.defjvp_zero(or_p)
|
||||
mlir.register_lowering(or_p, partial(_nary_lower_mhlo, mhlo.OrOp))
|
||||
mlir.register_lowering(or_p, partial(_nary_lower_hlo, hlo.OrOp))
|
||||
|
||||
xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor')
|
||||
ad.defjvp_zero(xor_p)
|
||||
mlir.register_lowering(xor_p, partial(_nary_lower_mhlo, mhlo.XorOp))
|
||||
mlir.register_lowering(xor_p, partial(_nary_lower_hlo, hlo.XorOp))
|
||||
|
||||
population_count_p = standard_unop(_int, 'population_count')
|
||||
mlir.register_lowering(population_count_p,
|
||||
partial(_nary_lower_mhlo, mhlo.PopulationCountOp))
|
||||
partial(_nary_lower_hlo, hlo.PopulationCountOp))
|
||||
|
||||
clz_p = standard_unop(_int, 'clz')
|
||||
mlir.register_lowering(clz_p, partial(_nary_lower_mhlo, mhlo.ClzOp))
|
||||
mlir.register_lowering(clz_p, partial(_nary_lower_hlo, hlo.ClzOp))
|
||||
|
||||
def _add_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
@ -2145,7 +2144,7 @@ def _add_inverse(r, x, y):
|
||||
add_p: Primitive = standard_naryop([_num, _num], 'add')
|
||||
ad.primitive_jvps[add_p] = _add_jvp
|
||||
ad.primitive_transposes[add_p] = _add_transpose
|
||||
mlir.register_lowering(add_p, partial(_nary_lower_mhlo, mhlo.AddOp))
|
||||
mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.AddOp))
|
||||
|
||||
def _sub_jvp(primals, tangents):
|
||||
x, y = primals
|
||||
@ -2174,7 +2173,7 @@ def _sub_transpose(t, x, y):
|
||||
sub_p = standard_naryop([_num, _num], 'sub')
|
||||
ad.primitive_jvps[sub_p] = _sub_jvp
|
||||
ad.primitive_transposes[sub_p] = _sub_transpose
|
||||
mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubtractOp))
|
||||
mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.SubtractOp))
|
||||
|
||||
|
||||
def _mul_transpose(ct, x, y):
|
||||
@ -2200,7 +2199,7 @@ ad.defjvp(mul_p,
|
||||
lambda xdot, x, y: mul(xdot, y),
|
||||
lambda ydot, x, y: mul(x, ydot))
|
||||
ad.primitive_transposes[mul_p] = _mul_transpose
|
||||
mlir.register_lowering(mul_p, partial(_nary_lower_mhlo, mhlo.MulOp))
|
||||
mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.MulOp))
|
||||
|
||||
def _div_transpose_rule(cotangent, x, y):
|
||||
assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y)
|
||||
@ -2213,14 +2212,14 @@ ad.defjvp(div_p,
|
||||
lambda g, x, y: div(g, y),
|
||||
lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2)))
|
||||
ad.primitive_transposes[div_p] = _div_transpose_rule
|
||||
mlir.register_lowering(div_p, partial(_nary_lower_mhlo, mhlo.DivOp))
|
||||
mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.DivOp))
|
||||
|
||||
rem_p = standard_naryop([_int | _float, _int | _float], 'rem')
|
||||
ad.defjvp(
|
||||
rem_p,
|
||||
lambda g, x, y: _maybe_broadcast(broadcast_shapes(np.shape(x), np.shape(y)), g),
|
||||
lambda g, x, y: mul(neg(g), mul(sign(div(x, y)), floor(abs(div(x, y))))))
|
||||
mlir.register_lowering(rem_p, partial(_nary_lower_mhlo, mhlo.RemOp))
|
||||
mlir.register_lowering(rem_p, partial(_nary_lower_hlo, hlo.RemOp))
|
||||
|
||||
def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x):
|
||||
result_shape = broadcast_shapes(np.shape(x), np.shape(y))
|
||||
@ -2236,29 +2235,29 @@ max_p: core.Primitive = standard_naryop([_any, _any], 'max')
|
||||
ad.defjvp2(max_p,
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
||||
mlir.register_lowering(max_p, partial(_nary_lower_mhlo, mlir.max_mhlo))
|
||||
mlir.register_lowering(max_p, partial(_nary_lower_hlo, mlir.max_hlo))
|
||||
|
||||
min_p: core.Primitive = standard_naryop([_any, _any], 'min')
|
||||
ad.defjvp2(min_p,
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
||||
mlir.register_lowering(min_p, partial(_nary_lower_mhlo, mlir.min_mhlo))
|
||||
mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo))
|
||||
|
||||
shift_left_p = standard_naryop([_int, _int], 'shift_left')
|
||||
ad.defjvp_zero(shift_left_p)
|
||||
mlir.register_lowering(shift_left_p, partial(_nary_lower_mhlo, mhlo.ShiftLeftOp))
|
||||
mlir.register_lowering(shift_left_p, partial(_nary_lower_hlo, hlo.ShiftLeftOp))
|
||||
|
||||
shift_right_arithmetic_p = standard_naryop([_int, _int], 'shift_right_arithmetic')
|
||||
ad.defjvp_zero(shift_right_arithmetic_p)
|
||||
mlir.register_lowering(shift_right_arithmetic_p,
|
||||
partial(_nary_lower_mhlo, mhlo.ShiftRightArithmeticOp))
|
||||
partial(_nary_lower_hlo, hlo.ShiftRightArithmeticOp))
|
||||
|
||||
shift_right_logical_p = standard_naryop([_int, _int], 'shift_right_logical')
|
||||
ad.defjvp_zero(shift_right_logical_p)
|
||||
mlir.register_lowering(shift_right_logical_p,
|
||||
partial(_nary_lower_mhlo, mhlo.ShiftRightLogicalOp))
|
||||
partial(_nary_lower_hlo, hlo.ShiftRightLogicalOp))
|
||||
|
||||
def _compare_lower_mhlo(direction: str, ctx, x, y):
|
||||
def _compare_lower_hlo(direction: str, ctx, x, y):
|
||||
avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out
|
||||
x_dtype = avals_in[0].dtype
|
||||
x, y = mlir.multi_broadcast_in_dim(ctx, (x, y), avals_in, aval_out.shape)
|
||||
@ -2269,31 +2268,31 @@ def _compare_lower_mhlo(direction: str, ctx, x, y):
|
||||
compare_type = "SIGNED"
|
||||
else:
|
||||
compare_type = "UNSIGNED"
|
||||
return mlir.compare_mhlo(x, y, direction, compare_type).results
|
||||
return mlir.compare_hlo(x, y, direction, compare_type).results
|
||||
|
||||
eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq')
|
||||
ad.defjvp_zero(eq_p)
|
||||
mlir.register_lowering(eq_p, partial(_compare_lower_mhlo, "EQ"))
|
||||
mlir.register_lowering(eq_p, partial(_compare_lower_hlo, "EQ"))
|
||||
|
||||
ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne')
|
||||
ad.defjvp_zero(ne_p)
|
||||
mlir.register_lowering(ne_p, partial(_compare_lower_mhlo, "NE"))
|
||||
mlir.register_lowering(ne_p, partial(_compare_lower_hlo, "NE"))
|
||||
|
||||
ge_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'ge')
|
||||
ad.defjvp_zero(ge_p)
|
||||
mlir.register_lowering(ge_p, partial(_compare_lower_mhlo, "GE"))
|
||||
mlir.register_lowering(ge_p, partial(_compare_lower_hlo, "GE"))
|
||||
|
||||
gt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'gt')
|
||||
ad.defjvp_zero(gt_p)
|
||||
mlir.register_lowering(gt_p, partial(_compare_lower_mhlo, "GT"))
|
||||
mlir.register_lowering(gt_p, partial(_compare_lower_hlo, "GT"))
|
||||
|
||||
le_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'le')
|
||||
ad.defjvp_zero(le_p)
|
||||
mlir.register_lowering(le_p, partial(_compare_lower_mhlo, "LE"))
|
||||
mlir.register_lowering(le_p, partial(_compare_lower_hlo, "LE"))
|
||||
|
||||
lt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt')
|
||||
ad.defjvp_zero(lt_p)
|
||||
mlir.register_lowering(lt_p, partial(_compare_lower_mhlo, "LT"))
|
||||
mlir.register_lowering(lt_p, partial(_compare_lower_hlo, "LT"))
|
||||
|
||||
|
||||
def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type):
|
||||
@ -2392,9 +2391,9 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type):
|
||||
aval_out, = ctx.avals_out
|
||||
if (dtypes.issubdtype(aval_in.dtype, np.complexfloating) and
|
||||
not dtypes.issubdtype(new_dtype, np.complexfloating)):
|
||||
operand = mhlo.RealOp(operand).result
|
||||
operand = hlo.RealOp(operand).result
|
||||
aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype))
|
||||
return [mlir.convert_mhlo(ctx, operand, aval_in, aval_out)]
|
||||
return [mlir.convert_hlo(ctx, operand, aval_in, aval_out)]
|
||||
|
||||
mlir.register_lowering(convert_element_type_p, _convert_element_type_lower)
|
||||
|
||||
@ -2419,7 +2418,7 @@ batching.defvectorized(bitcast_convert_type_p)
|
||||
|
||||
def _bitcast_convert_type_lower(ctx, operand, *, new_dtype):
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.BitcastConvertOp(mlir.aval_to_ir_type(aval_out), operand).results
|
||||
return hlo.BitcastConvertOp(mlir.aval_to_ir_type(aval_out), operand).results
|
||||
|
||||
mlir.register_lowering(bitcast_convert_type_p, _bitcast_convert_type_lower)
|
||||
|
||||
@ -2708,7 +2707,7 @@ def precision_attr(precision: PrecisionType) -> ir.ArrayAttr:
|
||||
else:
|
||||
full_precision = precision
|
||||
return ir.ArrayAttr.get(
|
||||
[mhlo.PrecisionAttr.get(str(p)) for p in full_precision])
|
||||
[hlo.PrecisionAttr.get(str(p)) for p in full_precision])
|
||||
|
||||
|
||||
|
||||
@ -2723,19 +2722,19 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
|
||||
if ctx.module_context.platform == "cpu":
|
||||
if lhs_aval.dtype == np.float16:
|
||||
f32 = mlir.dtype_to_ir_type(np.dtype(np.float32))
|
||||
lhs = mhlo.ConvertOp(ir.RankedTensorType.get(lhs_aval.shape, f32),
|
||||
lhs = hlo.ConvertOp(ir.RankedTensorType.get(lhs_aval.shape, f32),
|
||||
lhs).result
|
||||
if rhs_aval.dtype == np.float16:
|
||||
f32 = mlir.dtype_to_ir_type(np.dtype(np.float32))
|
||||
rhs = mhlo.ConvertOp(ir.RankedTensorType.get(rhs_aval.shape, f32),
|
||||
rhs = hlo.ConvertOp(ir.RankedTensorType.get(rhs_aval.shape, f32),
|
||||
rhs).result
|
||||
dot_dnums = mhlo.DotDimensionNumbers.get(
|
||||
dot_dnums = hlo.DotDimensionNumbers.get(
|
||||
lhs_batching_dimensions=list(lhs_batch),
|
||||
rhs_batching_dimensions=list(rhs_batch),
|
||||
lhs_contracting_dimensions=list(lhs_contracting),
|
||||
rhs_contracting_dimensions=list(rhs_contracting))
|
||||
return [
|
||||
mhlo.DotGeneralOp(
|
||||
hlo.DotGeneralOp(
|
||||
mlir.aval_to_ir_type(aval_out),
|
||||
lhs,
|
||||
rhs,
|
||||
@ -3005,7 +3004,7 @@ ad.defjvp(clamp_p,
|
||||
select(lt(max, operand), g, _zeros(operand)))
|
||||
batching.primitive_batchers[clamp_p] = _clamp_batch_rule
|
||||
mlir.register_lowering(
|
||||
clamp_p, partial(_nary_lower_mhlo, mhlo.ClampOp))
|
||||
clamp_p, partial(_nary_lower_hlo, hlo.ClampOp))
|
||||
pe.def_trivial_padding(clamp_p)
|
||||
|
||||
def _concatenate_shape_rule(*operands, **kwargs):
|
||||
@ -3083,7 +3082,7 @@ batching.primitive_batchers[concatenate_p] = _concatenate_batch_rule
|
||||
pe.padding_rules[concatenate_p] = _concatenate_pad_rule
|
||||
|
||||
def _concatenate_lower(ctx, *xs, dimension):
|
||||
return mhlo.ConcatenateOp(xs, mlir.i64_attr(dimension)).results
|
||||
return hlo.ConcatenateOp(xs, mlir.i64_attr(dimension)).results
|
||||
mlir.register_lowering(concatenate_p, _concatenate_lower)
|
||||
|
||||
|
||||
@ -3158,7 +3157,7 @@ batching.primitive_batchers[pad_p] = _pad_batch_rule
|
||||
|
||||
def _pad_lower(ctx, x, padding_value, *, padding_config):
|
||||
low, high, interior = util.unzip3(padding_config)
|
||||
return mhlo.PadOp(x, padding_value,
|
||||
return hlo.PadOp(x, padding_value,
|
||||
mlir.dense_int_elements(low),
|
||||
mlir.dense_int_elements(high),
|
||||
mlir.dense_int_elements(interior)).results
|
||||
@ -3302,7 +3301,7 @@ def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions):
|
||||
def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions):
|
||||
aval_out, = ctx.avals_out
|
||||
if dimensions is not None:
|
||||
x = mhlo.TransposeOp(x, mlir.dense_int_elements(dimensions)).result
|
||||
x = hlo.TransposeOp(x, mlir.dense_int_elements(dimensions)).result
|
||||
if dyn_shape:
|
||||
aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape))
|
||||
return [mlir.reshape(ctx, x, aval_out)]
|
||||
@ -3346,7 +3345,7 @@ ad.deflinear2(rev_p, lambda t, _, dimensions: [rev(t, dimensions)])
|
||||
batching.primitive_batchers[rev_p] = _rev_batch_rule
|
||||
|
||||
def _rev_lower(ctx, x, *, dimensions):
|
||||
return mhlo.ReverseOp(x, mlir.dense_int_elements(dimensions)).results
|
||||
return hlo.ReverseOp(x, mlir.dense_int_elements(dimensions)).results
|
||||
mlir.register_lowering(rev_p, _rev_lower)
|
||||
|
||||
|
||||
@ -3370,7 +3369,7 @@ def _transpose_lower(ctx, x, *, permutation):
|
||||
aval_out, = ctx.avals_out
|
||||
if core.is_opaque_dtype(aval_out.dtype):
|
||||
return [aval_out.dtype._rules.transpose_mlir(ctx, aval_out, x, permutation=permutation)]
|
||||
return mhlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results
|
||||
return hlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results
|
||||
|
||||
transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
|
||||
'transpose')
|
||||
@ -3470,12 +3469,12 @@ def _select_jvp(primals, tangents):
|
||||
out_dot = select_n(which, *case_tangents)
|
||||
return out, out_dot
|
||||
|
||||
def _select_mhlo_lowering(ctx, which, *cases):
|
||||
def _select_hlo_lowering(ctx, which, *cases):
|
||||
which_aval = ctx.avals_in[0]
|
||||
if which_aval.dtype == np.dtype(np.bool_):
|
||||
assert len(cases) <= 2
|
||||
if len(cases) == 1: return cases
|
||||
return mhlo.SelectOp(which, cases[1], cases[0]).results
|
||||
return hlo.SelectOp(which, cases[1], cases[0]).results
|
||||
|
||||
if dtypes.issubdtype(which_aval.dtype, np.signedinteger):
|
||||
compare_type = 'SIGNED'
|
||||
@ -3488,10 +3487,10 @@ def _select_mhlo_lowering(ctx, which, *cases):
|
||||
if len(cases) == 1:
|
||||
return cases[0]
|
||||
mid = len(cases) // 2
|
||||
pred = mlir.compare_mhlo(which,
|
||||
pred = mlir.compare_hlo(which,
|
||||
mlir.full_like_aval(ctx, offset + mid, which_aval),
|
||||
lt, compare_type)
|
||||
return mhlo.SelectOp(pred, _select(offset, cases[:mid]),
|
||||
return hlo.SelectOp(pred, _select(offset, cases[:mid]),
|
||||
_select(offset + mid, cases[mid:])).result
|
||||
|
||||
return [_select(0, cases)]
|
||||
@ -3502,7 +3501,7 @@ select_n_p = standard_primitive(
|
||||
ad.primitive_jvps[select_n_p] = _select_jvp
|
||||
ad.primitive_transposes[select_n_p] = _select_transpose_rule
|
||||
batching.primitive_batchers[select_n_p] = _select_batch_rule
|
||||
mlir.register_lowering(select_n_p, _select_mhlo_lowering)
|
||||
mlir.register_lowering(select_n_p, _select_hlo_lowering)
|
||||
pe.def_trivial_padding(select_n_p)
|
||||
|
||||
|
||||
@ -3622,7 +3621,7 @@ def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions):
|
||||
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
|
||||
operands, init_values = util.split_list(values, [len(values) // 2])
|
||||
init_value_avals = ctx.avals_in[len(values) // 2:]
|
||||
op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
operands, init_values, mlir.dense_int_elements(dimensions))
|
||||
ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
|
||||
reducer = op.regions[0].blocks.append(*(ir_types + ir_types))
|
||||
@ -3633,7 +3632,7 @@ def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions):
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(reducer_ctx, jaxpr, mlir.TokenSet(), consts,
|
||||
*([a] for a in reducer.arguments),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
return op.results
|
||||
|
||||
mlir.register_lowering(reduce_p, _reduce_lower)
|
||||
@ -3818,29 +3817,29 @@ batching.defreducer(reduce_xor_p)
|
||||
def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes):
|
||||
aval_out, = ctx.avals_out
|
||||
dtype = aval_out.dtype
|
||||
op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x],
|
||||
op = hlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x],
|
||||
mlir.ir_constants(unit_factory(aval_out.dtype)),
|
||||
mlir.dense_int_elements(axes))
|
||||
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype))
|
||||
reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer_region):
|
||||
add = reducer(*reducer_region.arguments)
|
||||
mhlo.ReturnOp(add.results)
|
||||
hlo.ReturnOp(add.results)
|
||||
return op.results
|
||||
|
||||
mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, mhlo.AddOp,
|
||||
mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp,
|
||||
_get_sum_identity))
|
||||
mlir.register_lowering(reduce_prod_p, partial(_unary_reduce_lower, mhlo.MulOp,
|
||||
mlir.register_lowering(reduce_prod_p, partial(_unary_reduce_lower, hlo.MulOp,
|
||||
_get_prod_identity))
|
||||
mlir.register_lowering(reduce_or_p, partial(_unary_reduce_lower, mhlo.OrOp,
|
||||
mlir.register_lowering(reduce_or_p, partial(_unary_reduce_lower, hlo.OrOp,
|
||||
_get_bitwise_or_identity))
|
||||
mlir.register_lowering(reduce_and_p, partial(_unary_reduce_lower, mhlo.AndOp,
|
||||
mlir.register_lowering(reduce_and_p, partial(_unary_reduce_lower, hlo.AndOp,
|
||||
_get_bitwise_and_identity))
|
||||
mlir.register_lowering(reduce_xor_p, partial(_unary_reduce_lower, mhlo.XorOp,
|
||||
mlir.register_lowering(reduce_xor_p, partial(_unary_reduce_lower, hlo.XorOp,
|
||||
_get_bitwise_or_identity))
|
||||
mlir.register_lowering(reduce_min_p, partial(_unary_reduce_lower, mlir.min_mhlo,
|
||||
mlir.register_lowering(reduce_min_p, partial(_unary_reduce_lower, mlir.min_hlo,
|
||||
_get_min_identity))
|
||||
mlir.register_lowering(reduce_max_p, partial(_unary_reduce_lower, mlir.max_mhlo,
|
||||
mlir.register_lowering(reduce_max_p, partial(_unary_reduce_lower, mlir.max_hlo,
|
||||
_get_max_identity))
|
||||
|
||||
|
||||
@ -3864,7 +3863,7 @@ batching.defvectorized(reduce_precision_p)
|
||||
|
||||
def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits):
|
||||
aval_out, = ctx.avals_out
|
||||
return mhlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits),
|
||||
return hlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits),
|
||||
mlir.i32_attr(mantissa_bits)).results
|
||||
|
||||
mlir.register_lowering(reduce_precision_p, _reduce_precision_lower)
|
||||
@ -4014,7 +4013,7 @@ batching.primitive_batchers[sort_p] = _sort_batch_rule
|
||||
|
||||
def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
|
||||
assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in
|
||||
sort = mhlo.SortOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
sort = hlo.SortOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
mlir.flatten_lowering_ir_args(operands),
|
||||
dimension=mlir.i64_attr(dimension),
|
||||
is_stable=ir.BoolAttr.get(is_stable))
|
||||
@ -4031,7 +4030,7 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys):
|
||||
|
||||
out = lower_comparator(sub_ctx, *[[a] for a in comparator.arguments],
|
||||
num_keys=num_keys)
|
||||
mhlo.ReturnOp(util.flatten(out))
|
||||
hlo.ReturnOp(util.flatten(out))
|
||||
return sort.results
|
||||
|
||||
mlir.register_lowering(sort_p, _sort_lower)
|
||||
@ -4134,10 +4133,9 @@ create_token_p.def_abstract_eval(lambda *_: abstract_token)
|
||||
def _create_token_lowering(ctx, *operands):
|
||||
aval_out, = ctx.avals_out
|
||||
if xc.mlir_api_version < 40:
|
||||
return mhlo.CreateTokenOp(mlir.aval_to_ir_type(aval_out)).results
|
||||
return hlo.CreateTokenOp(mlir.aval_to_ir_type(aval_out)).results
|
||||
else:
|
||||
return mhlo.CreateTokenOp().results
|
||||
|
||||
return hlo.CreateTokenOp().results
|
||||
mlir.register_lowering(create_token_p, _create_token_lowering)
|
||||
|
||||
|
||||
@ -4160,10 +4158,9 @@ after_all_p.def_abstract_eval(_after_all_abstract_eval)
|
||||
def _after_all_lowering(ctx, *operands):
|
||||
aval_out, = ctx.avals_out
|
||||
if xc.mlir_api_version < 40:
|
||||
return mhlo.AfterAllOp(mlir.aval_to_ir_type(aval_out), operands).results
|
||||
return hlo.AfterAllOp(mlir.aval_to_ir_type(aval_out), operands).results
|
||||
else:
|
||||
return mhlo.AfterAllOp(operands).results
|
||||
|
||||
return hlo.AfterAllOp(operands).results
|
||||
mlir.register_lowering(after_all_p, _after_all_lowering)
|
||||
|
||||
|
||||
@ -4215,8 +4212,8 @@ def _infeed_lowering(ctx, token, *, shapes, partitions):
|
||||
for i in range(len(aval.shape) - 1, -1, -1)])
|
||||
for aval in shapes
|
||||
])
|
||||
infeed = mhlo.InfeedOp(
|
||||
flat_output_types + [mhlo.TokenType.get()],
|
||||
infeed = hlo.InfeedOp(
|
||||
flat_output_types + [hlo.TokenType.get()],
|
||||
token,
|
||||
infeed_config=ir.StringAttr.get(''),
|
||||
layout=layouts)
|
||||
@ -4259,13 +4256,13 @@ mlir.lowerable_effects.add(InOutFeedEffect.Outfeed)
|
||||
def _outfeed_lowering(ctx, token, *xs, partitions):
|
||||
token_aval = ctx.avals_in[0]
|
||||
if xc.mlir_api_version < 40:
|
||||
outfeed = mhlo.OutfeedOp(
|
||||
outfeed = hlo.OutfeedOp(
|
||||
mlir.aval_to_ir_type(token_aval),
|
||||
mlir.flatten_lowering_ir_args(xs),
|
||||
token,
|
||||
outfeed_config=ir.StringAttr.get(''))
|
||||
else:
|
||||
outfeed = mhlo.OutfeedOp(
|
||||
outfeed = hlo.OutfeedOp(
|
||||
mlir.flatten_lowering_ir_args(xs),
|
||||
token,
|
||||
outfeed_config=ir.StringAttr.get(''))
|
||||
@ -4308,8 +4305,8 @@ def _rng_uniform_lowering(ctx, a, b, *, shape):
|
||||
aval_out, = ctx.avals_out
|
||||
shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64),
|
||||
canonicalize_types=False)
|
||||
return mhlo.RngOp(a, b, shape,
|
||||
mhlo.RngDistributionAttr.get('UNIFORM')).results
|
||||
return hlo.RngOp(a, b, shape,
|
||||
hlo.RngDistributionAttr.get('UNIFORM')).results
|
||||
|
||||
mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering)
|
||||
|
||||
@ -4331,11 +4328,11 @@ RandomAlgorithm.__str__ = lambda algorithm: algorithm.name # type: ignore[assig
|
||||
|
||||
def _rng_algorithm(algorithm: RandomAlgorithm):
|
||||
if algorithm == RandomAlgorithm.RNG_THREE_FRY:
|
||||
return mhlo.RngAlgorithmAttr.get("THREE_FRY")
|
||||
return hlo.RngAlgorithmAttr.get("THREE_FRY")
|
||||
elif algorithm == RandomAlgorithm.RNG_PHILOX:
|
||||
return mhlo.RngAlgorithmAttr.get("PHILOX")
|
||||
return hlo.RngAlgorithmAttr.get("PHILOX")
|
||||
elif algorithm == RandomAlgorithm.RNG_DEFAULT:
|
||||
return mhlo.RngAlgorithmAttr.get("DEFAULT")
|
||||
return hlo.RngAlgorithmAttr.get("DEFAULT")
|
||||
else:
|
||||
assert False
|
||||
|
||||
@ -4362,21 +4359,21 @@ def _rng_bit_generator_lowering(
|
||||
else:
|
||||
rbg_etype = u32_type
|
||||
if key_etype == u32_type:
|
||||
key = mhlo.BitcastConvertOp(
|
||||
key = hlo.BitcastConvertOp(
|
||||
ir.RankedTensorType.get([2], u64_type),
|
||||
mhlo.ReshapeOp(ir.RankedTensorType.get([2, 2], u32_type), key)).result
|
||||
hlo.ReshapeOp(ir.RankedTensorType.get([2, 2], u32_type), key)).result
|
||||
algorithm_attr = _rng_algorithm(algorithm)
|
||||
out_key, out_vals = mhlo.RngBitGeneratorOp(
|
||||
out_key, out_vals = hlo.RngBitGeneratorOp(
|
||||
key.type,
|
||||
ir.RankedTensorType.get(shape, rbg_etype),
|
||||
algorithm_attr, key).results
|
||||
if key_etype == u32_type:
|
||||
out_key = mhlo.ReshapeOp(
|
||||
out_key = hlo.ReshapeOp(
|
||||
ir.RankedTensorType.get([4], u32_type),
|
||||
mhlo.BitcastConvertOp(
|
||||
hlo.BitcastConvertOp(
|
||||
ir.RankedTensorType.get([2, 2], u32_type), out_key)).result
|
||||
if rbg_etype != etype:
|
||||
out_vals = mhlo.ConvertOp(
|
||||
out_vals = hlo.ConvertOp(
|
||||
ir.RankedTensorType.get(ir.RankedTensorType(out_vals.type).shape, etype),
|
||||
out_vals).result
|
||||
return [out_key, out_vals]
|
||||
@ -4523,13 +4520,13 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension):
|
||||
aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape))
|
||||
if not core.is_constant_shape(aval_out.shape):
|
||||
shape = mlir.eval_dynamic_shape(ctx, aval_out.shape)
|
||||
return mhlo.DynamicIotaOp(
|
||||
return hlo.DynamicIotaOp(
|
||||
mlir.aval_to_ir_type(aval_out),
|
||||
mlir.shape_tensor(shape),
|
||||
mlir.i64_attr(dimension),
|
||||
).results
|
||||
else:
|
||||
return mhlo.IotaOp(mlir.aval_to_ir_type(aval_out),
|
||||
return hlo.IotaOp(mlir.aval_to_ir_type(aval_out),
|
||||
mlir.i64_attr(dimension)).results
|
||||
mlir.register_lowering(iota_p, _iota_lower)
|
||||
|
||||
|
@ -52,7 +52,7 @@ from jax._src.lib import xla_client
|
||||
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
|
||||
xops = xla_client.ops
|
||||
@ -418,7 +418,7 @@ ad.primitive_jvps[cholesky_p] = _cholesky_jvp_rule
|
||||
batching.primitive_batchers[cholesky_p] = _cholesky_batching_rule
|
||||
|
||||
def _cholesky_lowering(ctx, x):
|
||||
return mhlo.CholeskyOp(x, lower=ir.BoolAttr.get(True)).results
|
||||
return hlo.CholeskyOp(x, lower=ir.BoolAttr.get(True)).results
|
||||
|
||||
mlir.register_lowering(cholesky_p, _cholesky_lowering)
|
||||
|
||||
@ -429,22 +429,28 @@ def _cholesky_cpu_gpu_lowering(potrf_impl, ctx, operand):
|
||||
out_aval, = ctx.avals_out
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
result, info = potrf_impl(operand_aval.dtype, operand, lower=True)
|
||||
ok = mlir.compare_mhlo(
|
||||
ok = mlir.compare_hlo(
|
||||
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
||||
"EQ", "SIGNED")
|
||||
select_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
||||
return [_broadcasting_select_mhlo(
|
||||
return [_broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok,
|
||||
select_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
select_aval,
|
||||
result, out_aval, _nan_like_mhlo(ctx, out_aval), out_aval)]
|
||||
result, out_aval, _nan_like_hlo(ctx, out_aval), out_aval)]
|
||||
|
||||
if xla_client.mlir_api_version < 41:
|
||||
mlir.register_lowering(
|
||||
cholesky_p,
|
||||
partial(_cholesky_cpu_gpu_lowering, lapack.potrf_mhlo),
|
||||
platform='cpu')
|
||||
else:
|
||||
mlir.register_lowering(
|
||||
cholesky_p,
|
||||
partial(_cholesky_cpu_gpu_lowering, lapack.potrf_hlo),
|
||||
platform='cpu')
|
||||
|
||||
# Asymmetric eigendecomposition
|
||||
|
||||
@ -491,42 +497,47 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors,
|
||||
out_aval = ctx.avals_out[0]
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
|
||||
if xla_client.mlir_api_version < 41:
|
||||
w, vl, vr, info = lapack.geev_mhlo(operand_aval.dtype, operand,
|
||||
jobvl=compute_left_eigenvectors,
|
||||
jobvr=compute_right_eigenvectors)
|
||||
else:
|
||||
w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand,
|
||||
jobvl=compute_left_eigenvectors,
|
||||
jobvr=compute_right_eigenvectors)
|
||||
|
||||
ok = mlir.compare_mhlo(
|
||||
ok = mlir.compare_hlo(
|
||||
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
||||
"EQ", "SIGNED")
|
||||
select_w_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
|
||||
w = _broadcasting_select_mhlo(
|
||||
w = _broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok, select_w_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
select_w_aval,
|
||||
w, out_aval, _nan_like_mhlo(ctx, out_aval), out_aval)
|
||||
w, out_aval, _nan_like_hlo(ctx, out_aval), out_aval)
|
||||
output = [w]
|
||||
|
||||
if compute_left_eigenvectors:
|
||||
aval = ctx.avals_out[len(output)]
|
||||
select_vl_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
||||
vl = _broadcasting_select_mhlo(
|
||||
vl = _broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok, select_vl_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
select_vl_aval,
|
||||
vl, aval, _nan_like_mhlo(ctx, aval), aval)
|
||||
vl, aval, _nan_like_hlo(ctx, aval), aval)
|
||||
output.append(vl)
|
||||
|
||||
if compute_right_eigenvectors:
|
||||
aval = ctx.avals_out[len(output)]
|
||||
select_vr_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
||||
vr = _broadcasting_select_mhlo(
|
||||
vr = _broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok, select_vr_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
select_vr_aval,
|
||||
vr, aval, _nan_like_mhlo(ctx, aval), aval)
|
||||
vr, aval, _nan_like_hlo(ctx, aval), aval)
|
||||
output.append(vr)
|
||||
|
||||
return output
|
||||
@ -645,21 +656,21 @@ def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower,
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
v, w, info = syevd_impl(operand_aval.dtype, operand, lower=lower)
|
||||
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
|
||||
ok = mlir.compare_mhlo(info, zeros, "EQ", "SIGNED")
|
||||
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
|
||||
select_v_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
||||
v = _broadcasting_select_mhlo(
|
||||
v = _broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok, select_v_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
select_v_aval,
|
||||
v, v_aval, _nan_like_mhlo(ctx, v_aval), v_aval)
|
||||
v, v_aval, _nan_like_hlo(ctx, v_aval), v_aval)
|
||||
select_w_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
|
||||
w = _broadcasting_select_mhlo(
|
||||
w = _broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok, select_w_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
select_w_aval,
|
||||
w, w_aval, _nan_like_mhlo(ctx, w_aval), w_aval)
|
||||
w, w_aval, _nan_like_hlo(ctx, w_aval), w_aval)
|
||||
return [v, w]
|
||||
|
||||
def _eigh_tpu_impl(x, *, lower, sort_eigenvalues):
|
||||
@ -742,9 +753,14 @@ eigh_p.def_abstract_eval(_eigh_abstract_eval)
|
||||
ad.primitive_jvps[eigh_p] = _eigh_jvp_rule
|
||||
batching.primitive_batchers[eigh_p] = _eigh_batching_rule
|
||||
|
||||
if xla_client.mlir_api_version < 41:
|
||||
mlir.register_lowering(
|
||||
eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_mhlo),
|
||||
platform='cpu')
|
||||
else:
|
||||
mlir.register_lowering(
|
||||
eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_hlo),
|
||||
platform='cpu')
|
||||
|
||||
if gpu_solver is not None:
|
||||
mlir.register_lowering(
|
||||
@ -882,15 +898,15 @@ def _triangular_solve_lowering(
|
||||
else:
|
||||
transpose = "ADJOINT" if conjugate_a else "TRANSPOSE"
|
||||
if mlir_api_version < 36:
|
||||
return mhlo.TriangularSolveOp(
|
||||
return hlo.TriangularSolveOp(
|
||||
mlir.aval_to_ir_type(out_aval), a, b, ir.BoolAttr.get(left_side),
|
||||
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
|
||||
mhlo.TransposeAttr.get(transpose)).results
|
||||
hlo.TransposeAttr.get(transpose)).results
|
||||
else:
|
||||
return mhlo.TriangularSolveOp(
|
||||
return hlo.TriangularSolveOp(
|
||||
a, b, ir.BoolAttr.get(left_side),
|
||||
ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal),
|
||||
mhlo.TransposeAttr.get(transpose)).results
|
||||
hlo.TransposeAttr.get(transpose)).results
|
||||
|
||||
mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering)
|
||||
|
||||
@ -904,9 +920,14 @@ def _triangular_solve_cpu_lower(
|
||||
conjugate_a = False
|
||||
if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types:
|
||||
alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype))
|
||||
if xla_client.mlir_api_version < 41:
|
||||
return [lapack.trsm_mhlo(
|
||||
a_aval.dtype, alpha,
|
||||
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)]
|
||||
else:
|
||||
return [lapack.trsm_hlo(
|
||||
a_aval.dtype, alpha,
|
||||
a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)]
|
||||
else:
|
||||
# Fall back to the HLO implementation for unsupported types or batching.
|
||||
# TODO: Consider swapping XLA for LAPACK in batched case
|
||||
@ -915,15 +936,15 @@ def _triangular_solve_cpu_lower(
|
||||
else:
|
||||
transpose = "NO_TRANSPOSE"
|
||||
if mlir_api_version < 36:
|
||||
return mhlo.TriangularSolveOp(b.type, a, b, ir.BoolAttr.get(left_side),
|
||||
return hlo.TriangularSolveOp(b.type, a, b, ir.BoolAttr.get(left_side),
|
||||
ir.BoolAttr.get(lower),
|
||||
ir.BoolAttr.get(unit_diagonal),
|
||||
mhlo.TransposeAttr.get(transpose)).results
|
||||
hlo.TransposeAttr.get(transpose)).results
|
||||
else:
|
||||
return mhlo.TriangularSolveOp(a, b, ir.BoolAttr.get(left_side),
|
||||
return hlo.TriangularSolveOp(a, b, ir.BoolAttr.get(left_side),
|
||||
ir.BoolAttr.get(lower),
|
||||
ir.BoolAttr.get(unit_diagonal),
|
||||
mhlo.TransposeAttr.get(transpose)).results
|
||||
hlo.TransposeAttr.get(transpose)).results
|
||||
|
||||
mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower,
|
||||
platform='cpu')
|
||||
@ -1186,17 +1207,17 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand):
|
||||
m = operand_aval.shape[-2]
|
||||
lu, pivot, info = getrf_impl(operand_aval.dtype, operand)
|
||||
# Subtract 1 from the pivot to get 0-based indices.
|
||||
pivot = mhlo.SubtractOp(pivot, mlir.full_like_aval(ctx, 1, pivot_aval)).result
|
||||
ok = mlir.compare_mhlo(
|
||||
pivot = hlo.SubtractOp(pivot, mlir.full_like_aval(ctx, 1, pivot_aval)).result
|
||||
ok = mlir.compare_hlo(
|
||||
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
||||
"GE", "SIGNED")
|
||||
select_lu_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
||||
lu = _broadcasting_select_mhlo(
|
||||
lu = _broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok, select_lu_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
select_lu_aval,
|
||||
lu, out_aval, _nan_like_mhlo(ctx, out_aval), out_aval)
|
||||
lu, out_aval, _nan_like_hlo(ctx, out_aval), out_aval)
|
||||
sub_ctx = ctx.replace(primitive=None, avals_in=[pivot_aval], avals_out=[perm_aval])
|
||||
perm_fn = mlir.lower_fun(lambda x: lu_pivots_to_permutation(x, m),
|
||||
multiple_results=False)
|
||||
@ -1216,9 +1237,14 @@ mlir.register_lowering(lu_p, mlir.lower_fun(_lu_python, multiple_results=True))
|
||||
ad.primitive_jvps[lu_p] = _lu_jvp_rule
|
||||
batching.primitive_batchers[lu_p] = _lu_batching_rule
|
||||
|
||||
if xla_client.mlir_api_version < 41:
|
||||
mlir.register_lowering(lu_p,
|
||||
partial(_lu_cpu_gpu_lowering, lapack.getrf_mhlo),
|
||||
platform='cpu')
|
||||
else:
|
||||
mlir.register_lowering(lu_p,
|
||||
partial(_lu_cpu_gpu_lowering, lapack.getrf_hlo),
|
||||
platform='cpu')
|
||||
|
||||
mlir.register_lowering(
|
||||
lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.cuda_getrf),
|
||||
@ -1339,15 +1365,15 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a):
|
||||
else:
|
||||
a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a)
|
||||
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
|
||||
ok = mlir.compare_mhlo(info_geqrf, zeros, "EQ", "SIGNED")
|
||||
ok = mlir.compare_hlo(info_geqrf, zeros, "EQ", "SIGNED")
|
||||
select_ok_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_))
|
||||
ok_a = mlir.broadcast_in_dim(ctx, ok, select_ok_a_aval,
|
||||
broadcast_dimensions=range(len(batch_dims)))
|
||||
a_out = _broadcasting_select_mhlo(ctx, ok_a, select_ok_a_aval, a_out, a_aval, _nan_like_mhlo(ctx, a_aval), a_aval)
|
||||
a_out = _broadcasting_select_hlo(ctx, ok_a, select_ok_a_aval, a_out, a_aval, _nan_like_hlo(ctx, a_aval), a_aval)
|
||||
select_ok_taus_aval = ShapedArray(batch_dims + [1], np.dtype(np.bool_))
|
||||
ok_taus = mlir.broadcast_in_dim(ctx, ok, select_ok_taus_aval,
|
||||
broadcast_dimensions=range(len(batch_dims)))
|
||||
taus = _broadcasting_select_mhlo(ctx, ok_taus, select_ok_taus_aval, taus, taus_aval, _nan_like_mhlo(ctx, taus_aval), taus_aval)
|
||||
taus = _broadcasting_select_hlo(ctx, ok_taus, select_ok_taus_aval, taus, taus_aval, _nan_like_hlo(ctx, taus_aval), taus_aval)
|
||||
return a_out, taus
|
||||
|
||||
geqrf_p = Primitive('geqrf')
|
||||
@ -1357,9 +1383,14 @@ geqrf_p.def_abstract_eval(_geqrf_abstract_eval)
|
||||
batching.primitive_batchers[geqrf_p] = _geqrf_batching_rule
|
||||
xla.register_translation(geqrf_p, _geqrf_translation_rule)
|
||||
|
||||
if xla_client.mlir_api_version < 41:
|
||||
mlir.register_lowering(
|
||||
geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_mhlo, None),
|
||||
platform='cpu')
|
||||
else:
|
||||
mlir.register_lowering(
|
||||
geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_hlo, None),
|
||||
platform='cpu')
|
||||
mlir.register_lowering(
|
||||
geqrf_p,
|
||||
partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf,
|
||||
@ -1425,11 +1456,11 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus):
|
||||
|
||||
a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus)
|
||||
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
|
||||
ok = mlir.compare_mhlo(info_orgqr, zeros, "EQ", "SIGNED")
|
||||
ok = mlir.compare_hlo(info_orgqr, zeros, "EQ", "SIGNED")
|
||||
select_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_))
|
||||
ok = mlir.broadcast_in_dim(ctx, ok, select_a_aval,
|
||||
broadcast_dimensions=range(len(batch_dims)))
|
||||
a = _broadcasting_select_mhlo(ctx, ok, select_a_aval, a, a_aval, _nan_like_mhlo(ctx, a_aval), a_aval)
|
||||
a = _broadcasting_select_hlo(ctx, ok, select_a_aval, a, a_aval, _nan_like_hlo(ctx, a_aval), a_aval)
|
||||
return [a]
|
||||
|
||||
|
||||
@ -1439,10 +1470,16 @@ householder_product_p.def_abstract_eval(_householder_product_abstract_eval)
|
||||
batching.primitive_batchers[householder_product_p] = _householder_product_batching_rule
|
||||
xla.register_translation(householder_product_p, _householder_product_translation_rule)
|
||||
|
||||
if xla_client.mlir_api_version < 41:
|
||||
mlir.register_lowering(
|
||||
householder_product_p,
|
||||
partial(_householder_product_cpu_gpu_lowering, lapack.orgqr_mhlo),
|
||||
platform='cpu')
|
||||
else:
|
||||
mlir.register_lowering(
|
||||
householder_product_p,
|
||||
partial(_householder_product_cpu_gpu_lowering, lapack.orgqr_hlo),
|
||||
platform='cpu')
|
||||
mlir.register_lowering(
|
||||
householder_product_p,
|
||||
partial(_householder_product_cpu_gpu_lowering, gpu_solver.cuda_orgqr),
|
||||
@ -1642,32 +1679,32 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices,
|
||||
full_matrices=full_matrices,
|
||||
compute_uv=compute_uv)
|
||||
zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32)))
|
||||
ok = mlir.compare_mhlo(info, zeros, "EQ", "SIGNED")
|
||||
ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED")
|
||||
select_s_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
|
||||
s = _broadcasting_select_mhlo(
|
||||
s = _broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok, select_s_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
select_s_aval,
|
||||
s, s_aval, _nan_like_mhlo(ctx, s_aval), s_aval)
|
||||
s, s_aval, _nan_like_hlo(ctx, s_aval), s_aval)
|
||||
result = [s]
|
||||
|
||||
if compute_uv:
|
||||
u_aval, vt_aval = ctx.avals_out[1:]
|
||||
select_u_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
||||
u = _broadcasting_select_mhlo(
|
||||
u = _broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok, select_u_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
select_u_aval,
|
||||
u, u_aval, _nan_like_mhlo(ctx, u_aval), u_aval)
|
||||
u, u_aval, _nan_like_hlo(ctx, u_aval), u_aval)
|
||||
select_v_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
||||
vt = _broadcasting_select_mhlo(
|
||||
vt = _broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok, select_v_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
select_v_aval,
|
||||
vt, vt_aval, _nan_like_mhlo(ctx, vt_aval), vt_aval)
|
||||
vt, vt_aval, _nan_like_hlo(ctx, vt_aval), vt_aval)
|
||||
result += [u, vt]
|
||||
|
||||
return result
|
||||
@ -1715,9 +1752,14 @@ svd_p.def_abstract_eval(_svd_abstract_eval)
|
||||
ad.primitive_jvps[svd_p] = _svd_jvp_rule
|
||||
batching.primitive_batchers[svd_p] = _svd_batching_rule
|
||||
|
||||
if xla_client.mlir_api_version < 41:
|
||||
mlir.register_lowering(
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo),
|
||||
platform='cpu')
|
||||
else:
|
||||
mlir.register_lowering(
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_hlo),
|
||||
platform='cpu')
|
||||
mlir.register_lowering(
|
||||
svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.cuda_gesvd),
|
||||
platform='cuda')
|
||||
@ -1877,33 +1919,40 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals,
|
||||
operand_aval, = ctx.avals_in
|
||||
batch_dims = operand_aval.shape[:-2]
|
||||
|
||||
if xla_client.mlir_api_version < 41:
|
||||
gees_result = lapack.gees_mhlo(operand_aval.dtype, operand,
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable)
|
||||
else:
|
||||
gees_result = lapack.gees_hlo(operand_aval.dtype, operand,
|
||||
jobvs=compute_schur_vectors,
|
||||
sort=sort_eig_vals,
|
||||
select=select_callable)
|
||||
|
||||
# Number of return values depends on value of sort_eig_vals.
|
||||
T, vs, *_, info = gees_result
|
||||
|
||||
ok = mlir.compare_mhlo(
|
||||
ok = mlir.compare_hlo(
|
||||
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
||||
"EQ", "SIGNED")
|
||||
|
||||
select_T_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
||||
T = _broadcasting_select_mhlo(
|
||||
T = _broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok, select_T_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
select_T_aval,
|
||||
T, ctx.avals_out[0],_nan_like_mhlo(ctx, ctx.avals_out[0]), ctx.avals_out[0])
|
||||
T, ctx.avals_out[0],_nan_like_hlo(ctx, ctx.avals_out[0]), ctx.avals_out[0])
|
||||
output = [T]
|
||||
if compute_schur_vectors:
|
||||
select_vs_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
||||
vs = _broadcasting_select_mhlo(
|
||||
vs = _broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok, select_vs_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
select_vs_aval,
|
||||
vs, ctx.avals_out[1], _nan_like_mhlo(ctx, ctx.avals_out[1]), ctx.avals_out[1])
|
||||
vs, ctx.avals_out[1], _nan_like_hlo(ctx, ctx.avals_out[1]), ctx.avals_out[1])
|
||||
|
||||
output.append(vs)
|
||||
|
||||
@ -1983,35 +2032,38 @@ def _hessenberg_batching_rule(batched_args, batch_dims):
|
||||
|
||||
batching.primitive_batchers[hessenberg_p] = _hessenberg_batching_rule
|
||||
|
||||
def _hessenberg_cpu_mhlo(ctx, a):
|
||||
def _hessenberg_cpu_hlo(ctx, a):
|
||||
# TODO(phawkins): remove this test after jaxlib 0.3.25 is the minimum.
|
||||
if not hasattr(lapack, "gehrd_mhlo"):
|
||||
if not hasattr(lapack, "gehrd_mhlo") and not hasattr(lapack, "gehrd_hlo"):
|
||||
raise RuntimeError("Hessenberg reduction on CPU requires jaxlib 0.3.25 or "
|
||||
"newer")
|
||||
a_aval, = ctx.avals_in
|
||||
batch_dims = a_aval.shape[:-2]
|
||||
if xla_client.mlir_api_version < 41:
|
||||
a, taus, info = lapack.gehrd_mhlo(a_aval.dtype, a)
|
||||
ok = mlir.compare_mhlo(
|
||||
else:
|
||||
a, taus, info = lapack.gehrd_hlo(a_aval.dtype, a)
|
||||
ok = mlir.compare_hlo(
|
||||
info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))),
|
||||
"EQ", "SIGNED")
|
||||
select_a_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_))
|
||||
select_taus_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_))
|
||||
return [
|
||||
_broadcasting_select_mhlo(
|
||||
_broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok, select_a_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
select_a_aval,
|
||||
a, ctx.avals_out[0], _nan_like_mhlo(ctx, ctx.avals_out[0]), ctx.avals_out[0]),
|
||||
_broadcasting_select_mhlo(
|
||||
a, ctx.avals_out[0], _nan_like_hlo(ctx, ctx.avals_out[0]), ctx.avals_out[0]),
|
||||
_broadcasting_select_hlo(
|
||||
ctx,
|
||||
mlir.broadcast_in_dim(ctx, ok, select_taus_aval,
|
||||
broadcast_dimensions=range(len(batch_dims))),
|
||||
select_taus_aval,
|
||||
taus, ctx.avals_out[1], _nan_like_mhlo(ctx, ctx.avals_out[1]), ctx.avals_out[1]),
|
||||
taus, ctx.avals_out[1], _nan_like_hlo(ctx, ctx.avals_out[1]), ctx.avals_out[1]),
|
||||
]
|
||||
|
||||
mlir.register_lowering(hessenberg_p, _hessenberg_cpu_mhlo, platform='cpu')
|
||||
mlir.register_lowering(hessenberg_p, _hessenberg_cpu_hlo, platform='cpu')
|
||||
|
||||
|
||||
# tridiagonal: Upper Hessenberg reduction
|
||||
@ -2085,35 +2137,40 @@ def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower):
|
||||
|
||||
batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule
|
||||
|
||||
def _tridiagonal_cpu_gpu_mhlo(sytrd_impl, ctx, a, *, lower):
|
||||
def _tridiagonal_cpu_gpu_hlo(sytrd_impl, ctx, a, *, lower):
|
||||
a_aval, = ctx.avals_in
|
||||
a, d, e, taus, info = sytrd_impl(a_aval.dtype, a, lower=lower)
|
||||
return a, d, e, taus, info
|
||||
|
||||
if jaxlib_version >= (0, 3, 25):
|
||||
if xla_client.mlir_api_version < 41:
|
||||
mlir.register_lowering(
|
||||
tridiagonal_p, partial(_tridiagonal_cpu_gpu_mhlo, lapack.sytrd_mhlo),
|
||||
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_mhlo),
|
||||
platform='cpu')
|
||||
else:
|
||||
mlir.register_lowering(
|
||||
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_hlo),
|
||||
platform='cpu')
|
||||
mlir.register_lowering(
|
||||
tridiagonal_p, partial(_tridiagonal_cpu_gpu_mhlo, gpu_solver.cuda_sytrd),
|
||||
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.cuda_sytrd),
|
||||
platform='cuda')
|
||||
mlir.register_lowering(
|
||||
tridiagonal_p, partial(_tridiagonal_cpu_gpu_mhlo, gpu_solver.rocm_sytrd),
|
||||
tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.rocm_sytrd),
|
||||
platform='rocm')
|
||||
|
||||
# Utilities
|
||||
|
||||
def _nan_like_mhlo(ctx: mlir.LoweringRuleContext, aval) -> ir.Value:
|
||||
def _nan_like_hlo(ctx: mlir.LoweringRuleContext, aval) -> ir.Value:
|
||||
if jnp.issubdtype(aval.dtype, np.complexfloating):
|
||||
return mlir.full_like_aval(ctx, np.nan + np.nan * 1j, aval)
|
||||
else:
|
||||
return mlir.full_like_aval(ctx, np.nan, aval)
|
||||
|
||||
def _broadcasting_select_mhlo(ctx, which, which_aval, x, x_aval, y, y_aval) -> ir.Value:
|
||||
def _broadcasting_select_hlo(ctx, which, which_aval, x, x_aval, y, y_aval) -> ir.Value:
|
||||
"""Wrapper around XLA `Select` that broadcasts its arguments."""
|
||||
out_shapes = list(lax_internal.broadcast_shapes(
|
||||
tuple(which_aval.shape), tuple(x_aval.shape), tuple(y_aval.shape)))
|
||||
which, x, y = mlir.multi_broadcast_in_dim(ctx, (which, x, y),
|
||||
(which_aval, x_aval, y_aval),
|
||||
out_shapes)
|
||||
return mhlo.SelectOp(which, x, y).result
|
||||
return hlo.SelectOp(which, x, y).result
|
||||
|
@ -39,7 +39,7 @@ import jax._src.util as util
|
||||
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import mlir_api_version
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
||||
unsafe_map, map = map, safe_map # type: ignore
|
||||
|
||||
@ -219,7 +219,7 @@ def ppermute(x, axis_name, perm):
|
||||
If ``x`` is a pytree then the result is equivalent to mapping this function to
|
||||
each leaf in the tree.
|
||||
|
||||
This function is an analog of the CollectivePermute XLA HLO.
|
||||
This function is an analog of the CollectivePermute HLO.
|
||||
|
||||
Args:
|
||||
x: array(s) with a mapped axis named ``axis_name``.
|
||||
@ -661,7 +661,7 @@ def _replica_groups(axis_env, axis_name, axis_index_groups):
|
||||
for axis_index_group in axis_index_groups]
|
||||
return replica_groups
|
||||
|
||||
def _replica_groups_mhlo(replica_groups: Sequence[Sequence[int]]
|
||||
def _replica_groups_hlo(replica_groups: Sequence[Sequence[int]]
|
||||
) -> ir.DenseIntElementsAttr:
|
||||
# Uneven replica groups are padded with -1.
|
||||
groups = np.array(list(itertools.zip_longest(*replica_groups, fillvalue=-1)),
|
||||
@ -711,7 +711,7 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
if not named_axes:
|
||||
return args
|
||||
|
||||
replica_groups = _replica_groups_mhlo(
|
||||
replica_groups = _replica_groups_hlo(
|
||||
_replica_groups(ctx.module_context.axis_env, named_axes,
|
||||
axis_index_groups))
|
||||
axis_context = ctx.module_context.axis_context
|
||||
@ -722,12 +722,12 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
if is_spmd:
|
||||
channel = ctx.module_context.new_channel()
|
||||
other_args = dict(
|
||||
channel_handle=mhlo.ChannelHandle.get(
|
||||
channel_handle=hlo.ChannelHandle.get(
|
||||
channel, mlir.DEVICE_TO_DEVICE_TYPE),
|
||||
use_global_device_ids=ir.BoolAttr.get(True))
|
||||
else:
|
||||
other_args = {}
|
||||
op = mhlo.AllReduceOp(
|
||||
op = hlo.AllReduceOp(
|
||||
x.type, x, replica_groups=replica_groups, **other_args)
|
||||
scalar_aval = core.ShapedArray((), aval.dtype)
|
||||
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
||||
@ -738,7 +738,7 @@ def _allreduce_lowering(prim, pos_fn, ctx, *args, axes, axis_index_groups):
|
||||
avals_in=[scalar_aval] * 2, avals_out=[scalar_aval])
|
||||
out_nodes = lower_reducer(
|
||||
reducer_ctx, *([a] for a in reducer_block.arguments))
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
return op.result
|
||||
|
||||
return [all_reduce(aval, x) for aval, x in zip(ctx.avals_in, args)]
|
||||
@ -849,11 +849,11 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm):
|
||||
if is_manual and mlir_api_version >= 35:
|
||||
channel = ctx.module_context.new_channel()
|
||||
other_args = dict(
|
||||
channel_handle=mhlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE))
|
||||
channel_handle=hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE))
|
||||
else:
|
||||
other_args = {}
|
||||
|
||||
return mhlo.CollectivePermuteOp(
|
||||
return hlo.CollectivePermuteOp(
|
||||
x, mlir.dense_int_elements(full_perm), **other_args).results
|
||||
|
||||
def _ppermute_transpose_rule(t, x, perm, axis_name):
|
||||
@ -952,16 +952,16 @@ def _all_to_all_lowering(ctx, x, *,
|
||||
# of partitions - and XLA is configured with only a single replica.
|
||||
channel = ctx.module_context.new_channel()
|
||||
other_args = dict(
|
||||
channel_handle=mhlo.ChannelHandle.get(channel,
|
||||
channel_handle=hlo.ChannelHandle.get(channel,
|
||||
mlir.DEVICE_TO_DEVICE_TYPE))
|
||||
else:
|
||||
other_args = {}
|
||||
return mhlo.AllToAllOp(
|
||||
return hlo.AllToAllOp(
|
||||
operand,
|
||||
split_dimension=mlir.i64_attr(split_axis),
|
||||
concat_dimension=mlir.i64_attr(concat_axis),
|
||||
split_count=mlir.i64_attr(split_count),
|
||||
replica_groups=_replica_groups_mhlo(replica_groups),
|
||||
replica_groups=_replica_groups_hlo(replica_groups),
|
||||
**other_args).results
|
||||
else:
|
||||
warnings.warn(
|
||||
@ -1184,7 +1184,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
|
||||
new_shape = list(x_aval.shape)
|
||||
new_shape.insert(all_gather_dimension, 1)
|
||||
broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension]
|
||||
x = mhlo.BroadcastInDimOp(
|
||||
x = hlo.BroadcastInDimOp(
|
||||
mlir.aval_to_ir_type(x_aval.update(shape=new_shape)), x,
|
||||
mlir.dense_int_elements(broadcast_dimensions))
|
||||
replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name,
|
||||
@ -1195,15 +1195,15 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name,
|
||||
# of partitions - and XLA is configured with only a single replica.
|
||||
channel = ctx.module_context.new_channel()
|
||||
other_args = dict(
|
||||
channel_handle=mhlo.ChannelHandle.get(
|
||||
channel_handle=hlo.ChannelHandle.get(
|
||||
channel, mlir.DEVICE_TO_DEVICE_TYPE),
|
||||
use_global_device_ids=ir.BoolAttr.get(True))
|
||||
else:
|
||||
other_args = {}
|
||||
return mhlo.AllGatherOp(
|
||||
return hlo.AllGatherOp(
|
||||
mlir.aval_to_ir_type(out_aval),
|
||||
x, all_gather_dim=mlir.i64_attr(all_gather_dimension),
|
||||
replica_groups=_replica_groups_mhlo(replica_groups),
|
||||
replica_groups=_replica_groups_hlo(replica_groups),
|
||||
**other_args).results
|
||||
else:
|
||||
lowering = mlir.lower_fun(_all_gather_via_psum, multiple_results=False)
|
||||
@ -1328,16 +1328,16 @@ def _reduce_scatter_lowering(prim, reducer, ctx, x,
|
||||
# of partitions - and XLA is configured with only a single replica.
|
||||
channel = ctx.module_context.new_channel()
|
||||
other_args = dict(
|
||||
channel_handle=mhlo.ChannelHandle.get(
|
||||
channel_handle=hlo.ChannelHandle.get(
|
||||
channel, mlir.DEVICE_TO_DEVICE_TYPE),
|
||||
use_global_device_ids=ir.BoolAttr.get(True))
|
||||
else:
|
||||
other_args = {}
|
||||
op = mhlo.ReduceScatterOp(
|
||||
op = hlo.ReduceScatterOp(
|
||||
mlir.aval_to_ir_type(x_aval.update(shape=scatter_out_shape)),
|
||||
x,
|
||||
scatter_dimension=mlir.i64_attr(scatter_dimension),
|
||||
replica_groups=_replica_groups_mhlo(replica_groups),
|
||||
replica_groups=_replica_groups_hlo(replica_groups),
|
||||
**other_args)
|
||||
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
||||
reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
@ -1348,12 +1348,12 @@ def _reduce_scatter_lowering(prim, reducer, ctx, x,
|
||||
avals_out=[scalar_aval])
|
||||
out_nodes = lower_reducer(
|
||||
reducer_ctx, *([a] for a in reducer_block.arguments))
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
|
||||
if tiled:
|
||||
return op.results
|
||||
else:
|
||||
return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), op.result).results
|
||||
return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), op.result).results
|
||||
else:
|
||||
return mlir.lower_fun(_reduce_scatter_via_reducer, multiple_results=False)(
|
||||
ctx, x,
|
||||
@ -1522,7 +1522,7 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, t
|
||||
return tree_util.tree_map(bind, x)
|
||||
|
||||
|
||||
def _build_axis_index_lowering_mhlo(ctx, axis_name, axis_env):
|
||||
def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env):
|
||||
if isinstance(axis_name, tuple):
|
||||
assert axis_name, 'empty axis name'
|
||||
if len(axis_name) > 1:
|
||||
@ -1539,20 +1539,20 @@ def _build_axis_index_lowering_mhlo(ctx, axis_name, axis_env):
|
||||
(mlir.SPMDAxisContext, mlir.ShardingContext))
|
||||
if is_spmd:
|
||||
if mlir_api_version >= 39:
|
||||
device_id = mhlo.PartitionIdOp()
|
||||
device_id = hlo.PartitionIdOp()
|
||||
else:
|
||||
device_id = mhlo.PartitionIdOp(
|
||||
device_id = hlo.PartitionIdOp(
|
||||
ir.RankedTensorType.get([], ir.IntegerType.get_unsigned(32)))
|
||||
else:
|
||||
device_id = mhlo.ReplicaIdOp()
|
||||
unsigned_index = mhlo.RemOp(mhlo.DivOp(device_id, div), mod)
|
||||
return mhlo.ConvertOp(
|
||||
device_id = hlo.ReplicaIdOp()
|
||||
unsigned_index = hlo.RemOp(hlo.DivOp(device_id, div), mod)
|
||||
return hlo.ConvertOp(
|
||||
ir.RankedTensorType.get([], ir.IntegerType.get_signless(32)),
|
||||
unsigned_index).result
|
||||
|
||||
def _axis_index_lowering(ctx, *, axis_name):
|
||||
return [
|
||||
_build_axis_index_lowering_mhlo(ctx, axis_name,
|
||||
_build_axis_index_lowering_hlo(ctx, axis_name,
|
||||
ctx.module_context.axis_env)
|
||||
]
|
||||
|
||||
|
@ -36,7 +36,7 @@ from jax._src.lax import lax
|
||||
from jax._src import util
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.typing import Array, ArrayLike, Shape
|
||||
@ -956,7 +956,7 @@ mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower)
|
||||
|
||||
# def _getslice_lower(ctx, x, lo, hi):
|
||||
# aval_out, = ctx.avals_out
|
||||
# return mhlo.RealDynamicSliceOp(
|
||||
# return hlo.RealDynamicSliceOp(
|
||||
# mlir.aval_to_ir_type(aval_out), x,
|
||||
# mlir.shape_tensor([lo]), mlir.shape_tensor([hi]), mlir.shape_tensor([1])
|
||||
# ).results
|
||||
@ -1393,7 +1393,7 @@ def _gather_lower(ctx, operand, indices, *,
|
||||
|
||||
assert mode in (GatherScatterMode.PROMISE_IN_BOUNDS,
|
||||
GatherScatterMode.CLIP), mode
|
||||
dnums = mhlo.GatherDimensionNumbers.get(
|
||||
dnums = hlo.GatherDimensionNumbers.get(
|
||||
collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims),
|
||||
index_vector_dim=len(ctx.avals_in[1].shape) - 1,
|
||||
offset_dims=list(dimension_numbers.offset_dims),
|
||||
@ -1402,7 +1402,7 @@ def _gather_lower(ctx, operand, indices, *,
|
||||
slice_sizes = mlir.eval_dynamic_shape(ctx, slice_sizes)
|
||||
# TODO(burmako): Fix overly conservative type inference of DynamicGatherOp.
|
||||
# For now use the build_generic so that we can specify the result type.
|
||||
# return mhlo.DynamicGatherOp(
|
||||
# return hlo.DynamicGatherOp(
|
||||
# operand, indices, mlir.shape_tensor(slice_sizes),
|
||||
# dnums, indices_are_sorted=ir.BoolAttr.get(indices_are_sorted)).results
|
||||
results = [mlir.aval_to_ir_type(aval_out)]
|
||||
@ -1411,10 +1411,10 @@ def _gather_lower(ctx, operand, indices, *,
|
||||
"dimension_numbers": dnums,
|
||||
"indices_are_sorted": ir.BoolAttr.get(indices_are_sorted)
|
||||
}
|
||||
return mhlo.DynamicGatherOp.build_generic(
|
||||
return hlo.DynamicGatherOp.build_generic(
|
||||
results=results, operands=operands, attributes=attributes).results
|
||||
else:
|
||||
return mhlo.GatherOp(
|
||||
return hlo.GatherOp(
|
||||
operand,
|
||||
indices,
|
||||
dnums,
|
||||
@ -2019,7 +2019,7 @@ def _scatter_lower(ctx, operand, indices, updates, *,
|
||||
|
||||
aval_out, = ctx.avals_out
|
||||
dnums = dimension_numbers
|
||||
scatter_dnums = mhlo.ScatterDimensionNumbers.get(
|
||||
scatter_dnums = hlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
@ -2027,7 +2027,7 @@ def _scatter_lower(ctx, operand, indices, updates, *,
|
||||
result = mlir.aval_to_ir_types(aval_out)
|
||||
operand = [operand]
|
||||
updates = [updates]
|
||||
op = mhlo.ScatterOp(
|
||||
op = hlo.ScatterOp(
|
||||
result,
|
||||
operand,
|
||||
indices,
|
||||
@ -2045,7 +2045,7 @@ def _scatter_lower(ctx, operand, indices, updates, *,
|
||||
update_ctx, update_jaxpr, mlir.TokenSet(), update_consts,
|
||||
(update.arguments[0],), (update.arguments[1],),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
return op.results
|
||||
|
||||
mlir.register_lowering(scatter_p, _scatter_lower)
|
||||
@ -2076,7 +2076,7 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
|
||||
|
||||
aval_out, = ctx.avals_out
|
||||
dnums = dimension_numbers
|
||||
scatter_dnums = mhlo.ScatterDimensionNumbers.get(
|
||||
scatter_dnums = hlo.ScatterDimensionNumbers.get(
|
||||
update_window_dims=list(dnums.update_window_dims),
|
||||
inserted_window_dims=list(dnums.inserted_window_dims),
|
||||
scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims),
|
||||
@ -2089,7 +2089,7 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
|
||||
operand_part = [operand_part]
|
||||
updates_part = [updates_part]
|
||||
|
||||
scatter = mhlo.ScatterOp(
|
||||
scatter = hlo.ScatterOp(
|
||||
operand_type_part,
|
||||
operand_part,
|
||||
indices,
|
||||
@ -2100,13 +2100,13 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates,
|
||||
scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), real_dtype))
|
||||
reducer = scatter.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer):
|
||||
add = mhlo.AddOp(*reducer.arguments).result
|
||||
mhlo.ReturnOp([add])
|
||||
add = hlo.AddOp(*reducer.arguments).result
|
||||
hlo.ReturnOp([add])
|
||||
return scatter.result
|
||||
|
||||
real = _scatter(mhlo.RealOp(operand).result, mhlo.RealOp(updates).result)
|
||||
imag = _scatter(mhlo.ImagOp(operand).result, mhlo.ImagOp(updates).result)
|
||||
return mhlo.ComplexOp(real, imag).results
|
||||
real = _scatter(hlo.RealOp(operand).result, hlo.RealOp(updates).result)
|
||||
imag = _scatter(hlo.ImagOp(operand).result, hlo.ImagOp(updates).result)
|
||||
return hlo.ComplexOp(real, imag).results
|
||||
|
||||
mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu")
|
||||
|
||||
|
@ -33,7 +33,7 @@ import jax._src.lax.lax as lax
|
||||
import jax._src.lax.convolution as convolution
|
||||
import jax._src.lax.slicing as slicing
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy.ufuncs import logaddexp
|
||||
import jax._src.util as util
|
||||
|
||||
@ -316,7 +316,7 @@ def _generic_reduce_window_lower(ctx, *args, jaxpr, consts,
|
||||
operands, init_values = util.split_list(args, [len(args) // 2])
|
||||
_, init_value_avals = util.split_list(ctx.avals_in, [len(operands)])
|
||||
scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals]
|
||||
rw = mhlo.ReduceWindowOp(
|
||||
rw = hlo.ReduceWindowOp(
|
||||
map(mlir.aval_to_ir_type, ctx.avals_out),
|
||||
operands,
|
||||
init_values,
|
||||
@ -333,7 +333,7 @@ def _generic_reduce_window_lower(ctx, *args, jaxpr, consts,
|
||||
out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr,
|
||||
mlir.TokenSet(), consts, *([a] for a in reducer.arguments),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
return rw.results
|
||||
|
||||
mlir.register_lowering(reduce_window_p, _generic_reduce_window_lower)
|
||||
@ -468,7 +468,7 @@ def _reduce_window_lower(
|
||||
operand_aval, = ctx.avals_in
|
||||
scalar_aval = operand_aval.update(shape=())
|
||||
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
||||
rw = mhlo.ReduceWindowOp(
|
||||
rw = hlo.ReduceWindowOp(
|
||||
mlir.aval_to_ir_types(aval_out), [operand],
|
||||
[mlir.full_like_aval(ctx, init_value(scalar_aval.dtype), scalar_aval)],
|
||||
mlir.dense_int_elements(window_dimensions),
|
||||
@ -479,15 +479,15 @@ def _reduce_window_lower(
|
||||
shape=(len(padding), 2)))
|
||||
reducer = rw.regions[0].blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(reducer):
|
||||
mhlo.ReturnOp(reduce_op(*reducer.arguments))
|
||||
hlo.ReturnOp(reduce_op(*reducer.arguments))
|
||||
return rw.results
|
||||
|
||||
mlir.register_lowering(reduce_window_sum_p, partial(
|
||||
_reduce_window_lower, mhlo.AddOp, lambda _: 0))
|
||||
_reduce_window_lower, hlo.AddOp, lambda _: 0))
|
||||
mlir.register_lowering(reduce_window_min_p, partial(
|
||||
_reduce_window_lower, mlir.min_mhlo, lax._get_min_identity))
|
||||
_reduce_window_lower, mlir.min_hlo, lax._get_min_identity))
|
||||
mlir.register_lowering(reduce_window_max_p, partial(
|
||||
_reduce_window_lower, mlir.max_mhlo, lax._get_max_identity))
|
||||
_reduce_window_lower, mlir.max_hlo, lax._get_max_identity))
|
||||
|
||||
|
||||
|
||||
@ -514,7 +514,7 @@ def _select_and_scatter_lower(
|
||||
aval_out, = ctx.avals_out
|
||||
scalar_aval = operand_aval.update(shape=())
|
||||
scalar_type = mlir.aval_to_ir_type(scalar_aval)
|
||||
op = mhlo.SelectAndScatterOp(
|
||||
op = hlo.SelectAndScatterOp(
|
||||
mlir.aval_to_ir_type(aval_out),
|
||||
operand,
|
||||
source,
|
||||
@ -531,7 +531,7 @@ def _select_and_scatter_lower(
|
||||
mlir.TokenSet(), select_consts,
|
||||
*([a] for a in select.arguments),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
scatter = op.scatter.blocks.append(scalar_type, scalar_type)
|
||||
with ir.InsertionPoint(scatter):
|
||||
if scatter_jaxpr.effects:
|
||||
@ -540,7 +540,7 @@ def _select_and_scatter_lower(
|
||||
mlir.TokenSet(), scatter_consts,
|
||||
*([a] for a in scatter.arguments),
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
mhlo.ReturnOp(util.flatten(out_nodes))
|
||||
hlo.ReturnOp(util.flatten(out_nodes))
|
||||
return op.results
|
||||
|
||||
mlir.register_lowering(select_and_scatter_p, _select_and_scatter_lower)
|
||||
@ -670,7 +670,7 @@ def _select_and_gather_add_lowering(
|
||||
canonicalize_types=False)
|
||||
|
||||
def _broadcast(x, dims):
|
||||
return mhlo.BroadcastOp(x, mlir.dense_int_elements(dims))
|
||||
return hlo.BroadcastOp(x, mlir.dense_int_elements(dims))
|
||||
|
||||
if double_word_reduction:
|
||||
# TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so
|
||||
@ -685,28 +685,28 @@ def _select_and_gather_add_lowering(
|
||||
def pack(a, b):
|
||||
a_dims = ir.RankedTensorType(a.type).shape
|
||||
b_dims = ir.RankedTensorType(b.type).shape
|
||||
a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
|
||||
b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
|
||||
a = mhlo.ConvertOp(ir.RankedTensorType.get(a_dims, double_word_type), a)
|
||||
b = mhlo.ConvertOp(ir.RankedTensorType.get(b_dims, double_word_type), b)
|
||||
a = mhlo.ShiftLeftOp(a,
|
||||
a = hlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
|
||||
b = hlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
|
||||
a = hlo.ConvertOp(ir.RankedTensorType.get(a_dims, double_word_type), a)
|
||||
b = hlo.ConvertOp(ir.RankedTensorType.get(b_dims, double_word_type), b)
|
||||
a = hlo.ShiftLeftOp(a,
|
||||
_broadcast(const(double_word_dtype, nbits), a_dims))
|
||||
return mhlo.OrOp(a, b)
|
||||
return hlo.OrOp(a, b)
|
||||
|
||||
# Unpacks the first element of a tuple.
|
||||
def fst(t):
|
||||
dims = ir.RankedTensorType(t.type).shape
|
||||
st = mhlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits))
|
||||
return mhlo.BitcastConvertOp(
|
||||
st = hlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits))
|
||||
return hlo.BitcastConvertOp(
|
||||
ir.RankedTensorType.get(dims, etype),
|
||||
mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), st)).result
|
||||
hlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), st)).result
|
||||
|
||||
# Unpacks the second element of a tuple.
|
||||
def snd(t):
|
||||
dims = ir.RankedTensorType(t.type).shape
|
||||
return mhlo.BitcastConvertOp(
|
||||
return hlo.BitcastConvertOp(
|
||||
ir.RankedTensorType.get(dims, etype),
|
||||
mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), t)).result
|
||||
hlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), t)).result
|
||||
|
||||
else:
|
||||
# The double-word trick above only works if we have a sufficiently large
|
||||
@ -729,33 +729,33 @@ def _select_and_gather_add_lowering(
|
||||
def pack(a, b):
|
||||
a_dims = ir.RankedTensorType(a.type).shape
|
||||
b_dims = ir.RankedTensorType(b.type).shape
|
||||
a = mhlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp),
|
||||
a = hlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
b = mhlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp),
|
||||
b = hlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp),
|
||||
mantissa_bits=mlir.i32_attr(nmant))
|
||||
a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
|
||||
b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
|
||||
b = mhlo.ShiftRightLogicalOp(
|
||||
a = hlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a)
|
||||
b = hlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b)
|
||||
b = hlo.ShiftRightLogicalOp(
|
||||
b, _broadcast(const(word_dtype, r_nbits), b_dims))
|
||||
return mhlo.OrOp(a, b)
|
||||
return hlo.OrOp(a, b)
|
||||
|
||||
# Unpacks the first element of a tuple.
|
||||
def fst(t):
|
||||
st = mhlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits))
|
||||
return mhlo.BitcastConvertOp(ir.RankedTensorType.get([], etype),
|
||||
st = hlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits))
|
||||
return hlo.BitcastConvertOp(ir.RankedTensorType.get([], etype),
|
||||
st).result
|
||||
|
||||
# Unpacks the second element of a tuple.
|
||||
def snd(t):
|
||||
dims = ir.RankedTensorType(t.type).shape
|
||||
return mhlo.BitcastConvertOp(
|
||||
return hlo.BitcastConvertOp(
|
||||
ir.RankedTensorType.get(dims, etype),
|
||||
mhlo.ShiftLeftOp(t, _broadcast(const(word_dtype, r_nbits), dims))
|
||||
hlo.ShiftLeftOp(t, _broadcast(const(word_dtype, r_nbits), dims))
|
||||
).result
|
||||
|
||||
assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim
|
||||
init = -np.inf if select_prim is lax.ge_p else np.inf
|
||||
rw = mhlo.ReduceWindowOp(
|
||||
rw = hlo.ReduceWindowOp(
|
||||
[ir.RankedTensorType.get(out_aval.shape, double_word_type)],
|
||||
pack(operand, tangents),
|
||||
pack(const(dtype, init), const(dtype, 0)),
|
||||
@ -771,8 +771,8 @@ def _select_and_gather_add_lowering(
|
||||
x, y = reducer.arguments
|
||||
assert select_prim is lax.ge_p or select_prim is lax.le_p
|
||||
which = "GE" if select_prim is lax.ge_p else "LE"
|
||||
out = mhlo.SelectOp(mlir.compare_mhlo(fst(x), fst(y), which), x, y)
|
||||
mhlo.ReturnOp(out)
|
||||
out = hlo.SelectOp(mlir.compare_hlo(fst(x), fst(y), which), x, y)
|
||||
hlo.ReturnOp(out)
|
||||
return [snd(rw.result)]
|
||||
|
||||
# TODO(phawkins): use this translation rule on all platforms.
|
||||
|
@ -23,3 +23,8 @@ import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor
|
||||
from jax.lib import xla_client
|
||||
if xla_client.mlir_api_version >= 37:
|
||||
import jaxlib.mlir.dialects.stablehlo as stablehlo
|
||||
|
||||
# Alias that is set up to abstract away the transition from MHLO to StableHLO.
|
||||
# At the moment, it points to MHLO, but in the future it will start to
|
||||
# conditionally and then unconditionally point to StableHLO.
|
||||
import jaxlib.mlir.dialects.mhlo as hlo
|
||||
|
@ -280,7 +280,7 @@ def canonicalize_platform(platform: str) -> str:
|
||||
|
||||
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
|
||||
hardware is actually present. We want to distinguish "cuda" and "rocm" for
|
||||
purposes such as MHLO lowering rules, but in many cases we don't want to
|
||||
purposes such as MLIR lowering rules, but in many cases we don't want to
|
||||
force users to care.
|
||||
"""
|
||||
platforms = _alias_to_platforms.get(platform, None)
|
||||
|
@ -40,7 +40,7 @@ from jax._src import dtypes
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import utils as lax_utils
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy import lax_numpy
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
|
||||
@ -443,7 +443,7 @@ class KeyTyRules:
|
||||
key_shape = aval_out.dtype.impl.key_shape
|
||||
trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))]
|
||||
perm = [*permutation, *trailing_dims]
|
||||
return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).result
|
||||
return hlo.TransposeOp(x, mlir.dense_int_elements(perm)).result
|
||||
|
||||
@staticmethod
|
||||
def gather_mlir(ctx, avals_in, aval_out, x, indices, *,
|
||||
@ -1041,27 +1041,27 @@ def bcast_iotas_to_reshaped_iota(add, mul, shape, iotas):
|
||||
|
||||
def iota_2x32_shape_lowering(ctx, *, shape):
|
||||
def _add(x, y):
|
||||
return mlir.mhlo.AddOp(x, y).result
|
||||
return mlir.hlo.AddOp(x, y).result
|
||||
|
||||
def _mul(x, y):
|
||||
x_const = mlir.ir_constant(np.array(x, np.dtype('uint64')),
|
||||
canonicalize_types=False)
|
||||
x_bcast = mlir.mhlo.BroadcastOp(x_const, mlir.dense_int_elements(shape))
|
||||
return mlir.mhlo.MulOp(x_bcast, y).result
|
||||
x_bcast = mlir.hlo.BroadcastOp(x_const, mlir.dense_int_elements(shape))
|
||||
return mlir.hlo.MulOp(x_bcast, y).result
|
||||
|
||||
assert len(shape) > 0
|
||||
aval_out, _ = ctx.avals_out
|
||||
aval_u64 = core.ShapedArray(shape, np.dtype('uint64'))
|
||||
iotas = [mlir.mhlo.IotaOp(mlir.aval_to_ir_type(aval_u64),
|
||||
iotas = [mlir.hlo.IotaOp(mlir.aval_to_ir_type(aval_u64),
|
||||
mlir.i64_attr(dimension)).result
|
||||
for dimension in range(len(shape))]
|
||||
counts = bcast_iotas_to_reshaped_iota(_add, _mul, shape, iotas)
|
||||
shift = mlir.ir_constant(np.array(32, np.dtype('uint64')),
|
||||
canonicalize_types=False)
|
||||
shift = mlir.mhlo.BroadcastOp(shift, mlir.dense_int_elements(shape)).result
|
||||
counts_shifted = mlir.mhlo.ShiftRightLogicalOp(counts, shift).result
|
||||
counts_lo = mlir.mhlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts).result
|
||||
counts_hi = mlir.mhlo.ConvertOp(mlir.aval_to_ir_type(aval_out),
|
||||
shift = mlir.hlo.BroadcastOp(shift, mlir.dense_int_elements(shape)).result
|
||||
counts_shifted = mlir.hlo.ShiftRightLogicalOp(counts, shift).result
|
||||
counts_lo = mlir.hlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts).result
|
||||
counts_hi = mlir.hlo.ConvertOp(mlir.aval_to_ir_type(aval_out),
|
||||
counts_shifted).result
|
||||
return counts_hi, counts_lo
|
||||
mlir.register_lowering(iota_2x32_shape_p, iota_2x32_shape_lowering)
|
||||
|
@ -20,7 +20,7 @@ from jax import tree_util
|
||||
from jax import linear_util as lu
|
||||
from jax.experimental import pjit
|
||||
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir import ir
|
||||
import jax.interpreters.pxla as pxla
|
||||
from jax.interpreters import mlir
|
||||
@ -245,7 +245,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
else:
|
||||
out_type = [ir.TupleType.get_tuple(mlir_shapes)]
|
||||
|
||||
out = mhlo.CustomCallOp(
|
||||
out = hlo.CustomCallOp(
|
||||
out_type,
|
||||
list(values),
|
||||
call_target_name=ir.StringAttr.get(_CUSTOM_PARTITIONING_CALL_NAME),
|
||||
@ -259,7 +259,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values,
|
||||
return [out.result]
|
||||
else:
|
||||
return [
|
||||
mhlo.GetTupleElementOp(out, mlir.i32_attr(i)).result
|
||||
hlo.GetTupleElementOp(out, mlir.i32_attr(i)).result
|
||||
for i in range(len(mlir_shapes))
|
||||
]
|
||||
|
||||
|
@ -524,7 +524,7 @@ from jax._src.lib import pytree
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -1143,14 +1143,14 @@ def _outside_call_lowering(
|
||||
assert has_token
|
||||
current_token = args[-2]
|
||||
current_itoken = args[-1]
|
||||
assert current_token.type == mhlo.TokenType.get(), "The last two arguments must be tokens"
|
||||
assert current_itoken.type == mhlo.TokenType.get(), "The last two arguments must be tokens"
|
||||
assert current_token.type == hlo.TokenType.get(), "The last two arguments must be tokens"
|
||||
assert current_itoken.type == hlo.TokenType.get(), "The last two arguments must be tokens"
|
||||
|
||||
args_to_outfeed = args[:-2]
|
||||
# TODO(necula): this is a weak attempt to get the device. This works
|
||||
# inside pmap, but does not work when we just execute on a single device,
|
||||
# because in such executions we always get replica_id == 0.
|
||||
replica_id = mhlo.ReplicaIdOp()
|
||||
replica_id = hlo.ReplicaIdOp()
|
||||
callback_operands = [replica_id, *args_to_outfeed]
|
||||
callback_operand_avals = [
|
||||
core.ShapedArray((), np.uint32), *ctx.avals_in[:-2]]
|
||||
|
@ -39,7 +39,7 @@ from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib import xla_client
|
||||
from jax.experimental.jax2tf import jax2tf as jax2tf_internal
|
||||
|
||||
@ -392,16 +392,16 @@ def _code_generator_and_avals(
|
||||
captured_ops = tuple(mlir.ir_constant(np.asarray(inp),
|
||||
canonicalize_types=False)
|
||||
for inp in captured_inputs)
|
||||
submodule = mlir.xla_computation_to_mhlo_module(xla_comp)
|
||||
submodule = mlir.xla_computation_to_mlir_module(xla_comp)
|
||||
symtab = ir.SymbolTable(submodule.operation)
|
||||
callee_result_types = symtab["main"].type.results
|
||||
fn = mlir.merge_mhlo_modules(ctx.module, f"call_tf_{function_flat_tf.name}",
|
||||
fn = mlir.merge_mlir_modules(ctx.module, f"call_tf_{function_flat_tf.name}",
|
||||
submodule)
|
||||
call = func_dialect.CallOp(callee_result_types,
|
||||
ir.FlatSymbolRefAttr.get(fn),
|
||||
tuple(args_op) + captured_ops)
|
||||
if result_shape.is_tuple():
|
||||
flat_results = [mhlo.GetTupleElementOp(call, mlir.i32_attr(i)).result
|
||||
flat_results = [hlo.GetTupleElementOp(call, mlir.i32_attr(i)).result
|
||||
for i in range(len(result_shapes))]
|
||||
else:
|
||||
flat_results = call.results
|
||||
@ -410,7 +410,7 @@ def _code_generator_and_avals(
|
||||
for op, res_aval, res_shape in zip(flat_results, result_avals,
|
||||
result_shapes):
|
||||
if res_aval.dtype != res_shape.numpy_dtype():
|
||||
op = mhlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result
|
||||
op = hlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result
|
||||
outputs.append(op)
|
||||
return outputs
|
||||
|
||||
|
@ -589,7 +589,7 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
|
||||
Work-in-progress.
|
||||
|
||||
Uses JAX native lowering to MHLO, and then wraps the result in a
|
||||
Uses JAX native lowering to MLIR, and then wraps the result in a
|
||||
XlaCallModule TF op. This op does not have backward-compatibility yet.
|
||||
|
||||
Special care must be taken in presence of shape polymorphism.
|
||||
@ -634,13 +634,13 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
fun_jax_lower = fun_jax.lower
|
||||
lowered = fun_jax_lower(*arg_specs_jax)._lowering
|
||||
if config.jax2tf_use_stablehlo:
|
||||
mhlo_module = lowered.stablehlo()
|
||||
mlir_module = lowered.stablehlo()
|
||||
xla_call_module_version = 2
|
||||
else:
|
||||
mhlo_module = lowered.mhlo()
|
||||
mlir_module = lowered.mhlo()
|
||||
xla_call_module_version = 1
|
||||
|
||||
mhlo_serialized_module = mlir.module_to_bytecode(mhlo_module)
|
||||
mlir_serialized_module = mlir.module_to_bytecode(mlir_module)
|
||||
# Figure out the result types and shapes
|
||||
if "global_out_avals" in lowered.compile_args:
|
||||
# This is currently the case for pjit
|
||||
@ -719,7 +719,7 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
args_tf = [atf for i, atf in enumerate(args_tf) if i in module_kept_var_idx]
|
||||
|
||||
# Apply the shardings on arguments and results for pjit. This is redundant
|
||||
# because the mhlo_module_text will already contain the shardings, but it
|
||||
# because the mlir_module_text will already contain the shardings, but it
|
||||
# makes it easier for tools like the TPU inference converter to see the
|
||||
# sharding without digging into the `module` attribute of the `XlaCallModule`
|
||||
# op, in the same way as it is done for the legacy jax2tf conversion.
|
||||
@ -728,14 +728,14 @@ def _lower_native_and_run(fun_jax: Callable,
|
||||
map(_shard_value, args_tf, args_avals, lowered.compile_args["in_shardings"]))
|
||||
|
||||
if logging.vlog_is_on(3):
|
||||
mhlo_module_text = mlir.module_to_string(mhlo_module)
|
||||
mlir_module_text = mlir.module_to_string(mlir_module)
|
||||
logging.vlog(3, "XlaCallModule (version=%d, dim_args_spec=%s)\n%s",
|
||||
xla_call_module_version, ", ".join(dim_args_spec),
|
||||
mhlo_module_text)
|
||||
mlir_module_text)
|
||||
res = tfxla.call_module(
|
||||
args_tf,
|
||||
version=xla_call_module_version,
|
||||
module=mhlo_serialized_module,
|
||||
module=mlir_serialized_module,
|
||||
Tout=out_types,
|
||||
Sout=out_shapes,
|
||||
dim_args_spec=dim_args_spec)
|
||||
|
@ -1479,7 +1479,7 @@ class XlaCallModuleTest(tf_test_util.JaxToTfTestCase):
|
||||
self.assertAllClose(tf.nest.map_structure(lambda t: t.numpy(), res),
|
||||
jax_res)
|
||||
|
||||
@unittest.skip("TODO(necula): 'mhlo.dynamic_iota' op can't be translated to XLA HLO")
|
||||
@unittest.skip("TODO(necula): 'dynamic_iota' op can't be translated to XLA HLO")
|
||||
def test_shape_poly_arange(self):
|
||||
if not config.jax_dynamic_shapes:
|
||||
raise unittest.SkipTest("jax_dynamic_shapes must be enabled")
|
||||
|
@ -1176,7 +1176,7 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings,
|
||||
name_stack=xla.extend_name_stack(ctx.module_context.name_stack,
|
||||
wrap_name(name, "pjit")))
|
||||
# TODO(b/228598865): inlined calls cannot have shardings set directly on the
|
||||
# inputs or outputs because they are lost during MHLO->HLO conversion.
|
||||
# inputs or outputs because they are lost during MLIR->HLO conversion.
|
||||
# using_sharding_annotation=False means we add an identity operation instead.
|
||||
func = mlir.lower_jaxpr_to_fun(sub_ctx, f"pjit_{name}", jaxpr, (),
|
||||
arg_shardings=arg_shardings,
|
||||
@ -1544,7 +1544,7 @@ sharding_constraint_p.def_abstract_eval(lambda x, **_: x)
|
||||
ad.deflinear2(sharding_constraint_p,
|
||||
lambda ct, _, **params: (sharding_constraint_p.bind(ct, **params),))
|
||||
|
||||
def _sharding_constraint_mhlo_lowering(ctx, x_node, *, sharding,
|
||||
def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding,
|
||||
resource_env, unconstrained_dims):
|
||||
aval, = ctx.avals_in
|
||||
axis_ctx = ctx.module_context.axis_context
|
||||
@ -1564,7 +1564,7 @@ def _sharding_constraint_mhlo_lowering(ctx, x_node, *, sharding,
|
||||
unspecified_dims=unconstrained_dims)
|
||||
]
|
||||
mlir.register_lowering(sharding_constraint_p,
|
||||
_sharding_constraint_mhlo_lowering)
|
||||
_sharding_constraint_hlo_lowering)
|
||||
|
||||
|
||||
def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size,
|
||||
|
@ -48,7 +48,7 @@ from jax._src.lax.slicing import GatherDimensionNumbers, GatherScatterMode
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy.setops import _unique
|
||||
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
|
||||
from jax._src.util import canonicalize_axis
|
||||
@ -728,13 +728,13 @@ def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_num
|
||||
_bcoo_dot_general_default_lowering = mlir.lower_fun(
|
||||
_bcoo_dot_general_impl, multiple_results=False)
|
||||
|
||||
def _collapse_mhlo(x, start, end):
|
||||
def _collapse_hlo(x, start, end):
|
||||
x_type = ir.RankedTensorType(x.type)
|
||||
shape = x_type.shape
|
||||
shape = (shape[:start]
|
||||
+ [functools.reduce(operator.mul, shape[start:end + 1])]
|
||||
+ shape[end + 1:])
|
||||
return mhlo.ReshapeOp(
|
||||
return hlo.ReshapeOp(
|
||||
ir.RankedTensorType.get(shape, x_type.element_type), x).result
|
||||
|
||||
def _bcoo_dot_general_cuda_lowering(
|
||||
@ -766,7 +766,7 @@ def _bcoo_dot_general_cuda_lowering(
|
||||
elif rhs_ndim == 2:
|
||||
bcoo_dot_general_fn = coo_matmat_lowering
|
||||
if rhs_contract[0] == 1:
|
||||
rhs = mhlo.TransposeOp(
|
||||
rhs = hlo.TransposeOp(
|
||||
rhs, permutation=mlir.dense_int_elements([1, 0])).result
|
||||
else:
|
||||
raise ValueError(f"rhs has to be 1d or 2d; get {rhs_ndim}d.")
|
||||
@ -776,7 +776,7 @@ def _bcoo_dot_general_cuda_lowering(
|
||||
lhs_transpose = False
|
||||
if props.n_sparse == 1:
|
||||
# Converts lhs to a row vector.
|
||||
col = _collapse_mhlo(lhs_indices, start=0, end=1)
|
||||
col = _collapse_hlo(lhs_indices, start=0, end=1)
|
||||
row = mlir.full_like_aval(
|
||||
ctx, 0, core.ShapedArray(ir.RankedTensorType(col.type).shape,
|
||||
np.dtype(np.int32)))
|
||||
@ -788,23 +788,23 @@ def _bcoo_dot_general_cuda_lowering(
|
||||
|
||||
if rhs_ndim == 1:
|
||||
# Transforms a single-element array to a scalar.
|
||||
return [mhlo.ReshapeOp(
|
||||
return [hlo.ReshapeOp(
|
||||
ir.RankedTensorType.get(
|
||||
[], ir.RankedTensorType(dot_product.type).element_type),
|
||||
dot_product).result]
|
||||
else:
|
||||
return [_collapse_mhlo(dot_product, start=0, end=1)]
|
||||
return [_collapse_hlo(dot_product, start=0, end=1)]
|
||||
elif props.n_sparse == 2:
|
||||
lhs_indices_shape = ir.RankedTensorType(lhs_indices.type).shape
|
||||
row = _collapse_mhlo(
|
||||
mhlo.SliceOp(
|
||||
row = _collapse_hlo(
|
||||
hlo.SliceOp(
|
||||
lhs_indices,
|
||||
start_indices=mlir.dense_int_elements([0, 0]),
|
||||
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 1]),
|
||||
strides=mlir.dense_int_elements([1, 1])).result,
|
||||
start=0, end=1)
|
||||
col = _collapse_mhlo(
|
||||
mhlo.SliceOp(
|
||||
col = _collapse_hlo(
|
||||
hlo.SliceOp(
|
||||
lhs_indices,
|
||||
start_indices=mlir.dense_int_elements([0, 1]),
|
||||
limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 2]),
|
||||
@ -833,28 +833,28 @@ def _bcoo_dot_general_cuda_lowering(
|
||||
lhs_indices_shape[-1])
|
||||
lhs_data_1d_shape = (np.prod(np.array(lhs_data_shape)), )
|
||||
|
||||
lhs_indices_2d = mhlo.ReshapeOp(
|
||||
lhs_indices_2d = hlo.ReshapeOp(
|
||||
ir.RankedTensorType.get(
|
||||
lhs_indices_2d_shape,
|
||||
ir.RankedTensorType(lhs_indices.type).element_type),
|
||||
lhs_indices).result
|
||||
|
||||
lhs_data_1d = mhlo.ReshapeOp(
|
||||
lhs_data_1d = hlo.ReshapeOp(
|
||||
ir.RankedTensorType.get(
|
||||
lhs_data_1d_shape,
|
||||
ir.RankedTensorType(lhs_data.type).element_type),
|
||||
lhs_data).result
|
||||
|
||||
row = _collapse_mhlo(
|
||||
mhlo.SliceOp(
|
||||
row = _collapse_hlo(
|
||||
hlo.SliceOp(
|
||||
lhs_indices_2d,
|
||||
start_indices=mlir.dense_int_elements([0, 0]),
|
||||
limit_indices=mlir.dense_int_elements([lhs_indices_2d_shape[0], 1]),
|
||||
strides=mlir.dense_int_elements([1, 1])).result,
|
||||
start=0, end=1)
|
||||
|
||||
col = _collapse_mhlo(
|
||||
mhlo.SliceOp(
|
||||
col = _collapse_hlo(
|
||||
hlo.SliceOp(
|
||||
lhs_indices_2d,
|
||||
start_indices=mlir.dense_int_elements([0, 1]),
|
||||
limit_indices=mlir.dense_int_elements([lhs_indices_2d_shape[0], 2]),
|
||||
@ -867,13 +867,13 @@ def _bcoo_dot_general_cuda_lowering(
|
||||
# The issue (https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643)
|
||||
# in cusparse library does not allow batch_stride = 0 for a non-batched rhs.
|
||||
batched_rhs_shape = (batch_count,) + tuple(rhs_shape)
|
||||
batched_rhs = mhlo.BroadcastInDimOp(
|
||||
batched_rhs = hlo.BroadcastInDimOp(
|
||||
ir.RankedTensorType.get(batched_rhs_shape,
|
||||
ir.RankedTensorType(rhs.type).element_type),
|
||||
rhs,
|
||||
broadcast_dimensions=mlir.dense_int_elements([1, 2])).result
|
||||
batched_rhs_2d_shape = (np.prod(np.array(batched_rhs_shape)[:-1]), batched_rhs_shape[-1])
|
||||
batched_rhs_2d = mhlo.ReshapeOp(
|
||||
batched_rhs_2d = hlo.ReshapeOp(
|
||||
ir.RankedTensorType.get(
|
||||
batched_rhs_2d_shape,
|
||||
ir.RankedTensorType(batched_rhs.type).element_type),
|
||||
@ -1404,12 +1404,12 @@ def _bcoo_sort_indices_jvp(primals, tangents, *, spinfo):
|
||||
data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot, perm)
|
||||
return (data_out, indices_out), (data_dot_out, indices_dot_out)
|
||||
|
||||
_bcoo_sort_indices_mhlo = mlir.lower_fun(
|
||||
_bcoo_sort_indices_hlo = mlir.lower_fun(
|
||||
_bcoo_sort_indices_impl, multiple_results=True)
|
||||
|
||||
ad.primitive_jvps[bcoo_sort_indices_p] = _bcoo_sort_indices_jvp
|
||||
batching.primitive_batchers[bcoo_sort_indices_p] = _bcoo_sort_indices_batching_rule
|
||||
mlir.register_lowering(bcoo_sort_indices_p, _bcoo_sort_indices_mhlo)
|
||||
mlir.register_lowering(bcoo_sort_indices_p, _bcoo_sort_indices_hlo)
|
||||
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
@ -1560,12 +1560,12 @@ def _bcoo_sum_duplicates_jvp(primals, tangents, *, spinfo, nse):
|
||||
data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot_out, mapping, data_dot)
|
||||
return (data_out, indices_out), (data_dot_out, indices_dot_out)
|
||||
|
||||
_bcoo_sum_duplicates_mhlo = mlir.lower_fun(
|
||||
_bcoo_sum_duplicates_hlo = mlir.lower_fun(
|
||||
_bcoo_sum_duplicates_impl, multiple_results=True)
|
||||
|
||||
ad.primitive_jvps[bcoo_sum_duplicates_p] = _bcoo_sum_duplicates_jvp
|
||||
batching.primitive_batchers[bcoo_sum_duplicates_p] = _bcoo_sum_duplicates_batching_rule
|
||||
mlir.register_lowering(bcoo_sum_duplicates_p, _bcoo_sum_duplicates_mhlo)
|
||||
mlir.register_lowering(bcoo_sum_duplicates_p, _bcoo_sum_duplicates_hlo)
|
||||
|
||||
#----------------------------------------------------------------------
|
||||
# BCOO functions that maybe should be primitives?
|
||||
|
@ -30,7 +30,7 @@ from jax.experimental.sparse._base import JAXSparse
|
||||
from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning
|
||||
from jax import tree_util
|
||||
from jax._src.lax.lax import _const
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.numpy.lax_numpy import _promote_dtypes
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
@ -199,7 +199,7 @@ def _coo_todense_abstract_eval(data, row, col, *, spinfo):
|
||||
_coo_todense_lowering = mlir.lower_fun(
|
||||
_coo_todense_impl, multiple_results=False)
|
||||
|
||||
def _coo_todense_gpu_lowering(coo_todense_mhlo, ctx, data, row, col, *, spinfo):
|
||||
def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo):
|
||||
data_aval, row_aval, _ = ctx.avals_in
|
||||
dtype = data_aval.dtype
|
||||
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
|
||||
@ -220,10 +220,10 @@ def _coo_todense_gpu_lowering(coo_todense_mhlo, ctx, data, row, col, *, spinfo):
|
||||
"back to the default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo)
|
||||
|
||||
result = coo_todense_mhlo(
|
||||
result = coo_todense_hlo(
|
||||
data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype)
|
||||
return (
|
||||
[mhlo.TransposeOp(result, mlir.dense_int_elements([1, 0])).result]
|
||||
[hlo.TransposeOp(result, mlir.dense_int_elements([1, 0])).result]
|
||||
if transpose else [result])
|
||||
|
||||
|
||||
@ -318,14 +318,14 @@ def _coo_fromdense_abstract_eval(mat, *, nse, index_dtype):
|
||||
_coo_fromdense_lowering = mlir.lower_fun(
|
||||
_coo_fromdense_impl, multiple_results=True)
|
||||
|
||||
def _coo_fromdense_gpu_lowering(coo_fromdense_mhlo, ctx, mat, *, nse,
|
||||
def _coo_fromdense_gpu_lowering(coo_fromdense_hlo, ctx, mat, *, nse,
|
||||
index_dtype):
|
||||
dtype = ctx.avals_in[0].dtype
|
||||
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
|
||||
warnings.warn(f"coo_fromdense cusparse/hipsparse lowering not available for {dtype=}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _coo_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype)
|
||||
data, row, col = coo_fromdense_mhlo(
|
||||
data, row, col = coo_fromdense_hlo(
|
||||
mat, nnz=nse,
|
||||
data_dtype=dtype,
|
||||
index_dtype=np.dtype(index_dtype),
|
||||
@ -438,7 +438,7 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, spinfo, transpose):
|
||||
_coo_matvec_lowering = mlir.lower_fun(
|
||||
_coo_matvec_impl, multiple_results=False)
|
||||
|
||||
def _coo_matvec_gpu_lowering(coo_matvec_mhlo, ctx, data, row, col, v, *, spinfo,
|
||||
def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo,
|
||||
transpose):
|
||||
data_aval, row_aval, _, x_aval = ctx.avals_in
|
||||
dtype = data_aval.dtype
|
||||
@ -461,7 +461,7 @@ def _coo_matvec_gpu_lowering(coo_matvec_mhlo, ctx, data, row, col, v, *, spinfo,
|
||||
return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo,
|
||||
transpose=transpose)
|
||||
|
||||
return [coo_matvec_mhlo(
|
||||
return [coo_matvec_hlo(
|
||||
data, row, col, v, shape=shape, transpose=transpose,
|
||||
index_dtype=row_aval.dtype, data_dtype=dtype, x_dtype=x_aval.dtype)]
|
||||
|
||||
@ -561,7 +561,7 @@ def _coo_matmat_abstract_eval(data, row, col, B, *, spinfo, transpose):
|
||||
|
||||
_coo_matmat_lowering = mlir.lower_fun(_coo_matmat_impl, multiple_results=False)
|
||||
|
||||
def _coo_matmat_gpu_lowering(coo_matmat_mhlo, ctx, data, row, col, B, *, spinfo,
|
||||
def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo,
|
||||
transpose):
|
||||
data_aval, row_aval, _, B_aval = ctx.avals_in
|
||||
dtype = data_aval.dtype
|
||||
@ -583,7 +583,7 @@ def _coo_matmat_gpu_lowering(coo_matmat_mhlo, ctx, data, row, col, B, *, spinfo,
|
||||
return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo,
|
||||
transpose=transpose)
|
||||
|
||||
return [coo_matmat_mhlo(data, row, col, B, shape=shape,
|
||||
return [coo_matmat_hlo(data, row, col, B, shape=shape,
|
||||
transpose=transpose, x_dtype=B_aval.dtype,
|
||||
data_dtype=data_aval.dtype,
|
||||
index_dtype=row_aval.dtype)]
|
||||
|
@ -227,7 +227,7 @@ def _csr_todense_abstract_eval(data, indices, indptr, *, shape):
|
||||
_csr_todense_lowering = mlir.lower_fun(
|
||||
_csr_todense_impl, multiple_results=False)
|
||||
|
||||
def _csr_todense_gpu_lowering(csr_todense_mhlo, ctx, data, indices, indptr, *,
|
||||
def _csr_todense_gpu_lowering(csr_todense_hlo, ctx, data, indices, indptr, *,
|
||||
shape):
|
||||
data_aval, indices_aval, _ = ctx.avals_in
|
||||
dtype = data_aval.dtype
|
||||
@ -235,7 +235,7 @@ def _csr_todense_gpu_lowering(csr_todense_mhlo, ctx, data, indices, indptr, *,
|
||||
warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for {dtype=}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _csr_todense_lowering(ctx, data, indices, indptr, shape=shape)
|
||||
return [csr_todense_mhlo(
|
||||
return [csr_todense_hlo(
|
||||
data, indices, indptr, shape=shape, data_dtype=dtype,
|
||||
index_dtype=indices_aval.dtype)]
|
||||
|
||||
@ -319,13 +319,13 @@ def _csr_fromdense_abstract_eval(mat, *, nse, index_dtype):
|
||||
_csr_fromdense_lowering = mlir.lower_fun(_csr_fromdense_impl,
|
||||
multiple_results=True)
|
||||
|
||||
def _csr_fromdense_gpu_lowering(csr_fromdense_mhlo, ctx, mat, *, nse, index_dtype):
|
||||
def _csr_fromdense_gpu_lowering(csr_fromdense_hlo, ctx, mat, *, nse, index_dtype):
|
||||
dtype = ctx.avals_in[0].dtype
|
||||
if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)):
|
||||
warnings.warn(f"csr_fromdense cusparse/hipsparse lowering not available for {dtype=}. "
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _csr_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype)
|
||||
data, indices, indptr = csr_fromdense_mhlo(
|
||||
data, indices, indptr = csr_fromdense_hlo(
|
||||
mat, nnz=nse, index_dtype=np.dtype(index_dtype),
|
||||
data_dtype=dtype, index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype)))
|
||||
return [data, indices, indptr]
|
||||
@ -412,7 +412,7 @@ def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose):
|
||||
|
||||
_csr_matvec_lowering = mlir.lower_fun(_csr_matvec_impl, multiple_results=False)
|
||||
|
||||
def _csr_matvec_gpu_lowering(csr_matvec_mhlo, ctx, data, indices, indptr, v, *,
|
||||
def _csr_matvec_gpu_lowering(csr_matvec_hlo, ctx, data, indices, indptr, v, *,
|
||||
shape, transpose):
|
||||
data_aval, indices_aval, _, v_aval = ctx.avals_in
|
||||
dtype = data_aval.dtype
|
||||
@ -421,7 +421,7 @@ def _csr_matvec_gpu_lowering(csr_matvec_mhlo, ctx, data, indices, indptr, v, *,
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _csr_matvec_lowering(ctx, data, indices, indptr, v, shape=shape,
|
||||
transpose=transpose)
|
||||
return [csr_matvec_mhlo(
|
||||
return [csr_matvec_hlo(
|
||||
data, indices, indptr, v, shape=shape, transpose=transpose,
|
||||
data_dtype=dtype, index_dtype=indices_aval.dtype, x_dtype=v_aval.dtype)]
|
||||
|
||||
@ -504,7 +504,7 @@ def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose):
|
||||
|
||||
_csr_matmat_lowering = mlir.lower_fun(_csr_matmat_impl, multiple_results=False)
|
||||
|
||||
def _csr_matmat_gpu_lowering(csr_matmat_mhlo, ctx, data, indices, indptr, B, *,
|
||||
def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *,
|
||||
shape, transpose):
|
||||
data_aval, indices_aval, _, B_aval = ctx.avals_in
|
||||
dtype = data_aval.dtype
|
||||
@ -513,7 +513,7 @@ def _csr_matmat_gpu_lowering(csr_matmat_mhlo, ctx, data, indices, indptr, B, *,
|
||||
"Falling back to default implementation.", CuSparseEfficiencyWarning)
|
||||
return _csr_matmat_lowering(ctx, data, indices, indptr, B, shape=shape,
|
||||
transpose=transpose)
|
||||
return [csr_matmat_mhlo(
|
||||
return [csr_matmat_hlo(
|
||||
data, indices, indptr, B, shape=shape, transpose=transpose,
|
||||
index_dtype=indices_aval.dtype, data_dtype=data_aval.dtype,
|
||||
B_dtype=B_aval.dtype)]
|
||||
|
@ -12,8 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Lowering and execution path that converts jaxprs into the MLIR MHLO/CHLO
|
||||
# dialects.
|
||||
# Lowering and execution path that converts jaxprs into MLIR.
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
@ -35,8 +34,7 @@ from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src.lib import mlir_api_version, xla_extension_version
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import can_execute_with_token
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
@ -86,12 +84,12 @@ def shape_tensor(sizes: Sequence[Union[int, ir.RankedTensorType]]
|
||||
if type(d) is int:
|
||||
return ir_constant(np.array([d], np.int32))
|
||||
else:
|
||||
return mhlo.ReshapeOp(int1d, mhlo.ConvertOp(aval_to_ir_type(core.ShapedArray((), np.int32)), d))
|
||||
return hlo.ReshapeOp(int1d, hlo.ConvertOp(aval_to_ir_type(core.ShapedArray((), np.int32)), d))
|
||||
d, *ds = map(lower_dim, sizes)
|
||||
if not ds:
|
||||
return d
|
||||
else:
|
||||
return mhlo.ConcatenateOp([d, *ds], i64_attr(0)).result
|
||||
return hlo.ConcatenateOp([d, *ds], i64_attr(0)).result
|
||||
|
||||
|
||||
def delegate_lowering(ctx, lowering_fun, *args, **ctx_override_kwargs):
|
||||
@ -162,7 +160,7 @@ def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]:
|
||||
|
||||
ir_type_handlers[core.ShapedArray] = _array_ir_types
|
||||
ir_type_handlers[core.ConcreteArray] = _array_ir_types
|
||||
ir_type_handlers[core.AbstractToken] = lambda _: [mhlo.TokenType.get()]
|
||||
ir_type_handlers[core.AbstractToken] = lambda _: [hlo.TokenType.get()]
|
||||
ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types
|
||||
|
||||
def aval_to_ir_type(aval: core.AbstractValue) -> ir.Type:
|
||||
@ -239,7 +237,7 @@ def _numpy_array_constant(x: np.ndarray, canonicalize_types
|
||||
x = x.view(np.uint16)
|
||||
x = np.ascontiguousarray(x)
|
||||
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
|
||||
return (mhlo.ConstantOp(attr).result,)
|
||||
return (hlo.ConstantOp(attr).result,)
|
||||
|
||||
|
||||
|
||||
@ -272,7 +270,7 @@ def _ndarray_constant_handler(val: np.ndarray, canonicalize_types
|
||||
if canonicalize_types:
|
||||
collapsed_val = np.asarray(
|
||||
collapsed_val, dtypes.canonicalize_dtype(collapsed_val.dtype))
|
||||
out = mhlo.BroadcastInDimOp(
|
||||
out = hlo.BroadcastInDimOp(
|
||||
ir.RankedTensorType.get(
|
||||
val.shape, dtype_to_ir_type(collapsed_val.dtype)),
|
||||
_numpy_array_constant(collapsed_val, canonicalize_types=False)[0],
|
||||
@ -304,9 +302,9 @@ for t in device_array.device_array_types:
|
||||
|
||||
def _token_constant_handler(val, canonicalize_types):
|
||||
if mlir_api_version < 40:
|
||||
return [mhlo.CreateTokenOp(mhlo.TokenType.get()).result]
|
||||
return [hlo.CreateTokenOp(hlo.TokenType.get()).result]
|
||||
else:
|
||||
return [mhlo.CreateTokenOp().result]
|
||||
return [hlo.CreateTokenOp().result]
|
||||
register_constant_handler(core.Token, _token_constant_handler)
|
||||
|
||||
# Source locations
|
||||
@ -331,12 +329,12 @@ def _source_info_to_location(
|
||||
# Translation rules
|
||||
def make_ir_context() -> ir.Context:
|
||||
"""Creates an MLIR context suitable for JAX IR."""
|
||||
from jax._src.lib.mlir import dialects
|
||||
context = ir.Context()
|
||||
mhlo.register_mhlo_dialect(context)
|
||||
chlo.register_dialect(context)
|
||||
dialects.mhlo.register_mhlo_dialect(context)
|
||||
dialects.chlo.register_dialect(context)
|
||||
if mlir_api_version >= 37:
|
||||
from jax._src.lib.mlir.dialects import stablehlo
|
||||
stablehlo.register_dialect(context)
|
||||
dialects.stablehlo.register_dialect(context)
|
||||
return context
|
||||
|
||||
|
||||
@ -581,18 +579,18 @@ class DimPolyEvaluator:
|
||||
def __add__(self, other: Union[np.int32, DimPolyEvaluator]):
|
||||
if not isinstance(other, DimPolyEvaluator):
|
||||
other = DimPolyEvaluator(ir_constant(other))
|
||||
return DimPolyEvaluator(mhlo.AddOp(self.value, other.value).result)
|
||||
return DimPolyEvaluator(hlo.AddOp(self.value, other.value).result)
|
||||
|
||||
def __radd__(self, other: np.int32):
|
||||
return DimPolyEvaluator(mhlo.AddOp(ir_constant(other), self.value).result)
|
||||
return DimPolyEvaluator(hlo.AddOp(ir_constant(other), self.value).result)
|
||||
|
||||
def __mul__(self, other: Union[np.int32, DimPolyEvaluator]):
|
||||
if not isinstance(other, DimPolyEvaluator):
|
||||
other = DimPolyEvaluator(ir_constant(other))
|
||||
return DimPolyEvaluator(mhlo.MulOp(self.value, other.value).result)
|
||||
return DimPolyEvaluator(hlo.MulOp(self.value, other.value).result)
|
||||
|
||||
def __rmul__(self, other: np.int32):
|
||||
return DimPolyEvaluator(mhlo.MulOp(ir_constant(other), self.value).result)
|
||||
return DimPolyEvaluator(hlo.MulOp(ir_constant(other), self.value).result)
|
||||
|
||||
|
||||
def eval_dynamic_shape(ctx: LoweringRuleContext,
|
||||
@ -640,7 +638,7 @@ def lower_jaxpr_to_module(
|
||||
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
|
||||
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None
|
||||
) -> LoweringResult:
|
||||
"""Lowers a top-level jaxpr to an MHLO module.
|
||||
"""Lowers a top-level jaxpr to an MLIR module.
|
||||
|
||||
Handles the quirks of the argument/return value passing conventions of the
|
||||
runtime.
|
||||
@ -678,7 +676,7 @@ def lower_jaxpr_to_module(
|
||||
msg = f"Donation is not implemented for {platform}.\n{msg}"
|
||||
warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
|
||||
|
||||
# MHLO channels need to start at 1
|
||||
# HLO channels need to start at 1
|
||||
channel_iter = itertools.count(1)
|
||||
# Create a keepalives list that will be mutated during the lowering.
|
||||
keepalives: List[Any] = []
|
||||
@ -761,20 +759,20 @@ def _set_up_aliases(avals_in, avals_out, donated_args):
|
||||
Token = Sequence[ir.Value]
|
||||
|
||||
def token_type() -> Sequence[ir.Type]:
|
||||
return [mhlo.TokenType.get()]
|
||||
return [hlo.TokenType.get()]
|
||||
|
||||
def create_token() -> Token:
|
||||
if mlir_api_version < 40:
|
||||
return wrap_singleton_ir_values(
|
||||
mhlo.CreateTokenOp(mhlo.TokenType.get()).result)
|
||||
hlo.CreateTokenOp(hlo.TokenType.get()).result)
|
||||
else:
|
||||
return wrap_singleton_ir_values(mhlo.CreateTokenOp().result)
|
||||
return wrap_singleton_ir_values(hlo.CreateTokenOp().result)
|
||||
|
||||
class TokenSet:
|
||||
"""An immutable container of tokens to be used to lower effectful jaxprs. When lowering
|
||||
effectful jaxprs, we need to thread MHLO tokens to sequence them. Each effect
|
||||
effectful jaxprs, we need to thread HLO tokens to sequence them. Each effect
|
||||
will need its own token that will be threaded in and out of the effectful
|
||||
primitives. A `TokenSet` encapsulates a set of MHLO tokens that will be
|
||||
primitives. A `TokenSet` encapsulates a set of HLO tokens that will be
|
||||
used by the lowering rules.
|
||||
"""
|
||||
_tokens: typing.OrderedDict[core.Effect, Token]
|
||||
@ -850,18 +848,18 @@ def lower_jaxpr_to_fun(
|
||||
jaxpr: the jaxpr to lower.
|
||||
effects: a sequence of `core.Effect`s corresponding to an ordering of tokens
|
||||
that will be created in or used by the lowered function.
|
||||
create_tokens: if true, the MHLO will create tokens and ignore dummy input tokens.
|
||||
create_tokens: if true, the HLO will create tokens and ignore dummy input tokens.
|
||||
public: if true, the function's visibility is set to "public".
|
||||
replace_tokens_with_dummy: if true, token arguments/return values are
|
||||
replaced with bool arrays of size [0].
|
||||
replicated_args: if present, annotates arguments as replicated.
|
||||
arg_shardings: sharding annotations for each argument (optional).
|
||||
result_shardings: sharding annotations for each argument (optional).
|
||||
use_sharding_annotations: if True, use mhlo.sharding annotations on
|
||||
use_sharding_annotations: if True, use "mhlo.sharding" annotations on
|
||||
parameters and return values to express sharding. If False, use
|
||||
mhlo.custom_call operators with sharding annotations.
|
||||
TODO(b/228598865): remove this option when mhlo.sharding annotations are
|
||||
propagated on non-entry functions during MHLO->HLO conversion.
|
||||
hlo.custom_call operators with sharding annotations.
|
||||
TODO(b/228598865): remove this option when "mhlo.sharding" annotations are
|
||||
propagated on non-entry functions during MLIR->HLO conversion.
|
||||
input_output_aliases: optional sequence that maps argument numbers to the
|
||||
corresponding output that should alias them.
|
||||
Returns the name of the function.
|
||||
@ -988,9 +986,9 @@ def lower_jaxpr_to_fun(
|
||||
for aval, arg in zip(jaxpr.in_avals, unflattened_args):
|
||||
if replace_tokens_with_dummy and aval is core.abstract_token:
|
||||
if mlir_api_version < 40:
|
||||
args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results)
|
||||
args.append(hlo.CreateTokenOp(hlo.TokenType.get()).results)
|
||||
else:
|
||||
args.append(mhlo.CreateTokenOp().results)
|
||||
args.append(hlo.CreateTokenOp().results)
|
||||
else:
|
||||
args.append(arg)
|
||||
callee_name_stack = xla.extend_name_stack(ctx.name_stack,
|
||||
@ -1061,7 +1059,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
|
||||
*args: Sequence[ir.Value],
|
||||
dim_var_values: Sequence[ir.Value]
|
||||
) -> Tuple[Sequence[Sequence[ir.Value]], TokenSet]:
|
||||
"""Lowers a jaxpr into mHLO, inlined into an existing function.
|
||||
"""Lowers a jaxpr into MLIR, inlined into an existing function.
|
||||
|
||||
Assumes that an MLIR context, location, and insertion point are set.
|
||||
|
||||
@ -1269,13 +1267,13 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue,
|
||||
broadcast_dimensions=broadcast_dimensions)
|
||||
if not core.is_constant_shape(aval_out.shape): # type: ignore
|
||||
shape = eval_dynamic_shape(ctx, aval_out.shape) # type: ignore
|
||||
return mhlo.DynamicBroadcastInDimOp(
|
||||
return hlo.DynamicBroadcastInDimOp(
|
||||
aval_to_ir_type(aval_out), op,
|
||||
shape_tensor(shape),
|
||||
dense_int_elements(broadcast_dimensions),
|
||||
).result
|
||||
else:
|
||||
return mhlo.BroadcastInDimOp(
|
||||
return hlo.BroadcastInDimOp(
|
||||
aval_to_ir_type(aval_out), op,
|
||||
dense_int_elements(broadcast_dimensions)).result
|
||||
|
||||
@ -1302,12 +1300,12 @@ def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Va
|
||||
aval_out, = aval_out.dtype._rules.physical_avals(aval_out) # type: ignore
|
||||
if not core.is_constant_shape(aval_out.shape): # type: ignore
|
||||
shape = eval_dynamic_shape(ctx, aval_out.shape) # type: ignore
|
||||
return mhlo.DynamicReshapeOp(
|
||||
return hlo.DynamicReshapeOp(
|
||||
aval_to_ir_type(aval_out), op,
|
||||
shape_tensor(shape),
|
||||
).result
|
||||
else:
|
||||
return mhlo.ReshapeOp(aval_to_ir_type(aval_out), op).result
|
||||
return hlo.ReshapeOp(aval_to_ir_type(aval_out), op).result
|
||||
|
||||
def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
|
||||
start_indices, limit_indices, strides) -> ir.Value:
|
||||
@ -1319,13 +1317,13 @@ def slice_op(ctx: LoweringRuleContext, x, aval_out, *,
|
||||
start_indices = eval_dynamic_shape(ctx, start_indices)
|
||||
limit_indices = eval_dynamic_shape(ctx, limit_indices)
|
||||
strides = eval_dynamic_shape(ctx, strides)
|
||||
return mhlo.RealDynamicSliceOp(aval_to_ir_type(aval_out),
|
||||
return hlo.RealDynamicSliceOp(aval_to_ir_type(aval_out),
|
||||
x,
|
||||
shape_tensor(start_indices),
|
||||
shape_tensor(limit_indices),
|
||||
shape_tensor(strides)).result
|
||||
else:
|
||||
return mhlo.SliceOp(x,
|
||||
return hlo.SliceOp(x,
|
||||
dense_int_elements(start_indices),
|
||||
dense_int_elements(limit_indices),
|
||||
dense_int_elements(strides)).result
|
||||
@ -1338,14 +1336,14 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *,
|
||||
slice_sizes = aval_out.shape
|
||||
if not core.is_constant_shape(slice_sizes):
|
||||
slice_sizes = eval_dynamic_shape(ctx, slice_sizes)
|
||||
return mhlo.RealDynamicSliceOp(
|
||||
return hlo.RealDynamicSliceOp(
|
||||
aval_to_ir_type(aval_out), x,
|
||||
shape_tensor(start_indices),
|
||||
shape_tensor(slice_sizes),
|
||||
shape_tensor([1] * len(slice_sizes))
|
||||
).result
|
||||
else:
|
||||
return mhlo.DynamicSliceOp(x, start_indices,
|
||||
return hlo.DynamicSliceOp(x, start_indices,
|
||||
dense_int_elements(slice_sizes)).result
|
||||
|
||||
def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
|
||||
@ -1356,10 +1354,10 @@ def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *,
|
||||
|
||||
# TODO(necula): handle dynamic shapes
|
||||
if mlir_api_version < 40:
|
||||
return mhlo.DynamicUpdateSliceOp(
|
||||
return hlo.DynamicUpdateSliceOp(
|
||||
aval_to_ir_type(aval_out), x, update, start_indices).result
|
||||
else:
|
||||
return mhlo.DynamicUpdateSliceOp(x, update, start_indices).result
|
||||
return hlo.DynamicUpdateSliceOp(x, update, start_indices).result
|
||||
|
||||
def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> ir.Value:
|
||||
"""Returns an IR constant shaped full of `value` shaped like `aval`."""
|
||||
@ -1373,14 +1371,14 @@ def zeros_like_lowering(ctx, x):
|
||||
register_lowering(ad_util.zeros_like_p, zeros_like_lowering)
|
||||
|
||||
def add_jaxvals_lowering(ctx, x, y):
|
||||
return mhlo.AddOp(x, y).results
|
||||
return hlo.AddOp(x, y).results
|
||||
register_lowering(ad_util.add_jaxvals_p, add_jaxvals_lowering)
|
||||
|
||||
register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x])
|
||||
|
||||
|
||||
def compare_mhlo(x, y, direction: str, comparison_type: Optional[str] = None):
|
||||
"""Creates mhlo.CompareOp."""
|
||||
def compare_hlo(x, y, direction: str, comparison_type: Optional[str] = None):
|
||||
"""Creates CompareOp."""
|
||||
if comparison_type is None:
|
||||
elem_type = ir.RankedTensorType(x.type).element_type
|
||||
if ir.IntegerType.isinstance(elem_type):
|
||||
@ -1389,34 +1387,34 @@ def compare_mhlo(x, y, direction: str, comparison_type: Optional[str] = None):
|
||||
else:
|
||||
comparison_type = "FLOAT"
|
||||
|
||||
return mhlo.CompareOp(
|
||||
return hlo.CompareOp(
|
||||
x,
|
||||
y,
|
||||
mhlo.ComparisonDirectionAttr.get(direction),
|
||||
compare_type=mhlo.ComparisonTypeAttr.get(comparison_type))
|
||||
hlo.ComparisonDirectionAttr.get(direction),
|
||||
compare_type=hlo.ComparisonTypeAttr.get(comparison_type))
|
||||
|
||||
def _minmax_mhlo(op, cmp, x, y):
|
||||
def _minmax_hlo(op, cmp, x, y):
|
||||
"""Min/max that compares complex values lexicographically as pairs."""
|
||||
tensor_type = ir.RankedTensorType(x.type)
|
||||
if ir.ComplexType.isinstance(tensor_type.element_type):
|
||||
rx = mhlo.RealOp(x).result
|
||||
ry = mhlo.RealOp(y).result
|
||||
real_eq = compare_mhlo(rx, ry, "EQ", "FLOAT")
|
||||
real_cmp = compare_mhlo(rx, ry, cmp, "FLOAT")
|
||||
imag_cmp = compare_mhlo(
|
||||
mhlo.ImagOp(x).result,
|
||||
mhlo.ImagOp(y).result, cmp, "FLOAT")
|
||||
which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
|
||||
return mhlo.SelectOp(which, x, y)
|
||||
rx = hlo.RealOp(x).result
|
||||
ry = hlo.RealOp(y).result
|
||||
real_eq = compare_hlo(rx, ry, "EQ", "FLOAT")
|
||||
real_cmp = compare_hlo(rx, ry, cmp, "FLOAT")
|
||||
imag_cmp = compare_hlo(
|
||||
hlo.ImagOp(x).result,
|
||||
hlo.ImagOp(y).result, cmp, "FLOAT")
|
||||
which = hlo.SelectOp(real_eq, imag_cmp, real_cmp).result
|
||||
return hlo.SelectOp(which, x, y)
|
||||
else:
|
||||
return op(x, y)
|
||||
|
||||
min_mhlo = partial(_minmax_mhlo, mhlo.MinOp, "LT")
|
||||
max_mhlo = partial(_minmax_mhlo, mhlo.MaxOp, "GT")
|
||||
min_hlo = partial(_minmax_hlo, hlo.MinOp, "LT")
|
||||
max_hlo = partial(_minmax_hlo, hlo.MaxOp, "GT")
|
||||
|
||||
|
||||
def convert_mhlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
|
||||
"""Variant of convert that has XLA HLO semantics.
|
||||
def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
|
||||
"""Variant of convert that has HLO semantics.
|
||||
|
||||
In particular, treat casts to boolean as x != 0, rather than truncating
|
||||
integer values (b/209440332)."""
|
||||
@ -1428,9 +1426,9 @@ def convert_mhlo(ctx: LoweringRuleContext, x, aval_in, aval_out):
|
||||
compare_type = "SIGNED"
|
||||
else:
|
||||
compare_type = "UNSIGNED"
|
||||
return compare_mhlo(x, full_like_aval(ctx, 0, aval_in), "NE",
|
||||
return compare_hlo(x, full_like_aval(ctx, 0, aval_in), "NE",
|
||||
compare_type).result
|
||||
return mhlo.ConvertOp(aval_to_ir_type(aval_out), x).result
|
||||
return hlo.ConvertOp(aval_to_ir_type(aval_out), x).result
|
||||
|
||||
def _wrap_with_spmd_op(name: str,
|
||||
result_type: ir.Type,
|
||||
@ -1444,7 +1442,7 @@ def _wrap_with_spmd_op(name: str,
|
||||
[str(i) for i in sorted(unspecified_dims)]) + "]"
|
||||
else:
|
||||
backend_config = ""
|
||||
op = mhlo.CustomCallOp([result_type], [x],
|
||||
op = hlo.CustomCallOp([result_type], [x],
|
||||
call_target_name=ir.StringAttr.get(name),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(backend_config),
|
||||
@ -1489,7 +1487,7 @@ def cache_lowering(f):
|
||||
except TypeError:
|
||||
# If the parameters aren't hashable, give up on caching.
|
||||
# TODO(phawkins): switch to requiring hashability, when XLA fallback
|
||||
# computations have been ported to MHLO.
|
||||
# computations have been ported to MLIR.
|
||||
return f(ctx, *args, **params)
|
||||
if func is None:
|
||||
func = _emit_lowering_rule_as_fun(partial(f, **params), ctx)
|
||||
@ -1506,12 +1504,12 @@ def cache_lowering(f):
|
||||
|
||||
|
||||
|
||||
def xla_computation_to_mhlo_module(xla_computation: xc.XlaComputation
|
||||
def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation
|
||||
) -> ir.Module:
|
||||
module_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation)
|
||||
return ir.Module.parse(module_str)
|
||||
|
||||
def merge_mhlo_modules(dst_module: ir.Module,
|
||||
def merge_mlir_modules(dst_module: ir.Module,
|
||||
sym_name: str,
|
||||
src_module: ir.Module) -> str:
|
||||
"""Returns the name of src_module's main() function, after renaming."""
|
||||
@ -1563,8 +1561,8 @@ def xla_fallback_lowering(prim: core.Primitive):
|
||||
xla_computation = xla.primitive_subcomputation(
|
||||
module_ctx.platform, axis_env, prim, ctx.avals_in,
|
||||
ctx.avals_out, **params)
|
||||
xla_module = xla_computation_to_mhlo_module(xla_computation)
|
||||
callee_name = merge_mhlo_modules(
|
||||
xla_module = xla_computation_to_mlir_module(xla_computation)
|
||||
callee_name = merge_mlir_modules(
|
||||
module_ctx.module, f"xla_fallback_{prim.name}", xla_module)
|
||||
output_types = map(aval_to_ir_types, ctx.avals_out)
|
||||
flat_output_types = util.flatten(output_types)
|
||||
@ -1576,7 +1574,7 @@ def xla_fallback_lowering(prim: core.Primitive):
|
||||
flatten_lowering_ir_args(args)).result
|
||||
if not prim.multiple_results:
|
||||
return [call]
|
||||
flat_results = [mhlo.GetTupleElementOp(call, i32_attr(i)).result
|
||||
flat_results = [hlo.GetTupleElementOp(call, i32_attr(i)).result
|
||||
for i in range(len(flat_output_types))]
|
||||
|
||||
return util.unflatten(flat_results, map(len, output_types))
|
||||
@ -1611,15 +1609,15 @@ def _dtype_to_xla_type_string(dtype: np.dtype) -> str:
|
||||
raise NotImplementedError(dtype)
|
||||
return _dtype_to_xla_type_string_map[dtype]
|
||||
|
||||
def send_to_host(channel: int, token: mhlo.TokenType, operand: Any,
|
||||
def send_to_host(channel: int, token: hlo.TokenType, operand: Any,
|
||||
aval: core.ShapedArray, name: str, *,
|
||||
sharding: Optional[xc.OpSharding] = None) -> ir.Value:
|
||||
channel_handle = mhlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE)
|
||||
channel_handle = hlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE)
|
||||
if mlir_api_version < 40:
|
||||
send_op = mhlo.SendOp(mhlo.TokenType.get(), [operand], token, channel_handle,
|
||||
send_op = hlo.SendOp(hlo.TokenType.get(), [operand], token, channel_handle,
|
||||
is_host_transfer=ir.BoolAttr.get(True))
|
||||
else:
|
||||
send_op = mhlo.SendOp([operand], token, channel_handle,
|
||||
send_op = hlo.SendOp([operand], token, channel_handle,
|
||||
is_host_transfer=ir.BoolAttr.get(True))
|
||||
dtype_str = _dtype_to_xla_type_string(aval.dtype)
|
||||
if dtype_str in {"f64", "s64", "u64", "c64", "c128"}:
|
||||
@ -1634,12 +1632,12 @@ def send_to_host(channel: int, token: mhlo.TokenType, operand: Any,
|
||||
return send_op.result
|
||||
|
||||
|
||||
def receive_from_host(channel: int, token: mhlo.TokenType,
|
||||
def receive_from_host(channel: int, token: hlo.TokenType,
|
||||
out_aval: core.ShapedArray, name: str, *,
|
||||
sharding: Optional[xc.OpSharding] = None) -> ir.Value:
|
||||
channel_handle = mhlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE)
|
||||
recv_op = mhlo.RecvOp([aval_to_ir_type(out_aval),
|
||||
mhlo.TokenType.get()], token, channel_handle,
|
||||
channel_handle = hlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE)
|
||||
recv_op = hlo.RecvOp([aval_to_ir_type(out_aval),
|
||||
hlo.TokenType.get()], token, channel_handle,
|
||||
is_host_transfer=ir.BoolAttr.get(True))
|
||||
dtype_str = _dtype_to_xla_type_string(out_aval.dtype)
|
||||
if dtype_str in {"f64", "s64", "u64", "c64", "c128"}:
|
||||
@ -1670,9 +1668,9 @@ def _emit_tpu_python_callback(
|
||||
sharding: Optional[xc.OpSharding] = None
|
||||
) -> Tuple[List[ir.Value], Any, Any]:
|
||||
if mlir_api_version < 40:
|
||||
token = token or mhlo.CreateTokenOp(mhlo.TokenType.get()).result
|
||||
token = token or hlo.CreateTokenOp(hlo.TokenType.get()).result
|
||||
else:
|
||||
token = token or mhlo.CreateTokenOp().result
|
||||
token = token or hlo.CreateTokenOp().result
|
||||
_wrapped_callback = callback
|
||||
|
||||
send_channels = []
|
||||
@ -1680,7 +1678,7 @@ def _emit_tpu_python_callback(
|
||||
# If there are no operands to the callback, we need to insert a dummy send
|
||||
# op or the callback will never be triggered!
|
||||
# TODO(sharadmv,chky): Enable this fix in the runtime as opposed to in
|
||||
# MHLO builder.
|
||||
# MLIR builder.
|
||||
callback_without_args = _wrapped_callback
|
||||
def _wrapped_callback(*args): # pylint: disable=function-redefined
|
||||
del args
|
||||
@ -1761,7 +1759,7 @@ def emit_python_callback(
|
||||
operand_layouts: Optional[Sequence[Optional[Sequence[int]]]] = None,
|
||||
result_layouts: Optional[Sequence[Optional[Sequence[int]]]] = None,
|
||||
) -> Tuple[List[ir.Value], Any, Any]:
|
||||
"""Emits MHLO that calls back to a provided Python function."""
|
||||
"""Emits MLIR that calls back to a provided Python function."""
|
||||
platform = ctx.module_context.platform
|
||||
if platform not in {"cpu", "cuda", "rocm", "tpu"}:
|
||||
raise ValueError(
|
||||
@ -1842,7 +1840,7 @@ def emit_python_callback(
|
||||
result_type = ir.TupleType.get_tuple(result_types)
|
||||
call_target_name = ("xla_python_gpu_callback"
|
||||
if platform in {"cuda", "rocm"} else "xla_python_cpu_callback")
|
||||
result = mhlo.CustomCallOp(
|
||||
result = hlo.CustomCallOp(
|
||||
[result_type],
|
||||
callback_operands,
|
||||
call_target_name=ir.StringAttr.get(call_target_name),
|
||||
@ -1859,7 +1857,7 @@ def emit_python_callback(
|
||||
if sharding is not None:
|
||||
set_sharding(result, sharding)
|
||||
results = [
|
||||
mhlo.GetTupleElementOp(result, i32_attr(i)).result
|
||||
hlo.GetTupleElementOp(result, i32_attr(i)).result
|
||||
for i in range(len(result_types))
|
||||
]
|
||||
if token:
|
||||
|
@ -77,7 +77,7 @@ from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
|
||||
new_name_stack, wrap_name, assert_unreachable,
|
||||
tuple_insert, tuple_delete, distributed_debug_log,
|
||||
@ -2211,14 +2211,14 @@ ad.call_transpose_param_updaters[xla_pmap_p] = \
|
||||
ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
|
||||
|
||||
|
||||
def _unravel_index_mhlo(axis_env):
|
||||
def _unravel_index_hlo(axis_env):
|
||||
div = mlir.ir_constant(
|
||||
np.array(axis_env.nreps // util.prod(axis_env.sizes), np.uint32))
|
||||
mod = mlir.ir_constant(np.array(axis_env.sizes[-1], np.uint32))
|
||||
return mhlo.RemOp(
|
||||
mhlo.DivOp(mhlo.ReplicaIdOp().result, div).result, mod).result
|
||||
return hlo.RemOp(
|
||||
hlo.DivOp(hlo.ReplicaIdOp().result, div).result, mod).result
|
||||
|
||||
def _mhlo_shard(aval, axis_env, xs, in_axis):
|
||||
def _hlo_shard(aval, axis_env, xs, in_axis):
|
||||
if aval is core.abstract_token:
|
||||
return xs
|
||||
elif isinstance(aval, core.ShapedArray):
|
||||
@ -2226,20 +2226,20 @@ def _mhlo_shard(aval, axis_env, xs, in_axis):
|
||||
dims = list(aval.shape)
|
||||
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
|
||||
idxs = [zero] * len(dims)
|
||||
idxs.insert(in_axis, _unravel_index_mhlo(axis_env))
|
||||
idxs.insert(in_axis, _unravel_index_hlo(axis_env))
|
||||
dims_unsqueezed = dims.copy()
|
||||
dims_unsqueezed.insert(in_axis, 1)
|
||||
dynamic_slice_result = mhlo.DynamicSliceOp(
|
||||
dynamic_slice_result = hlo.DynamicSliceOp(
|
||||
x, idxs, mlir.dense_int_elements(dims_unsqueezed)).result
|
||||
return [
|
||||
mhlo.ReshapeOp(mlir.aval_to_ir_type(aval), dynamic_slice_result).result
|
||||
hlo.ReshapeOp(mlir.aval_to_ir_type(aval), dynamic_slice_result).result
|
||||
]
|
||||
else:
|
||||
raise TypeError(aval)
|
||||
|
||||
|
||||
# TODO(b/110096942): more efficient gather
|
||||
def _mhlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, platform):
|
||||
def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, platform):
|
||||
if aval is core.abstract_token:
|
||||
return xs
|
||||
elif isinstance(aval, core.ShapedArray):
|
||||
@ -2249,23 +2249,23 @@ def _mhlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, p
|
||||
and platform in ('cpu', 'gpu'))
|
||||
if convert_bool:
|
||||
aval = aval.update(dtype=np.dtype(np.float32))
|
||||
x = mhlo.ConvertOp(mlir.aval_to_ir_type(aval), x).result
|
||||
x = hlo.ConvertOp(mlir.aval_to_ir_type(aval), x).result
|
||||
|
||||
dims = list(aval.shape)
|
||||
padded_aval = aval.update(shape=[axis_env.sizes[-1]] + dims)
|
||||
padded = mlir.full_like_aval(ctx, 0, padded_aval)
|
||||
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
|
||||
idxs = [_unravel_index_mhlo(axis_env)] + [zero] * len(dims)
|
||||
broadcast_result = mhlo.BroadcastOp(
|
||||
idxs = [_unravel_index_hlo(axis_env)] + [zero] * len(dims)
|
||||
broadcast_result = hlo.BroadcastOp(
|
||||
x, mlir.dense_int_elements([1])).result
|
||||
if xc.mlir_api_version < 40:
|
||||
padded = mhlo.DynamicUpdateSliceOp(padded.type, padded, broadcast_result,
|
||||
padded = hlo.DynamicUpdateSliceOp(padded.type, padded, broadcast_result,
|
||||
idxs).result
|
||||
else:
|
||||
padded = mhlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result
|
||||
padded = hlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result
|
||||
replica_groups = mlir.dense_int_elements(
|
||||
xla.axis_groups(axis_env, axis_env.names[-1]))
|
||||
out = mhlo.CrossReplicaSumOp(padded, replica_groups).result
|
||||
out = hlo.CrossReplicaSumOp(padded, replica_groups).result
|
||||
if out_axis != 0:
|
||||
# TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead
|
||||
perm = list(range(1, len(dims)))
|
||||
@ -2273,16 +2273,16 @@ def _mhlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, p
|
||||
transposed_dims = list(dims)
|
||||
transposed_dims.insert(out_axis, axis_env.sizes[-1])
|
||||
aval = aval.update(shape=transposed_dims)
|
||||
out = mhlo.TransposeOp(out, mlir.dense_int_elements(perm)).result
|
||||
out = hlo.TransposeOp(out, mlir.dense_int_elements(perm)).result
|
||||
|
||||
# TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU
|
||||
if convert_bool:
|
||||
float_zero = mlir.full_like_aval(ctx, 0, padded_aval)
|
||||
out = mhlo.CompareOp(
|
||||
out = hlo.CompareOp(
|
||||
out,
|
||||
float_zero,
|
||||
mhlo.ComparisonDirectionAttr.get("NE"),
|
||||
compare_type=mhlo.ComparisonTypeAttr.get("FLOAT")).result
|
||||
hlo.ComparisonDirectionAttr.get("NE"),
|
||||
compare_type=hlo.ComparisonTypeAttr.get("FLOAT")).result
|
||||
return out
|
||||
else:
|
||||
raise TypeError(aval)
|
||||
@ -2305,7 +2305,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
|
||||
# Shard the in_nodes that are mapped
|
||||
in_avals = [v.aval for v in call_jaxpr.invars]
|
||||
in_nodes_sharded = (
|
||||
_mhlo_shard(aval, new_env, mlir.wrap_singleton_ir_values(in_node), in_axis)
|
||||
_hlo_shard(aval, new_env, mlir.wrap_singleton_ir_values(in_node), in_axis)
|
||||
if in_axis is not None else mlir.wrap_singleton_ir_values(in_node)
|
||||
for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes))
|
||||
|
||||
@ -2318,7 +2318,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name,
|
||||
*in_nodes_sharded,
|
||||
dim_var_values=ctx.dim_var_values)
|
||||
out_avals = [v.aval for v in call_jaxpr.outvars]
|
||||
outs = [_mhlo_unshard(ctx, aval, new_env, out_axis, shard,
|
||||
outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard,
|
||||
platform=ctx.module_context.platform)
|
||||
for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)]
|
||||
return outs
|
||||
|
@ -586,6 +586,6 @@ def lower_fun(fun: Callable, *, multiple_results: bool, backend=None,
|
||||
def f(*args, **kw):
|
||||
raise RuntimeError("XLA translation rules are deprecated and "
|
||||
"jax.interpreters.xla.lower_fun is no longer supported. "
|
||||
"Add an MLIR (MHLO) lowering via jax.interpreters.mlir "
|
||||
"Add an MLIR lowering via jax.interpreters.mlir "
|
||||
"instead.")
|
||||
return f
|
||||
|
@ -34,9 +34,9 @@ py_library(
|
||||
"gpu_rnn.py",
|
||||
"gpu_solver.py",
|
||||
"gpu_sparse.py",
|
||||
"hlo_helpers.py",
|
||||
"init.py",
|
||||
"lapack.py",
|
||||
"mhlo_helpers.py",
|
||||
":version",
|
||||
":xla_client",
|
||||
],
|
||||
|
@ -15,10 +15,10 @@
|
||||
from typing import List
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.mhlo as mhlo
|
||||
import jaxlib.mlir.dialects.mhlo as hlo
|
||||
|
||||
|
||||
from .mhlo_helpers import custom_call
|
||||
from .hlo_helpers import custom_call
|
||||
from .cpu import _ducc_fft
|
||||
import numpy as np
|
||||
|
||||
@ -107,7 +107,11 @@ def _ducc_fft_descriptor(shape: List[int], dtype, fft_type: FftType,
|
||||
return descriptor, out_dtype, out_shape
|
||||
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def ducc_fft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
return ducc_fft_hlo(a, dtype, fft_type=fft_type, fft_lengths=fft_lengths)
|
||||
|
||||
def ducc_fft_hlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
"""DUCC FFT kernel for CPU."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
n = len(a_type.shape)
|
||||
@ -128,15 +132,15 @@ def ducc_fft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
|
||||
raise ValueError(f"Unknown output type {out_dtype}")
|
||||
|
||||
if 0 in a_type.shape or 0 in out_shape:
|
||||
zero = mhlo.ConstantOp(
|
||||
zero = hlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(0, dtype=out_dtype), type=out_type))
|
||||
return mhlo.BroadcastOp(
|
||||
return hlo.BroadcastOp(
|
||||
zero,
|
||||
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result
|
||||
|
||||
u8_type = ir.IntegerType.get_unsigned(8)
|
||||
descriptor = mhlo.ConstantOp(
|
||||
descriptor = hlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type))
|
||||
layout = tuple(range(n - 1, -1, -1))
|
||||
|
@ -18,7 +18,7 @@ import operator
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
|
||||
from .mhlo_helpers import custom_call
|
||||
from .hlo_helpers import custom_call
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
@ -39,7 +39,7 @@ except ImportError:
|
||||
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
||||
|
||||
|
||||
def _lu_pivots_to_permutation_mhlo(platform, gpu_linalg, pivots, *, permutation_size):
|
||||
def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_size):
|
||||
"""Kernel for the transformation of pivots to permutations on GPU."""
|
||||
typ = ir.RankedTensorType(pivots.type)
|
||||
dims = typ.shape
|
||||
@ -65,7 +65,7 @@ def _lu_pivots_to_permutation_mhlo(platform, gpu_linalg, pivots, *, permutation_
|
||||
operand_layouts=[pivots_layout],
|
||||
result_layouts=[permutations_layout])
|
||||
|
||||
cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_mhlo, "cu",
|
||||
cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_hlo, "cu",
|
||||
_cuda_linalg)
|
||||
hip_lu_pivots_to_permutation = partial(
|
||||
_lu_pivots_to_permutation_mhlo, "hip", _hip_linalg)
|
||||
_lu_pivots_to_permutation_hlo, "hip", _hip_linalg)
|
||||
|
@ -22,7 +22,7 @@ import jaxlib.mlir.ir as ir
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
from .mhlo_helpers import custom_call
|
||||
from .hlo_helpers import custom_call
|
||||
|
||||
try:
|
||||
from .cuda import _prng as _cuda_prng
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.mhlo as mhlo
|
||||
import jaxlib.mlir.dialects.mhlo as hlo
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -61,7 +61,7 @@ def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *,
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
|
||||
out = mhlo.CustomCallOp(
|
||||
out = hlo.CustomCallOp(
|
||||
[
|
||||
ir.TupleType.get_tuple([
|
||||
output_type, h_0.type, c_0.type, workspace_type,
|
||||
@ -76,13 +76,13 @@ def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *,
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
)
|
||||
return [
|
||||
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
|
||||
hlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
|
||||
def _mhlo_zeros_f32(shape):
|
||||
return mhlo.ConstantOp(
|
||||
def _hlo_zeros_f32(shape):
|
||||
return hlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.zeros(shape, dtype=np.float32), type=ir.F32Type.get())).result
|
||||
|
||||
@ -102,8 +102,8 @@ def cudnn_rnn_bwd_lowering(ctx, dy, dhn, dcn, x, h0, c0, w, y, workspace,
|
||||
reserve_space_shape[0])
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
zeroed_dw = _mhlo_zeros_f32(ctx.avals_out[3].shape)
|
||||
out = mhlo.CustomCallOp(
|
||||
zeroed_dw = _hlo_zeros_f32(ctx.avals_out[3].shape)
|
||||
out = hlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([x.type, h0.type, c0.type, w.type])], [
|
||||
dy, dhn, dcn, x, h0, c0, w, y, workspace, reserve_space, zeroed_dw,
|
||||
seq_lengths
|
||||
@ -114,12 +114,12 @@ def cudnn_rnn_bwd_lowering(ctx, dy, dhn, dcn, x, h0, c0, w, y, workspace,
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
output_operand_aliases=ir.ArrayAttr.get([
|
||||
mhlo.OutputOperandAlias.get(
|
||||
hlo.OutputOperandAlias.get(
|
||||
output_tuple_indices=[3],
|
||||
operand_index=10,
|
||||
operand_tuple_indices=[])
|
||||
]))
|
||||
return [
|
||||
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
|
||||
hlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
|
||||
for i in range(4)
|
||||
]
|
||||
|
@ -18,13 +18,13 @@ from functools import partial
|
||||
import operator
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.mhlo as mhlo
|
||||
import jaxlib.mlir.dialects.mhlo as hlo
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
from .mhlo_helpers import custom_call
|
||||
from .hlo_helpers import custom_call
|
||||
|
||||
try:
|
||||
from .cuda import _blas as _cublas
|
||||
@ -63,7 +63,7 @@ def _real_type(dtype):
|
||||
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
||||
|
||||
|
||||
def _getrf_mhlo(platform, gpu_blas, gpu_solver, dtype, a):
|
||||
def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a):
|
||||
"""LU decomposition."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
@ -106,11 +106,11 @@ def _getrf_mhlo(platform, gpu_blas, gpu_solver, dtype, a):
|
||||
operand_output_aliases={0: 0})
|
||||
return out[:3]
|
||||
|
||||
cuda_getrf = partial(_getrf_mhlo, "cu", _cublas, _cusolver)
|
||||
rocm_getrf = partial(_getrf_mhlo, "hip", _hipblas, _hipsolver)
|
||||
cuda_getrf = partial(_getrf_hlo, "cu", _cublas, _cusolver)
|
||||
rocm_getrf = partial(_getrf_hlo, "hip", _hipblas, _hipsolver)
|
||||
|
||||
|
||||
def _geqrf_mhlo(platform, gpu_solver, dtype, a):
|
||||
def _geqrf_hlo(platform, gpu_solver, dtype, a):
|
||||
"""QR decomposition."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
@ -145,10 +145,10 @@ def _geqrf_mhlo(platform, gpu_solver, dtype, a):
|
||||
operand_output_aliases={0: 0})
|
||||
return out[:3]
|
||||
|
||||
cuda_geqrf = partial(_geqrf_mhlo, "cu", _cusolver)
|
||||
rocm_geqrf = partial(_geqrf_mhlo, "hip", _hipsolver)
|
||||
cuda_geqrf = partial(_geqrf_hlo, "cu", _cusolver)
|
||||
rocm_geqrf = partial(_geqrf_hlo, "hip", _hipsolver)
|
||||
|
||||
def _geqrf_batched_mhlo(platform, gpu_blas, dtype, a):
|
||||
def _geqrf_batched_hlo(platform, gpu_blas, dtype, a):
|
||||
"""Batched QR decomposition."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
@ -183,11 +183,11 @@ def _geqrf_batched_mhlo(platform, gpu_blas, dtype, a):
|
||||
)
|
||||
return out[:2]
|
||||
|
||||
cuda_geqrf_batched = partial(_geqrf_batched_mhlo, "cu", _cublas)
|
||||
rocm_geqrf_batched = partial(_geqrf_batched_mhlo, "hip", _hipblas)
|
||||
cuda_geqrf_batched = partial(_geqrf_batched_hlo, "cu", _cublas)
|
||||
rocm_geqrf_batched = partial(_geqrf_batched_hlo, "hip", _hipblas)
|
||||
|
||||
|
||||
def _csrlsvqr_mhlo(platform, gpu_solver, dtype, data,
|
||||
def _csrlsvqr_hlo(platform, gpu_solver, dtype, data,
|
||||
indices, indptr, b, tol, reorder):
|
||||
"""Sparse solver via QR decomposition. CUDA only."""
|
||||
b_type = ir.RankedTensorType(b.type)
|
||||
@ -209,10 +209,10 @@ def _csrlsvqr_mhlo(platform, gpu_solver, dtype, data,
|
||||
)
|
||||
return [out]
|
||||
|
||||
cuda_csrlsvqr = partial(_csrlsvqr_mhlo, "cu", _cusolver)
|
||||
cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver)
|
||||
|
||||
|
||||
def _orgqr_mhlo(platform, gpu_solver, dtype, a, tau):
|
||||
def _orgqr_hlo(platform, gpu_solver, dtype, a, tau):
|
||||
"""Product of elementary Householder reflections."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
@ -252,11 +252,11 @@ def _orgqr_mhlo(platform, gpu_solver, dtype, a, tau):
|
||||
operand_output_aliases={0: 0})
|
||||
return out[:2]
|
||||
|
||||
cuda_orgqr = partial(_orgqr_mhlo, "cu", _cusolver)
|
||||
rocm_orgqr = partial(_orgqr_mhlo, "hip", _hipsolver)
|
||||
cuda_orgqr = partial(_orgqr_hlo, "cu", _cusolver)
|
||||
rocm_orgqr = partial(_orgqr_hlo, "hip", _hipsolver)
|
||||
|
||||
|
||||
def _syevd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
lower=False):
|
||||
"""Symmetric (Hermitian) eigendecomposition."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
@ -304,11 +304,11 @@ def _syevd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
operand_output_aliases={0: 0})
|
||||
return out[:3]
|
||||
|
||||
cuda_syevd = partial(_syevd_mhlo, "cu", _cusolver, True)
|
||||
rocm_syevd = partial(_syevd_mhlo, "hip", _hipsolver, True)
|
||||
cuda_syevd = partial(_syevd_hlo, "cu", _cusolver, True)
|
||||
rocm_syevd = partial(_syevd_hlo, "hip", _hipsolver, True)
|
||||
|
||||
|
||||
def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
full_matrices=True, compute_uv=True):
|
||||
"""Singular value decomposition."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
@ -358,18 +358,18 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
[0],
|
||||
],
|
||||
operand_output_aliases={0: 0})
|
||||
vt = mhlo.TransposeOp(
|
||||
vt = hlo.TransposeOp(
|
||||
v,
|
||||
ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result
|
||||
if np.issubdtype(dtype, np.complexfloating):
|
||||
vt = mhlo.ComplexOp(mhlo.RealOp(vt), mhlo.NegOp(mhlo.ImagOp(vt))).result
|
||||
vt = hlo.ComplexOp(hlo.RealOp(vt), hlo.NegOp(hlo.ImagOp(vt))).result
|
||||
if not full_matrices and not econ:
|
||||
u = mhlo.SliceOp(
|
||||
u = hlo.SliceOp(
|
||||
u,
|
||||
ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)),
|
||||
ir.DenseIntElementsAttr.get(np.array(batch_dims + (m, min(m, n)))),
|
||||
ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result
|
||||
vt = mhlo.SliceOp(
|
||||
vt = hlo.SliceOp(
|
||||
vt,
|
||||
ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)),
|
||||
ir.DenseIntElementsAttr.get(np.array(batch_dims + (min(m, n), n))),
|
||||
@ -430,11 +430,11 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a,
|
||||
operand_output_aliases={0: 0})
|
||||
return s, u, vt, info
|
||||
|
||||
cuda_gesvd = partial(_gesvd_mhlo, "cu", _cusolver, True)
|
||||
rocm_gesvd = partial(_gesvd_mhlo, "hip", _hipsolver, False)
|
||||
cuda_gesvd = partial(_gesvd_hlo, "cu", _cusolver, True)
|
||||
rocm_gesvd = partial(_gesvd_hlo, "hip", _hipsolver, False)
|
||||
|
||||
|
||||
def _sytrd_mhlo(platform, gpu_solver, dtype, a, *, lower):
|
||||
def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower):
|
||||
"""sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form."""
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
@ -490,20 +490,20 @@ def _sytrd_mhlo(platform, gpu_solver, dtype, a, *, lower):
|
||||
if not lower and platform == "cu" and m > 1:
|
||||
start = (0,) * len(batch_dims) + (0,)
|
||||
end = batch_dims + (1,)
|
||||
s = mhlo.SliceOp(e, intattr(start), intattr(end), intattr([1] * len(start)))
|
||||
s = hlo.SliceOp(e, intattr(start), intattr(end), intattr([1] * len(start)))
|
||||
s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type)
|
||||
s = mhlo.BroadcastInDimOp(s_type, s, intattr(range(len(dims) - 1)))
|
||||
s = hlo.BroadcastInDimOp(s_type, s, intattr(range(len(dims) - 1)))
|
||||
# The diagonals are always real; convert to complex if needed.
|
||||
s = mhlo.ConvertOp(
|
||||
s = hlo.ConvertOp(
|
||||
ir.RankedTensorType.get(s_type.shape, a_type.element_type), s)
|
||||
offsets = tuple(mhlo.ConstantOp(intattr(i))
|
||||
offsets = tuple(hlo.ConstantOp(intattr(i))
|
||||
for i in ((0,) * len(batch_dims) + (0, 1)))
|
||||
if xla_client.mlir_api_version < 40:
|
||||
a = mhlo.DynamicUpdateSliceOp(a.type, a, s, offsets).result
|
||||
a = hlo.DynamicUpdateSliceOp(a.type, a, s, offsets).result
|
||||
else:
|
||||
a = mhlo.DynamicUpdateSliceOp(a, s, offsets).result
|
||||
a = hlo.DynamicUpdateSliceOp(a, s, offsets).result
|
||||
|
||||
return a, d, e, taus, info
|
||||
|
||||
cuda_sytrd = partial(_sytrd_mhlo, "cu", _cusolver)
|
||||
rocm_sytrd = partial(_sytrd_mhlo, "hip", _hipsolver)
|
||||
cuda_sytrd = partial(_sytrd_hlo, "cu", _cusolver)
|
||||
rocm_sytrd = partial(_sytrd_hlo, "hip", _hipsolver)
|
||||
|
@ -23,7 +23,7 @@ import numpy as np
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
from .mhlo_helpers import custom_call
|
||||
from .hlo_helpers import custom_call
|
||||
|
||||
try:
|
||||
from .cuda import _sparse as _cusparse
|
||||
@ -46,7 +46,7 @@ cuda_is_supported : bool = _cusparse and _cusparse.sparse_supported
|
||||
rocm_is_supported : bool = _hipsparse and _hipsparse.sparse_supported
|
||||
|
||||
|
||||
def _validate_csr_mhlo(data, indices, indptr, shape):
|
||||
def _validate_csr_hlo(data, indices, indptr, shape):
|
||||
data_type = ir.RankedTensorType(data.type)
|
||||
indices_type = ir.RankedTensorType(indices.type)
|
||||
indptr_type = ir.RankedTensorType(indptr.type)
|
||||
@ -57,7 +57,7 @@ def _validate_csr_mhlo(data, indices, indptr, shape):
|
||||
assert indptr_type.shape == [shape[0] + 1]
|
||||
return data_type.element_type, indices_type.element_type, nnz
|
||||
|
||||
def _validate_coo_mhlo(data, row, col):
|
||||
def _validate_coo_hlo(data, row, col):
|
||||
data_type = ir.RankedTensorType(data.type)
|
||||
row_type = ir.RankedTensorType(row.type)
|
||||
col_type = ir.RankedTensorType(col.type)
|
||||
@ -69,10 +69,10 @@ def _validate_coo_mhlo(data, row, col):
|
||||
return data_type.element_type, row_type.element_type, nnz
|
||||
|
||||
|
||||
def _csr_todense_mhlo(platform, gpu_sparse, data, indices, indptr, *, shape,
|
||||
def _csr_todense_hlo(platform, gpu_sparse, data, indices, indptr, *, shape,
|
||||
data_dtype, index_dtype):
|
||||
"""CSR to dense matrix."""
|
||||
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
|
||||
data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape)
|
||||
rows, cols = shape
|
||||
|
||||
buffer_size, opaque = gpu_sparse.build_csr_todense_descriptor(
|
||||
@ -91,11 +91,11 @@ def _csr_todense_mhlo(platform, gpu_sparse, data, indices, indptr, *, shape,
|
||||
result_layouts=[[1, 0], [0]])
|
||||
return out[0]
|
||||
|
||||
cuda_csr_todense = partial(_csr_todense_mhlo, "cu", _cusparse)
|
||||
rocm_csr_todense = partial(_csr_todense_mhlo, "hip", _hipsparse)
|
||||
cuda_csr_todense = partial(_csr_todense_hlo, "cu", _cusparse)
|
||||
rocm_csr_todense = partial(_csr_todense_hlo, "hip", _hipsparse)
|
||||
|
||||
|
||||
def _csr_fromdense_mhlo(platform, gpu_sparse, mat, *, nnz, index_dtype,
|
||||
def _csr_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, index_dtype,
|
||||
data_dtype, index_type):
|
||||
"""CSR from dense matrix."""
|
||||
mat_type = ir.RankedTensorType(mat.type)
|
||||
@ -119,15 +119,15 @@ def _csr_fromdense_mhlo(platform, gpu_sparse, mat, *, nnz, index_dtype,
|
||||
result_layouts=[[0]] * 4)
|
||||
return out[:3]
|
||||
|
||||
cuda_csr_fromdense = partial(_csr_fromdense_mhlo, "cu", _cusparse)
|
||||
rocm_csr_fromdense = partial(_csr_fromdense_mhlo, "hip", _hipsparse)
|
||||
cuda_csr_fromdense = partial(_csr_fromdense_hlo, "cu", _cusparse)
|
||||
rocm_csr_fromdense = partial(_csr_fromdense_hlo, "hip", _hipsparse)
|
||||
|
||||
|
||||
def _csr_matvec_mhlo(platform, gpu_sparse, data, indices, indptr, x, *, shape,
|
||||
def _csr_matvec_hlo(platform, gpu_sparse, data, indices, indptr, x, *, shape,
|
||||
transpose=False, compute_dtype=None, compute_type=None,
|
||||
data_dtype, index_dtype, x_dtype):
|
||||
"""CSR matrix/vector multiply."""
|
||||
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
|
||||
data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape)
|
||||
rows, cols = shape
|
||||
|
||||
if compute_dtype is None:
|
||||
@ -152,15 +152,15 @@ def _csr_matvec_mhlo(platform, gpu_sparse, data, indices, indptr, x, *, shape,
|
||||
result_layouts=[[0]] * 2)
|
||||
return out[0]
|
||||
|
||||
cuda_csr_matvec = partial(_csr_matvec_mhlo, "cu", _cusparse)
|
||||
rocm_csr_matvec = partial(_csr_matvec_mhlo, "hip", _hipsparse)
|
||||
cuda_csr_matvec = partial(_csr_matvec_hlo, "cu", _cusparse)
|
||||
rocm_csr_matvec = partial(_csr_matvec_hlo, "hip", _hipsparse)
|
||||
|
||||
|
||||
def _csr_matmat_mhlo(platform, gpu_sparse, data, indices, indptr, B, *, shape,
|
||||
def _csr_matmat_hlo(platform, gpu_sparse, data, indices, indptr, B, *, shape,
|
||||
transpose=False, compute_dtype=None, compute_type=None,
|
||||
index_dtype, data_dtype, B_dtype):
|
||||
"""CSR from dense matrix."""
|
||||
data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape)
|
||||
data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape)
|
||||
rows, cols = shape
|
||||
B_shape = ir.RankedTensorType(B.type).shape
|
||||
_, Ccols = B_shape
|
||||
@ -187,14 +187,14 @@ def _csr_matmat_mhlo(platform, gpu_sparse, data, indices, indptr, B, *, shape,
|
||||
result_layouts=[[1, 0], [0]])
|
||||
return out[0]
|
||||
|
||||
cuda_csr_matmat = partial(_csr_matmat_mhlo, "cu", _cusparse)
|
||||
rocm_csr_matmat = partial(_csr_matmat_mhlo, "hip", _hipsparse)
|
||||
cuda_csr_matmat = partial(_csr_matmat_hlo, "cu", _cusparse)
|
||||
rocm_csr_matmat = partial(_csr_matmat_hlo, "hip", _hipsparse)
|
||||
|
||||
|
||||
def _coo_todense_mhlo(platform, gpu_sparse, data, row, col, *, shape,
|
||||
def _coo_todense_hlo(platform, gpu_sparse, data, row, col, *, shape,
|
||||
data_dtype, index_dtype):
|
||||
"""COO to dense matrix."""
|
||||
data_type, _, nnz = _validate_coo_mhlo(data, row, col)
|
||||
data_type, _, nnz = _validate_coo_hlo(data, row, col)
|
||||
rows, cols = shape
|
||||
|
||||
buffer_size, opaque = gpu_sparse.build_coo_todense_descriptor(
|
||||
@ -213,11 +213,11 @@ def _coo_todense_mhlo(platform, gpu_sparse, data, row, col, *, shape,
|
||||
result_layouts=[[1, 0], [0]])
|
||||
return out[0]
|
||||
|
||||
cuda_coo_todense = partial(_coo_todense_mhlo, "cu", _cusparse)
|
||||
rocm_coo_todense = partial(_coo_todense_mhlo, "hip", _hipsparse)
|
||||
cuda_coo_todense = partial(_coo_todense_hlo, "cu", _cusparse)
|
||||
rocm_coo_todense = partial(_coo_todense_hlo, "hip", _hipsparse)
|
||||
|
||||
|
||||
def _coo_fromdense_mhlo(platform, gpu_sparse, mat, *, nnz, data_dtype,
|
||||
def _coo_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, data_dtype,
|
||||
index_dtype, index_type):
|
||||
"""COO from dense matrix."""
|
||||
mat_type = ir.RankedTensorType(mat.type)
|
||||
@ -241,15 +241,15 @@ def _coo_fromdense_mhlo(platform, gpu_sparse, mat, *, nnz, data_dtype,
|
||||
result_layouts=[[0]] * 4)
|
||||
return out[:3]
|
||||
|
||||
cuda_coo_fromdense = partial(_coo_fromdense_mhlo, "cu", _cusparse)
|
||||
rocm_coo_fromdense = partial(_coo_fromdense_mhlo, "hip", _hipsparse)
|
||||
cuda_coo_fromdense = partial(_coo_fromdense_hlo, "cu", _cusparse)
|
||||
rocm_coo_fromdense = partial(_coo_fromdense_hlo, "hip", _hipsparse)
|
||||
|
||||
|
||||
def _coo_matvec_mhlo(platform, gpu_sparse, data, row, col, x, *, shape,
|
||||
def _coo_matvec_hlo(platform, gpu_sparse, data, row, col, x, *, shape,
|
||||
transpose=False, compute_dtype=None, compute_type=None,
|
||||
index_dtype, data_dtype, x_dtype):
|
||||
"""COO matrix/vector multiply."""
|
||||
data_type, _, nnz = _validate_coo_mhlo(data, row, col)
|
||||
data_type, _, nnz = _validate_coo_hlo(data, row, col)
|
||||
rows, cols = shape
|
||||
|
||||
if compute_dtype is None:
|
||||
@ -274,15 +274,15 @@ def _coo_matvec_mhlo(platform, gpu_sparse, data, row, col, x, *, shape,
|
||||
result_layouts=[[0]] * 2)
|
||||
return out[0]
|
||||
|
||||
cuda_coo_matvec = partial(_coo_matvec_mhlo, "cu", _cusparse)
|
||||
rocm_coo_matvec = partial(_coo_matvec_mhlo, "hip", _hipsparse)
|
||||
cuda_coo_matvec = partial(_coo_matvec_hlo, "cu", _cusparse)
|
||||
rocm_coo_matvec = partial(_coo_matvec_hlo, "hip", _hipsparse)
|
||||
|
||||
|
||||
def _coo_matmat_mhlo(platform, gpu_sparse, data, row, col, B, *, shape,
|
||||
def _coo_matmat_hlo(platform, gpu_sparse, data, row, col, B, *, shape,
|
||||
transpose=False, compute_dtype=None, compute_type=None,
|
||||
x_dtype, data_dtype, index_dtype):
|
||||
"""COO from dense matrix."""
|
||||
data_type, _, nnz = _validate_coo_mhlo(data, row, col)
|
||||
data_type, _, nnz = _validate_coo_hlo(data, row, col)
|
||||
is_batched_matmat = False
|
||||
batch_count = 1
|
||||
if len(shape) == 2:
|
||||
@ -334,11 +334,11 @@ def _coo_matmat_mhlo(platform, gpu_sparse, data, row, col, B, *, shape,
|
||||
result_layouts=[out_layout, [0]])
|
||||
return out[0]
|
||||
|
||||
cuda_coo_matmat = partial(_coo_matmat_mhlo, "cu", _cusparse)
|
||||
rocm_coo_matmat = partial(_coo_matmat_mhlo, "hip", _hipsparse)
|
||||
cuda_coo_matmat = partial(_coo_matmat_hlo, "cu", _cusparse)
|
||||
rocm_coo_matmat = partial(_coo_matmat_hlo, "hip", _hipsparse)
|
||||
|
||||
|
||||
def _gtsv2_mhlo(platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t):
|
||||
def _gtsv2_hlo(platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t):
|
||||
"""Calls `cusparse<t>gtsv2(dl, d, du, B, m, n, ldb)`."""
|
||||
f32 = (t == np.float32)
|
||||
if f32:
|
||||
@ -360,5 +360,5 @@ def _gtsv2_mhlo(platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t):
|
||||
operand_output_aliases={3: 0})
|
||||
return out[0]
|
||||
|
||||
cuda_gtsv2 = partial(_gtsv2_mhlo, "cu", _cusparse)
|
||||
rocm_gtsv2 = partial(_gtsv2_mhlo, "hip", _hipsparse)
|
||||
cuda_gtsv2 = partial(_gtsv2_hlo, "cu", _cusparse)
|
||||
rocm_gtsv2 = partial(_gtsv2_hlo, "hip", _hipsparse)
|
||||
|
@ -12,11 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Helpers for building MHLO operators
|
||||
# Helpers for building MLIR operators
|
||||
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.mhlo as mhlo
|
||||
import jaxlib.mlir.dialects.mhlo as hlo
|
||||
import numpy as np
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ def custom_call(
|
||||
api_version: int = 2,
|
||||
operand_output_aliases: Dict[int, int] = {},
|
||||
) -> Union[ir.Value, Sequence[ir.Value]]:
|
||||
"""Less-verbose helper for building an MHLO custom call op.
|
||||
"""Less-verbose helper for building a CustomCallOp.
|
||||
|
||||
Once https://github.com/llvm/llvm-project/issues/54932 is fixed, this helper
|
||||
may be able to go away.
|
||||
@ -42,7 +42,7 @@ def custom_call(
|
||||
that must alias.
|
||||
"""
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = mhlo.CustomCallOp(
|
||||
out = hlo.CustomCallOp(
|
||||
(out_types
|
||||
if len(out_types) == 1 else [ir.TupleType.get_tuple(out_types)]),
|
||||
operands,
|
||||
@ -63,7 +63,7 @@ def custom_call(
|
||||
type=ir.IndexType.get()) for l in result_layouts
|
||||
]),
|
||||
output_operand_aliases=ir.ArrayAttr.get([
|
||||
mhlo.OutputOperandAlias.get(
|
||||
hlo.OutputOperandAlias.get(
|
||||
output_tuple_indices=[] if len(out_types) == 1 else [output],
|
||||
operand_index=input,
|
||||
operand_tuple_indices=[])
|
||||
@ -73,6 +73,6 @@ def custom_call(
|
||||
return out.result
|
||||
else:
|
||||
return [
|
||||
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
|
||||
hlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
|
||||
for i in range(len(out_types))
|
||||
]
|
106
jaxlib/lapack.py
106
jaxlib/lapack.py
@ -16,12 +16,12 @@
|
||||
# via CustomCallWithLayout.
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.mhlo as mhlo
|
||||
import jaxlib.mlir.dialects.mhlo as hlo
|
||||
|
||||
import numpy as np
|
||||
from jaxlib import xla_client
|
||||
|
||||
from .mhlo_helpers import custom_call
|
||||
from .hlo_helpers import custom_call
|
||||
from .cpu import _lapack
|
||||
|
||||
for _name, _value in _lapack.registrations().items():
|
||||
@ -32,23 +32,29 @@ for _name, _value in _lapack.registrations().items():
|
||||
_initialize = _lapack.initialize
|
||||
|
||||
|
||||
def _mhlo_u8(x):
|
||||
return mhlo.ConstantOp(
|
||||
def _hlo_u8(x):
|
||||
return hlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(x, dtype=np.uint8),
|
||||
type=ir.IntegerType.get_unsigned(8))).result
|
||||
|
||||
def _mhlo_s32(x):
|
||||
return mhlo.ConstantOp(
|
||||
def _hlo_s32(x):
|
||||
return hlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.array(x, dtype=np.int32),
|
||||
type=ir.IntegerType.get_signless(32))).result
|
||||
|
||||
# TODO(phawkins): it would be nice to avoid duplicating code for each type.
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
conj_a=False, diag=False):
|
||||
return trsm_hlo(dtype, alpha, a, b, left_side=left_side, lower=lower,
|
||||
trans_a=trans_a, conj_a=conj_a, diag=diag)
|
||||
|
||||
# ?trsm(left_side, lower, trans_a, diag, m, n, alpha, a, b):
|
||||
# triangular solve
|
||||
def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
def trsm_hlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
conj_a=False, diag=False):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
@ -87,9 +93,9 @@ def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
return custom_call(
|
||||
fn,
|
||||
[b.type],
|
||||
[_mhlo_s32(int(left_side)), _mhlo_s32(int(lower)),
|
||||
_mhlo_s32((2 if conj_a else 1) if trans_a else 0), _mhlo_s32(int(diag)),
|
||||
_mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(num_b),
|
||||
[_hlo_s32(int(left_side)), _hlo_s32(int(lower)),
|
||||
_hlo_s32((2 if conj_a else 1) if trans_a else 0), _hlo_s32(int(diag)),
|
||||
_hlo_s32(m), _hlo_s32(n), _hlo_s32(num_b),
|
||||
alpha, a, b],
|
||||
operand_layouts=[scalar_layout] * 8 + [layout] * 2,
|
||||
result_layouts=[layout],
|
||||
@ -99,7 +105,11 @@ def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False,
|
||||
|
||||
# # ?getrf: LU decomposition
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def getrf_mhlo(dtype, a):
|
||||
return getrf_hlo(dtype, a)
|
||||
|
||||
def getrf_hlo(dtype, a):
|
||||
_initialize()
|
||||
dims = ir.RankedTensorType(a.type).shape
|
||||
assert len(dims) >= 2
|
||||
@ -131,7 +141,7 @@ def getrf_mhlo(dtype, a):
|
||||
ir.RankedTensorType.get(batch_dims + (min(m, n),), i32_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
],
|
||||
[_mhlo_s32(int(b)), _mhlo_s32(m), _mhlo_s32(n), a],
|
||||
[_hlo_s32(int(b)), _hlo_s32(m), _hlo_s32(n), a],
|
||||
operand_layouts=[scalar_layout] * 3 + [layout],
|
||||
result_layouts=[
|
||||
layout,
|
||||
@ -144,7 +154,11 @@ def getrf_mhlo(dtype, a):
|
||||
|
||||
# # ?geqrf: QR decomposition
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def geqrf_mhlo(dtype, a):
|
||||
return geqrf_hlo(dtype, a)
|
||||
|
||||
def geqrf_hlo(dtype, a):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
@ -182,7 +196,7 @@ def geqrf_mhlo(dtype, a):
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
ir.RankedTensorType.get([lwork], a_type.element_type),
|
||||
],
|
||||
[_mhlo_s32(int(b)), _mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(lwork), a],
|
||||
[_hlo_s32(int(b)), _hlo_s32(m), _hlo_s32(n), _hlo_s32(lwork), a],
|
||||
operand_layouts=[scalar_layout] * 4 + [layout],
|
||||
result_layouts=[
|
||||
layout,
|
||||
@ -197,7 +211,11 @@ def geqrf_mhlo(dtype, a):
|
||||
|
||||
# # ?orgqr: product of elementary Householder reflectors:
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def orgqr_mhlo(dtype, a, tau):
|
||||
return orgqr_hlo(dtype, a, tau)
|
||||
|
||||
def orgqr_hlo(dtype, a, tau):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
@ -238,8 +256,8 @@ def orgqr_mhlo(dtype, a, tau):
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
ir.RankedTensorType.get([lwork], a_type.element_type),
|
||||
],
|
||||
[_mhlo_s32(int(b)), _mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(k),
|
||||
_mhlo_s32(lwork), a, tau],
|
||||
[_hlo_s32(int(b)), _hlo_s32(m), _hlo_s32(n), _hlo_s32(k),
|
||||
_hlo_s32(lwork), a, tau],
|
||||
operand_layouts=[scalar_layout] * 5 + [
|
||||
layout,
|
||||
tuple(range(num_bd, -1, -1)),
|
||||
@ -256,7 +274,11 @@ def orgqr_mhlo(dtype, a, tau):
|
||||
|
||||
# ?potrf: Cholesky decomposition
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def potrf_mhlo(dtype, a, lower=False):
|
||||
return potrf_hlo(dtype, a, lower=lower)
|
||||
|
||||
def potrf_hlo(dtype, a, lower=False):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
@ -286,7 +308,7 @@ def potrf_mhlo(dtype, a, lower=False):
|
||||
fn,
|
||||
[a.type,
|
||||
ir.RankedTensorType.get(batch_dims, ir.IntegerType.get_signless(32))],
|
||||
[_mhlo_s32(int(lower)), _mhlo_s32(b), _mhlo_s32(n), a],
|
||||
[_hlo_s32(int(lower)), _hlo_s32(b), _hlo_s32(n), a],
|
||||
operand_layouts=[scalar_layout] * 3 + [layout],
|
||||
result_layouts=[layout, info_layout],
|
||||
operand_output_aliases={3: 0},
|
||||
@ -297,7 +319,11 @@ def potrf_mhlo(dtype, a, lower=False):
|
||||
|
||||
# # ?gesdd: Singular value decomposition
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
|
||||
return gesdd_hlo(dtype, a, full_matrices=full_matrices, compute_uv=compute_uv)
|
||||
|
||||
def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
@ -370,8 +396,8 @@ def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
|
||||
a_type.element_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
] + workspace,
|
||||
[_mhlo_s32(int(full_matrices)), _mhlo_s32(int(compute_uv)), _mhlo_s32(b),
|
||||
_mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(lwork), a],
|
||||
[_hlo_s32(int(full_matrices)), _hlo_s32(int(compute_uv)), _hlo_s32(b),
|
||||
_hlo_s32(m), _hlo_s32(n), _hlo_s32(lwork), a],
|
||||
operand_layouts=[scalar_layout] * 6 + [layout],
|
||||
result_layouts=[
|
||||
layout,
|
||||
@ -387,7 +413,11 @@ def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True):
|
||||
|
||||
# # syevd: Symmetric eigendecomposition
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def syevd_mhlo(dtype, a, lower=False):
|
||||
return syevd_hlo(dtype, a, lower=lower)
|
||||
|
||||
def syevd_hlo(dtype, a, lower=False):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
@ -452,7 +482,7 @@ def syevd_mhlo(dtype, a, lower=False):
|
||||
ir.RankedTensorType.get(batch_dims + (n,), eigvals_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
] + workspace,
|
||||
[_mhlo_s32(1 if lower else 0), _mhlo_s32(b), _mhlo_s32(n), a],
|
||||
[_hlo_s32(1 if lower else 0), _hlo_s32(b), _hlo_s32(n), a],
|
||||
operand_layouts=[scalar_layout] * 3 + [layout],
|
||||
result_layouts=[
|
||||
layout,
|
||||
@ -466,7 +496,11 @@ def syevd_mhlo(dtype, a, lower=False):
|
||||
|
||||
# # geev: Nonsymmetric eigendecomposition
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def geev_mhlo(dtype, a, jobvl=True, jobvr=True):
|
||||
return geev_hlo(dtype, a, jobvl=jobvl, jobvr=jobvr)
|
||||
|
||||
def geev_hlo(dtype, a, jobvl=True, jobvr=True):
|
||||
_initialize()
|
||||
dims = ir.RankedTensorType(a.type).shape
|
||||
assert len(dims) >= 2
|
||||
@ -539,19 +573,23 @@ def geev_mhlo(dtype, a, jobvl=True, jobvr=True):
|
||||
ir.RankedTensorType.get(dims, eigvecs_type),
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
],
|
||||
[_mhlo_s32(b), _mhlo_s32(n), _mhlo_u8(jobvl_c), _mhlo_u8(jobvr_c), a],
|
||||
[_hlo_s32(b), _hlo_s32(n), _hlo_u8(jobvl_c), _hlo_u8(jobvr_c), a],
|
||||
operand_layouts=[scalar_layout] * 4 + [layout],
|
||||
result_layouts=(workspace_layouts + eigvals_layouts + [layout] * 2 +
|
||||
[info_layout])
|
||||
)
|
||||
if real:
|
||||
return (mhlo.ComplexOp(out[3], out[4]).result, out[5], out[6], out[7])
|
||||
return (hlo.ComplexOp(out[3], out[4]).result, out[5], out[6], out[7])
|
||||
else:
|
||||
return out[2:6]
|
||||
|
||||
# # gees : Schur factorization
|
||||
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
|
||||
return gees_hlo(dtype, a, jobvs=jobvs, sort=sort, select=select)
|
||||
|
||||
def gees_hlo(dtype, a, jobvs=True, sort=False, select=None):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
etype = a_type.element_type
|
||||
@ -609,10 +647,10 @@ def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
],
|
||||
[
|
||||
_mhlo_s32(b),
|
||||
_mhlo_s32(n),
|
||||
_mhlo_u8(np.uint8(jobvs)),
|
||||
_mhlo_u8(np.uint8(sort)),
|
||||
_hlo_s32(b),
|
||||
_hlo_s32(n),
|
||||
_hlo_u8(np.uint8(jobvs)),
|
||||
_hlo_u8(np.uint8(sort)),
|
||||
# TODO: figure out how to put the callable select function here
|
||||
a
|
||||
],
|
||||
@ -630,8 +668,12 @@ def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None):
|
||||
return (out[0], out[3], out[5])
|
||||
|
||||
|
||||
# gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form.
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def gehrd_mhlo(dtype, a):
|
||||
return gehrd_hlo(dtype, a)
|
||||
|
||||
# gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form.
|
||||
def gehrd_hlo(dtype, a):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
@ -669,8 +711,8 @@ def gehrd_mhlo(dtype, a):
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
ir.RankedTensorType.get([lwork], a_type.element_type),
|
||||
],
|
||||
[_mhlo_s32(n), _mhlo_s32(1), _mhlo_s32(n), _mhlo_s32(n), _mhlo_s32(b),
|
||||
_mhlo_s32(lwork), a],
|
||||
[_hlo_s32(n), _hlo_s32(1), _hlo_s32(n), _hlo_s32(n), _hlo_s32(b),
|
||||
_hlo_s32(lwork), a],
|
||||
operand_layouts=[[]] * 6 + [layout],
|
||||
result_layouts=[
|
||||
layout,
|
||||
@ -683,8 +725,12 @@ def gehrd_mhlo(dtype, a):
|
||||
return out[:3]
|
||||
|
||||
|
||||
# sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form.
|
||||
# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41.
|
||||
def sytrd_mhlo(dtype, a, *, lower):
|
||||
return sytrd_hlo(dtype, a, lower=lower)
|
||||
|
||||
# sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form.
|
||||
def sytrd_hlo(dtype, a, *, lower):
|
||||
_initialize()
|
||||
a_type = ir.RankedTensorType(a.type)
|
||||
dims = a_type.shape
|
||||
@ -728,8 +774,8 @@ def sytrd_mhlo(dtype, a, *, lower):
|
||||
ir.RankedTensorType.get(batch_dims, i32_type),
|
||||
ir.RankedTensorType.get([lwork], a_type.element_type),
|
||||
],
|
||||
[_mhlo_s32(n), _mhlo_s32(1 if lower else 0), _mhlo_s32(max(1, n)),
|
||||
_mhlo_s32(b), _mhlo_s32(lwork), a],
|
||||
[_hlo_s32(n), _hlo_s32(1 if lower else 0), _hlo_s32(max(1, n)),
|
||||
_hlo_s32(b), _hlo_s32(lwork), a],
|
||||
operand_layouts=[[]] * 5 + [layout],
|
||||
result_layouts=[
|
||||
layout,
|
||||
|
@ -155,10 +155,10 @@ def _sp_indices_abstract_eval(mat):
|
||||
# Note: cannot use lower_fun to define attribute access primitives
|
||||
# because it leads to infinite recursion.
|
||||
|
||||
def _sp_indices_mhlo_lowering(ctx, data_and_indices):
|
||||
def _sp_indices_hlo_lowering(ctx, data_and_indices):
|
||||
return [data_and_indices[1]]
|
||||
|
||||
mlir.register_lowering(sp_indices_p, _sp_indices_mhlo_lowering)
|
||||
mlir.register_lowering(sp_indices_p, _sp_indices_hlo_lowering)
|
||||
|
||||
sp_data_p = core.Primitive('sp_data')
|
||||
|
||||
@ -173,10 +173,10 @@ def _sp_data_abstract_eval(mat):
|
||||
# Note: cannot use lower_fun to define attribute access primitives
|
||||
# because it leads to infinite recursion.
|
||||
|
||||
def _sp_data_mhlo_lowering(ctx, data_and_indices):
|
||||
def _sp_data_hlo_lowering(ctx, data_and_indices):
|
||||
return [data_and_indices[0]]
|
||||
|
||||
mlir.register_lowering(sp_data_p, _sp_data_mhlo_lowering)
|
||||
mlir.register_lowering(sp_data_p, _sp_data_hlo_lowering)
|
||||
|
||||
def identity(x):
|
||||
return identity_p.bind(x)
|
||||
|
@ -767,7 +767,7 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=True)
|
||||
|
||||
# TODO(mattjj,dougalm,phawkins): debug iree failure, "failed to legalize
|
||||
# operation 'mhlo.while' that was explicitly marked illegal"
|
||||
# operation 'while' that was explicitly marked illegal"
|
||||
@unittest.skip("revising slicing logic")
|
||||
def test_scan_basic(self):
|
||||
def cumsum(x):
|
||||
@ -1275,8 +1275,8 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
return x.sum()
|
||||
|
||||
f_lowered = f.lower(np.arange(3, dtype='int32'))
|
||||
mhlo = f_lowered.compiler_ir('mhlo')
|
||||
self.assertIn('tensor<?xi32>', str(mhlo))
|
||||
mlir_str = f_lowered.compiler_ir()
|
||||
self.assertIn('tensor<?xi32>', str(mlir_str))
|
||||
|
||||
def test_lower_abstracted_axes_shapedtypestruct(self):
|
||||
@partial(jax.jit, abstracted_axes=('n',))
|
||||
@ -1284,8 +1284,8 @@ class DynamicShapeTest(jtu.JaxTestCase):
|
||||
return x.sum()
|
||||
|
||||
f_lowered = f.lower(jax.ShapeDtypeStruct((3,), np.int32))
|
||||
mhlo = f_lowered.compiler_ir('mhlo')
|
||||
self.assertIn('tensor<?xi32>', str(mhlo))
|
||||
mlir_str = f_lowered.compiler_ir()
|
||||
self.assertIn('tensor<?xi32>', str(mlir_str))
|
||||
|
||||
def test_vmap_abstracted_axis(self):
|
||||
def foo(x, y):
|
||||
|
@ -1,6 +1,6 @@
|
||||
This directory contains LLVM
|
||||
[FileCheck](https://llvm.org/docs/CommandGuide/FileCheck.html) tests that verify
|
||||
that JAX primitives can be lowered to MHLO.
|
||||
that JAX primitives can be lowered to MLIR.
|
||||
|
||||
These tests are intended to be a quick and easy-to-understand way to catch
|
||||
regressions from changes due the MLIR Python bindings and from changes to the
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Tests for lowering of array origami ops into MHLO.
|
||||
# Tests for lowering of array origami ops into MLIR.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
@ -32,62 +32,62 @@ jax.config.update("jax_enable_x64", True)
|
||||
|
||||
def main(_):
|
||||
# CHECK-LABEL: TEST: concatenate bool[2,7] bool[2,5]
|
||||
# CHECK: mhlo.concatenate
|
||||
# CHECK: hlo.concatenate
|
||||
# CHECK-SAME: tensor<2x12xi1>
|
||||
print_ir([np.empty([2, 7], np.bool_), np.empty([2, 5], np.bool_)])(
|
||||
partial(lax.concatenate, dimension=1))
|
||||
|
||||
# CHECK-LABEL: TEST: broadcast_in_dim bool[2,7]
|
||||
# CHECK: mhlo.broadcast_in_dim
|
||||
# CHECK: hlo.broadcast_in_dim
|
||||
# CHECK-SAME: tensor<3x2x5x7x2xi1>
|
||||
print_ir(np.empty([2, 7], np.bool_))(
|
||||
partial(lax.broadcast_in_dim, shape=(3, 2, 5, 7, 2),
|
||||
broadcast_dimensions=(1, 3)))
|
||||
|
||||
# CHECK-LABEL: TEST: iota
|
||||
# CHECK: mhlo.iota
|
||||
# CHECK: hlo.iota
|
||||
# CHECK-SAME: tensor<10xf32>
|
||||
print_ir()(partial(lax.iota, dtype=np.float32, size=10))
|
||||
|
||||
# CHECK-LABEL: TEST: pad int32[2,7]
|
||||
# CHECK: mhlo.pad
|
||||
# CHECK: hlo.pad
|
||||
# CHECK-SAME: tensor<11x52xi32>
|
||||
print_ir(np.empty([2, 7], np.int32))(
|
||||
partial(lax.pad, padding_value=np.int32(7),
|
||||
padding_config=((2, 3, 4), (4, 5, 6))))
|
||||
|
||||
# CHECK-LABEL: TEST: _reduce_sum int32[2,3,7]
|
||||
# CHECK: mhlo.reduce
|
||||
# CHECK: mhlo.add
|
||||
# CHECK: hlo.reduce
|
||||
# CHECK: hlo.add
|
||||
# CHECK: tensor<3xi32>
|
||||
print_ir(np.empty([2, 3, 7], np.int32))(
|
||||
partial(lax_internal._reduce_sum, axes=(0, 2)))
|
||||
|
||||
# CHECK-LABEL: TEST: reshape int32[2,3,7]
|
||||
# CHECK: mhlo.reshape
|
||||
# CHECK: hlo.reshape
|
||||
# CHECK-SAME: tensor<42xi32>
|
||||
print_ir(np.empty([2, 3, 7], np.int32))(
|
||||
partial(lax.reshape, new_sizes=(42,)))
|
||||
|
||||
# CHECK-LABEL: TEST: rev int32[2,7]
|
||||
# CHECK: mhlo.rev
|
||||
# CHECK: hlo.rev
|
||||
# CHECK-SAME: tensor<2x7xi32>
|
||||
print_ir(np.empty([2, 7], np.int32))(
|
||||
partial(lax.rev, dimensions=(0, 1)))
|
||||
|
||||
# CHECK-LABEL: TEST: select bool[2,7] int32[2,7] int32[2,7]
|
||||
# CHECK: mhlo.select
|
||||
# CHECK: hlo.select
|
||||
# CHECK-SAME: tensor<2x7xi1>, tensor<2x7xi32>
|
||||
print_ir(np.empty([2, 7], np.bool_), np.empty([2, 7], np.int32),
|
||||
np.empty([2, 7], np.int32))(lax.select)
|
||||
|
||||
# CHECK-LABEL: TEST: sort int32[2,7]
|
||||
# CHECK: mhlo.sort
|
||||
# CHECK: hlo.sort
|
||||
# CHECK: tensor<2x7xi32>
|
||||
print_ir(np.empty([2, 7], np.int32))(lax.sort)
|
||||
|
||||
# CHECK-LABEL: TEST: squeeze int32[2,1,7]
|
||||
# CHECK: mhlo.reshape
|
||||
# CHECK: hlo.reshape
|
||||
# CHECK-SAME: tensor<2x7xi32>
|
||||
print_ir(np.empty([2, 1, 7], np.int32))(
|
||||
partial(lax.squeeze, dimensions=(1,)))
|
||||
@ -98,7 +98,7 @@ def main(_):
|
||||
print_ir(np.empty([2, 7], np.int32))(partial(lax.top_k, k=7))
|
||||
|
||||
# CHECK-LABEL: TEST: transpose int32[2,7]
|
||||
# CHECK: mhlo.transpose
|
||||
# CHECK: hlo.transpose
|
||||
# CHECK-SAME: tensor<7x2xi32>
|
||||
print_ir(np.empty([2, 7], np.int32))(
|
||||
partial(lax.transpose, permutation=(1, 0)))
|
||||
|
@ -20,7 +20,7 @@ import numpy as np
|
||||
|
||||
def print_ir(*prototypes):
|
||||
def lower(f):
|
||||
"""Prints the MHLO IR that results from lowering `f`.
|
||||
"""Prints the MLIR IR that results from lowering `f`.
|
||||
|
||||
The arguments to `f` are taken to be arrays shaped like `prototypes`."""
|
||||
inputs = tree_util.tree_map(np.array, prototypes)
|
||||
@ -29,5 +29,5 @@ def print_ir(*prototypes):
|
||||
for x in flat_inputs])
|
||||
name = f.func.__name__ if hasattr(f, "func") else f.__name__
|
||||
print(f"\nTEST: {name} {shape_strs}")
|
||||
print(jax.jit(f).lower(*inputs).compiler_ir(dialect="mhlo"))
|
||||
print(jax.jit(f).lower(*inputs).compiler_ir())
|
||||
return lower
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Tests for lowerings of elementwise ops to MHLO.
|
||||
# Tests for lowerings of elementwise ops to MLIR.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
@ -31,17 +31,17 @@ jax.config.update("jax_enable_x64", True)
|
||||
|
||||
def main(_):
|
||||
# CHECK-LABEL: TEST: abs int32[]
|
||||
# CHECK: mhlo.abs
|
||||
# CHECK: hlo.abs
|
||||
# CHECK-SAME: tensor<i32>
|
||||
print_ir(np.int32(0))(lax.abs)
|
||||
|
||||
# CHECK-LABEL: TEST: add float32[] float32[]
|
||||
# CHECK: mhlo.add
|
||||
# CHECK: hlo.add
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.add)
|
||||
|
||||
# CHECK-LABEL: TEST: acos float32[]
|
||||
# CHECK: mhlo.atan2
|
||||
# CHECK: hlo.atan2
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1))(lax.acos)
|
||||
|
||||
@ -71,7 +71,7 @@ def main(_):
|
||||
print_ir(np.float32(0))(lax.atanh)
|
||||
|
||||
# CHECK-LABEL: TEST: atan2 float64[] float64[]
|
||||
# CHECK: mhlo.atan2
|
||||
# CHECK: hlo.atan2
|
||||
# CHECK-SAME: tensor<f64>
|
||||
print_ir(np.float64(1), np.float64(2))(lax.atan2)
|
||||
|
||||
@ -91,93 +91,93 @@ def main(_):
|
||||
print_ir(np.float32(0), np.float32(0), np.float32(0))(lax.betainc)
|
||||
|
||||
# CHECK-LABEL: TEST: bitcast_convert_type uint32[7]
|
||||
# CHECK: mhlo.bitcast_convert
|
||||
# CHECK: hlo.bitcast_convert
|
||||
# CHECK-SAME: tensor<7xui32>
|
||||
# CHECK-SAME: tensor<7xf32>
|
||||
print_ir(np.empty((7,), np.uint32))(
|
||||
partial(lax.bitcast_convert_type, new_dtype=np.float32))
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_and int32[] int32[]
|
||||
# CHECK: mhlo.and
|
||||
# CHECK: hlo.and
|
||||
# CHECK-SAME: tensor<i32>
|
||||
print_ir(np.int32(1), np.int32(2))(lax.bitwise_and)
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_and bool[] bool[]
|
||||
# CHECK: mhlo.and
|
||||
# CHECK: hlo.and
|
||||
# CHECK-SAME: tensor<i1>
|
||||
print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_and)
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_or int32[] int32[]
|
||||
# CHECK: mhlo.or
|
||||
# CHECK: hlo.or
|
||||
# CHECK-SAME: tensor<i32>
|
||||
print_ir(np.int32(1), np.int32(2))(lax.bitwise_or)
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_or bool[] bool[]
|
||||
# CHECK: mhlo.or
|
||||
# CHECK: hlo.or
|
||||
# CHECK-SAME: tensor<i1>
|
||||
print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_or)
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_xor int32[] int32[]
|
||||
# CHECK: mhlo.xor
|
||||
# CHECK: hlo.xor
|
||||
# CHECK-SAME: tensor<i32>
|
||||
print_ir(np.int32(1), np.int32(2))(lax.bitwise_xor)
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_xor bool[] bool[]
|
||||
# CHECK: mhlo.xor
|
||||
# CHECK: hlo.xor
|
||||
# CHECK-SAME: tensor<i1>
|
||||
print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_xor)
|
||||
|
||||
# CHECK-LABEL: TEST: cbrt bfloat16[]
|
||||
# CHECK: mhlo.cbrt
|
||||
# CHECK: hlo.cbrt
|
||||
# CHECK-SAME: tensor<bf16>
|
||||
print_ir(jnp.bfloat16(0))(lax.cbrt)
|
||||
|
||||
# CHECK-LABEL: TEST: clamp bfloat16[] bfloat16[] bfloat16[]
|
||||
# CHECK: mhlo.clamp
|
||||
# CHECK: hlo.clamp
|
||||
# CHECK-SAME: tensor<bf16>
|
||||
print_ir(jnp.bfloat16(0), jnp.bfloat16(0), jnp.bfloat16(0))(lax.clamp)
|
||||
|
||||
# CHECK-LABEL: TEST: ceil float16[7]
|
||||
# CHECK: mhlo.ceil
|
||||
# CHECK: hlo.ceil
|
||||
# CHECK-SAME: tensor<7xf16>
|
||||
print_ir(np.empty((7,), np.float16))(lax.ceil)
|
||||
|
||||
# CHECK-LABEL: TEST: convert_element_type float16[7]
|
||||
# CHECK: mhlo.convert
|
||||
# CHECK: hlo.convert
|
||||
# CHECK-SAME: tensor<7xf16>
|
||||
# CHECK-SAME: tensor<7xf32>
|
||||
print_ir(np.empty((7,), np.float16))(
|
||||
partial(lax.convert_element_type, new_dtype=np.float32))
|
||||
|
||||
# CHECK-LABEL: TEST: convert_element_type complex64[7]
|
||||
# CHECK: mhlo.real
|
||||
# CHECK: hlo.real
|
||||
# CHECK-SAME: tensor<7xcomplex<f32>>
|
||||
# CHECK-SAME: tensor<7xf32>
|
||||
print_ir(np.empty((7,), np.complex64))(
|
||||
partial(lax.convert_element_type, new_dtype=np.float32))
|
||||
|
||||
# CHECK-LABEL: TEST: convert_element_type float32[7]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK: hlo.compare
|
||||
# CHECK-SAME: tensor<7xf32>
|
||||
# CHECK-SAME: tensor<7xi1>
|
||||
print_ir(np.empty((7,), np.float32))(
|
||||
partial(lax.convert_element_type, new_dtype=np.bool_))
|
||||
|
||||
# CHECK-LABEL: TEST: clz uint32[]
|
||||
# CHECK: mhlo.count_leading_zeros
|
||||
# CHECK: hlo.count_leading_zeros
|
||||
# CHECK-SAME: tensor<ui32>
|
||||
print_ir(np.uint32(0))(lax.clz)
|
||||
|
||||
# CHECK-LABEL: TEST: conj complex64[]
|
||||
# CHECK-DAG: mhlo.real
|
||||
# CHECK-DAG: mhlo.imag
|
||||
# CHECK-DAG: mhlo.neg
|
||||
# CHECK-DAG: mhlo.complex
|
||||
# CHECK-DAG: hlo.real
|
||||
# CHECK-DAG: hlo.imag
|
||||
# CHECK-DAG: hlo.neg
|
||||
# CHECK-DAG: hlo.complex
|
||||
# CHECK-SAME: tensor<complex<f32>>
|
||||
print_ir(np.complex64(0))(lax.conj)
|
||||
|
||||
# CHECK-LABEL: TEST: cos float32[]
|
||||
# CHECK: mhlo.cos
|
||||
# CHECK: hlo.cos
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.cos)
|
||||
|
||||
@ -192,30 +192,30 @@ def main(_):
|
||||
print_ir(np.float32(0))(lax.digamma)
|
||||
|
||||
# CHECK-LABEL: TEST: div float32[] float32[]
|
||||
# CHECK: mhlo.div
|
||||
# CHECK: hlo.div
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.div)
|
||||
|
||||
# CHECK-LABEL: TEST: eq float32[] float32[]
|
||||
# CHECK: mhlo.compare EQ
|
||||
# CHECK: hlo.compare EQ
|
||||
# CHECK-SAME: FLOAT
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.eq)
|
||||
|
||||
# CHECK-LABEL: TEST: eq complex128[] complex128[]
|
||||
# CHECK: mhlo.compare EQ
|
||||
# CHECK: hlo.compare EQ
|
||||
# CHECK-SAME: FLOAT
|
||||
# CHECK-SAME: tensor<complex<f64>>
|
||||
print_ir(np.complex128(1), np.complex128(2))(lax.eq)
|
||||
|
||||
# CHECK-LABEL: TEST: eq int64[] int64[]
|
||||
# CHECK: mhlo.compare EQ
|
||||
# CHECK: hlo.compare EQ
|
||||
# CHECK-SAME: SIGNED
|
||||
# CHECK-SAME: tensor<i64>
|
||||
print_ir(np.int64(1), np.int64(2))(lax.eq)
|
||||
|
||||
# CHECK-LABEL: TEST: eq uint16[] uint16[]
|
||||
# CHECK: mhlo.compare EQ
|
||||
# CHECK: hlo.compare EQ
|
||||
# CHECK-SAME: UNSIGNED
|
||||
# CHECK-SAME: tensor<ui16>
|
||||
print_ir(np.uint16(1), np.uint16(2))(lax.eq)
|
||||
@ -236,28 +236,28 @@ def main(_):
|
||||
print_ir(np.float32(0))(lax.erf_inv)
|
||||
|
||||
# CHECK-LABEL: TEST: exp float16[]
|
||||
# CHECK: mhlo.exp
|
||||
# CHECK: hlo.exp
|
||||
# CHECK-SAME: tensor<f16>
|
||||
print_ir(np.float16(0))(lax.exp)
|
||||
|
||||
# CHECK-LABEL: TEST: expm1 bfloat16[]
|
||||
# CHECK: mhlo.exponential_minus_one
|
||||
# CHECK: hlo.exponential_minus_one
|
||||
# CHECK-SAME: tensor<bf16>
|
||||
print_ir(jnp.bfloat16(0))(lax.expm1)
|
||||
|
||||
# CHECK-LABEL: TEST: floor bfloat16[2,3]
|
||||
# CHECK: mhlo.floor
|
||||
# CHECK: hlo.floor
|
||||
# CHECK-SAME: tensor<2x3xbf16>
|
||||
print_ir(np.empty((2, 3), jnp.bfloat16))(lax.floor)
|
||||
|
||||
# CHECK-LABEL: TEST: ge float32[] float32[]
|
||||
# CHECK: mhlo.compare GE
|
||||
# CHECK: hlo.compare GE
|
||||
# CHECK-SAME: FLOAT
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.ge)
|
||||
|
||||
# CHECK-LABEL: TEST: gt float32[] float32[]
|
||||
# CHECK: mhlo.compare GT
|
||||
# CHECK: hlo.compare GT
|
||||
# CHECK-SAME: FLOAT
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.gt)
|
||||
@ -278,23 +278,23 @@ def main(_):
|
||||
print_ir(np.float32(0), np.float32(0))(lax.igamma_grad_a)
|
||||
|
||||
# CHECK-LABEL: TEST: imag complex64[]
|
||||
# CHECK: mhlo.imag
|
||||
# CHECK: hlo.imag
|
||||
# CHECK-SAME: tensor<complex<f32>>
|
||||
print_ir(np.complex64(0))(lax.imag)
|
||||
|
||||
# CHECK-LABEL: TEST: integer_pow float32[]
|
||||
# CHECK-DAG: mhlo.mul
|
||||
# CHECK-DAG: hlo.mul
|
||||
# CHECK-SAME: tensor<f32>
|
||||
@print_ir(np.float32(1))
|
||||
def integer_pow(x): return lax.integer_pow(x, 3)
|
||||
|
||||
# CHECK-LABEL: TEST: is_finite float64[]
|
||||
# CHECK: mhlo.is_finite
|
||||
# CHECK: hlo.is_finite
|
||||
# CHECK-SAME: tensor<f64>
|
||||
print_ir(np.float64(0))(lax.is_finite)
|
||||
|
||||
# CHECK-LABEL: TEST: le float32[] float32[]
|
||||
# CHECK: mhlo.compare LE
|
||||
# CHECK: hlo.compare LE
|
||||
# CHECK-SAME: FLOAT
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.le)
|
||||
@ -305,44 +305,44 @@ def main(_):
|
||||
print_ir(np.float32(0))(lax.lgamma)
|
||||
|
||||
# CHECK-LABEL: TEST: log float32[]
|
||||
# CHECK: mhlo.log
|
||||
# CHECK: hlo.log
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.log)
|
||||
|
||||
# CHECK-LABEL: TEST: log1p float32[]
|
||||
# CHECK: mhlo.log_plus_one
|
||||
# CHECK: hlo.log_plus_one
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.log1p)
|
||||
|
||||
# CHECK-LABEL: TEST: lt float32[] float32[]
|
||||
# CHECK: mhlo.compare LT
|
||||
# CHECK: hlo.compare LT
|
||||
# CHECK-SAME: FLOAT
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.lt)
|
||||
|
||||
# CHECK-LABEL: TEST: max float32[] float32[]
|
||||
# CHECK: mhlo.max
|
||||
# CHECK: hlo.max
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.max)
|
||||
|
||||
# CHECK-LABEL: TEST: min float32[] float32[]
|
||||
# CHECK: mhlo.min
|
||||
# CHECK: hlo.min
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.min)
|
||||
|
||||
# CHECK-LABEL: TEST: mul float32[] float32[]
|
||||
# CHECK: mhlo.mul
|
||||
# CHECK: hlo.mul
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.mul)
|
||||
|
||||
# CHECK-LABEL: TEST: ne float32[] float32[]
|
||||
# CHECK: mhlo.compare NE
|
||||
# CHECK: hlo.compare NE
|
||||
# CHECK-SAME: FLOAT
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.ne)
|
||||
|
||||
# CHECK-LABEL: TEST: neg int64[]
|
||||
# CHECK: mhlo.negate
|
||||
# CHECK: hlo.negate
|
||||
# CHECK-SAME: tensor<i64>
|
||||
print_ir(np.int64(0))(lax.neg)
|
||||
|
||||
@ -352,22 +352,22 @@ def main(_):
|
||||
print_ir(np.float32(0), np.float32(0))(lax.nextafter)
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_not int64[]
|
||||
# CHECK: mhlo.not
|
||||
# CHECK: hlo.not
|
||||
# CHECK-SAME: tensor<i64>
|
||||
print_ir(np.int64(0))(lax.bitwise_not)
|
||||
|
||||
# CHECK-LABEL: TEST: bitwise_not bool[]
|
||||
# CHECK: mhlo.not
|
||||
# CHECK: hlo.not
|
||||
# CHECK-SAME: tensor<i1>
|
||||
print_ir(np.bool_(0))(lax.bitwise_not)
|
||||
|
||||
# CHECK-LABEL: TEST: population_count uint32[]
|
||||
# CHECK: mhlo.popcnt
|
||||
# CHECK: hlo.popcnt
|
||||
# CHECK-SAME: tensor<ui32>
|
||||
print_ir(np.uint32(0))(lax.population_count)
|
||||
|
||||
# CHECK-LABEL: TEST: pow float32[] float32[]
|
||||
# CHECK: mhlo.power
|
||||
# CHECK: hlo.power
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.pow)
|
||||
|
||||
@ -377,59 +377,59 @@ def main(_):
|
||||
print_ir(np.float32(0), np.float32(0))(lax.random_gamma_grad)
|
||||
|
||||
# CHECK-LABEL: TEST: real complex128[]
|
||||
# CHECK: mhlo.real
|
||||
# CHECK: hlo.real
|
||||
# CHECK-SAME: tensor<complex<f64>>
|
||||
print_ir(np.complex128(0))(lax.real)
|
||||
|
||||
# CHECK-LABEL: TEST: reduce_precision bfloat16[]
|
||||
# CHECK: mhlo.reduce_precision
|
||||
# CHECK: hlo.reduce_precision
|
||||
# CHECK-SAME: tensor<bf16>
|
||||
print_ir(jnp.bfloat16(0))(
|
||||
partial(lax.reduce_precision, exponent_bits=2, mantissa_bits=2))
|
||||
|
||||
# CHECK-LABEL: TEST: rem float32[] float32[]
|
||||
# CHECK: mhlo.rem
|
||||
# CHECK: hlo.rem
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.rem)
|
||||
|
||||
# CHECK-LABEL: TEST: round float64[7,1]
|
||||
# CHECK: mhlo.round
|
||||
# CHECK: hlo.round
|
||||
# CHECK-SAME: tensor<7x1xf64>
|
||||
print_ir(np.empty((7,1), np.float64))(
|
||||
partial(lax.round, rounding_method=lax.RoundingMethod.AWAY_FROM_ZERO))
|
||||
|
||||
# CHECK-LABEL: TEST: rsqrt complex64[]
|
||||
# CHECK: mhlo.rsqrt
|
||||
# CHECK: hlo.rsqrt
|
||||
# CHECK-SAME: tensor<complex<f32>>
|
||||
print_ir(jnp.complex64(0))(lax.rsqrt)
|
||||
|
||||
# CHECK-LABEL: TEST: shift_left uint32[] uint32[]
|
||||
# CHECK: mhlo.shift_left
|
||||
# CHECK: hlo.shift_left
|
||||
# CHECK-SAME: tensor<ui32>
|
||||
print_ir(np.uint32(0), np.uint32(0))(lax.shift_left)
|
||||
|
||||
# CHECK-LABEL: TEST: shift_right_arithmetic uint8[] uint8[]
|
||||
# CHECK: mhlo.shift_right_arithmetic
|
||||
# CHECK: hlo.shift_right_arithmetic
|
||||
# CHECK-SAME: tensor<ui8>
|
||||
print_ir(np.uint8(0), np.uint8(0))(lax.shift_right_arithmetic)
|
||||
|
||||
# CHECK-LABEL: TEST: shift_right_logical uint16[] uint16[]
|
||||
# CHECK: mhlo.shift_right_logical
|
||||
# CHECK: hlo.shift_right_logical
|
||||
# CHECK-SAME: tensor<ui16>
|
||||
print_ir(np.uint16(0), np.uint16(0))(lax.shift_right_logical)
|
||||
|
||||
# CHECK-LABEL: TEST: sign int64[]
|
||||
# CHECK: mhlo.sign
|
||||
# CHECK: hlo.sign
|
||||
# CHECK-SAME: tensor<i64>
|
||||
print_ir(np.int64(0))(lax.sign)
|
||||
|
||||
# CHECK-LABEL: TEST: sign uint32[]
|
||||
# CHECK: mhlo.compare
|
||||
# CHECK: hlo.compare
|
||||
# CHECK-SAME: tensor<ui32>
|
||||
print_ir(np.uint32(0))(lax.sign)
|
||||
|
||||
# CHECK-LABEL: TEST: sin float32[]
|
||||
# CHECK: mhlo.sin
|
||||
# CHECK: hlo.sin
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.sin)
|
||||
|
||||
@ -439,12 +439,12 @@ def main(_):
|
||||
print_ir(np.float32(0))(lax.sinh)
|
||||
|
||||
# CHECK-LABEL: TEST: sub float32[] float32[]
|
||||
# CHECK: mhlo.sub
|
||||
# CHECK: hlo.sub
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(1), np.float32(2))(lax.sub)
|
||||
|
||||
# CHECK-LABEL: TEST: sqrt bfloat16[]
|
||||
# CHECK: mhlo.sqrt
|
||||
# CHECK: hlo.sqrt
|
||||
# CHECK-SAME: tensor<bf16>
|
||||
print_ir(jnp.bfloat16(0))(lax.sqrt)
|
||||
|
||||
@ -454,7 +454,7 @@ def main(_):
|
||||
print_ir(np.float16(0))(lax.tan)
|
||||
|
||||
# CHECK-LABEL: TEST: tanh float32[]
|
||||
# CHECK: mhlo.tanh
|
||||
# CHECK: hlo.tanh
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.tanh)
|
||||
|
||||
|
@ -30,77 +30,77 @@ jax.config.update("jax_enable_x64", True)
|
||||
|
||||
def main(_):
|
||||
# CHECK-LABEL: TEST: bitwise_not bool[7]
|
||||
# CHECK: mhlo.not
|
||||
# CHECK: hlo.not
|
||||
# CHECK-SAME: tensor<7xi1>
|
||||
print_ir(np.empty([7], np.bool_))(lax.bitwise_not)
|
||||
|
||||
# CHECK-LABEL: TEST: neg int8[]
|
||||
# CHECK: mhlo.negate
|
||||
# CHECK: hlo.negate
|
||||
# CHECK-SAME: tensor<i8>
|
||||
print_ir(np.int8(0))(lax.neg)
|
||||
|
||||
# CHECK-LABEL: TEST: neg int16[0]
|
||||
# CHECK: mhlo.negate
|
||||
# CHECK: hlo.negate
|
||||
# CHECK-SAME: tensor<0xi16>
|
||||
print_ir(np.empty([0], np.int16))(lax.neg)
|
||||
|
||||
# CHECK-LABEL: TEST: neg int32[2,3]
|
||||
# CHECK: mhlo.negate
|
||||
# CHECK: hlo.negate
|
||||
# CHECK-SAME: tensor<2x3xi32>
|
||||
print_ir(np.empty([2, 3], np.int32))(lax.neg)
|
||||
|
||||
# CHECK-LABEL: TEST: neg int64[2,3,4]
|
||||
# CHECK: mhlo.negate
|
||||
# CHECK: hlo.negate
|
||||
# CHECK-SAME: tensor<2x3x4xi64>
|
||||
print_ir(np.empty([2,3,4], np.int64))(lax.neg)
|
||||
|
||||
# CHECK-LABEL: TEST: add uint8[4,0,1] uint8[4,0,1]
|
||||
# CHECK: mhlo.add
|
||||
# CHECK: hlo.add
|
||||
# CHECK-SAME: tensor<4x0x1xui8>
|
||||
print_ir(np.empty([4,0,1], np.uint8), np.empty([4,0,1], np.uint8))(lax.add)
|
||||
|
||||
# CHECK-LABEL: TEST: add uint16[] uint16[]
|
||||
# CHECK: mhlo.add
|
||||
# CHECK: hlo.add
|
||||
# CHECK-SAME: tensor<ui16>
|
||||
print_ir(np.uint16(0), np.uint16(0))(lax.add)
|
||||
|
||||
# CHECK-LABEL: TEST: add uint32[] uint32[]
|
||||
# CHECK: mhlo.add
|
||||
# CHECK: hlo.add
|
||||
# CHECK-SAME: tensor<ui32>
|
||||
print_ir(np.uint32(0), np.uint32(0))(lax.add)
|
||||
|
||||
# CHECK-LABEL: TEST: add uint64[] uint64[]
|
||||
# CHECK: mhlo.add
|
||||
# CHECK: hlo.add
|
||||
# CHECK-SAME: tensor<ui64>
|
||||
print_ir(np.uint64(0), np.uint64(0))(lax.add)
|
||||
|
||||
# CHECK-LABEL: TEST: sin float16[]
|
||||
# CHECK: mhlo.sine
|
||||
# CHECK: hlo.sine
|
||||
# CHECK-SAME: tensor<f16>
|
||||
print_ir(np.float16(0))(lax.sin)
|
||||
|
||||
# CHECK-LABEL: TEST: sin bfloat16[]
|
||||
# CHECK: mhlo.sine
|
||||
# CHECK: hlo.sine
|
||||
# CHECK-SAME: tensor<bf16>
|
||||
print_ir(jnp.bfloat16(0))(lax.sin)
|
||||
|
||||
# CHECK-LABEL: TEST: sin float32[]
|
||||
# CHECK: mhlo.sine
|
||||
# CHECK: hlo.sine
|
||||
# CHECK-SAME: tensor<f32>
|
||||
print_ir(np.float32(0))(lax.sin)
|
||||
|
||||
# CHECK-LABEL: TEST: sin float64[]
|
||||
# CHECK: mhlo.sine
|
||||
# CHECK: hlo.sine
|
||||
# CHECK-SAME: tensor<f64>
|
||||
print_ir(np.float64(0))(lax.sin)
|
||||
|
||||
# CHECK-LABEL: TEST: cos complex64[]
|
||||
# CHECK: mhlo.cosine
|
||||
# CHECK: hlo.cosine
|
||||
# CHECK-SAME: tensor<complex<f32>>
|
||||
print_ir(np.complex64(0))(lax.cos)
|
||||
|
||||
# CHECK-LABEL: TEST: cos complex128[]
|
||||
# CHECK: mhlo.cosine
|
||||
# CHECK: hlo.cosine
|
||||
# CHECK-SAME: tensor<complex<f64>>
|
||||
print_ir(np.complex128(0))(lax.cos)
|
||||
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Tests for lowering of array origami ops into MHLO.
|
||||
# Tests for lowering of array origami ops into MLIR.
|
||||
|
||||
# RUN: %PYTHON %s | FileCheck %s
|
||||
|
||||
@ -65,7 +65,7 @@ def main(_):
|
||||
with m1.context:
|
||||
# Reparse m2 in m1's context.
|
||||
m2_copy = ir.Module.parse(m2)
|
||||
mlir.merge_mhlo_modules(m1, "m2_main_renamed", m2_copy)
|
||||
mlir.merge_mlir_modules(m1, "m2_main_renamed", m2_copy)
|
||||
print("\nTEST: merge_modules")
|
||||
print(str(m1))
|
||||
|
||||
|
@ -370,43 +370,43 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
return x + 1.
|
||||
mhlo = f.lower(2.).compiler_ir()
|
||||
main = mhlo.body.operations[0]
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
first_op = main.body.blocks[0].operations[0]
|
||||
self.assertEqual(first_op.operation.name, "mhlo.create_token")
|
||||
self.assertIn('hlo.create_token', first_op.operation.name)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='foo2')
|
||||
return x + 1.
|
||||
mhlo = f.lower(2.).compiler_ir()
|
||||
main = mhlo.body.operations[0]
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
first_op = main.body.blocks[0].operations[0]
|
||||
self.assertEqual(first_op.operation.name, "mhlo.create_token")
|
||||
self.assertIn('hlo.create_token', first_op.operation.name)
|
||||
second_op = main.body.blocks[0].operations[1]
|
||||
self.assertEqual(second_op.operation.name, "mhlo.create_token")
|
||||
self.assertIn('hlo.create_token', second_op.operation.name)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
return x + 1.
|
||||
mhlo = f.lower(2.).compiler_ir()
|
||||
main = mhlo.body.operations[0]
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
first_op = main.body.blocks[0].operations[0]
|
||||
self.assertEqual(first_op.operation.name, "mhlo.create_token")
|
||||
self.assertIn('hlo.create_token', first_op.operation.name)
|
||||
|
||||
@jax.jit
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='foo2')
|
||||
return x + 1.
|
||||
mhlo = f.lower(2.).compiler_ir()
|
||||
main = mhlo.body.operations[0]
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
first_op = main.body.blocks[0].operations[0]
|
||||
self.assertEqual(first_op.operation.name, "mhlo.create_token")
|
||||
self.assertIn('hlo.create_token', first_op.operation.name)
|
||||
second_op = main.body.blocks[0].operations[1]
|
||||
self.assertEqual(second_op.operation.name, "mhlo.create_token")
|
||||
self.assertIn('hlo.create_token', second_op.operation.name)
|
||||
|
||||
def test_nontrivial_lowering_with_ordered_effect_should_consume_token(self):
|
||||
|
||||
@ -416,19 +416,18 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
return x + 1.
|
||||
|
||||
mhlo = f.lower(2.).compiler_ir()
|
||||
main = mhlo.body.operations[0]
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
first_op = main.body.blocks[0].operations[0]
|
||||
self.assertEqual(first_op.operation.name, "mhlo.create_token")
|
||||
self.assertIn('hlo.create_token', first_op.operation.name)
|
||||
second_op = main.body.blocks[0].operations[1]
|
||||
self.assertEqual(second_op.operation.name, "func.call")
|
||||
self.assertEqual(str(second_op.attributes["callee"]), "@effect")
|
||||
self.assertEqual(second_op.operands[0].owner, first_op)
|
||||
func = mhlo.body.operations[1]
|
||||
func = module.body.operations[1]
|
||||
self.assertEqual(func.name.value, "effect")
|
||||
self.assertEqual(str(func.type.inputs[0]), "!mhlo.token")
|
||||
self.assertEqual(str(func.type.results[0]), "!mhlo.token")
|
||||
self.assertIn('hlo.token', str(func.type.inputs[0]))
|
||||
self.assertIn('hlo.token', str(func.type.results[0]))
|
||||
|
||||
def test_nontrivial_lowering_with_unordered_effect_should_consume_token(self):
|
||||
|
||||
@ -438,14 +437,13 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
effect_p.bind(effect='bar')
|
||||
return x + 1.
|
||||
|
||||
mhlo = f.lower(2.).compiler_ir()
|
||||
main = mhlo.body.operations[0]
|
||||
module = f.lower(2.).compiler_ir()
|
||||
main = module.body.operations[0]
|
||||
first_op = main.body.blocks[0].operations[0]
|
||||
self.assertEqual(first_op.operation.name, "func.call")
|
||||
self.assertEqual(str(first_op.attributes["callee"]), "@effect")
|
||||
self.assertLen(list(first_op.operands), 0)
|
||||
func = mhlo.body.operations[1]
|
||||
func = module.body.operations[1]
|
||||
self.assertEqual(func.name.value, "effect")
|
||||
self.assertLen(list(func.type.inputs), 0)
|
||||
self.assertLen(list(func.type.results), 0)
|
||||
@ -455,13 +453,13 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
effect_p.bind(effect='bar')
|
||||
return x + 1.
|
||||
mhlo = f.lower(1.).compiler_ir(dialect='mhlo')
|
||||
input_types = mhlo.body.operations[0].type.inputs
|
||||
module = f.lower(1.).compiler_ir()
|
||||
input_types = module.body.operations[0].type.inputs
|
||||
self.assertLen(list(input_types), 1)
|
||||
self.assertEqual(str(input_types[0]), 'tensor<f32>')
|
||||
|
||||
# First output should be output token
|
||||
result_types = mhlo.body.operations[0].type.results
|
||||
result_types = module.body.operations[0].type.results
|
||||
if not can_execute_with_token:
|
||||
self.assertLen(list(result_types), 2)
|
||||
self.assertEqual(str(result_types[0]), 'tensor<0xi1>')
|
||||
@ -476,14 +474,14 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
def f(x):
|
||||
effect_p.bind(effect='foo')
|
||||
return x + 1.
|
||||
mhlo = f.lower(1.).compiler_ir(dialect='mhlo')
|
||||
input_types = mhlo.body.operations[0].type.inputs
|
||||
module = f.lower(1.).compiler_ir()
|
||||
input_types = module.body.operations[0].type.inputs
|
||||
# First argument should be dummy token
|
||||
self.assertLen(list(input_types), 2)
|
||||
self.assertEqual(str(input_types[0]), 'tensor<0xi1>')
|
||||
|
||||
# First output should be dummy token
|
||||
result_types = mhlo.body.operations[0].type.results
|
||||
result_types = module.body.operations[0].type.results
|
||||
self.assertLen(list(result_types), 2)
|
||||
self.assertEqual(str(result_types[0]), 'tensor<0xi1>')
|
||||
|
||||
@ -493,15 +491,15 @@ class EffectfulJaxprLoweringTest(jtu.JaxTestCase):
|
||||
effect_p.bind(effect='foo')
|
||||
effect_p.bind(effect='foo2')
|
||||
return x + 1.
|
||||
mhlo = f.lower(1.).compiler_ir(dialect='mhlo')
|
||||
input_types = mhlo.body.operations[0].type.inputs
|
||||
module = f.lower(1.).compiler_ir()
|
||||
input_types = module.body.operations[0].type.inputs
|
||||
# First two arguments should be dummy values
|
||||
self.assertLen(list(input_types), 3)
|
||||
self.assertEqual(str(input_types[0]), 'tensor<0xi1>')
|
||||
self.assertEqual(str(input_types[1]), 'tensor<0xi1>')
|
||||
|
||||
# First two outputs should be dummy values
|
||||
result_types = mhlo.body.operations[0].type.results
|
||||
result_types = module.body.operations[0].type.results
|
||||
self.assertLen(list(result_types), 3)
|
||||
self.assertEqual(str(result_types[0]), 'tensor<0xi1>')
|
||||
self.assertEqual(str(result_types[1]), 'tensor<0xi1>')
|
||||
|
@ -40,7 +40,7 @@ from jax.interpreters import mlir
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import pxla
|
||||
from jax._src import array
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import test_util as jtu
|
||||
@ -2867,7 +2867,7 @@ class FooTyRules:
|
||||
start_indices = (*start_indices, 0)
|
||||
limit_indices = (*limit_indices, 2)
|
||||
strides = (*strides, 1)
|
||||
return mhlo.SliceOp(x,
|
||||
return hlo.SliceOp(x,
|
||||
mlir.dense_int_elements(start_indices),
|
||||
mlir.dense_int_elements(limit_indices),
|
||||
mlir.dense_int_elements(strides)).result
|
||||
@ -2877,7 +2877,7 @@ class FooTyRules:
|
||||
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
|
||||
start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype)))
|
||||
slice_sizes_ = mlir.dense_int_elements((*aval_out.shape, 2))
|
||||
return mhlo.DynamicSliceOp(x, start_indices, slice_sizes_).result
|
||||
return hlo.DynamicSliceOp(x, start_indices, slice_sizes_).result
|
||||
|
||||
@staticmethod
|
||||
def dynamic_update_slice_mlir(ctx, aval_out, x, update, *start_indices):
|
||||
@ -2885,22 +2885,22 @@ class FooTyRules:
|
||||
dtype = dtypes.canonicalize_dtype(np.dtype('int64'))
|
||||
start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype)))
|
||||
if xc.mlir_api_version < 40:
|
||||
return mhlo.DynamicUpdateSliceOp(
|
||||
return hlo.DynamicUpdateSliceOp(
|
||||
mlir.aval_to_ir_type(aval_out), x, update, start_indices).result
|
||||
else:
|
||||
return mhlo.DynamicUpdateSliceOp(x, update, start_indices).result
|
||||
return hlo.DynamicUpdateSliceOp(x, update, start_indices).result
|
||||
|
||||
@staticmethod
|
||||
def broadcast_in_dim_mlir(ctx, aval_out, x, broadcast_dimensions):
|
||||
broadcast_dimensions = [*broadcast_dimensions, aval_out.ndim]
|
||||
return mhlo.BroadcastInDimOp(
|
||||
return hlo.BroadcastInDimOp(
|
||||
mlir.aval_to_ir_type(aval_out), x,
|
||||
mlir.dense_int_elements(broadcast_dimensions)).result
|
||||
|
||||
@staticmethod
|
||||
def transpose_mlir(ctx, aval_out, x, *, permutation):
|
||||
perm = [*permutation, len(permutation)]
|
||||
return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).result
|
||||
return hlo.TransposeOp(x, mlir.dense_int_elements(perm)).result
|
||||
|
||||
@staticmethod
|
||||
def gather_mlir(ctx, avals_in, aval_out, x, indices, *,
|
||||
|
@ -501,9 +501,9 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
self.assertLen(actual[0]['a'].device_buffers, 4)
|
||||
|
||||
mhlo_str = str(f.lower(x).compiler_ir(dialect="mhlo"))
|
||||
self.assertIn("unspecified_dims=[0]", mhlo_str)
|
||||
self.assertIn("unspecified_dims=[1]", mhlo_str)
|
||||
mlir_str = str(f.lower(x).compiler_ir())
|
||||
self.assertIn("unspecified_dims=[0]", mlir_str)
|
||||
self.assertIn("unspecified_dims=[1]", mlir_str)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def testShardingConstraintPyTreeVmapWithUnconstrainedDims(self):
|
||||
@ -521,9 +521,9 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
v = np.arange(prod(shape)).reshape(shape)
|
||||
x = [{'a': v, 'b': v * 2}, v * 3]
|
||||
|
||||
mhlo_str = str(f.lower(x).compiler_ir(dialect="mhlo"))
|
||||
self.assertIn("unspecified_dims=[0,1]", mhlo_str)
|
||||
self.assertIn("unspecified_dims=[0,2]", mhlo_str)
|
||||
mlir_str = str(f.lower(x).compiler_ir())
|
||||
self.assertIn("unspecified_dims=[0,1]", mlir_str)
|
||||
self.assertIn("unspecified_dims=[0,2]", mlir_str)
|
||||
|
||||
def testCaching(self):
|
||||
def f(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user