mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Package XLA FFI headers with jaxlib wheel
The new "typed" API that XLA provides for foreign function calls is header-only and packaging it as part of jaxlib could simplify the open source workflow for building custom calls. It's not completely obvious that we need to include this, because jaxlib isn't strictly required as a _build_ dependency for FFI calls, although it typically will be required as a _run time_ dependency. Also, it probably wouldn't be too painful for external projects to use the headers directly from the openxla/xla repo. All that being said, I wanted to figure out how to do this, and it has been requested a few times.
This commit is contained in:
parent
47420a3825
commit
88790711e8
@ -228,6 +228,7 @@ py_library_providing_imports_info(
|
||||
"_src/dispatch.py",
|
||||
"_src/dlpack.py",
|
||||
"_src/earray.py",
|
||||
"_src/ffi.py",
|
||||
"_src/flatten_util.py",
|
||||
"_src/interpreters/__init__.py",
|
||||
"_src/interpreters/ad.py",
|
||||
|
@ -158,6 +158,7 @@ from jax import debug as debug
|
||||
from jax import dlpack as dlpack
|
||||
from jax import dtypes as dtypes
|
||||
from jax import errors as errors
|
||||
from jax import ffi as ffi
|
||||
from jax import image as image
|
||||
from jax import lax as lax
|
||||
from jax import monitoring as monitoring
|
||||
|
25
jax/_src/ffi.py
Normal file
25
jax/_src/ffi.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
from jax._src.lib import jaxlib
|
||||
|
||||
|
||||
def include_dir() -> str:
|
||||
"""Get the path to the directory containing header files bundled with jaxlib"""
|
||||
jaxlib_dir = os.path.dirname(os.path.abspath(jaxlib.__file__))
|
||||
return os.path.join(jaxlib_dir, "include")
|
15
jax/ffi.py
Normal file
15
jax/ffi.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.ffi import include_dir as include_dir
|
@ -118,6 +118,7 @@ setup(
|
||||
'triton/*.pyi',
|
||||
'triton/*.pyd',
|
||||
'triton/*.so',
|
||||
'include/xla/ffi/api/*.h',
|
||||
],
|
||||
'jaxlib.xla_extension': ['*.pyi'],
|
||||
},
|
||||
|
@ -32,6 +32,9 @@ py_binary(
|
||||
"//jaxlib:setup.py",
|
||||
"@xla//xla/python:xla_client.py",
|
||||
"@xla//xla/python:xla_extension",
|
||||
"@xla//xla/ffi/api:c_api.h",
|
||||
"@xla//xla/ffi/api:api.h",
|
||||
"@xla//xla/ffi/api:ffi.h",
|
||||
] + if_windows([
|
||||
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
|
||||
]) + if_cuda([
|
||||
|
@ -407,6 +407,14 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
|
||||
"__main__/jaxlib/triton/_triton_ops_gen.py", dst_dir=triton_dir
|
||||
)
|
||||
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "include" / "xla" / "ffi" / "api",
|
||||
src_files=[
|
||||
"xla/xla/ffi/api/c_api.h",
|
||||
"xla/xla/ffi/api/api.h",
|
||||
"xla/xla/ffi/api/ffi.h",
|
||||
],
|
||||
)
|
||||
|
||||
tmpdir = None
|
||||
sources_path = args.sources_path
|
||||
|
@ -126,6 +126,15 @@ jax_test(
|
||||
deps = ["//jax:extend"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ffi_test",
|
||||
srcs = ["ffi_test.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "fft_test",
|
||||
srcs = ["fft_test.py"],
|
||||
|
32
tests/ffi_test.py
Normal file
32
tests/ffi_test.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from jax import ffi
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
|
||||
class IncludeDirTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 265, "Requires jaxlib 0.4.29")
|
||||
def testHeadersExist(self):
|
||||
base_dir = os.path.join(ffi.include_dir(), "xla", "ffi", "api")
|
||||
for header in ["c_api.h", "api.h", "ffi.h"]:
|
||||
print(os.path.join(base_dir, header))
|
||||
self.assertTrue(os.path.exists(os.path.join(base_dir, header)))
|
Loading…
x
Reference in New Issue
Block a user