mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #23394 from gspschmid:gschmid/ffi-support-token
PiperOrigin-RevId: 670652970
This commit is contained in:
commit
cda2408e14
@ -30,7 +30,7 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.lib import jaxlib
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, Shape
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
|
||||
@ -100,6 +100,14 @@ def include_dir() -> str:
|
||||
return os.path.join(jaxlib_dir, "include")
|
||||
|
||||
|
||||
def _aval_shape(aval: core.AbstractValue) -> Shape:
|
||||
return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error
|
||||
|
||||
|
||||
def _default_layouts(avals: Iterable[core.AbstractValue]) -> list[list[DimSize]]:
|
||||
return [list(reversed(range(len(_aval_shape(aval))))) for aval in avals]
|
||||
|
||||
|
||||
def ffi_lowering(
|
||||
call_target_name: str,
|
||||
*,
|
||||
@ -139,17 +147,17 @@ def ffi_lowering(
|
||||
if "result_types" not in kwargs:
|
||||
kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out]
|
||||
if operand_layouts is None:
|
||||
kwargs["operand_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_in) # pytype: disable=attribute-error
|
||||
kwargs["operand_layouts"] = _default_layouts(ctx.avals_in)
|
||||
else:
|
||||
kwargs["operand_layouts"] = operand_layouts
|
||||
if result_layouts is None:
|
||||
kwargs["result_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_out)
|
||||
kwargs["result_layouts"] = _default_layouts(ctx.avals_out)
|
||||
else:
|
||||
kwargs["result_layouts"] = result_layouts
|
||||
if "result_shapes" not in kwargs and not all(
|
||||
core.is_constant_shape(aval.shape) for aval in ctx.avals_out):
|
||||
core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out):
|
||||
kwargs["result_shapes"] = [
|
||||
mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, aval.shape))
|
||||
mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, _aval_shape(aval)))
|
||||
for aval in ctx.avals_out]
|
||||
|
||||
return mlir.custom_call(call_target_name, operands=operands, **kwargs).results # type: ignore
|
||||
@ -157,13 +165,23 @@ def ffi_lowering(
|
||||
return _lowering
|
||||
|
||||
|
||||
def _default_layouts(shapes: Iterable[Sequence[DimSize]]) -> list[list[DimSize]]:
|
||||
return [list(reversed(range(len(shape)))) for shape in shapes]
|
||||
ResultMetadata = DuckTypedArray | core.AbstractToken
|
||||
|
||||
|
||||
def _result_avals(results: Sequence[ResultMetadata]) -> tuple[core.AbstractValue, ...]:
|
||||
avals: list[core.AbstractValue] = []
|
||||
for result in results:
|
||||
if isinstance(result, core.AbstractToken):
|
||||
avals.append(result)
|
||||
else:
|
||||
_check_shape_dtype(result)
|
||||
avals.append(core.ShapedArray(result.shape, result.dtype))
|
||||
return tuple(avals)
|
||||
|
||||
|
||||
def ffi_call(
|
||||
target_name: str,
|
||||
result_shape_dtypes: DuckTypedArray | Sequence[DuckTypedArray],
|
||||
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
|
||||
*args: ArrayLike,
|
||||
vectorized: bool = False,
|
||||
**kwargs: Any,
|
||||
@ -189,6 +207,7 @@ def ffi_call(
|
||||
``dtype`` attributes which are expected to match the shape and dtype of
|
||||
the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often
|
||||
used to define the elements of ``result_shape_dtypes``.
|
||||
``jax.core.abstract_token`` may be used to represent a token-typed output.
|
||||
*args: the arguments passed to the custom call.
|
||||
vectorized: boolean specifying whether the callback function can operate in
|
||||
a vectorized manner, as described above.
|
||||
@ -201,12 +220,10 @@ def ffi_call(
|
||||
"""
|
||||
if isinstance(result_shape_dtypes, Sequence):
|
||||
multiple_results = True
|
||||
result_types = result_shape_dtypes
|
||||
result_avals = _result_avals(result_shape_dtypes)
|
||||
else:
|
||||
multiple_results = False
|
||||
result_types = (result_shape_dtypes,)
|
||||
map(_check_shape_dtype, result_types)
|
||||
result_avals = tuple(core.ShapedArray(x.shape, x.dtype) for x in result_types)
|
||||
result_avals = _result_avals((result_shape_dtypes,))
|
||||
results = ffi_call_p.bind(
|
||||
*args,
|
||||
result_avals=result_avals,
|
||||
@ -222,7 +239,7 @@ def ffi_call(
|
||||
|
||||
def ffi_call_abstract_eval(
|
||||
*avals_in,
|
||||
result_avals: tuple[core.ShapedArray, ...],
|
||||
result_avals: tuple[core.AbstractValue, ...],
|
||||
target_name: str,
|
||||
vectorized: bool,
|
||||
**kwargs: Any,
|
||||
@ -248,7 +265,7 @@ def ffi_call_transpose(*args, target_name, **kwargs):
|
||||
def ffi_call_lowering(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
*operands: ir.Value,
|
||||
result_avals: tuple[core.ShapedArray, ...],
|
||||
result_avals: tuple[core.AbstractValue, ...],
|
||||
target_name: str,
|
||||
vectorized: bool,
|
||||
**kwargs: Any,
|
||||
|
@ -31,6 +31,7 @@ from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
@ -153,6 +154,22 @@ class FfiTest(jtu.JaxTestCase):
|
||||
return
|
||||
self.fail("No custom_call found in the lowered IR")
|
||||
|
||||
def testToken(self):
|
||||
def fun():
|
||||
token = lax.create_token()
|
||||
return jex.ffi.ffi_call("test_ffi", core.abstract_token, token)
|
||||
|
||||
# Ensure that token inputs and outputs are translated to the correct type
|
||||
module = jax.jit(fun).lower().compiler_ir("stablehlo")
|
||||
for func in module.body.operations:
|
||||
for block in func.body.blocks:
|
||||
for op in block.operations:
|
||||
if op.OPERATION_NAME == "stablehlo.custom_call":
|
||||
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
|
||||
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
|
||||
return
|
||||
self.fail("No custom_call found in the lowered IR")
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(1,), (4,), (5,)],
|
||||
dtype=(np.int32,),
|
||||
|
Loading…
x
Reference in New Issue
Block a user