[shape_poly] Shape polymorphism support for approx_top_k

PiperOrigin-RevId: 543633818
This commit is contained in:
George Necula 2023-06-26 22:02:03 -07:00 committed by jax authors
parent 744a64fce6
commit cb42fae810
3 changed files with 61 additions and 45 deletions

View File

@ -2008,7 +2008,7 @@ def custom_call(
out_types: Sequence[ir.Type],
operands: Sequence[ir.Value],
*,
backend_config: Optional[str] = None,
backend_config: Union[str, dict] = "",
has_side_effect: bool = False,
result_shapes: Optional[Sequence[ir.Value]] = None,
called_computations: Sequence[str] = (),
@ -2023,11 +2023,27 @@ def custom_call(
number of the results.
called_computations: the list of function names called by the custom call.
"""
if backend_config is None:
backend_config_attr = ir.StringAttr.get("")
elif isinstance(backend_config, str):
backend_config_attr = ir.StringAttr.get(backend_config)
elif isinstance(backend_config, dict):
# TODO(necula): it seems that the CustomCallOp constructor requires that
# backend_config_attr be a string attribute, even though in some cases we
# need it to be a DictAttr, e.g., for ApproxTopK on TPU.
# "Verification failed: 'stablehlo.custom_call' op attribute 'backend_config' failed to satisfy constraint: string attribute"
# To workaround this limitation we first set it to the empty string and we
# use an unregistered attribute mhlo.backend_config to hold the DictAttr.
# We must also use api_version=1 to ensure that mhlo.backend_config is
# handled properly.
backend_config_attr = ir.StringAttr.get("")
api_version = 1
else:
raise ValueError("custom_call backend_config unexpected type: " + str(backend_config))
attributes = dict(
call_target_name=ir.StringAttr.get(call_target_name),
has_side_effect=ir.BoolAttr.get(has_side_effect),
backend_config=ir.StringAttr.get(
"" if backend_config is None else backend_config),
backend_config=backend_config_attr,
api_version=i32_attr(api_version),
called_computations=ir.ArrayAttr.get([
ir.FlatSymbolRefAttr.get(name) for name in called_computations]),
@ -2043,7 +2059,12 @@ def custom_call(
dtype=np.int64))
operands = list(operands) + list(result_shapes)
return hlo.CustomCallOp.build_generic(results=out_types, operands=operands, attributes=attributes)
op = hlo.CustomCallOp.build_generic(results=out_types, operands=operands, attributes=attributes)
if isinstance(backend_config, dict):
backend_config_attr = ir.DictAttr.get(backend_config)
op.operation.attributes["mhlo.backend_config"] = backend_config_attr
return op
def reduce_window(
ctx: LoweringRuleContext, *,

View File

@ -81,13 +81,13 @@ from jax._src import dispatch
from jax._src import dtypes
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lax import lax
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import hlo
from jax.interpreters import mlir
Array = Any
@ -316,22 +316,21 @@ def _approx_top_k_lowering(ctx, operand, *, k,
op_dims = op_shape
op_type = mlir.dtype_to_ir_type(ctx.avals_in[0].dtype)
index_type = ir.IntegerType.get_signless(32)
recall_type = ir.F32Type.get()
if reduction_dimension < 0:
reduction_dimension = len(op_dims) + reduction_dimension
comparator = _comparator_builder_mlir(ctx, op_type, is_max_k)
iota = hlo.IotaOp(ir.RankedTensorType.get(op_dims, index_type),
reduction_dimension)
iota = mlir.iota(ctx, core.ShapedArray(ctx.avals_in[0].shape, np.int32),
dimension=reduction_dimension)
init_arg = hlo.ConstantOp(ir.DenseElementsAttr.get(np.int32(-1)))
init_arg = hlo.ConstantOp(ir.DenseElementsAttr.get(np.int32(-1))).result
# Can't write bf16 literals, so we write a f64 literal and convert it.
init_val_literal = _get_init_val_literal(np.float64, is_max_k)
init_val_array = np.array(init_val_literal, dtype=np.float64).reshape(())
init_val = mlir.ir_constant(init_val_array)
init_val = hlo.ConvertOp(ir.RankedTensorType.get([],
mlir.dtype_to_ir_type(ctx.avals_in[0].dtype)), init_val)
mlir.dtype_to_ir_type(ctx.avals_in[0].dtype)), init_val).result
backend_config = {
"top_k" : mlir.i64_attr(k),
@ -343,42 +342,38 @@ def _approx_top_k_lowering(ctx, operand, *, k,
if fallback:
backend_config["is_fallback"] = mlir.ir.BoolAttr.get(fallback)
out = hlo.CustomCallOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
[operand, iota, init_val, init_arg],
call_target_name=b"ApproxTopK",
called_computations=mlir.ir.ArrayAttr.get(
[mlir.ir.FlatSymbolRefAttr.get(comparator.name.value)]))
backend_config_attr = mlir.ir.DictAttr.get(backend_config,
ctx.module_context.context)
out.operation.attributes["mhlo.backend_config"] = backend_config_attr
if xc.mlir_api_version >= 51: # jaxlib >= 0.4.14
if all(core.is_constant_shape(aval_out.shape) for aval_out in ctx.avals_out):
result_shapes = None
else:
result_shapes = [
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape))
for aval_out in ctx.avals_out]
out = mlir.custom_call(
"ApproxTopK",
[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
[operand, iota, init_val, init_arg],
called_computations=[comparator.name.value],
backend_config=backend_config,
result_shapes=result_shapes)
else:
# Older versions do not support has_side_effect attribute; we just use
# the old lowering code.
if any(not core.is_constant_shape(aval_out.shape) for aval_out in ctx.avals_out):
raise ValueError("approx_top_k not supported with shape polymorphism; "
"try upgrading jaxlib")
out = hlo.CustomCallOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
[operand, iota, init_val, init_arg],
call_target_name=b"ApproxTopK",
called_computations=mlir.ir.ArrayAttr.get(
[mlir.ir.FlatSymbolRefAttr.get(comparator.name.value)]))
backend_config_attr = mlir.ir.DictAttr.get(backend_config,
ctx.module_context.context)
out.operation.attributes["mhlo.backend_config"] = backend_config_attr
return out.results
def _approx_top_k_fallback_translation(ctx, avals_in, avals_out, operand, *, k,
reduction_dimension, recall_target,
is_max_k, reduction_input_size_override,
aggregate_to_topk):
c = ctx.builder
op_shape = c.get_shape(operand)
if not op_shape.is_array():
raise ValueError(f'operand must be an array, but was {op_shape}')
op_dims = op_shape.dimensions()
op_type = op_shape.element_type()
if reduction_dimension < 0:
reduction_dimension = len(op_dims) + reduction_dimension
comparator = _comparator_builder(op_type, is_max_k)
iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims),
reduction_dimension)
init_val_literal = _get_init_val_literal(op_type, is_max_k)
init_val = xc.ops.Constant(c, init_val_literal)
init_arg = xc.ops.Constant(c, np.int32(-1))
out = xc.ops.ApproxTopKFallback(c, [operand, iota], [init_val, init_arg], k,
reduction_dimension, comparator,
recall_target, aggregate_to_topk,
reduction_input_size_override)
return xla.xla_destructure(c, out)
def _approx_top_k_batch_rule(batch_operands, batch_axes, *, k,
reduction_dimension, recall_target, is_max_k,
reduction_input_size_override, aggregate_to_topk):

View File

@ -2898,7 +2898,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
raise unittest.SkipTest(
"native lowering with shape polymorphism requires additional StableHLO feature support")
if "top_k" in harness.fullname:
if "top_k" in harness.fullname and "approx_top_k" not in harness.fullname:
# https://github.com/openxla/stablehlo/issues/1255: need DynamicTopK
raise unittest.SkipTest("native lowering with shape polymorphism not implemented for top_k")