[MLIR] Fix CPU test failures for MLIR lowering.

The remaining failures relate to buffer donation and xmap_p, which are not yet implemented.

Quite a few primitives still use fallback paths.

PiperOrigin-RevId: 413130158
This commit is contained in:
Peter Hawkins 2021-11-30 06:08:26 -08:00 committed by jax authors
parent 42647e013f
commit fa411d864e
10 changed files with 128 additions and 72 deletions

View File

@ -717,16 +717,23 @@ def _complex_mul(mul, x, y):
k3 = mul(x_im, lax.add(y_re, y_im))
return lax.complex(lax.sub(k1, k3), lax.add(k1, k2))
_real_dtype = lambda dtype: np.finfo(dtype).dtype
def _conv_general_dilated_lower(
ctx, avals_in, avals_out, lhs, rhs, *, window_strides, padding,
lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
batch_group_count, precision, expand_complex_convolutions=False,
**unused_kwargs):
batch_group_count, precision, preferred_element_type,
expand_complex_convolutions=False, **unused_kwargs):
lhs_aval, rhs_aval = avals_in
aval_out, = avals_out
assert isinstance(dimension_numbers, ConvDimensionNumbers)
dtype = lhs_aval.dtype
if expand_complex_convolutions and np.issubdtype(dtype, np.complexfloating):
if preferred_element_type is not None:
# Convert complex dtype to types used for real and imaginary parts
assert np.issubdtype(preferred_element_type, np.complexfloating)
preferred_element_type = _real_dtype(preferred_element_type)
complex_conv = mlir.lower_fun(
partial(
_complex_mul,
@ -734,7 +741,8 @@ def _conv_general_dilated_lower(
padding=padding, lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count, precision=precision)),
batch_group_count=batch_group_count, precision=precision,
preferred_element_type=preferred_element_type)),
multiple_results=False)
return complex_conv(ctx, avals_in, avals_out, lhs, rhs)

View File

