Move the CUDA end-to-end example to FFI examples workflow + hosted

runner.
This commit is contained in:
Dan Foreman-Mackey 2024-10-08 10:03:58 -04:00
parent 8abedda8a6
commit ce8dba98fb
10 changed files with 178 additions and 250 deletions

View File

@ -61,7 +61,7 @@ jobs:
- name: Image Setup
run: |
apt update
apt install -y libssl-dev
apt install -y libssl-dev
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
@ -217,14 +217,16 @@ jobs:
ffi:
name: FFI example
runs-on: ubuntu-latest
timeout-minutes: 5
runs-on: linux-x86-g2-16-l4-1gpu
container:
image: index.docker.io/tensorflow/build:latest-python3.12@sha256:48e99608fe9434ada5b14e19fdfd8e64f4cfc83aacd328b9c2101b210e984295 # ratchet:index.docker.io/tensorflow/build:latest-python3.12
timeout-minutes: 30
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python 3.11
- name: Set up Python
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: 3.11
python-version: 3.12
- name: Get pip cache dir
id: pip-cache
run: |
@ -236,7 +238,7 @@ jobs:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}
- name: Install JAX
run: pip install .
run: pip install .[cuda12]
- name: Build and install example project
run: python -m pip install -v ./examples/ffi[test]
env:
@ -245,6 +247,10 @@ jobs:
# a different toolchain. GCC is the default compiler on the
# 'ubuntu-latest' runner, but we still set this explicitly just to be
# clear.
CMAKE_ARGS: -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++
- name: Run tests
CMAKE_ARGS: -DCMAKE_CXX_COMPILER=g++ -DJAX_FFI_EXAMPLE_ENABLE_CUDA=ON
- name: Run CPU tests
run: python -m pytest examples/ffi/tests
env:
JAX_PLATFORM_NAME: cpu
- name: Run GPU tests
run: python -m pytest examples/ffi/tests

View File

@ -1,60 +0,0 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load(
"//jaxlib:jax.bzl",
"cuda_library",
"jax_generate_backend_suites",
"jax_multiplatform_test",
)
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//visibility:private"],
)
jax_generate_backend_suites()
jax_multiplatform_test(
name = "cuda_custom_call_test",
srcs = ["cuda_custom_call_test.py"],
data = [":foo"],
enable_backends = ["gpu"],
tags = ["notap"],
deps = [
"//jax:extend",
],
)
# this second target is needed to properly link in CUDA runtime symbols
# such as cudaLaunchKernel, even though we are only building one library.
cc_shared_library(
name = "foo",
deps = [
":foo_",
"@xla//xla/tsl/cuda:cudart",
],
)
cuda_library(
name = "foo_",
srcs = ["foo.cu.cc"],
deps = [
"@local_config_cuda//cuda:cuda_headers",
"@xla//xla/ffi/api:c_api",
"@xla//xla/ffi/api:ffi",
],
)

View File

@ -1,35 +0,0 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This Makefile is not used by Bazel for this test, it is intended to serve as
# documentation of build instructions for JAX users that are not using Bazel to
# build their custom call code. For that reason, this Makefile is likely subject
# to bitrot over time. Please file a JAX issue on GitHub if typing "make" in
# this directory no longer runs the test to completion.
NVCC = nvcc
NVCCFLAGS += -I$(shell python -c 'from jax.extend import ffi; print(ffi.include_dir())')
NVCCFLAGS += -arch native
# since the file extension is .cu.cc, tell NVCC explicitly to treat it as .cu
NVCCFLAGS += -x cu
# depends on libfoo.so being in the same directory as cuda_custom_call_test.py
check: libfoo.so
python cuda_custom_call_test.py
lib%.so: %.cu.cc
$(NVCC) $(NVCCFLAGS) --compiler-options=-shared,-fPIC -o $@ $<
clean:
rm -rf *.so

View File

