mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Plumbing for dynamic shapes for custom calls.
PiperOrigin-RevId: 521439418
This commit is contained in:
parent
0d32724882
commit
607c7c1fdd
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
# Helpers for building MLIR operators
|
||||
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
@ -30,6 +29,7 @@ def custom_call(
|
||||
has_side_effect: bool = False,
|
||||
api_version: int = 2,
|
||||
operand_output_aliases: Dict[int, int] = {},
|
||||
indices_of_shape_operands: Sequence[int] = (),
|
||||
) -> Union[ir.Value, Sequence[ir.Value]]:
|
||||
"""Less-verbose helper for building a CustomCallOp.
|
||||
|
||||
@ -40,12 +40,14 @@ def custom_call(
|
||||
...
|
||||
operand_output_alias: a dictionary mapping input numbers -> output numbers
|
||||
that must alias.
|
||||
indices_of_shape_operands: in presence of dynamic shapes, must pass in the
|
||||
output shapes as some of the operands. These are the indices of those
|
||||
operands.
|
||||
"""
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
out = hlo.CustomCallOp(
|
||||
(out_types
|
||||
if len(out_types) == 1 else [ir.TupleType.get_tuple(out_types)]),
|
||||
operands,
|
||||
results = (out_types
|
||||
if len(out_types) == 1 else [ir.TupleType.get_tuple(out_types)])
|
||||
attributes = dict(
|
||||
call_target_name=ir.StringAttr.get(call_target_name),
|
||||
has_side_effect=ir.BoolAttr.get(has_side_effect),
|
||||
backend_config=ir.StringAttr.get(
|
||||
@ -68,7 +70,20 @@ def custom_call(
|
||||
operand_index=input,
|
||||
operand_tuple_indices=[])
|
||||
for input, output in operand_output_aliases.items()
|
||||
]))
|
||||
])
|
||||
)
|
||||
if indices_of_shape_operands:
|
||||
attributes["indices_of_shape_operands"] = ir.DenseIntElementsAttr.get(
|
||||
np.asarray(indices_of_shape_operands, dtype=np.int64))
|
||||
|
||||
# TODO(necula): CustomCall constructor does not yet support
|
||||
# indices_of_shape_operands, so we use the generic builder
|
||||
|
||||
# The generic builder is pickier about the type of the operands, and some
|
||||
# of the callers did not call .result
|
||||
operands = [opnd if isinstance(opnd, ir.Value) else opnd.result
|
||||
for opnd in operands]
|
||||
out = hlo.CustomCallOp.build_generic(results=results, operands=operands, attributes=attributes)
|
||||
if len(out_types) == 1:
|
||||
return out.result
|
||||
else:
|
||||
|
Loading…
x
Reference in New Issue
Block a user