Merge pull request #23394 from gspschmid:gschmid/ffi-support-token

PiperOrigin-RevId: 670652970
This commit is contained in:
jax authors 2024-09-03 12:06:07 -07:00
commit cda2408e14
2 changed files with 48 additions and 14 deletions

View File

@ -30,7 +30,7 @@ from jax._src.interpreters import mlir
from jax._src.lib import jaxlib
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, Shape
map, unsafe_map = util.safe_map, map
@ -100,6 +100,14 @@ def include_dir() -> str:
return os.path.join(jaxlib_dir, "include")
def _aval_shape(aval: core.AbstractValue) -> Shape:
return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error
def _default_layouts(avals: Iterable[core.AbstractValue]) -> list[list[DimSize]]:
return [list(reversed(range(len(_aval_shape(aval))))) for aval in avals]
def ffi_lowering(
call_target_name: str,
*,
@ -139,17 +147,17 @@ def ffi_lowering(
if "result_types" not in kwargs:
kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out]
if operand_layouts is None:
kwargs["operand_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_in) # pytype: disable=attribute-error
kwargs["operand_layouts"] = _default_layouts(ctx.avals_in)
else:
kwargs["operand_layouts"] = operand_layouts
if result_layouts is None:
kwargs["result_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_out)
kwargs["result_layouts"] = _default_layouts(ctx.avals_out)
else:
kwargs["result_layouts"] = result_layouts
if "result_shapes" not in kwargs and not all(
core.is_constant_shape(aval.shape) for aval in ctx.avals_out):
core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out):
kwargs["result_shapes"] = [
mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, aval.shape))
mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, _aval_shape(aval)))
for aval in ctx.avals_out]
return mlir.custom_call(call_target_name, operands=operands, **kwargs).results # type: ignore
@ -157,13 +165,23 @@ def ffi_lowering(
return _lowering
def _default_layouts(shapes: Iterable[Sequence[DimSize]]) -> list[list[DimSize]]:
return [list(reversed(range(len(shape)))) for shape in shapes]
ResultMetadata = DuckTypedArray | core.AbstractToken
def _result_avals(results: Sequence[ResultMetadata]) -> tuple[core.AbstractValue, ...]:
avals: list[core.AbstractValue] = []
for result in results:
if isinstance(result, core.AbstractToken):
avals.append(result)
else:
_check_shape_dtype(result)
avals.append(core.ShapedArray(result.shape, result.dtype))
return tuple(avals)
def ffi_call(
target_name: str,
result_shape_dtypes: DuckTypedArray | Sequence[DuckTypedArray],
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
*args: ArrayLike,
vectorized: bool = False,
**kwargs: Any,
@ -189,6 +207,7 @@ def ffi_call(
``dtype`` attributes which are expected to match the shape and dtype of
the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often
used to define the elements of ``result_shape_dtypes``.
``jax.core.abstract_token`` may be used to represent a token-typed output.
*args: the arguments passed to the custom call.
vectorized: boolean specifying whether the callback function can operate in
a vectorized manner, as described above.
@ -201,12 +220,10 @@ def ffi_call(
"""
if isinstance(result_shape_dtypes, Sequence):
multiple_results = True
result_types = result_shape_dtypes
result_avals = _result_avals(result_shape_dtypes)
else:
multiple_results = False
result_types = (result_shape_dtypes,)
map(_check_shape_dtype, result_types)
result_avals = tuple(core.ShapedArray(x.shape, x.dtype) for x in result_types)
result_avals = _result_avals((result_shape_dtypes,))
results = ffi_call_p.bind(
*args,
result_avals=result_avals,
@ -222,7 +239,7 @@ def ffi_call(
def ffi_call_abstract_eval(
*avals_in,
result_avals: tuple[core.ShapedArray, ...],
result_avals: tuple[core.AbstractValue, ...],
target_name: str,
vectorized: bool,
**kwargs: Any,
@ -248,7 +265,7 @@ def ffi_call_transpose(*args, target_name, **kwargs):
def ffi_call_lowering(
ctx: mlir.LoweringRuleContext,
*operands: ir.Value,
result_avals: tuple[core.ShapedArray, ...],
result_avals: tuple[core.AbstractValue, ...],
target_name: str,
vectorized: bool,
**kwargs: Any,

View File

@ -31,6 +31,7 @@ from jax._src import prng
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.interpreters import mlir
from jax._src.lib.mlir.dialects import hlo
jax.config.parse_flags_with_absl()
@ -153,6 +154,22 @@ class FfiTest(jtu.JaxTestCase):
return
self.fail("No custom_call found in the lowered IR")
def testToken(self):
def fun():
token = lax.create_token()
return jex.ffi.ffi_call("test_ffi", core.abstract_token, token)
# Ensure that token inputs and outputs are translated to the correct type
module = jax.jit(fun).lower().compiler_ir("stablehlo")
for func in module.body.operations:
for block in func.body.blocks:
for op in block.operations:
if op.OPERATION_NAME == "stablehlo.custom_call":
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
return
self.fail("No custom_call found in the lowered IR")
@jtu.sample_product(
shape=[(1,), (4,), (5,)],
dtype=(np.int32,),