Package XLA FFI headers with jaxlib wheel

The new "typed" API that XLA provides for foreign function calls is
header-only and packaging it as part of jaxlib could simplify the open
source workflow for building custom calls.

It's not completely obvious that we need to include this, because jaxlib
isn't strictly required as a _build_ dependency for FFI calls, although
it typically will be required as a _run time_ dependency. Also, it
probably wouldn't be too painful for external projects to use the
headers directly from the openxla/xla repo.

All that being said, I wanted to figure out how to do this, and it has
been requested a few times.
This commit is contained in:
Dan Foreman-Mackey 2024-05-21 10:22:13 -04:00
parent 47420a3825
commit 88790711e8
9 changed files with 95 additions and 0 deletions

View File

@ -228,6 +228,7 @@ py_library_providing_imports_info(
"_src/dispatch.py",
"_src/dlpack.py",
"_src/earray.py",
"_src/ffi.py",
"_src/flatten_util.py",
"_src/interpreters/__init__.py",
"_src/interpreters/ad.py",

View File

@ -158,6 +158,7 @@ from jax import debug as debug
from jax import dlpack as dlpack
from jax import dtypes as dtypes
from jax import errors as errors
from jax import ffi as ffi
from jax import image as image
from jax import lax as lax
from jax import monitoring as monitoring

25
jax/_src/ffi.py Normal file
View File

@ -0,0 +1,25 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
from jax._src.lib import jaxlib
def include_dir() -> str:
"""Get the path to the directory containing header files bundled with jaxlib"""
jaxlib_dir = os.path.dirname(os.path.abspath(jaxlib.__file__))
return os.path.join(jaxlib_dir, "include")

15
jax/ffi.py Normal file
View File

@ -0,0 +1,15 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.ffi import include_dir as include_dir

View File

@ -118,6 +118,7 @@ setup(
'triton/*.pyi',
'triton/*.pyd',
'triton/*.so',
'include/xla/ffi/api/*.h',
],
'jaxlib.xla_extension': ['*.pyi'],
},

View File

@ -32,6 +32,9 @@ py_binary(
"//jaxlib:setup.py",
"@xla//xla/python:xla_client.py",
"@xla//xla/python:xla_extension",
"@xla//xla/ffi/api:c_api.h",
"@xla//xla/ffi/api:api.h",
"@xla//xla/ffi/api:ffi.h",
] + if_windows([
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
]) + if_cuda([

View File

@ -407,6 +407,14 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
"__main__/jaxlib/triton/_triton_ops_gen.py", dst_dir=triton_dir
)
copy_runfiles(
dst_dir=jaxlib_dir / "include" / "xla" / "ffi" / "api",
src_files=[
"xla/xla/ffi/api/c_api.h",
"xla/xla/ffi/api/api.h",
"xla/xla/ffi/api/ffi.h",
],
)
tmpdir = None
sources_path = args.sources_path

View File

@ -126,6 +126,15 @@ jax_test(
deps = ["//jax:extend"],
)
py_test(
name = "ffi_test",
srcs = ["ffi_test.py"],
deps = [
"//jax",
"//jax:test_util",
],
)
jax_test(
name = "fft_test",
srcs = ["fft_test.py"],

32
tests/ffi_test.py Normal file
View File

@ -0,0 +1,32 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import unittest
from jax import ffi
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
class IncludeDirTest(jtu.JaxTestCase):
@unittest.skipIf(xla_extension_version < 265, "Requires jaxlib 0.4.29")
def testHeadersExist(self):
base_dir = os.path.join(ffi.include_dir(), "xla", "ffi", "api")
for header in ["c_api.h", "api.h", "ffi.h"]:
print(os.path.join(base_dir, header))
self.assertTrue(os.path.exists(os.path.join(base_dir, header)))