mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[MLIR] Fix CPU test failures for MLIR lowering.
The remaining failures relate to buffer donation and xmap_p, which are not yet implemented. Quite a few primitives still use fallback paths. PiperOrigin-RevId: 413130158
This commit is contained in:
parent
42647e013f
commit
fa411d864e
@ -717,16 +717,23 @@ def _complex_mul(mul, x, y):
|
||||
k3 = mul(x_im, lax.add(y_re, y_im))
|
||||
return lax.complex(lax.sub(k1, k3), lax.add(k1, k2))
|
||||
|
||||
|
||||
_real_dtype = lambda dtype: np.finfo(dtype).dtype
|
||||
|
||||
def _conv_general_dilated_lower(
|
||||
ctx, avals_in, avals_out, lhs, rhs, *, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation, dimension_numbers, feature_group_count,
|
||||
batch_group_count, precision, expand_complex_convolutions=False,
|
||||
**unused_kwargs):
|
||||
batch_group_count, precision, preferred_element_type,
|
||||
expand_complex_convolutions=False, **unused_kwargs):
|
||||
lhs_aval, rhs_aval = avals_in
|
||||
aval_out, = avals_out
|
||||
assert isinstance(dimension_numbers, ConvDimensionNumbers)
|
||||
dtype = lhs_aval.dtype
|
||||
if expand_complex_convolutions and np.issubdtype(dtype, np.complexfloating):
|
||||
if preferred_element_type is not None:
|
||||
# Convert complex dtype to types used for real and imaginary parts
|
||||
assert np.issubdtype(preferred_element_type, np.complexfloating)
|
||||
preferred_element_type = _real_dtype(preferred_element_type)
|
||||
complex_conv = mlir.lower_fun(
|
||||
partial(
|
||||
_complex_mul,
|
||||
@ -734,7 +741,8 @@ def _conv_general_dilated_lower(
|
||||
padding=padding, lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation, dimension_numbers=dimension_numbers,
|
||||
feature_group_count=feature_group_count,
|
||||
batch_group_count=batch_group_count, precision=precision)),
|
||||
batch_group_count=batch_group_count, precision=precision,
|
||||
preferred_element_type=preferred_element_type)),
|
||||
multiple_results=False)
|
||||
return complex_conv(ctx, avals_in, avals_out, lhs, rhs)
|
||||
|
||||
|
@ -1784,7 +1784,7 @@ mlir.register_lowering(complex_p, partial(_nary_lower_mhlo, mhlo.ComplexOp))
|
||||
|
||||
conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj')
|
||||
|
||||
def _conj_impl(x, *, input_dtype):
|
||||
def _conj_impl(x, **kw):
|
||||
if dtypes.issubdtype(x.dtype, np.complexfloating):
|
||||
return complex(real(x), -imag(x))
|
||||
else:
|
||||
@ -2075,30 +2075,6 @@ def _minmax_translation_rule(ctx, avals_in, avals_out, x, y, *, op_minmax=None,
|
||||
else:
|
||||
return [op_minmax(x, y)]
|
||||
|
||||
def _minmax_mhlo(op, cmp, x, y):
|
||||
"""Min/max that compares complex values lexicographically as pairs."""
|
||||
tensor_type = ir.RankedTensorType(x.type)
|
||||
if ir.ComplexType.isinstance(tensor_type.element_type):
|
||||
rx = mhlo.RealOp(x).result
|
||||
ry = mhlo.RealOp(y).result
|
||||
dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
|
||||
bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
|
||||
real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"),
|
||||
ir.StringAttr.get("FLOAT"))
|
||||
real_cmp = mhlo.CompareOp(bool_shape, rx, ry,
|
||||
ir.StringAttr.get(cmp),
|
||||
ir.StringAttr.get("FLOAT"))
|
||||
imag_cmp = mhlo.CompareOp(bool_shape, mhlo.ImagOp(x).result,
|
||||
mhlo.ImagOp(y).result,
|
||||
ir.StringAttr.get(cmp),
|
||||
ir.StringAttr.get("FLOAT"))
|
||||
which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
|
||||
return mhlo.SelectOp(which, x, y)
|
||||
else:
|
||||
return op(x, y)
|
||||
|
||||
_min_mhlo = partial(_minmax_mhlo, mhlo.MinOp, "LT")
|
||||
_max_mhlo = partial(_minmax_mhlo, mhlo.MaxOp, "GT")
|
||||
|
||||
max_p: core.Primitive = standard_naryop(
|
||||
[_any, _any], 'max', translation_rule=partial(
|
||||
@ -2106,7 +2082,7 @@ max_p: core.Primitive = standard_naryop(
|
||||
ad.defjvp2(max_p,
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
||||
mlir.register_lowering(max_p, partial(_nary_lower_mhlo, _max_mhlo))
|
||||
mlir.register_lowering(max_p, partial(_nary_lower_mhlo, mlir.max_mhlo))
|
||||
|
||||
min_p: core.Primitive = standard_naryop(
|
||||
[_any, _any], 'min', translation_rule=partial(
|
||||
@ -2114,7 +2090,7 @@ min_p: core.Primitive = standard_naryop(
|
||||
ad.defjvp2(min_p,
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)),
|
||||
lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x)))
|
||||
mlir.register_lowering(min_p, partial(_nary_lower_mhlo, _min_mhlo))
|
||||
mlir.register_lowering(min_p, partial(_nary_lower_mhlo, mlir.min_mhlo))
|
||||
|
||||
shift_left_p = standard_naryop([_int, _int], 'shift_left')
|
||||
ad.defjvp_zero(shift_left_p)
|
||||
@ -3590,9 +3566,9 @@ mlir.register_lowering(reduce_or_p, partial(_unary_reduce_lower, mhlo.OrOp,
|
||||
lambda dtype: np.array(False, dtype)))
|
||||
mlir.register_lowering(reduce_and_p, partial(_unary_reduce_lower, mhlo.AndOp,
|
||||
lambda dtype: np.array(True, dtype)))
|
||||
mlir.register_lowering(reduce_min_p, partial(_unary_reduce_lower, _min_mhlo,
|
||||
mlir.register_lowering(reduce_min_p, partial(_unary_reduce_lower, mlir.min_mhlo,
|
||||
_get_min_identity))
|
||||
mlir.register_lowering(reduce_max_p, partial(_unary_reduce_lower, _max_mhlo,
|
||||
mlir.register_lowering(reduce_max_p, partial(_unary_reduce_lower, mlir.max_mhlo,
|
||||
_get_max_identity))
|
||||
|
||||
|
||||
@ -3960,7 +3936,7 @@ xla.register_translation(infeed_p, _infeed_translation_rule)
|
||||
|
||||
def _infeed_lowering(ctx, avals_in, avals_out, token, *, shapes, partitions):
|
||||
assert partitions is None, partitions # TODO(phawkins): implement me.
|
||||
output_types = map(mlir.aval_to_ir_types, avals_out[:-1])
|
||||
output_types = safe_map(mlir.aval_to_ir_types, avals_out[:-1])
|
||||
flat_output_types = util.flatten(output_types)
|
||||
output_tuple_type = ir.TupleType.get_tuple(flat_output_types)
|
||||
# TODO(phawkins): verify `shapes` have a major-to-minor layout.
|
||||
|
@ -1984,7 +1984,8 @@ def _real_dtype(dtype): return np.finfo(dtype).dtype
|
||||
def _scatter_add_lower_gpu(ctx, avals_in, avals_out, operand, indices, updates,
|
||||
*, update_jaxpr, update_consts, dimension_numbers,
|
||||
indices_are_sorted, unique_indices, mode):
|
||||
if operand.dtype != np.complex128:
|
||||
operand_aval_in, _, updates_aval_in = avals_in
|
||||
if operand_aval_in.dtype != np.complex128:
|
||||
return _scatter_lower(ctx, avals_in, avals_out, operand, indices, updates,
|
||||
update_jaxpr=update_jaxpr,
|
||||
update_consts=update_consts,
|
||||
@ -1992,7 +1993,6 @@ def _scatter_add_lower_gpu(ctx, avals_in, avals_out, operand, indices, updates,
|
||||
indices_are_sorted=indices_are_sorted,
|
||||
unique_indices=unique_indices, mode=mode)
|
||||
assert mode == GatherScatterMode.PROMISE_IN_BOUNDS, mode
|
||||
_, _, updates_aval_in = avals_in
|
||||
aval_out, = avals_out
|
||||
dnums = dimension_numbers
|
||||
scatter_dnums = mhlo.ScatterDimensionNumbers.get(
|
||||
|
@ -510,9 +510,9 @@ def _reduce_window_lower(
|
||||
mlir.register_lowering(reduce_window_sum_p, partial(
|
||||
_reduce_window_lower, mhlo.AddOp, lambda _: 0))
|
||||
mlir.register_lowering(reduce_window_min_p, partial(
|
||||
_reduce_window_lower, mhlo.MinOp, lax._get_min_identity))
|
||||
_reduce_window_lower, mlir.min_mhlo, lax._get_min_identity))
|
||||
mlir.register_lowering(reduce_window_max_p, partial(
|
||||
_reduce_window_lower, mhlo.MaxOp, lax._get_max_identity))
|
||||
_reduce_window_lower, mlir.max_mhlo, lax._get_max_identity))
|
||||
|
||||
|
||||
|
||||
|
@ -39,6 +39,7 @@ from jax._src.util import prod, unzip2
|
||||
from jax.tree_util import tree_multimap, tree_all, tree_map, tree_reduce
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src import dispatch
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.experimental.maps import mesh
|
||||
|
||||
@ -368,18 +369,25 @@ def count_jit_and_pmap_compiles():
|
||||
# No need to clear any caches since we generally jit and pmap fresh callables
|
||||
# in tests.
|
||||
|
||||
jaxpr_subcomp = xla.jaxpr_subcomp
|
||||
xla_jaxpr_subcomp = xla.jaxpr_subcomp
|
||||
mlir_jaxpr_subcomp = mlir.jaxpr_subcomp
|
||||
count = [0]
|
||||
|
||||
def jaxpr_subcomp_and_count(*args, **kwargs):
|
||||
def xla_jaxpr_subcomp_and_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return jaxpr_subcomp(*args, **kwargs)
|
||||
return xla_jaxpr_subcomp(*args, **kwargs)
|
||||
|
||||
xla.jaxpr_subcomp = jaxpr_subcomp_and_count
|
||||
def mlir_jaxpr_subcomp_and_count(*args, **kwargs):
|
||||
count[0] += 1
|
||||
return mlir_jaxpr_subcomp(*args, **kwargs)
|
||||
|
||||
xla.jaxpr_subcomp = xla_jaxpr_subcomp_and_count
|
||||
mlir.jaxpr_subcomp = mlir_jaxpr_subcomp_and_count
|
||||
try:
|
||||
yield count
|
||||
finally:
|
||||
xla.jaxpr_subcomp = jaxpr_subcomp
|
||||
xla.jaxpr_subcomp = xla_jaxpr_subcomp
|
||||
mlir.jaxpr_subcomp = mlir_jaxpr_subcomp
|
||||
|
||||
@contextmanager
|
||||
def assert_num_jit_and_pmap_compilations(times):
|
||||
|
@ -178,9 +178,10 @@ def _comparator_builder(operand, op_type, is_max_k):
|
||||
return c.build(cmp_result)
|
||||
|
||||
|
||||
def _approx_top_k_tpu_translation(c, operand, k, reduction_dimension,
|
||||
recall_target, is_max_k,
|
||||
def _approx_top_k_tpu_translation(ctx, avals_in, avals_out, operand, *, k,
|
||||
reduction_dimension, recall_target, is_max_k,
|
||||
reduction_input_size_override):
|
||||
c = ctx.builder
|
||||
op_shape = c.get_shape(operand)
|
||||
if not op_shape.is_array():
|
||||
raise ValueError('operand must be an array, but was {}'.format(op_shape))
|
||||
@ -203,14 +204,17 @@ def _approx_top_k_tpu_translation(c, operand, k, reduction_dimension,
|
||||
reduction_dimension)
|
||||
init_val = xc.ops.Constant(c, init_literal)
|
||||
init_arg = xc.ops.Constant(c, np.int32(-1))
|
||||
return xc.ops.ApproxTopK(c, [operand, iota], [init_val, init_arg], k,
|
||||
reduction_dimension, comparator, recall_target, True,
|
||||
reduction_input_size_override)
|
||||
out = xc.ops.ApproxTopK(c, [operand, iota], [init_val, init_arg], k,
|
||||
reduction_dimension, comparator, recall_target, True,
|
||||
reduction_input_size_override)
|
||||
return xla.xla_destructure(c, out)
|
||||
|
||||
|
||||
def _approx_top_k_fallback_translation(c, operand, k, reduction_dimension,
|
||||
def _approx_top_k_fallback_translation(ctx, avals_in, avals_out, operand, *, k,
|
||||
reduction_dimension,
|
||||
recall_target, is_max_k,
|
||||
reduction_input_size_override):
|
||||
c = ctx.builder
|
||||
op_shape = c.get_shape(operand)
|
||||
if not op_shape.is_array():
|
||||
raise ValueError('operand must be an array, but was {}'.format(op_shape))
|
||||
@ -226,7 +230,7 @@ def _approx_top_k_fallback_translation(c, operand, k, reduction_dimension,
|
||||
args = xc.ops.GetTupleElement(val_arg, 1)
|
||||
sliced_vals = xc.ops.SliceInDim(vals, 0, k, 1, reduction_dimension)
|
||||
sliced_args = xc.ops.SliceInDim(args, 0, k, 1, reduction_dimension)
|
||||
return xc.ops.Tuple(c, [sliced_vals, sliced_args])
|
||||
return sliced_vals, sliced_args
|
||||
|
||||
|
||||
def _approx_top_k_batch_rule(batched_args, batch_dims, *, k,
|
||||
@ -292,11 +296,8 @@ approx_top_k_p = core.Primitive('approx_top_k')
|
||||
approx_top_k_p.multiple_results = True
|
||||
approx_top_k_p.def_impl(partial(xla.apply_primitive, approx_top_k_p))
|
||||
approx_top_k_p.def_abstract_eval(_approx_top_k_abstract_eval)
|
||||
xla.backend_specific_translations['tpu'][
|
||||
approx_top_k_p] = _approx_top_k_tpu_translation
|
||||
xla.backend_specific_translations['cpu'][
|
||||
approx_top_k_p] = _approx_top_k_fallback_translation
|
||||
xla.backend_specific_translations['gpu'][
|
||||
approx_top_k_p] = _approx_top_k_fallback_translation
|
||||
xla.register_translation(approx_top_k_p, _approx_top_k_fallback_translation)
|
||||
xla.register_translation(approx_top_k_p, _approx_top_k_tpu_translation,
|
||||
platform='tpu')
|
||||
batching.primitive_batchers[approx_top_k_p] = _approx_top_k_batch_rule
|
||||
ad.primitive_jvps[approx_top_k_p] = _approx_top_k_jvp
|
||||
|
@ -69,7 +69,7 @@ def i64_attr(i): return ir.IntegerAttr.get(ir.IntegerType.get_signless(64), i)
|
||||
# IR Types
|
||||
|
||||
# Non-canonicalized dtype to IR type mapping.
|
||||
_dtype_to_ir_type : Dict[np.dtype, Callable[[], ir.Type]] = {
|
||||
dtype_to_ir_type : Dict[np.dtype, Callable[[], ir.Type]] = {
|
||||
np.dtype(dtypes.float0): partial(ir.IntegerType.get_signless, 1),
|
||||
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
|
||||
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
|
||||
@ -91,13 +91,13 @@ _dtype_to_ir_type : Dict[np.dtype, Callable[[], ir.Type]] = {
|
||||
|
||||
def _array_ir_types(aval: core.ShapedArray) -> ir.Type:
|
||||
try:
|
||||
ir_type_factory = _dtype_to_ir_type[aval.dtype]
|
||||
ir_type_factory = dtype_to_ir_type[aval.dtype]
|
||||
except KeyError as err:
|
||||
raise TypeError(
|
||||
f"No dtype_to_ir_type handler for dtype: {aval.dtype}") from err
|
||||
return (ir.RankedTensorType.get(aval.shape, ir_type_factory()),)
|
||||
|
||||
_ir_type_handlers: Dict[Type[core.AbstractValue],
|
||||
ir_type_handlers: Dict[Type[core.AbstractValue],
|
||||
Callable[[Any], Sequence[ir.Type]]] = {}
|
||||
|
||||
def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]:
|
||||
@ -106,14 +106,14 @@ def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]:
|
||||
In general, a JAX value may be represented by multiple IR values, so this
|
||||
function returns multiple types."""
|
||||
try:
|
||||
return _ir_type_handlers[type(aval)](aval)
|
||||
return ir_type_handlers[type(aval)](aval)
|
||||
except KeyError as err:
|
||||
raise TypeError(f"No ir_type_handler for aval type: {type(aval)}") from err
|
||||
|
||||
_ir_type_handlers[core.AbstractUnit] = lambda _: ()
|
||||
_ir_type_handlers[core.ShapedArray] = _array_ir_types
|
||||
_ir_type_handlers[core.ConcreteArray] = _array_ir_types
|
||||
_ir_type_handlers[core.AbstractToken] = lambda _: [mhlo.TokenType.get()]
|
||||
ir_type_handlers[core.AbstractUnit] = lambda _: ()
|
||||
ir_type_handlers[core.ShapedArray] = _array_ir_types
|
||||
ir_type_handlers[core.ConcreteArray] = _array_ir_types
|
||||
ir_type_handlers[core.AbstractToken] = lambda _: [mhlo.TokenType.get()]
|
||||
|
||||
def aval_to_ir_type(aval: core.AbstractValue) -> ir.Type:
|
||||
"""Convenience wrapper around aval_to_ir_types for single types.
|
||||
@ -310,6 +310,7 @@ def register_lowering(prim: core.Primitive, rule: LoweringRule,
|
||||
def _unwrap_singleton_ir_values(x): return x[0] if len(x) == 1 else x
|
||||
def wrap_singleton_ir_values(x: Union[ir.Value, Sequence[ir.Value]]
|
||||
) -> Sequence[ir.Value]:
|
||||
"""Adds a consistent tuples to a mixture of tupled and untuple values."""
|
||||
return (x,) if isinstance(x, ir.Value) else tuple(x)
|
||||
|
||||
def flatten_lowering_ir_args(
|
||||
@ -568,6 +569,32 @@ register_lowering(ad_util.stop_gradient_p,
|
||||
lambda ctx, avals_in, avals_out, x: [x])
|
||||
|
||||
|
||||
def _minmax_mhlo(op, cmp, x, y):
|
||||
"""Min/max that compares complex values lexicographically as pairs."""
|
||||
tensor_type = ir.RankedTensorType(x.type)
|
||||
if ir.ComplexType.isinstance(tensor_type.element_type):
|
||||
rx = mhlo.RealOp(x).result
|
||||
ry = mhlo.RealOp(y).result
|
||||
dims = [tensor_type.get_dim_size(i) for i in range(tensor_type.rank)]
|
||||
bool_shape = ir.RankedTensorType.get(dims, ir.IntegerType.get_signless(1))
|
||||
real_eq = mhlo.CompareOp(bool_shape, rx, ry, ir.StringAttr.get("EQ"),
|
||||
ir.StringAttr.get("FLOAT"))
|
||||
real_cmp = mhlo.CompareOp(bool_shape, rx, ry,
|
||||
ir.StringAttr.get(cmp),
|
||||
ir.StringAttr.get("FLOAT"))
|
||||
imag_cmp = mhlo.CompareOp(bool_shape, mhlo.ImagOp(x).result,
|
||||
mhlo.ImagOp(y).result,
|
||||
ir.StringAttr.get(cmp),
|
||||
ir.StringAttr.get("FLOAT"))
|
||||
which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result
|
||||
return mhlo.SelectOp(which, x, y)
|
||||
else:
|
||||
return op(x, y)
|
||||
|
||||
min_mhlo = partial(_minmax_mhlo, mhlo.MinOp, "LT")
|
||||
max_mhlo = partial(_minmax_mhlo, mhlo.MaxOp, "GT")
|
||||
|
||||
|
||||
# MLIR lowerings for lax primitives
|
||||
|
||||
def xla_fallback_lowering(prim: core.Primitive, ctx: LoweringContext,
|
||||
|
@ -1637,7 +1637,8 @@ def _mhlo_unshard(aval, axis_env, out_axis, xs, platform):
|
||||
x = mhlo.ConvertOp(mlir.aval_to_ir_type(aval), x).result
|
||||
|
||||
dims = list(aval.shape)
|
||||
padded = mlir.full_like_aval(0, aval.update(shape=[axis_env.sizes[-1]] + dims))
|
||||
padded = mlir.full_like_aval(
|
||||
0, aval.update(shape=[axis_env.sizes[-1]] + dims))
|
||||
zero = mlir.ir_constant(np.zeros((), dtype=np.uint32))
|
||||
idxs = [_unravel_index_mhlo(axis_env)] + [zero] * len(dims)
|
||||
padded = mhlo.DynamicUpdateSliceOp(
|
||||
@ -1695,7 +1696,8 @@ def _pmap_lowering(ctx, avals_in, avals_out, *in_nodes, axis_name,
|
||||
axis_env = new_env,
|
||||
name_stack=xla.extend_name_stack(ctx.name_stack,
|
||||
util.wrap_name(name, 'pmap')))
|
||||
sharded_outs = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, (), *in_nodes_sharded)
|
||||
sharded_outs = mlir.jaxpr_subcomp(sub_ctx, call_jaxpr, (),
|
||||
*in_nodes_sharded)
|
||||
out_avals = [v.aval for v in call_jaxpr.outvars]
|
||||
outs = [_mhlo_unshard(aval, new_env, out_axis, shard, platform=ctx.platform)
|
||||
for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)]
|
||||
|
@ -45,6 +45,7 @@ from jax._src import api, dtypes
|
||||
from jax.core import Primitive
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters.sharded_jit import PartitionSpec as P
|
||||
@ -243,6 +244,8 @@ class CPPJitTest(jtu.BufferDonationTestCase):
|
||||
# Jit and Donate arguments
|
||||
|
||||
def test_jit_donate_argnums_warning_raised(self):
|
||||
if jax.config.jax_enable_mlir:
|
||||
raise unittest.SkipTest("Buffer donation not yet implemented via MLIR")
|
||||
x = jnp.array([1.0, 2.0], jnp.float32)
|
||||
y = jnp.array([1, 2], jnp.int32)
|
||||
f = self.jit(lambda x, y: x.sum() + y.sum(), donate_argnums=(0, 1))
|
||||
@ -984,8 +987,8 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
foo_p.def_abstract_eval(lambda x: x)
|
||||
|
||||
jtu.check_raises(lambda: jit(foo)(1.0), NotImplementedError,
|
||||
"XLA translation rule for primitive 'foo' not found")
|
||||
jtu.check_raises_regexp(lambda: jit(foo)(1.0), NotImplementedError,
|
||||
".* rule for primitive 'foo' not found.*")
|
||||
|
||||
foo_p.def_impl(lambda x: x)
|
||||
ad.defjvp(foo_p, lambda g, x: foo(g))
|
||||
@ -2406,18 +2409,24 @@ class APITest(jtu.JaxTestCase):
|
||||
def g(x):
|
||||
return f(2, x)
|
||||
|
||||
jaxpr_subcomp = xla.jaxpr_subcomp
|
||||
xla_jaxpr_subcomp = xla.jaxpr_subcomp
|
||||
mlir_jaxpr_subcomp = mlir.jaxpr_subcomp
|
||||
|
||||
jaxprs = []
|
||||
def jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
|
||||
def xla_jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
|
||||
jaxprs.append(jaxpr)
|
||||
return jaxpr_subcomp(c, jaxpr, *args, **kwargs)
|
||||
return xla_jaxpr_subcomp(c, jaxpr, *args, **kwargs)
|
||||
def mlir_jaxpr_subcomp_and_collect(c, jaxpr, *args, **kwargs):
|
||||
jaxprs.append(jaxpr)
|
||||
return mlir_jaxpr_subcomp(c, jaxpr, *args, **kwargs)
|
||||
|
||||
try:
|
||||
xla.jaxpr_subcomp = jaxpr_subcomp_and_collect
|
||||
xla.jaxpr_subcomp = xla_jaxpr_subcomp_and_collect
|
||||
mlir.jaxpr_subcomp = mlir_jaxpr_subcomp_and_collect
|
||||
ans = g(3)
|
||||
finally:
|
||||
xla.jaxpr_subcomp = jaxpr_subcomp
|
||||
xla.jaxpr_subcomp = xla_jaxpr_subcomp
|
||||
mlir.jaxpr_subcomp = mlir_jaxpr_subcomp
|
||||
|
||||
self.assertEqual(ans, (7, 3))
|
||||
self.assertLen(jaxprs, 2)
|
||||
|
@ -21,7 +21,9 @@ import jax.numpy as jnp
|
||||
from jax import core, jit, lax, make_jaxpr
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import xla_bridge, xla_client
|
||||
xops = xla_client.ops
|
||||
xb = xla_bridge
|
||||
@ -137,6 +139,15 @@ dispatch.num_buffers_handlers[AbstractSparseArray] = lambda _: 2
|
||||
xla.xla_shape_handlers[AbstractSparseArray] = sparse_array_shape_handler
|
||||
xla.register_constant_handler(SparseArray, sparse_array_constant_handler)
|
||||
|
||||
def sparse_array_mlir_type_handler(a):
|
||||
return (
|
||||
ir.RankedTensorType.get(
|
||||
a.data_aval.shape, mlir.dtype_to_ir_type[a.data_aval.dtype]()),
|
||||
ir.RankedTensorType.get(
|
||||
a.indices_aval.shape, mlir.dtype_to_ir_type[a.indices_aval.dtype]()),
|
||||
)
|
||||
|
||||
mlir.ir_type_handlers[AbstractSparseArray] = sparse_array_mlir_type_handler
|
||||
|
||||
sp_indices_p = core.Primitive('sp_indices')
|
||||
|
||||
@ -155,6 +166,11 @@ def _sp_indices_translation_rule(ctx, avals_in, avals_out, data, indices):
|
||||
# because it leads to infinite recursion.
|
||||
xla.register_translation(sp_indices_p, _sp_indices_translation_rule)
|
||||
|
||||
def _sp_indices_mhlo_lowering(ctx, avals_in, avals_out, data_and_indices):
|
||||
return [data_and_indices[1]]
|
||||
|
||||
mlir.register_lowering(sp_indices_p, _sp_indices_mhlo_lowering)
|
||||
|
||||
sp_data_p = core.Primitive('sp_data')
|
||||
|
||||
@sp_data_p.def_impl
|
||||
@ -172,6 +188,11 @@ def _sp_data_translation_rule(ctx, avals_in, avals_out, data, indices):
|
||||
# because it leads to infinite recursion.
|
||||
xla.register_translation(sp_data_p, _sp_data_translation_rule)
|
||||
|
||||
def _sp_data_mhlo_lowering(ctx, avals_in, avals_out, data_and_indices):
|
||||
return [data_and_indices[0]]
|
||||
|
||||
mlir.register_lowering(sp_data_p, _sp_data_mhlo_lowering)
|
||||
|
||||
def identity(x):
|
||||
return identity_p.bind(x)
|
||||
|
||||
@ -189,6 +210,10 @@ xla.register_translation(
|
||||
identity_p, xla.lower_fun(_identity_impl, multiple_results=False,
|
||||
new_style=True))
|
||||
|
||||
|
||||
mlir.register_lowering(
|
||||
identity_p, mlir.lower_fun(_identity_impl, multiple_results=False))
|
||||
|
||||
def split(x):
|
||||
return split_p.bind(x)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user