mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Move the CUDA end-to-end example to FFI examples workflow + hosted
runner.
This commit is contained in:
parent
8abedda8a6
commit
ce8dba98fb
22
.github/workflows/ci-build.yaml
vendored
22
.github/workflows/ci-build.yaml
vendored
@ -61,7 +61,7 @@ jobs:
|
||||
- name: Image Setup
|
||||
run: |
|
||||
apt update
|
||||
apt install -y libssl-dev
|
||||
apt install -y libssl-dev
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
@ -217,14 +217,16 @@ jobs:
|
||||
|
||||
ffi:
|
||||
name: FFI example
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
runs-on: linux-x86-g2-16-l4-1gpu
|
||||
container:
|
||||
image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
- name: Set up Python 3.11
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
|
||||
with:
|
||||
python-version: 3.11
|
||||
python-version: 3.12
|
||||
- name: Get pip cache dir
|
||||
id: pip-cache
|
||||
run: |
|
||||
@ -236,7 +238,7 @@ jobs:
|
||||
path: ${{ steps.pip-cache.outputs.dir }}
|
||||
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}
|
||||
- name: Install JAX
|
||||
run: pip install .
|
||||
run: pip install .[cuda12]
|
||||
- name: Build and install example project
|
||||
run: python -m pip install -v ./examples/ffi[test]
|
||||
env:
|
||||
@ -245,6 +247,10 @@ jobs:
|
||||
# a different toolchain. GCC is the default compiler on the
|
||||
# 'ubuntu-latest' runner, but we still set this explicitly just to be
|
||||
# clear.
|
||||
CMAKE_ARGS: -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++
|
||||
- name: Run tests
|
||||
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
|
||||
- name: Run CPU tests
|
||||
run: python -m pytest examples/ffi/tests
|
||||
env:
|
||||
JAX_PLATFORM_NAME: cpu
|
||||
- name: Run GPU tests
|
||||
run: python -m pytest examples/ffi/tests
|
||||
|
@ -1,60 +0,0 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"cuda_library",
|
||||
"jax_generate_backend_suites",
|
||||
"jax_multiplatform_test",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
jax_generate_backend_suites()
|
||||
|
||||
jax_multiplatform_test(
|
||||
name = "cuda_custom_call_test",
|
||||
srcs = ["cuda_custom_call_test.py"],
|
||||
data = [":foo"],
|
||||
enable_backends = ["gpu"],
|
||||
tags = ["notap"],
|
||||
deps = [
|
||||
"//jax:extend",
|
||||
],
|
||||
)
|
||||
|
||||
# this second target is needed to properly link in CUDA runtime symbols
|
||||
# such as cudaLaunchKernel, even though we are only building one library.
|
||||
cc_shared_library(
|
||||
name = "foo",
|
||||
deps = [
|
||||
":foo_",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_library(
|
||||
name = "foo_",
|
||||
srcs = ["foo.cu.cc"],
|
||||
deps = [
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
],
|
||||
)
|
@ -1,35 +0,0 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# This Makefile is not used by Bazel for this test, it is intended to serve as
|
||||
# documentation of build instructions for JAX users that are not using Bazel to
|
||||
# build their custom call code. For that reason, this Makefile is likely subject
|
||||
# to bitrot over time. Please file a JAX issue on GitHub if typing "make" in
|
||||
# this directory no longer runs the test to completion.
|
||||
NVCC = nvcc
|
||||
NVCCFLAGS += -I$(shell python -c 'from jax.extend import ffi; print(ffi.include_dir())')
|
||||
NVCCFLAGS += -arch native
|
||||
# since the file extension is .cu.cc, tell NVCC explicitly to treat it as .cu
|
||||
NVCCFLAGS += -x cu
|
||||
|
||||
# depends on libfoo.so being in the same directory as cuda_custom_call_test.py
|
||||
check: libfoo.so
|
||||
python cuda_custom_call_test.py
|
||||
|
||||
lib%.so: %.cu.cc
|
||||
$(NVCC) $(NVCCFLAGS) --compiler-options=-shared,-fPIC -o $@ $<
|
||||
|
||||
clean:
|
||||
rm -rf *.so
|
@ -1,147 +0,0 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# This test is intentionally structured to stay close to what a standalone JAX
|
||||
# custom call integration might look like. JAX test harness is in a separate
|
||||
# section towards the end of this file. The test can be run standalone by typing
|
||||
# "make" in the directory containing this file.
|
||||
|
||||
import os
|
||||
import ctypes
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.extend import ffi
|
||||
|
||||
# start test boilerplate
|
||||
from absl.testing import absltest
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
# end test boilerplate
|
||||
|
||||
# XLA needs uppercase, "cuda" isn't recognized
|
||||
XLA_PLATFORM = "CUDA"
|
||||
|
||||
# JAX needs lowercase, "CUDA" isn't recognized
|
||||
JAX_PLATFORM = "cuda"
|
||||
|
||||
# 0 = original ("opaque"), 1 = FFI
|
||||
XLA_CUSTOM_CALL_API_VERSION = 1
|
||||
|
||||
# these strings are how we identify kernels to XLA:
|
||||
# - first we register a pointer to the kernel with XLA under this name
|
||||
# - then we "tell" JAX to emit StableHLO specifying this name to XLA
|
||||
XLA_CUSTOM_CALL_TARGET_FWD = "foo-fwd"
|
||||
XLA_CUSTOM_CALL_TARGET_BWD = "foo-bwd"
|
||||
|
||||
# load the shared library with the FFI target definitions
|
||||
if jtu.is_running_under_pytest():
|
||||
raise unittest.SkipTest("libfoo.so hasn't been built")
|
||||
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "libfoo.so")
|
||||
library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
|
||||
|
||||
# register the custom calls targets with XLA, api_version=1 by default
|
||||
ffi.register_ffi_target(name=XLA_CUSTOM_CALL_TARGET_FWD,
|
||||
fn=ffi.pycapsule(library.FooFwd),
|
||||
platform=XLA_PLATFORM)
|
||||
ffi.register_ffi_target(name=XLA_CUSTOM_CALL_TARGET_BWD,
|
||||
fn=ffi.pycapsule(library.FooBwd),
|
||||
platform=XLA_PLATFORM)
|
||||
|
||||
def foo_fwd(a, b):
|
||||
assert a.dtype == jnp.float32
|
||||
assert a.shape == b.shape
|
||||
assert a.dtype == b.dtype
|
||||
n = np.prod(a.shape).astype(np.uint64)
|
||||
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
|
||||
c, b_plus_1 = ffi.ffi_call(XLA_CUSTOM_CALL_TARGET_FWD, (out_type, out_type),
|
||||
a, b, n=n)
|
||||
return c, (a, b_plus_1)
|
||||
|
||||
|
||||
def foo_bwd(res, c_grad):
|
||||
a, b_plus_1 = res
|
||||
assert c_grad.dtype == jnp.float32
|
||||
assert c_grad.shape == a.shape
|
||||
assert a.shape == b_plus_1.shape
|
||||
assert c_grad.dtype == a.dtype
|
||||
assert a.dtype == b_plus_1.dtype
|
||||
n = np.prod(a.shape).astype(np.uint64)
|
||||
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
|
||||
return ffi.ffi_call(XLA_CUSTOM_CALL_TARGET_BWD, (out_type, out_type),
|
||||
c_grad, a, b_plus_1, n=n)
|
||||
|
||||
|
||||
@jax.custom_vjp
|
||||
def foo(a, b):
|
||||
c, _ = foo_fwd(a, b)
|
||||
return c
|
||||
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
#-----------------------------------------------------------------------------#
|
||||
# Test #
|
||||
#-----------------------------------------------------------------------------#
|
||||
|
||||
|
||||
class CustomCallTest(jtu.JaxTestCase):
|
||||
|
||||
def test_fwd_interpretable(self):
|
||||
shape = (2, 3)
|
||||
a = 2. * jnp.ones(shape)
|
||||
b = 3. * jnp.ones(shape)
|
||||
observed = jax.jit(foo)(a, b)
|
||||
expected = (2. * (3. + 1.))
|
||||
self.assertArraysEqual(observed, expected)
|
||||
|
||||
def test_bwd_interpretable(self):
|
||||
shape = (2, 3)
|
||||
a = 2. * jnp.ones(shape)
|
||||
b = 3. * jnp.ones(shape)
|
||||
|
||||
def loss(a, b):
|
||||
return jnp.sum(foo(a, b))
|
||||
|
||||
da_observed, db_observed = jax.jit(jax.grad(loss, argnums=(0, 1)))(a, b)
|
||||
da_expected = b + 1
|
||||
db_expected = a
|
||||
self.assertArraysEqual(da_observed, da_expected)
|
||||
self.assertArraysEqual(db_observed, db_expected)
|
||||
|
||||
def test_fwd_random(self):
|
||||
shape = (2, 3)
|
||||
akey, bkey = jax.random.split(jax.random.key(0))
|
||||
a = jax.random.normal(key=akey, shape=shape)
|
||||
b = jax.random.normal(key=bkey, shape=shape)
|
||||
observed = jax.jit(foo)(a, b)
|
||||
expected = a * (b + 1)
|
||||
self.assertAllClose(observed, expected)
|
||||
|
||||
def test_bwd_random(self):
|
||||
shape = (2, 3)
|
||||
akey, bkey = jax.random.split(jax.random.key(0))
|
||||
a = jax.random.normal(key=akey, shape=shape)
|
||||
b = jax.random.normal(key=bkey, shape=shape)
|
||||
jtu.check_grads(f=jax.jit(foo), args=(a, b), order=1, modes=("rev",))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -1,6 +1,8 @@
|
||||
cmake_minimum_required(VERSION 3.15...3.30)
|
||||
project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX)
|
||||
|
||||
option(JAX_FFI_EXAMPLE_ENABLE_CUDA "Enable CUDA support" OFF)
|
||||
|
||||
find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}"
|
||||
@ -17,3 +19,12 @@ install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
|
||||
nanobind_add_module(_attrs NB_STATIC "src/jax_ffi_example/attrs.cc")
|
||||
target_include_directories(_attrs PUBLIC ${XLA_DIR})
|
||||
install(TARGETS _attrs LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
|
||||
|
||||
if(JAX_FFI_EXAMPLE_ENABLE_CUDA)
|
||||
enable_language(CUDA)
|
||||
add_library(_cuda_e2e SHARED "src/jax_ffi_example/cuda_e2e.cu")
|
||||
set_target_properties(_cuda_e2e PROPERTIES POSITION_INDEPENDENT_CODE ON
|
||||
CUDA_STANDARD 17)
|
||||
target_include_directories(_cuda_e2e PUBLIC ${XLA_DIR})
|
||||
install(TARGETS _cuda_e2e LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
|
||||
endif()
|
||||
|
68
examples/ffi/src/jax_ffi_example/cuda_e2e.py
Normal file
68
examples/ffi/src/jax_ffi_example/cuda_e2e.py
Normal file
@ -0,0 +1,68 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""An end-to-end example demonstrating the use of the JAX FFI with CUDA.
|
||||
|
||||
The specifics of the kernels are not very important, but the general structure,
|
||||
and packaging of the extension are useful for testing.
|
||||
"""
|
||||
|
||||
import os
|
||||
import ctypes
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.extend as jex
|
||||
|
||||
# Load the shared library with the FFI target definitions
|
||||
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_e2e.so")
|
||||
library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
|
||||
|
||||
jex.ffi.register_ffi_target("foo-fwd", jex.ffi.pycapsule(library.FooFwd),
|
||||
platform="CUDA")
|
||||
jex.ffi.register_ffi_target("foo-bwd", jex.ffi.pycapsule(library.FooBwd),
|
||||
platform="CUDA")
|
||||
|
||||
|
||||
def foo_fwd(a, b):
|
||||
assert a.dtype == jnp.float32
|
||||
assert a.shape == b.shape
|
||||
assert a.dtype == b.dtype
|
||||
n = np.prod(a.shape).astype(np.uint64)
|
||||
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
|
||||
c, b_plus_1 = jex.ffi.ffi_call("foo-fwd", (out_type, out_type))(a, b, n=n)
|
||||
return c, (a, b_plus_1)
|
||||
|
||||
|
||||
def foo_bwd(res, c_grad):
|
||||
a, b_plus_1 = res
|
||||
assert c_grad.dtype == jnp.float32
|
||||
assert c_grad.shape == a.shape
|
||||
assert a.shape == b_plus_1.shape
|
||||
assert c_grad.dtype == a.dtype
|
||||
assert a.dtype == b_plus_1.dtype
|
||||
n = np.prod(a.shape).astype(np.uint64)
|
||||
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
|
||||
return jex.ffi.ffi_call("foo-bwd", (out_type, out_type))(c_grad, a, b_plus_1,
|
||||
n=n)
|
||||
|
||||
|
||||
@jax.custom_vjp
|
||||
def foo(a, b):
|
||||
c, _ = foo_fwd(a, b)
|
||||
return c
|
||||
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
@ -24,6 +24,11 @@ jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class AttrsTests(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.test_device_matches(["cpu"]):
|
||||
self.skipTest("Unsupported platform")
|
||||
|
||||
def test_array_attr(self):
|
||||
self.assertEqual(attrs.array_attr(5), jnp.arange(5).sum())
|
||||
self.assertEqual(attrs.array_attr(3), jnp.arange(3).sum())
|
||||
|
75
examples/ffi/tests/cuda_e2e_test.py
Normal file
75
examples/ffi/tests/cuda_e2e_test.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class CudaE2eTests(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.test_device_matches(["cuda"]):
|
||||
self.skipTest("Unsupported platform")
|
||||
|
||||
# Import here to avoid trying to load the library when it's not built.
|
||||
from jax_ffi_example import cuda_e2e
|
||||
self.foo = cuda_e2e.foo
|
||||
|
||||
def test_fwd_interpretable(self):
|
||||
shape = (2, 3)
|
||||
a = 2. * jnp.ones(shape)
|
||||
b = 3. * jnp.ones(shape)
|
||||
observed = jax.jit(self.foo)(a, b)
|
||||
expected = (2. * (3. + 1.))
|
||||
self.assertArraysEqual(observed, expected)
|
||||
|
||||
def test_bwd_interpretable(self):
|
||||
shape = (2, 3)
|
||||
a = 2. * jnp.ones(shape)
|
||||
b = 3. * jnp.ones(shape)
|
||||
|
||||
def loss(a, b):
|
||||
return jnp.sum(self.foo(a, b))
|
||||
|
||||
da_observed, db_observed = jax.jit(jax.grad(loss, argnums=(0, 1)))(a, b)
|
||||
da_expected = b + 1
|
||||
db_expected = a
|
||||
self.assertArraysEqual(da_observed, da_expected)
|
||||
self.assertArraysEqual(db_observed, db_expected)
|
||||
|
||||
def test_fwd_random(self):
|
||||
shape = (2, 3)
|
||||
akey, bkey = jax.random.split(jax.random.key(0))
|
||||
a = jax.random.normal(key=akey, shape=shape)
|
||||
b = jax.random.normal(key=bkey, shape=shape)
|
||||
observed = jax.jit(self.foo)(a, b)
|
||||
expected = a * (b + 1)
|
||||
self.assertAllClose(observed, expected)
|
||||
|
||||
def test_bwd_random(self):
|
||||
shape = (2, 3)
|
||||
akey, bkey = jax.random.split(jax.random.key(0))
|
||||
a = jax.random.normal(key=akey, shape=shape)
|
||||
b = jax.random.normal(key=bkey, shape=shape)
|
||||
jtu.check_grads(f=jax.jit(self.foo), args=(a, b), order=1,
|
||||
modes=("rev",))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -29,6 +29,11 @@ def rms_norm_ref(x, eps=1e-5):
|
||||
|
||||
|
||||
class RmsNormTests(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.test_device_matches(["cpu"]):
|
||||
self.skipTest("Unsupported platform")
|
||||
|
||||
def test_basic(self):
|
||||
x = jnp.linspace(-0.5, 0.5, 15)
|
||||
self.assertAllClose(rms_norm.rms_norm(x), rms_norm_ref(x))
|
||||
|
Loading…
x
Reference in New Issue
Block a user