@ -1784,7 +1784,7 @@ mlir.register_lowering(complex_p, partial(_nary_lower_mhlo, mhlo.ComplexOp))
conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj')
def _conj_impl(x, *, input_dtype):
def _conj_impl(x, **kw):
if dtypes.issubdtype(x.dtype, np.complexfloating):
return complex(real(x), -imag(x))
else:
@ -2075,30 +2075,6 @@ def _minmax_translation_rule(ctx, avals_in, avals_out, x, y, *, op_minmax=None,
else:
return [op_minmax(x, y)]
def _minmax_mhlo(op, cmp, x, y):
"""Min/max that compares complex values lexicographically as pairs."""
tensor_type = ir.RankedTensorType(x.type)
if ir.ComplexType.isinstance(tensor_type.element_type):
rx = mhlo.RealOp(x).result
ry = mhlo.RealOp(y).result
dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"),
ir.StringAttr.get("FLOAT"))
real_cmp = mhlo.CompareOp(bool_shape, rx, ry,
ir.StringAttr.get(cmp),
ir.StringAttr.get("FLOAT"))
imag_cmp = mhlo.CompareOp(bool_shape, mhlo.ImagOp(x).result,
mhlo.ImagOp(y).result,
ir.StringAttr.get(cmp),
ir.StringAttr.get("FLOAT"))
which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
return mhlo.SelectOp(which, x, y)
else:
return op(x, y)
_min_mhlo = partial(_minmax_mhlo, mhlo.MinOp, "LT")
_max_mhlo = partial(_minmax_mhlo, mhlo.MaxOp, "GT")
max_p: core.Primitive = standard_naryop(
[_any, _any], 'max', translation_rule=partial(
@ -2106,7 +2082,7 @@ max_p: core.Primitive = standard_naryop(
ad.defjvp2(max_p,
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
mlir.register_lowering(max_p, partial(_nary_lower_mhlo, _max_mhlo))
mlir.register_lowering(max_p, partial(_nary_lower_mhlo, mlir.max_mhlo))
min_p: core.Primitive = standard_naryop(
[_any, _any], 'min', translation_rule=partial(
@ -2114,7 +2090,7 @@ min_p: core.Primitive = standard_naryop(
ad.defjvp2(min_p,
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
mlir.register_lowering(min_p, partial(_nary_lower_mhlo, _min_mhlo))
mlir.register_lowering(min_p, partial(_nary_lower_mhlo, mlir.min_mhlo))
shift_left_p = standard_naryop([_int, _int], 'shift_left')
ad.defjvp_zero(shift_left_p)
@ -3590,9 +3566,9 @@ mlir.register_lowering(reduce_or_p, partial(_unary_reduce_lower, mhlo.OrOp,
lambda dtype: np.array(False, dtype)))
mlir.register_lowering(reduce_and_p, partial(_unary_reduce_lower, mhlo.AndOp,
lambda dtype: np.array(True, dtype)))
mlir.register_lowering(reduce_min_p, partial(_unary_reduce_lower, _min_mhlo,
mlir.register_lowering(reduce_min_p, partial(_unary_reduce_lower, mlir.min_mhlo,
_get_min_identity))
mlir.register_lowering(reduce_max_p, partial(_unary_reduce_lower, _max_mhlo,
mlir.register_lowering(reduce_max_p, partial(_unary_reduce_lower, mlir.max_mhlo,
_get_max_identity))
@ -3960,7 +3936,7 @@ xla.register_translation(infeed_p, _infeed_translation_rule)
def _infeed_lowering(ctx, avals_in, avals_out, token, *, shapes, partitions):
assert partitions is None, partitions # TODO(phawkins): implement me.
output_types = map(mlir.aval_to_ir_types, avals_out[:-1])
output_types = safe_map(mlir.aval_to_ir_types, avals_out[:-1])
flat_output_types = util.flatten(output_types)
output_tuple_type = ir.TupleType.get_tuple(flat_output_types)
# TODO(phawkins): verify `shapes` have a major-to-minor layout.

View File

@ -1984,7 +1984,8 @@ def _real_dtype(dtype): return np.finfo(dtype).dtype
def _scatter_add_lower_gpu(ctx, avals_in, avals_out, operand, indices, updates,
*, update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
if operand.dtype != np.complex128:
operand_aval_in, _, updates_aval_in = avals_in
if operand_aval_in.dtype != np.complex128:
return _scatter_lower(ctx, avals_in, avals_out, operand, indices, updates,
update_jaxpr=update_jaxpr,
update_consts=update_consts,
@ -1992,7 +1993,6 @@ def _scatter_add_lower_gpu(ctx, avals_in, avals_out, operand, indices, updates,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices, mode=mode)
assert mode == GatherScatterMode.PROMISE_IN_BOUNDS, mode
_, _, updates_aval_in = avals_in
aval_out, = avals_out
dnums = dimension_numbers
scatter_dnums = mhlo.ScatterDimensionNumbers.get(

View File

@ -510,9 +510,9 @@ def _reduce_window_lower(
mlir.register_lowering(reduce_window_sum_p, partial(
_reduce_window_lower, mhlo.AddOp, lambda _: 0))
mlir.register_lowering(reduce_window_min_p, partial(
_reduce_window_lower, mhlo.MinOp, lax._get_min_identity))
_reduce_window_lower, mlir.min_mhlo, lax._get_min_identity))
mlir.register_lowering(reduce_window_max_p, partial(
_reduce_window_lower, mhlo.MaxOp, lax._get_max_identity))
_reduce_window_lower, mlir.max_mhlo, lax._get_max_identity))

View File

@ -39,6 +39,7 @@ from jax._src.util import prod, unzip2
from jax.tree_util import tree_multimap, tree_all, tree_map, tree_reduce
from jax._src.lib import xla_bridge
from jax._src import dispatch
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.experimental.maps import mesh
@ -368,18 +369,25 @@ def count_jit_and_pmap_compiles():
# No need to clear any caches since we generally jit and pmap fresh callables
# in tests.
jaxpr_subcomp = xla.jaxpr_subcomp
xla_jaxpr_subcomp = xla.jaxpr_subcomp
mlir_jaxpr_subcomp = mlir.jaxpr_subcomp
count = [0]
def jaxpr_subcomp_and_count(*args, **kwargs):
def xla_jaxpr_subcomp_and_count(*args, **kwargs):
count[0] += 1
return jaxpr_subcomp(*args, **kwargs)
return xla_jaxpr_subcomp(*args, **kwargs)
xla.jaxpr_subcomp = jaxpr_subcomp_and_count
def mlir_jaxpr_subcomp_and_count(*args, **kwargs):
count[0] += 1
return mlir_jaxpr_subcomp(*args, **kwargs)
xla.jaxpr_subcomp = xla_jaxpr_subcomp_and_count
mlir.jaxpr_subcomp = mlir_jaxpr_subcomp_and_count
try:
yield count
finally:
xla.jaxpr_subcomp = jaxpr_subcomp
xla.jaxpr_subcomp = xla_jaxpr_subcomp
mlir.jaxpr_subcomp = mlir_jaxpr_subcomp
@contextmanager
def assert_num_jit_and_pmap_compilations(times):

View File

@ -178,9 +178,10 @@ def _comparator_builder(operand, op_type, is_max_k):
return c.build(cmp_result)
def _approx_top_k_tpu_translation(c, operand, k, reduction_dimension,
recall_target, is_max_k,
def _approx_top_k_tpu_translation(ctx, avals_in, avals_out, operand, *, k,
reduction_dimension, recall_target, is_max_k,
reduction_input_size_override):
c = ctx.builder
op_shape = c.get_shape(operand)
if not op_shape.is_array():
raise ValueError('operand must be an array, but was {}'.format(op_shape))
@ -203,14 +204,17 @@ def _approx_top_k_tpu_translation(c, operand, k, reduction_dimension,
reduction_dimension)
init_val = xc.ops.Constant(c, init_literal)
init_arg = xc.ops.Constant(c, np.int32(-1))
return xc.ops.ApproxTopK(c, [operand, iota], [init_val, init_arg], k,
reduction_dimension, comparator, recall_target, True,
reduction_input_size_override)
out = xc.ops.ApproxTopK(c, [operand, iota], [init_val, init_arg], k,
reduction_dimension, comparator, recall_target, True,
reduction_input_size_override)
return xla.xla_destructure(c, out)
def _approx_top_k_fallback_translation(c, operand, k, reduction_dimension,
def _approx_top_k_fallback_translation(ctx, avals_in, avals_out, operand, *, k,
reduction_dimension,
recall_target, is_max_k,
reduction_input_size_override):
c = ctx.builder
op_shape = c.get_shape(operand)
if not op_shape.is_array():
raise ValueError('operand must be an array, but was {}'.format(op_shape))
@ -226,7 +230,7 @@ def _approx_top_k_fallback_translation(c, operand, k, reduction_dimension,
args = xc.ops.GetTupleElement(val_arg, 1)
sliced_vals = xc.ops.SliceInDim(vals, 0, k, 1, reduction_dimension)
sliced_args = xc.ops.SliceInDim(args, 0, k, 1, reduction_dimension)
return xc.ops.Tuple(c, [sliced_vals, sliced_args])
return sliced_vals, sliced_args
def _approx_top_k_batch_rule(batched_args, batch_dims, *, k,
@ -292,11 +296,8 @@ approx_top_k_p = core.Primitive('approx_top_k')
approx_top_k_p.multiple_results = True
approx_top_k_p.def_impl(partial(xla.apply_primitive, approx_top_k_p))
approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval)
xla.backend_specific_translations['tpu'][
approx_top_k_p] = _approx_top_k_tpu_translation
xla.backend_specific_translations['cpu'][
approx_top_k_p] = _approx_top_k_fallback_translation
xla.backend_specific_translations['gpu'][
approx_top_k_p] = _approx_top_k_fallback_translation
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation,
platform='tpu')
batching.primitive_batchers[approx_top_k_p] = _approx_top_k_batch_rule
ad.primitive_jvps[approx_top_k_p] = _approx_top_k_jvp

View File

@ -69,7 +69,7 @@ def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i)
# IR Types
# Non-canonicalized dtype to IR type mapping.
_dtype_to_ir_type : Dict[np.dtype, Callable[[], ir.Type]] = {
dtype_to_ir_type : Dict[np.dtype, Callable[[], ir.Type]] = {
np.dtype(dtypes.float0): partial(ir.IntegerType.get_signless, 1),
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
@ -91,13 +91,13 @@ _dtype_to_ir_type : Dict[np.dtype, Callable[[], ir.Type]] = {
def _array_ir_types(aval: core.ShapedArray) -> ir.Type:
try:
ir_type_factory = _dtype_to_ir_type[aval.dtype]
ir_type_factory = dtype_to_ir_type[aval.dtype]
except KeyError as err:
raise TypeError(
f"No dtype_to_ir_type handler for dtype: {aval.dtype}") from err
return (ir.RankedTensorType.get(aval.shape, ir_type_factory()),)
_ir_type_handlers: Dict[Type[core.AbstractValue],
ir_type_handlers: Dict[Type[core.AbstractValue],
Callable[[Any], Sequence[ir.Type]]] = {}
def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]:
@ -106,14 +106,14 @@ def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]:
In general, a JAX value may be represented by multiple IR values, so this
function returns multiple types."""
try:
return _ir_type_handlers[type(aval)](aval)
return ir_type_handlers[type(aval)](aval)
except KeyError as err:
raise TypeError(f"No ir_type_handler for aval type: {type(aval)}") from err
_ir_type_handlers[core.AbstractUnit] = lambda _: ()
_ir_type_handlers[core.ShapedArray] = _array_ir_types
_ir_type_handlers[core.ConcreteArray] = _array_ir_types
_ir_type_handlers[core.AbstractToken] = lambda _: [mhlo.TokenType.get()]
ir_type_handlers[core.AbstractUnit] = lambda _: ()
ir_type_handlers[core.ShapedArray] = _array_ir_types
ir_type_handlers[core.ConcreteArray] = _array_ir_types
ir_type_handlers[core.AbstractToken] = lambda _: [mhlo.TokenType.get()]
def aval_to_ir_type(aval: core.AbstractValue) -> ir.Type:
"""Convenience wrapper around aval_to_ir_types for single types.
@ -310,6 +310,7 @@ def register_lowering(prim: core.Primitive, rule: LoweringRule,
def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x
def wrap_singleton_ir_values(x: Union[ir.Value, Sequence[ir.Value]]
) -> Sequence[ir.Value]:
"""Adds a consistent tuples to a mixture of tupled and untuple values."""
return (x,) if isinstance(x, ir.Value) else tuple(x)
def flatten_lowering_ir_args(
@ -568,6 +569,32 @@ register_lowering(ad_util.stop_gradient_p,
lambda ctx, avals_in, avals_out, x: [x])
def _minmax_mhlo(op, cmp, x, y):
"""Min/max that compares complex values lexicographically as pairs."""
tensor_type = ir.RankedTensorType(x.type)
if ir.ComplexType.isinstance(tensor_type.element_type):
rx = mhlo.RealOp(x).result
ry = mhlo.RealOp(y).result
dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"),
ir.StringAttr.get("FLOAT"))
real_cmp = mhlo.CompareOp(bool_shape, rx, ry,
ir.StringAttr.get(cmp),
ir.StringAttr.get("FLOAT"))
imag_cmp = mhlo.CompareOp(bool_shape, mhlo.ImagOp(x).result,
mhlo.ImagOp(y).result,
ir.StringAttr.get(cmp),
ir.StringAttr.get("FLOAT"))
which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
return mhlo.SelectOp(which, x, y)
else:
return op(x, y)
min_mhlo = partial(_minmax_mhlo, mhlo.MinOp, "LT")
max_mhlo = partial(_minmax_mhlo, mhlo.MaxOp, "GT")
# MLIR lowerings for lax primitives
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext,

View File

@ -1637,7 +1637,8 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
x = mhlo.ConvertOp(mlir.aval_to_ir_type(aval), x).result
dims = list(aval.shape)
padded = mlir.full_like_aval(0, aval.update(shape=[axis_env.sizes[-1]] + dims))
padded = mlir.full_like_aval(
0, aval.update(shape=[axis_env.sizes[-1]] + dims))
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
idxs = [_unravel_index_mhlo(axis_env)] + [zero] * len(dims)
padded = mhlo.DynamicUpdateSliceOp(
@ -1695,7 +1696,8 @@ def _pmap_lowering(ctx, avals_in, avals_out, *in_nodes, axis_name,
axis_env = new_env,
name_stack=xla.extend_name_stack(ctx.name_stack,
util.wrap_name(name, 'pmap')))
sharded_outs = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *in_nodes_sharded)
sharded_outs = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, (),
*in_nodes_sharded)
out_avals = [v.aval for v in call_jaxpr.outvars]
outs = [_mhlo_unshard(aval, new_env, out_axis, shard, platform=ctx.platform)
for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)]

View File

@ -45,6 +45,7 @@ from jax._src import api, dtypes
from jax.core import Primitive
from jax.errors import UnexpectedTracerError
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters.sharded_jit import PartitionSpec as P
@ -243,6 +244,8 @@ class CPPJitTest(jtu.BufferDonationTestCase):
# Jit and Donate arguments
def test_jit_donate_argnums_warning_raised(self):
if jax.config.jax_enable_mlir:
raise unittest.SkipTest("Buffer donation not yet implemented via MLIR")
x = jnp.array([1.0, 2.0], jnp.float32)
y = jnp.array([1, 2], jnp.int32)
f = self.jit(lambda x, y: x.sum() + y.sum(), donate_argnums=(0, 1))
@ -984,8 +987,8 @@ class APITest(jtu.JaxTestCase):
foo_p.def_abstract_eval(lambda x: x)
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
"XLA translation rule for primitive 'foo' not found")
jtu.check_raises_regexp(lambda: jit(foo)(1.0), NotImplementedError,
".* rule for primitive 'foo' not found.*")
foo_p.def_impl(lambda x: x)
ad.defjvp(foo_p, lambda g, x: foo(g))
@ -2406,18 +2409,24 @@ class APITest(jtu.JaxTestCase):
def g(x):
return f(2, x)
jaxpr_subcomp = xla.jaxpr_subcomp
xla_jaxpr_subcomp = xla.jaxpr_subcomp
mlir_jaxpr_subcomp = mlir.jaxpr_subcomp
jaxprs = []
def jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
def xla_jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
jaxprs.append(jaxpr)
return jaxpr_subcomp(c, jaxpr, *args, **kwargs)
return xla_jaxpr_subcomp(c, jaxpr, *args, **kwargs)
def mlir_jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
jaxprs.append(jaxpr)
return mlir_jaxpr_subcomp(c, jaxpr, *args, **kwargs)
try:
xla.jaxpr_subcomp = jaxpr_subcomp_and_collect
xla.jaxpr_subcomp = xla_jaxpr_subcomp_and_collect
mlir.jaxpr_subcomp = mlir_jaxpr_subcomp_and_collect
ans = g(3)
finally:
xla.jaxpr_subcomp = jaxpr_subcomp
xla.jaxpr_subcomp = xla_jaxpr_subcomp
mlir.jaxpr_subcomp = mlir_jaxpr_subcomp
self.assertEqual(ans, (7, 3))
self.assertLen(jaxprs, 2)

View File

@ -21,7 +21,9 @@ import jax.numpy as jnp
from jax import core, jit, lax, make_jaxpr
from jax._src import device_array
from jax._src import dispatch
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.lib.mlir import ir
from jax._src.lib import xla_bridge, xla_client
xops = xla_client.ops
xb = xla_bridge
@ -137,6 +139,15 @@ dispatch.num_buffers_handlers[AbstractSparseArray] = lambda _: 2
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
xla.register_constant_handler(SparseArray, sparse_array_constant_handler)
def sparse_array_mlir_type_handler(a):
return (
ir.RankedTensorType.get(
a.data_aval.shape, mlir.dtype_to_ir_type[a.data_aval.dtype]()),
ir.RankedTensorType.get(
a.indices_aval.shape, mlir.dtype_to_ir_type[a.indices_aval.dtype]()),
)
mlir.ir_type_handlers[AbstractSparseArray] = sparse_array_mlir_type_handler
sp_indices_p = core.Primitive('sp_indices')
@ -155,6 +166,11 @@ def _sp_indices_translation_rule(ctx, avals_in, avals_out, data, indices):
# because it leads to infinite recursion.
xla.register_translation(sp_indices_p, _sp_indices_translation_rule)
def _sp_indices_mhlo_lowering(ctx, avals_in, avals_out, data_and_indices):
return [data_and_indices[1]]
mlir.register_lowering(sp_indices_p, _sp_indices_mhlo_lowering)
sp_data_p = core.Primitive('sp_data')
@sp_data_p.def_impl
@ -172,6 +188,11 @@ def _sp_data_translation_rule(ctx, avals_in, avals_out, data, indices):
# because it leads to infinite recursion.
xla.register_translation(sp_data_p, _sp_data_translation_rule)
def _sp_data_mhlo_lowering(ctx, avals_in, avals_out, data_and_indices):
return [data_and_indices[0]]
mlir.register_lowering(sp_data_p, _sp_data_mhlo_lowering)
def identity(x):
return identity_p.bind(x)
@ -189,6 +210,10 @@ xla.register_translation(
identity_p, xla.lower_fun(_identity_impl, multiple_results=False,
new_style=True))
mlir.register_lowering(
identity_p, mlir.lower_fun(_identity_impl, multiple_results=False))
def split(x):
return split_p.bind(x)