mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #23458 from dfm:ffi-layouts
PiperOrigin-RevId: 671465163
This commit is contained in:
commit
0cfb9ac35a
@ -14,7 +14,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from collections.abc import Mapping, Sequence
|
||||
import ctypes
|
||||
import functools
|
||||
import os
|
||||
@ -27,12 +27,14 @@ from jax._src.callback import _check_shape_dtype, callback_batching_rule
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.layout import DeviceLocalLayout
|
||||
from jax._src.lib import jaxlib
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, Shape
|
||||
from jax._src.typing import Array, ArrayLike, DuckTypedArray, Shape
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None
|
||||
|
||||
|
||||
def register_ffi_target(
|
||||
@ -104,15 +106,24 @@ def _aval_shape(aval: core.AbstractValue) -> Shape:
|
||||
return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error
|
||||
|
||||
|
||||
def _default_layouts(avals: Iterable[core.AbstractValue]) -> list[list[DimSize]]:
|
||||
return [list(reversed(range(len(_aval_shape(aval))))) for aval in avals]
|
||||
def _convert_layout(aval: core.AbstractValue,
|
||||
layout: FfiLayoutOptions = None) -> Sequence[int]:
|
||||
"""Convert a layout to the minor-to-major order used by the custom call API."""
|
||||
if layout is None:
|
||||
return list(reversed(range(len(_aval_shape(aval)))))
|
||||
elif isinstance(layout, DeviceLocalLayout):
|
||||
if layout._tiling is not None:
|
||||
raise ValueError("The FFI does not support layouts with tiling")
|
||||
return layout.major_to_minor[::-1]
|
||||
else:
|
||||
return layout
|
||||
|
||||
|
||||
def ffi_lowering(
|
||||
call_target_name: str,
|
||||
*,
|
||||
operand_layouts: Sequence[Sequence[DimSize]] | None = None,
|
||||
result_layouts: Sequence[Sequence[DimSize]] | None = None,
|
||||
operand_layouts: Sequence[FfiLayoutOptions] | None = None,
|
||||
result_layouts: Sequence[FfiLayoutOptions] | None = None,
|
||||
backend_config: Mapping[str, ir.Attribute] | None = None,
|
||||
**lowering_args: Any
|
||||
) -> mlir.LoweringRule:
|
||||
@ -147,13 +158,15 @@ def ffi_lowering(
|
||||
if "result_types" not in kwargs:
|
||||
kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out]
|
||||
if operand_layouts is None:
|
||||
kwargs["operand_layouts"] = _default_layouts(ctx.avals_in)
|
||||
kwargs["operand_layouts"] = map(_convert_layout, ctx.avals_in)
|
||||
else:
|
||||
kwargs["operand_layouts"] = operand_layouts
|
||||
kwargs["operand_layouts"] = [
|
||||
_convert_layout(*args) for args in zip(ctx.avals_in, operand_layouts)]
|
||||
if result_layouts is None:
|
||||
kwargs["result_layouts"] = _default_layouts(ctx.avals_out)
|
||||
kwargs["result_layouts"] = map(_convert_layout, ctx.avals_out)
|
||||
else:
|
||||
kwargs["result_layouts"] = result_layouts
|
||||
kwargs["result_layouts"] = [
|
||||
_convert_layout(*args) for args in zip(ctx.avals_out, result_layouts)]
|
||||
if "result_shapes" not in kwargs and not all(
|
||||
core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out):
|
||||
kwargs["result_shapes"] = [
|
||||
|
@ -31,6 +31,7 @@ from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.layout import DeviceLocalLayout
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
@ -97,33 +98,50 @@ class RandomTest(jtu.JaxTestCase):
|
||||
|
||||
class FfiTest(jtu.JaxTestCase):
|
||||
|
||||
def find_custom_call_in_module(self, module):
|
||||
for func in module.body.operations:
|
||||
for block in func.body.blocks:
|
||||
for op in block.operations:
|
||||
if op.OPERATION_NAME == "stablehlo.custom_call":
|
||||
return op
|
||||
self.fail("No custom_call found in the lowered IR")
|
||||
|
||||
def testHeadersExist(self):
|
||||
base_dir = os.path.join(jex.ffi.include_dir(), "xla", "ffi", "api")
|
||||
for header in ["c_api.h", "api.h", "ffi.h"]:
|
||||
self.assertTrue(os.path.exists(os.path.join(base_dir, header)))
|
||||
|
||||
def testLoweringLayouts(self):
|
||||
@parameterized.parameters([
|
||||
(tuple(range(3)), tuple(range(3))),
|
||||
(None, tuple(reversed(range(3)))),
|
||||
(DeviceLocalLayout(tuple(range(3))), tuple(reversed(range(3)))),
|
||||
])
|
||||
def testLoweringLayouts(self, layout_spec, expected_layout):
|
||||
# Regression test to ensure that the lowering rule properly captures
|
||||
# layouts.
|
||||
def lowering_rule(ctx, x):
|
||||
aval, = ctx.avals_in
|
||||
ndim = len(aval.shape)
|
||||
layout = tuple(range(ndim))
|
||||
return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout],
|
||||
result_layouts=[layout])(ctx, x)
|
||||
return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec],
|
||||
result_layouts=[layout_spec])(ctx, x)
|
||||
prim = core.Primitive("test_ffi")
|
||||
prim.def_impl(lambda x: x)
|
||||
prim.def_abstract_eval(lambda x: x)
|
||||
mlir.register_lowering(prim, lowering_rule)
|
||||
x = jnp.linspace(0, 1, 5)
|
||||
|
||||
x = jnp.ones((3,) * len(expected_layout))
|
||||
lowered = jax.jit(prim.bind).lower(x)
|
||||
module = lowered.compiler_ir("stablehlo")
|
||||
for func in module.body.operations:
|
||||
for block in func.body.blocks:
|
||||
for op in block.operations:
|
||||
if op.OPERATION_NAME == "stablehlo.custom_call":
|
||||
self.assertIn("operand_layouts", op.attributes)
|
||||
self.assertIn("result_layouts", op.attributes)
|
||||
op = self.find_custom_call_in_module(module)
|
||||
self.assertIn("operand_layouts", op.attributes)
|
||||
self.assertIn("result_layouts", op.attributes)
|
||||
|
||||
text = lowered.as_text()
|
||||
expected = ", ".join(map(str, expected_layout))
|
||||
pattern = rf"operand_layouts = \[dense<\[{expected}\]>"
|
||||
self.assertRegex(text, pattern)
|
||||
pattern = rf"result_layouts = \[dense<\[{expected}\]>"
|
||||
self.assertRegex(text, pattern)
|
||||
|
||||
@parameterized.parameters([
|
||||
(True, mlir.ir.BoolAttr.get),
|
||||
@ -140,19 +158,14 @@ class FfiTest(jtu.JaxTestCase):
|
||||
# Here we inspect the lowered IR to test that the parameter has been
|
||||
# serialized with the appropriate type.
|
||||
module = jax.jit(fun).lower(0.5).compiler_ir("stablehlo")
|
||||
for func in module.body.operations:
|
||||
for block in func.body.blocks:
|
||||
for op in block.operations:
|
||||
if op.OPERATION_NAME == "stablehlo.custom_call":
|
||||
config = op.attributes["mhlo.backend_config"]
|
||||
self.assertIsInstance(config, mlir.ir.DictAttr)
|
||||
self.assertIn("param", config)
|
||||
with mlir.make_ir_context(), mlir.ir.Location.unknown():
|
||||
expected = expected_builder(param)
|
||||
self.assertEqual(type(config["param"]), type(expected))
|
||||
self.assertTrue(expected.type.isinstance(config["param"].type))
|
||||
return
|
||||
self.fail("No custom_call found in the lowered IR")
|
||||
op = self.find_custom_call_in_module(module)
|
||||
config = op.attributes["mhlo.backend_config"]
|
||||
self.assertIsInstance(config, mlir.ir.DictAttr)
|
||||
self.assertIn("param", config)
|
||||
with mlir.make_ir_context(), mlir.ir.Location.unknown():
|
||||
expected = expected_builder(param)
|
||||
self.assertEqual(type(config["param"]), type(expected))
|
||||
self.assertTrue(expected.type.isinstance(config["param"].type))
|
||||
|
||||
def testToken(self):
|
||||
def fun():
|
||||
@ -161,14 +174,9 @@ class FfiTest(jtu.JaxTestCase):
|
||||
|
||||
# Ensure that token inputs and outputs are translated to the correct type
|
||||
module = jax.jit(fun).lower().compiler_ir("stablehlo")
|
||||
for func in module.body.operations:
|
||||
for block in func.body.blocks:
|
||||
for op in block.operations:
|
||||
if op.OPERATION_NAME == "stablehlo.custom_call":
|
||||
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
|
||||
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
|
||||
return
|
||||
self.fail("No custom_call found in the lowered IR")
|
||||
op = self.find_custom_call_in_module(module)
|
||||
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
|
||||
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(1,), (4,), (5,)],
|
||||
|
Loading…
x
Reference in New Issue
Block a user