mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[shape_poly] Shape polymorphism support for approx_top_k
PiperOrigin-RevId: 543633818
This commit is contained in:
parent
744a64fce6
commit
cb42fae810
@ -2008,7 +2008,7 @@ def custom_call(
|
||||
out_types: Sequence[ir.Type],
|
||||
operands: Sequence[ir.Value],
|
||||
*,
|
||||
backend_config: Optional[str] = None,
|
||||
backend_config: Union[str, dict] = "",
|
||||
has_side_effect: bool = False,
|
||||
result_shapes: Optional[Sequence[ir.Value]] = None,
|
||||
called_computations: Sequence[str] = (),
|
||||
@ -2023,11 +2023,27 @@ def custom_call(
|
||||
number of the results.
|
||||
called_computations: the list of function names called by the custom call.
|
||||
"""
|
||||
if backend_config is None:
|
||||
backend_config_attr = ir.StringAttr.get("")
|
||||
elif isinstance(backend_config, str):
|
||||
backend_config_attr = ir.StringAttr.get(backend_config)
|
||||
elif isinstance(backend_config, dict):
|
||||
# TODO(necula): it seems that the CustomCallOp constructor requires that
|
||||
# backend_config_attr be a string attribute, even though in some cases we
|
||||
# need it to be a DictAttr, e.g., for ApproxTopK on TPU.
|
||||
# "Verification failed: 'stablehlo.custom_call' op attribute 'backend_config' failed to satisfy constraint: string attribute"
|
||||
# To workaround this limitation we first set it to the empty string and we
|
||||
# use an unregistered attribute mhlo.backend_config to hold the DictAttr.
|
||||
# We must also use api_version=1 to ensure that mhlo.backend_config is
|
||||
# handled properly.
|
||||
backend_config_attr = ir.StringAttr.get("")
|
||||
api_version = 1
|
||||
else:
|
||||
raise ValueError("custom_call backend_config unexpected type: " + str(backend_config))
|
||||
attributes = dict(
|
||||
call_target_name=ir.StringAttr.get(call_target_name),
|
||||
has_side_effect=ir.BoolAttr.get(has_side_effect),
|
||||
backend_config=ir.StringAttr.get(
|
||||
"" if backend_config is None else backend_config),
|
||||
backend_config=backend_config_attr,
|
||||
api_version=i32_attr(api_version),
|
||||
called_computations=ir.ArrayAttr.get([
|
||||
ir.FlatSymbolRefAttr.get(name) for name in called_computations]),
|
||||
@ -2043,7 +2059,12 @@ def custom_call(
|
||||
dtype=np.int64))
|
||||
operands = list(operands) + list(result_shapes)
|
||||
|
||||
return hlo.CustomCallOp.build_generic(results=out_types, operands=operands, attributes=attributes)
|
||||
op = hlo.CustomCallOp.build_generic(results=out_types, operands=operands, attributes=attributes)
|
||||
if isinstance(backend_config, dict):
|
||||
backend_config_attr = ir.DictAttr.get(backend_config)
|
||||
op.operation.attributes["mhlo.backend_config"] = backend_config_attr
|
||||
return op
|
||||
|
||||
|
||||
def reduce_window(
|
||||
ctx: LoweringRuleContext, *,
|
||||
|
@ -81,13 +81,13 @@ from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax.interpreters import mlir
|
||||
|
||||
|
||||
Array = Any
|
||||
@ -316,22 +316,21 @@ def _approx_top_k_lowering(ctx, operand, *, k,
|
||||
|
||||
op_dims = op_shape
|
||||
op_type = mlir.dtype_to_ir_type(ctx.avals_in[0].dtype)
|
||||
index_type = ir.IntegerType.get_signless(32)
|
||||
recall_type = ir.F32Type.get()
|
||||
if reduction_dimension < 0:
|
||||
reduction_dimension = len(op_dims) + reduction_dimension
|
||||
|
||||
comparator = _comparator_builder_mlir(ctx, op_type, is_max_k)
|
||||
iota = hlo.IotaOp(ir.RankedTensorType.get(op_dims, index_type),
|
||||
reduction_dimension)
|
||||
iota = mlir.iota(ctx, core.ShapedArray(ctx.avals_in[0].shape, np.int32),
|
||||
dimension=reduction_dimension)
|
||||
|
||||
init_arg = hlo.ConstantOp(ir.DenseElementsAttr.get(np.int32(-1)))
|
||||
init_arg = hlo.ConstantOp(ir.DenseElementsAttr.get(np.int32(-1))).result
|
||||
# Can't write bf16 literals, so we write a f64 literal and convert it.
|
||||
init_val_literal = _get_init_val_literal(np.float64, is_max_k)
|
||||
init_val_array = np.array(init_val_literal, dtype=np.float64).reshape(())
|
||||
init_val = mlir.ir_constant(init_val_array)
|
||||
init_val = hlo.ConvertOp(ir.RankedTensorType.get([],
|
||||
mlir.dtype_to_ir_type(ctx.avals_in[0].dtype)), init_val)
|
||||
mlir.dtype_to_ir_type(ctx.avals_in[0].dtype)), init_val).result
|
||||
|
||||
backend_config = {
|
||||
"top_k" : mlir.i64_attr(k),
|
||||
@ -343,42 +342,38 @@ def _approx_top_k_lowering(ctx, operand, *, k,
|
||||
if fallback:
|
||||
backend_config["is_fallback"] = mlir.ir.BoolAttr.get(fallback)
|
||||
|
||||
out = hlo.CustomCallOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
[operand, iota, init_val, init_arg],
|
||||
call_target_name=b"ApproxTopK",
|
||||
called_computations=mlir.ir.ArrayAttr.get(
|
||||
[mlir.ir.FlatSymbolRefAttr.get(comparator.name.value)]))
|
||||
backend_config_attr = mlir.ir.DictAttr.get(backend_config,
|
||||
ctx.module_context.context)
|
||||
out.operation.attributes["mhlo.backend_config"] = backend_config_attr
|
||||
if xc.mlir_api_version >= 51: # jaxlib >= 0.4.14
|
||||
if all(core.is_constant_shape(aval_out.shape) for aval_out in ctx.avals_out):
|
||||
result_shapes = None
|
||||
else:
|
||||
result_shapes = [
|
||||
mlir.shape_tensor(mlir.eval_dynamic_shape(ctx, aval_out.shape))
|
||||
for aval_out in ctx.avals_out]
|
||||
|
||||
out = mlir.custom_call(
|
||||
"ApproxTopK",
|
||||
[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
[operand, iota, init_val, init_arg],
|
||||
called_computations=[comparator.name.value],
|
||||
backend_config=backend_config,
|
||||
result_shapes=result_shapes)
|
||||
else:
|
||||
# Older versions do not support has_side_effect attribute; we just use
|
||||
# the old lowering code.
|
||||
if any(not core.is_constant_shape(aval_out.shape) for aval_out in ctx.avals_out):
|
||||
raise ValueError("approx_top_k not supported with shape polymorphism; "
|
||||
"try upgrading jaxlib")
|
||||
out = hlo.CustomCallOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
||||
[operand, iota, init_val, init_arg],
|
||||
call_target_name=b"ApproxTopK",
|
||||
called_computations=mlir.ir.ArrayAttr.get(
|
||||
[mlir.ir.FlatSymbolRefAttr.get(comparator.name.value)]))
|
||||
backend_config_attr = mlir.ir.DictAttr.get(backend_config,
|
||||
ctx.module_context.context)
|
||||
out.operation.attributes["mhlo.backend_config"] = backend_config_attr
|
||||
|
||||
return out.results
|
||||
|
||||
def _approx_top_k_fallback_translation(ctx, avals_in, avals_out, operand, *, k,
|
||||
reduction_dimension, recall_target,
|
||||
is_max_k, reduction_input_size_override,
|
||||
aggregate_to_topk):
|
||||
c = ctx.builder
|
||||
op_shape = c.get_shape(operand)
|
||||
if not op_shape.is_array():
|
||||
raise ValueError(f'operand must be an array, but was {op_shape}')
|
||||
op_dims = op_shape.dimensions()
|
||||
op_type = op_shape.element_type()
|
||||
|
||||
if reduction_dimension < 0:
|
||||
reduction_dimension = len(op_dims) + reduction_dimension
|
||||
comparator = _comparator_builder(op_type, is_max_k)
|
||||
iota = xc.ops.Iota(c, xc.Shape.array_shape(np.dtype(np.int32), op_dims),
|
||||
reduction_dimension)
|
||||
init_val_literal = _get_init_val_literal(op_type, is_max_k)
|
||||
init_val = xc.ops.Constant(c, init_val_literal)
|
||||
init_arg = xc.ops.Constant(c, np.int32(-1))
|
||||
out = xc.ops.ApproxTopKFallback(c, [operand, iota], [init_val, init_arg], k,
|
||||
reduction_dimension, comparator,
|
||||
recall_target, aggregate_to_topk,
|
||||
reduction_input_size_override)
|
||||
return xla.xla_destructure(c, out)
|
||||
|
||||
|
||||
def _approx_top_k_batch_rule(batch_operands, batch_axes, *, k,
|
||||
reduction_dimension, recall_target, is_max_k,
|
||||
reduction_input_size_override, aggregate_to_topk):
|
||||
|
@ -2898,7 +2898,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
raise unittest.SkipTest(
|
||||
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
||||
|
||||
if "top_k" in harness.fullname:
|
||||
if "top_k" in harness.fullname and "approx_top_k" not in harness.fullname:
|
||||
# https://github.com/openxla/stablehlo/issues/1255: need DynamicTopK
|
||||
raise unittest.SkipTest("native lowering with shape polymorphism not implemented for top_k")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user