@ -1,147 +0,0 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This test is intentionally structured to stay close to what a standalone JAX
# custom call integration might look like. JAX test harness is in a separate
# section towards the end of this file. The test can be run standalone by typing
# "make" in the directory containing this file.
import os
import ctypes
import unittest
import numpy as np
import jax
import jax.numpy as jnp
from jax.extend import ffi
# start test boilerplate
from absl.testing import absltest
from jax._src import config
from jax._src import test_util as jtu
config.parse_flags_with_absl()
# end test boilerplate
# XLA needs uppercase, "cuda" isn't recognized
XLA_PLATFORM = "CUDA"
# JAX needs lowercase, "CUDA" isn't recognized
JAX_PLATFORM = "cuda"
# 0 = original ("opaque"), 1 = FFI
XLA_CUSTOM_CALL_API_VERSION = 1
# these strings are how we identify kernels to XLA:
# - first we register a pointer to the kernel with XLA under this name
# - then we "tell" JAX to emit StableHLO specifying this name to XLA
XLA_CUSTOM_CALL_TARGET_FWD = "foo-fwd"
XLA_CUSTOM_CALL_TARGET_BWD = "foo-bwd"
# load the shared library with the FFI target definitions
if jtu.is_running_under_pytest():
raise unittest.SkipTest("libfoo.so hasn't been built")
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "libfoo.so")
library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
# register the custom calls targets with XLA, api_version=1 by default
ffi.register_ffi_target(name=XLA_CUSTOM_CALL_TARGET_FWD,
fn=ffi.pycapsule(library.FooFwd),
platform=XLA_PLATFORM)
ffi.register_ffi_target(name=XLA_CUSTOM_CALL_TARGET_BWD,
fn=ffi.pycapsule(library.FooBwd),
platform=XLA_PLATFORM)
def foo_fwd(a, b):
assert a.dtype == jnp.float32
assert a.shape == b.shape
assert a.dtype == b.dtype
n = np.prod(a.shape).astype(np.uint64)
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
c, b_plus_1 = ffi.ffi_call(XLA_CUSTOM_CALL_TARGET_FWD, (out_type, out_type),
a, b, n=n)
return c, (a, b_plus_1)
def foo_bwd(res, c_grad):
a, b_plus_1 = res
assert c_grad.dtype == jnp.float32
assert c_grad.shape == a.shape
assert a.shape == b_plus_1.shape
assert c_grad.dtype == a.dtype
assert a.dtype == b_plus_1.dtype
n = np.prod(a.shape).astype(np.uint64)
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
return ffi.ffi_call(XLA_CUSTOM_CALL_TARGET_BWD, (out_type, out_type),
c_grad, a, b_plus_1, n=n)
@jax.custom_vjp
def foo(a, b):
c, _ = foo_fwd(a, b)
return c
foo.defvjp(foo_fwd, foo_bwd)
#-----------------------------------------------------------------------------#
# Test #
#-----------------------------------------------------------------------------#
class CustomCallTest(jtu.JaxTestCase):
def test_fwd_interpretable(self):
shape = (2, 3)
a = 2. * jnp.ones(shape)
b = 3. * jnp.ones(shape)
observed = jax.jit(foo)(a, b)
expected = (2. * (3. + 1.))
self.assertArraysEqual(observed, expected)
def test_bwd_interpretable(self):
shape = (2, 3)
a = 2. * jnp.ones(shape)
b = 3. * jnp.ones(shape)
def loss(a, b):
return jnp.sum(foo(a, b))
da_observed, db_observed = jax.jit(jax.grad(loss, argnums=(0, 1)))(a, b)
da_expected = b + 1
db_expected = a
self.assertArraysEqual(da_observed, da_expected)
self.assertArraysEqual(db_observed, db_expected)
def test_fwd_random(self):
shape = (2, 3)
akey, bkey = jax.random.split(jax.random.key(0))
a = jax.random.normal(key=akey, shape=shape)
b = jax.random.normal(key=bkey, shape=shape)
observed = jax.jit(foo)(a, b)
expected = a * (b + 1)
self.assertAllClose(observed, expected)
def test_bwd_random(self):
shape = (2, 3)
akey, bkey = jax.random.split(jax.random.key(0))
a = jax.random.normal(key=akey, shape=shape)
b = jax.random.normal(key=bkey, shape=shape)
jtu.check_grads(f=jax.jit(foo), args=(a, b), order=1, modes=("rev",))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -1,6 +1,8 @@
cmake_minimum_required(VERSION 3.15...3.30)
project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX)
option(JAX_FFI_EXAMPLE_ENABLE_CUDA "Enable CUDA support" OFF)
find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module)
execute_process(
COMMAND "${Python_EXECUTABLE}"
@ -17,3 +19,12 @@ install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
nanobind_add_module(_attrs NB_STATIC "src/jax_ffi_example/attrs.cc")
target_include_directories(_attrs PUBLIC ${XLA_DIR})
install(TARGETS _attrs LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
if(JAX_FFI_EXAMPLE_ENABLE_CUDA)
enable_language(CUDA)
add_library(_cuda_e2e SHARED "src/jax_ffi_example/cuda_e2e.cu")
set_target_properties(_cuda_e2e PROPERTIES POSITION_INDEPENDENT_CODE ON
CUDA_STANDARD 17)
target_include_directories(_cuda_e2e PUBLIC ${XLA_DIR})
install(TARGETS _cuda_e2e LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
endif()

View File

@ -0,0 +1,68 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An end-to-end example demonstrating the use of the JAX FFI with CUDA.
The specifics of the kernels are not very important, but the general structure,
and packaging of the extension are useful for testing.
"""
import os
import ctypes
import numpy as np
import jax
import jax.numpy as jnp
import jax.extend as jex
# Load the shared library with the FFI target definitions
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_e2e.so")
library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
jex.ffi.register_ffi_target("foo-fwd", jex.ffi.pycapsule(library.FooFwd),
platform="CUDA")
jex.ffi.register_ffi_target("foo-bwd", jex.ffi.pycapsule(library.FooBwd),
platform="CUDA")
def foo_fwd(a, b):
assert a.dtype == jnp.float32
assert a.shape == b.shape
assert a.dtype == b.dtype
n = np.prod(a.shape).astype(np.uint64)
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
c, b_plus_1 = jex.ffi.ffi_call("foo-fwd", (out_type, out_type))(a, b, n=n)
return c, (a, b_plus_1)
def foo_bwd(res, c_grad):
a, b_plus_1 = res
assert c_grad.dtype == jnp.float32
assert c_grad.shape == a.shape
assert a.shape == b_plus_1.shape
assert c_grad.dtype == a.dtype
assert a.dtype == b_plus_1.dtype
n = np.prod(a.shape).astype(np.uint64)
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
return jex.ffi.ffi_call("foo-bwd", (out_type, out_type))(c_grad, a, b_plus_1,
n=n)
@jax.custom_vjp
def foo(a, b):
c, _ = foo_fwd(a, b)
return c
foo.defvjp(foo_fwd, foo_bwd)

View File

@ -24,6 +24,11 @@ jax.config.parse_flags_with_absl()
class AttrsTests(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cpu"]):
self.skipTest("Unsupported platform")
def test_array_attr(self):
self.assertEqual(attrs.array_attr(5), jnp.arange(5).sum())
self.assertEqual(attrs.array_attr(3), jnp.arange(3).sum())

View File

@ -0,0 +1,75 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
jax.config.parse_flags_with_absl()
class CudaE2eTests(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cuda"]):
self.skipTest("Unsupported platform")
# Import here to avoid trying to load the library when it's not built.
from jax_ffi_example import cuda_e2e
self.foo = cuda_e2e.foo
def test_fwd_interpretable(self):
shape = (2, 3)
a = 2. * jnp.ones(shape)
b = 3. * jnp.ones(shape)
observed = jax.jit(self.foo)(a, b)
expected = (2. * (3. + 1.))
self.assertArraysEqual(observed, expected)
def test_bwd_interpretable(self):
shape = (2, 3)
a = 2. * jnp.ones(shape)
b = 3. * jnp.ones(shape)
def loss(a, b):
return jnp.sum(self.foo(a, b))
da_observed, db_observed = jax.jit(jax.grad(loss, argnums=(0, 1)))(a, b)
da_expected = b + 1
db_expected = a
self.assertArraysEqual(da_observed, da_expected)
self.assertArraysEqual(db_observed, db_expected)
def test_fwd_random(self):
shape = (2, 3)
akey, bkey = jax.random.split(jax.random.key(0))
a = jax.random.normal(key=akey, shape=shape)
b = jax.random.normal(key=bkey, shape=shape)
observed = jax.jit(self.foo)(a, b)
expected = a * (b + 1)
self.assertAllClose(observed, expected)
def test_bwd_random(self):
shape = (2, 3)
akey, bkey = jax.random.split(jax.random.key(0))
a = jax.random.normal(key=akey, shape=shape)
b = jax.random.normal(key=bkey, shape=shape)
jtu.check_grads(f=jax.jit(self.foo), args=(a, b), order=1,
modes=("rev",))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -29,6 +29,11 @@ def rms_norm_ref(x, eps=1e-5):
class RmsNormTests(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cpu"]):
self.skipTest("Unsupported platform")
def test_basic(self):
x = jnp.linspace(-0.5, 0.5, 15)
self.assertAllClose(rms_norm.rms_norm(x), rms_norm_ref(x))