From 88790711e847d2a3363e9c418712e356b9a2a138 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 21 May 2024 10:22:13 -0400 Subject: [PATCH] Package XLA FFI headers with jaxlib wheel The new "typed" API that XLA provides for foreign function calls is header-only and packaging it as part of jaxlib could simplify the open source workflow for building custom calls. It's not completely obvious that we need to include this, because jaxlib isn't strictly required as a _build_ dependency for FFI calls, although it typically will be required as a _run time_ dependency. Also, it probably wouldn't be too painful for external projects to use the headers directly from the openxla/xla repo. All that being said, I wanted to figure out how to do this, and it has been requested a few times. --- jax/BUILD | 1 + jax/__init__.py | 1 + jax/_src/ffi.py | 25 +++++++++++++++++++++++++ jax/ffi.py | 15 +++++++++++++++ jaxlib/setup.py | 1 + jaxlib/tools/BUILD.bazel | 3 +++ jaxlib/tools/build_wheel.py | 8 ++++++++ tests/BUILD | 9 +++++++++ tests/ffi_test.py | 32 ++++++++++++++++++++++++++++++++ 9 files changed, 95 insertions(+) create mode 100644 jax/_src/ffi.py create mode 100644 jax/ffi.py create mode 100644 tests/ffi_test.py diff --git a/jax/BUILD b/jax/BUILD index 25d7c073f..343ab07e6 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -228,6 +228,7 @@ py_library_providing_imports_info( "_src/dispatch.py", "_src/dlpack.py", "_src/earray.py", + "_src/ffi.py", "_src/flatten_util.py", "_src/interpreters/__init__.py", "_src/interpreters/ad.py", diff --git a/jax/__init__.py b/jax/__init__.py index f3ee0edd7..d7b4479e2 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -158,6 +158,7 @@ from jax import debug as debug from jax import dlpack as dlpack from jax import dtypes as dtypes from jax import errors as errors +from jax import ffi as ffi from jax import image as image from jax import lax as lax from jax import monitoring as monitoring diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py new file mode 100644 index 000000000..1b394802a --- /dev/null +++ b/jax/_src/ffi.py @@ -0,0 +1,25 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os + +from jax._src.lib import jaxlib + + +def include_dir() -> str: + """Get the path to the directory containing header files bundled with jaxlib""" + jaxlib_dir = os.path.dirname(os.path.abspath(jaxlib.__file__)) + return os.path.join(jaxlib_dir, "include") diff --git a/jax/ffi.py b/jax/ffi.py new file mode 100644 index 000000000..ddbac4fa3 --- /dev/null +++ b/jax/ffi.py @@ -0,0 +1,15 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from jax._src.ffi import include_dir as include_dir diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 58a4d43c5..c6131b7ea 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -118,6 +118,7 @@ setup( 'triton/*.pyi', 'triton/*.pyd', 'triton/*.so', + 'include/xla/ffi/api/*.h', ], 'jaxlib.xla_extension': ['*.pyi'], }, diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index b5d087388..832f53249 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -32,6 +32,9 @@ py_binary( "//jaxlib:setup.py", "@xla//xla/python:xla_client.py", "@xla//xla/python:xla_extension", + "@xla//xla/ffi/api:c_api.h", + "@xla//xla/ffi/api:api.h", + "@xla//xla/ffi/api:ffi.h", ] + if_windows([ "//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll", ]) + if_cuda([ diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 2e0ca1353..7c9bfa120 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -407,6 +407,14 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): "__main__/jaxlib/triton/_triton_ops_gen.py", dst_dir=triton_dir ) + copy_runfiles( + dst_dir=jaxlib_dir / "include" / "xla" / "ffi" / "api", + src_files=[ + "xla/xla/ffi/api/c_api.h", + "xla/xla/ffi/api/api.h", + "xla/xla/ffi/api/ffi.h", + ], + ) tmpdir = None sources_path = args.sources_path diff --git a/tests/BUILD b/tests/BUILD index 1a3b85afb..54c77cd54 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -126,6 +126,15 @@ jax_test( deps = ["//jax:extend"], ) +py_test( + name = "ffi_test", + srcs = ["ffi_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ], +) + jax_test( name = "fft_test", srcs = ["fft_test.py"], diff --git a/tests/ffi_test.py b/tests/ffi_test.py new file mode 100644 index 000000000..e0e652131 --- /dev/null +++ b/tests/ffi_test.py @@ -0,0 +1,32 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +import unittest + +from jax import ffi +from jax._src import test_util as jtu +from jax._src.lib import xla_extension_version + + +class IncludeDirTest(jtu.JaxTestCase): + + @unittest.skipIf(xla_extension_version < 265, "Requires jaxlib 0.4.29") + def testHeadersExist(self): + base_dir = os.path.join(ffi.include_dir(), "xla", "ffi", "api") + for header in ["c_api.h", "api.h", "ffi.h"]: + print(os.path.join(base_dir, header)) + self.assertTrue(os.path.exists(os.path.join(base_dir, header)))