Merge pull request #23458 from dfm:ffi-layouts

PiperOrigin-RevId: 671465163
This commit is contained in:
jax authors 2024-09-05 12:19:07 -07:00
commit 0cfb9ac35a
2 changed files with 63 additions and 42 deletions

View File

@ -14,7 +14,7 @@
from __future__ import annotations
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Mapping, Sequence
import ctypes
import functools
import os
@ -27,12 +27,14 @@ from jax._src.callback import _check_shape_dtype, callback_batching_rule
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.layout import DeviceLocalLayout
from jax._src.lib import jaxlib
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, Shape
from jax._src.typing import Array, ArrayLike, DuckTypedArray, Shape
map, unsafe_map = util.safe_map, map
FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None
def register_ffi_target(
@ -104,15 +106,24 @@ def _aval_shape(aval: core.AbstractValue) -> Shape:
return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error
def _default_layouts(avals: Iterable[core.AbstractValue]) -> list[list[DimSize]]:
return [list(reversed(range(len(_aval_shape(aval))))) for aval in avals]
def _convert_layout(aval: core.AbstractValue,
layout: FfiLayoutOptions = None) -> Sequence[int]:
"""Convert a layout to the minor-to-major order used by the custom call API."""
if layout is None:
return list(reversed(range(len(_aval_shape(aval)))))
elif isinstance(layout, DeviceLocalLayout):
if layout._tiling is not None:
raise ValueError("The FFI does not support layouts with tiling")
return layout.major_to_minor[::-1]
else:
return layout
def ffi_lowering(
call_target_name: str,
*,
operand_layouts: Sequence[Sequence[DimSize]] | None = None,
result_layouts: Sequence[Sequence[DimSize]] | None = None,
operand_layouts: Sequence[FfiLayoutOptions] | None = None,
result_layouts: Sequence[FfiLayoutOptions] | None = None,
backend_config: Mapping[str, ir.Attribute] | None = None,
**lowering_args: Any
) -> mlir.LoweringRule:
@ -147,13 +158,15 @@ def ffi_lowering(
if "result_types" not in kwargs:
kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out]
if operand_layouts is None:
kwargs["operand_layouts"] = _default_layouts(ctx.avals_in)
kwargs["operand_layouts"] = map(_convert_layout, ctx.avals_in)
else:
kwargs["operand_layouts"] = operand_layouts
kwargs["operand_layouts"] = [
_convert_layout(*args) for args in zip(ctx.avals_in, operand_layouts)]
if result_layouts is None:
kwargs["result_layouts"] = _default_layouts(ctx.avals_out)
kwargs["result_layouts"] = map(_convert_layout, ctx.avals_out)
else:
kwargs["result_layouts"] = result_layouts
kwargs["result_layouts"] = [
_convert_layout(*args) for args in zip(ctx.avals_out, result_layouts)]
if "result_shapes" not in kwargs and not all(
core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out):
kwargs["result_shapes"] = [

View File

@ -31,6 +31,7 @@ from jax._src import prng
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.interpreters import mlir
from jax._src.layout import DeviceLocalLayout
from jax._src.lib.mlir.dialects import hlo
jax.config.parse_flags_with_absl()
@ -97,33 +98,50 @@ class RandomTest(jtu.JaxTestCase):
class FfiTest(jtu.JaxTestCase):
def find_custom_call_in_module(self, module):
for func in module.body.operations:
for block in func.body.blocks:
for op in block.operations:
if op.OPERATION_NAME == "stablehlo.custom_call":
return op
self.fail("No custom_call found in the lowered IR")
def testHeadersExist(self):
base_dir = os.path.join(jex.ffi.include_dir(), "xla", "ffi", "api")
for header in ["c_api.h", "api.h", "ffi.h"]:
self.assertTrue(os.path.exists(os.path.join(base_dir, header)))
def testLoweringLayouts(self):
@parameterized.parameters([
(tuple(range(3)), tuple(range(3))),
(None, tuple(reversed(range(3)))),
(DeviceLocalLayout(tuple(range(3))), tuple(reversed(range(3)))),
])
def testLoweringLayouts(self, layout_spec, expected_layout):
# Regression test to ensure that the lowering rule properly captures
# layouts.
def lowering_rule(ctx, x):
aval, = ctx.avals_in
ndim = len(aval.shape)
layout = tuple(range(ndim))
return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout],
result_layouts=[layout])(ctx, x)
return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec],
result_layouts=[layout_spec])(ctx, x)
prim = core.Primitive("test_ffi")
prim.def_impl(lambda x: x)
prim.def_abstract_eval(lambda x: x)
mlir.register_lowering(prim, lowering_rule)
x = jnp.linspace(0, 1, 5)
x = jnp.ones((3,) * len(expected_layout))
lowered = jax.jit(prim.bind).lower(x)
module = lowered.compiler_ir("stablehlo")
for func in module.body.operations:
for block in func.body.blocks:
for op in block.operations:
if op.OPERATION_NAME == "stablehlo.custom_call":
self.assertIn("operand_layouts", op.attributes)
self.assertIn("result_layouts", op.attributes)
op = self.find_custom_call_in_module(module)
self.assertIn("operand_layouts", op.attributes)
self.assertIn("result_layouts", op.attributes)
text = lowered.as_text()
expected = ", ".join(map(str, expected_layout))
pattern = rf"operand_layouts = \[dense<\[{expected}\]>"
self.assertRegex(text, pattern)
pattern = rf"result_layouts = \[dense<\[{expected}\]>"
self.assertRegex(text, pattern)
@parameterized.parameters([
(True, mlir.ir.BoolAttr.get),
@ -140,19 +158,14 @@ class FfiTest(jtu.JaxTestCase):
# Here we inspect the lowered IR to test that the parameter has been
# serialized with the appropriate type.
module = jax.jit(fun).lower(0.5).compiler_ir("stablehlo")
for func in module.body.operations:
for block in func.body.blocks:
for op in block.operations:
if op.OPERATION_NAME == "stablehlo.custom_call":
config = op.attributes["mhlo.backend_config"]
self.assertIsInstance(config, mlir.ir.DictAttr)
self.assertIn("param", config)
with mlir.make_ir_context(), mlir.ir.Location.unknown():
expected = expected_builder(param)
self.assertEqual(type(config["param"]), type(expected))
self.assertTrue(expected.type.isinstance(config["param"].type))
return
self.fail("No custom_call found in the lowered IR")
op = self.find_custom_call_in_module(module)
config = op.attributes["mhlo.backend_config"]
self.assertIsInstance(config, mlir.ir.DictAttr)
self.assertIn("param", config)
with mlir.make_ir_context(), mlir.ir.Location.unknown():
expected = expected_builder(param)
self.assertEqual(type(config["param"]), type(expected))
self.assertTrue(expected.type.isinstance(config["param"].type))
def testToken(self):
def fun():
@ -161,14 +174,9 @@ class FfiTest(jtu.JaxTestCase):
# Ensure that token inputs and outputs are translated to the correct type
module = jax.jit(fun).lower().compiler_ir("stablehlo")
for func in module.body.operations:
for block in func.body.blocks:
for op in block.operations:
if op.OPERATION_NAME == "stablehlo.custom_call":
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
return
self.fail("No custom_call found in the lowered IR")
op = self.find_custom_call_in_module(module)
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
@jtu.sample_product(
shape=[(1,), (4,), (5,)],