Add CUDA custom call example as a JAX test

This commit is contained in:
Andrey Portnoy 2024-06-04 16:23:33 -04:00
parent a9edaeb38e
commit ec5c4f5a10
6 changed files with 467 additions and 5 deletions

View File

@ -0,0 +1,66 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
load("@rules_python//python:defs.bzl", "py_test")
load(
"//jaxlib:jax.bzl",
"cuda_library",
"jax_test",
)
licenses(["notice"])
package(
default_applicable_licenses = [],
default_visibility = ["//visibility:private"],
)
jax_test(
name = "cuda_custom_call_test",
srcs = ["cuda_custom_call_test.py"],
disable_backends = [
"cpu",
"tpu",
],
# libfoo.so is a runtime dependency for this test
data = [":foo"],
tags = [
"notap",
"manual",
],
deps = [
"//jax:extend",
]
)
# this second target is needed to properly link in CUDA runtime symbols
# such as cudaLaunchKernel, even though we are only building one library.
cc_shared_library(
name = "foo",
deps = [
":foo_",
"@xla//xla/tsl/cuda:cudart",
],
)
cuda_library(
name = "foo_",
srcs = ["foo.cu.cc"],
deps = [
"@xla//xla/ffi/api:ffi",
"@xla//xla/ffi/api:api",
"@xla//xla/ffi/api:c_api",
"@local_config_cuda//cuda:cuda_headers",
],
)

View File

@ -0,0 +1,35 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This Makefile is not used by Bazel for this test, it is intended to serve as
# documentation of build instructions for JAX users that are not using Bazel to
# build their custom call code. For that reason, this Makefile is likely subject
# to bitrot over time. Please file a JAX issue on GitHub if typing "make" in
# this directory no longer runs the test to completion.
NVCC = nvcc
NVCCFLAGS += -I$(shell python -c 'from jax.extend import ffi; print(ffi.include_dir())')
NVCCFLAGS += -arch native
# since the file extension is .cu.cc, tell NVCC explicitly to treat it as .cu
NVCCFLAGS += -x cu
# depends on libfoo.so being in the same directory as cuda_custom_call_test.py
check: libfoo.so
python cuda_custom_call_test.py
lib%.so: %.cu.cc
$(NVCC) $(NVCCFLAGS) --compiler-options=-shared,-fPIC -o $@ $<
clean:
rm -rf *.so

View File

@ -0,0 +1,216 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This test is intentionally structured to stay close to what a standalone JAX
# custom call integration might look like. JAX test harness is in a separate
# section towards the end of this file. The test can be run standalone by typing
# "make" in the directory containing this file.
import os
import ctypes
import unittest
import numpy as np
import jax
import jax.numpy as jnp
from jax.extend import ffi
from jax.lib import xla_client
from jax.interpreters import mlir
# start test boilerplate
from absl.testing import absltest
from jax._src import config
from jax._src import test_util as jtu
config.parse_flags_with_absl()
# end test boilerplate
# XLA needs uppercase, "cuda" isn't recognized
XLA_PLATFORM = "CUDA"
# JAX needs lowercase, "CUDA" isn't recognized
JAX_PLATFORM = "cuda"
# 0 = original ("opaque"), 1 = FFI
XLA_CUSTOM_CALL_API_VERSION = 1
# these strings are how we identify kernels to XLA:
# - first we register a pointer to the kernel with XLA under this name
# - then we "tell" JAX to emit StableHLO specifying this name to XLA
XLA_CUSTOM_CALL_TARGET_FWD = "foo-fwd"
XLA_CUSTOM_CALL_TARGET_BWD = "foo-bwd"
# independently, corresponding JAX primitives must also be named,
# names can be different from XLA targets, here they are the same
JAX_PRIMITIVE_FWD = "foo-fwd"
JAX_PRIMITIVE_BWD = "foo-bwd"
if jtu.is_running_under_pytest():
raise unittest.SkipTest("libfoo.so hasn't been built")
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "libfoo.so")
library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
#-----------------------------------------------------------------------------#
# Forward pass #
#-----------------------------------------------------------------------------#
# register the XLA FFI binding pointer with XLA
xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_FWD,
fn=ffi.pycapsule(library.FooFwd),
platform=XLA_PLATFORM,
api_version=XLA_CUSTOM_CALL_API_VERSION)
# our forward primitive will also return the intermediate output b+1
# so it can be reused in the backward pass computation
def _foo_fwd_abstract_eval(a, b):
assert a.shape == b.shape
assert a.dtype == b.dtype
shaped_array = jax.core.ShapedArray(a.shape, a.dtype)
return (
shaped_array, # output c
shaped_array, # intermediate output b+1
)
def _foo_fwd_lowering(ctx, a, b):
# ffi.ffi_lowering does most of the heavy lifting building a lowering.
# Keyword arguments passed to the lowering constructed by ffi_lowering are
# turned into custom call backend_config entries, which we take advantage of
# here for the dynamically computed n.
n = np.prod(a.type.shape).astype(np.uint64)
return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_FWD)(ctx, a, b, n=n)
# construct a new JAX primitive
foo_fwd_p = jax.core.Primitive(JAX_PRIMITIVE_FWD)
# register the abstract evaluation rule for the forward primitive
foo_fwd_p.def_abstract_eval(_foo_fwd_abstract_eval)
foo_fwd_p.multiple_results = True
mlir.register_lowering(foo_fwd_p, _foo_fwd_lowering, platform=JAX_PLATFORM)
#-----------------------------------------------------------------------------#
# Backward pass #
#-----------------------------------------------------------------------------#
# register the XLA FFI binding pointer with XLA
xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_BWD,
fn=ffi.pycapsule(library.FooBwd),
platform=XLA_PLATFORM,
api_version=XLA_CUSTOM_CALL_API_VERSION)
def _foo_bwd_abstract_eval(c_grad, a, b_plus_1):
assert c_grad.shape == a.shape
assert a.shape == b_plus_1.shape
assert c_grad.dtype == a.dtype
assert a.dtype == b_plus_1.dtype
shaped_array = jax.core.ShapedArray(a.shape, a.dtype)
return (
shaped_array, # a_grad
shaped_array, # b_grad
)
def _foo_bwd_lowering(ctx, c_grad, a, b_plus_1):
n = np.prod(a.type.shape).astype(np.uint64)
return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_BWD)(ctx,
c_grad,
a,
b_plus_1,
n=n)
# construct a new JAX primitive
foo_bwd_p = jax.core.Primitive(JAX_PRIMITIVE_BWD)
# register the abstract evaluation rule for the backward primitive
foo_bwd_p.def_abstract_eval(_foo_bwd_abstract_eval)
foo_bwd_p.multiple_results = True
mlir.register_lowering(foo_bwd_p, _foo_bwd_lowering, platform=JAX_PLATFORM)
#-----------------------------------------------------------------------------#
# User facing API #
#-----------------------------------------------------------------------------#
def foo_fwd(a, b):
c, b_plus_1 = foo_fwd_p.bind(a, b)
return c, (a, b_plus_1)
def foo_bwd(res, c_grad):
a, b_plus_1 = res
return foo_bwd_p.bind(c_grad, a, b_plus_1)
@jax.custom_vjp
def foo(a, b):
c, _ = foo_fwd(a, b)
return c
foo.defvjp(foo_fwd, foo_bwd)
#-----------------------------------------------------------------------------#
# Test #
#-----------------------------------------------------------------------------#
class CustomCallTest(jtu.JaxTestCase):
def test_fwd_interpretable(self):
shape = (2, 3)
a = 2. * jnp.ones(shape)
b = 3. * jnp.ones(shape)
observed = jax.jit(foo)(a, b)
expected = (2. * (3. + 1.))
self.assertArraysEqual(observed, expected)
def test_bwd_interpretable(self):
shape = (2, 3)
a = 2. * jnp.ones(shape)
b = 3. * jnp.ones(shape)
def loss(a, b):
return jnp.sum(foo(a, b))
da_observed, db_observed = jax.jit(jax.grad(loss, argnums=(0, 1)))(a, b)
da_expected = b + 1
db_expected = a
self.assertArraysEqual(da_observed, da_expected)
self.assertArraysEqual(db_observed, db_expected)
def test_fwd_random(self):
shape = (2, 3)
akey, bkey = jax.random.split(jax.random.key(0))
a = jax.random.normal(key=akey, shape=shape)
b = jax.random.normal(key=bkey, shape=shape)
observed = jax.jit(foo)(a, b)
expected = a * (b + 1)
self.assertAllClose(observed, expected)
def test_bwd_random(self):
shape = (2, 3)
akey, bkey = jax.random.split(jax.random.key(0))
a = jax.random.normal(key=akey, shape=shape)
b = jax.random.normal(key=bkey, shape=shape)
jtu.check_grads(f=jax.jit(foo), args=(a, b), order=1, modes=("rev",))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -0,0 +1,137 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/ffi/api/api.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
namespace ffi = xla::ffi;
//----------------------------------------------------------------------------//
// Forward pass //
//----------------------------------------------------------------------------//
// c = a * (b+1)
// This strawman operation works well for demo purposes because:
// 1. it's simple enough to be quickly understood,
// 2. it's complex enough to require intermediate outputs in grad computation,
// like many operations in practice do, and
// 3. it does not have a built-in implementation in JAX.
__global__ void FooFwdKernel(const float *a, const float *b, float *c,
float *b_plus_1, // intermediate output b+1
size_t n) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t grid_stride = blockDim.x * gridDim.x;
for (size_t i = tid; i < n; i += grid_stride) {
b_plus_1[i] = b[i] + 1.0f;
c[i] = a[i] * b_plus_1[i];
}
}
// Host function wrapper that launches the kernel with hardcoded grid/block
// size. Note, it uses types from XLA FFI. The return type must be ffi::Error.
// Buffer type provides buffer dimensions, so the "n" argument here is not
// strictly necessary, but it allows us to demonstrate the use of attributes
// (.Attr in the FFI handler definition above).
ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer<ffi::DataType::F32> a,
ffi::Buffer<ffi::DataType::F32> b,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> c,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_plus_1,
size_t n) {
const int block_dim = 128;
const int grid_dim = 1;
// Note how we access regular Buffer data vs Result Buffer data:
FooFwdKernel<<<grid_dim, block_dim, /*shared_mem=*/0, stream>>>(
a.data, b.data, c->data, b_plus_1->data, n);
// Check for launch time errors. Note that this function may also
// return error codes from previous, asynchronous launches. This
// means that an error status returned here could have been caused
// by a different kernel previously launched by XLA.
cudaError_t last_error = cudaGetLastError();
if (last_error != cudaSuccess) {
return ffi::Error(
XLA_FFI_Error_Code_INTERNAL,
std::string("CUDA error: ") + cudaGetErrorString(last_error));
}
return ffi::Error::Success();
}
// Creates symbol FooFwd with C linkage that can be loaded using Python ctypes
XLA_FFI_DEFINE_HANDLER_SYMBOL(
FooFwd, FooFwdHost,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<cudaStream_t>>() // stream
.Arg<ffi::Buffer<ffi::DataType::F32>>() // a
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b
.Ret<ffi::Buffer<ffi::DataType::F32>>() // c
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
.Attr<size_t>("n"));
//----------------------------------------------------------------------------//
// Backward pass //
//----------------------------------------------------------------------------//
// compute da = dc * (b+1), and
// db = dc * a
__global__ void FooBwdKernel(const float *c_grad, // incoming gradient wrt c
const float *a, // original input a
const float *b_plus_1, // intermediate output b+1
float *a_grad, // outgoing gradient wrt a
float *b_grad, // outgoing gradient wrt b
size_t n) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t grid_stride = blockDim.x * gridDim.x;
for (size_t i = tid; i < n; i += grid_stride) {
// In practice on GPUs b_plus_1 can be recomputed for practically free
// instead of storing it out and reusing, so the reuse here is a bit
// contrived. We do it to demonstrate residual/intermediate output passing
// between the forward and the backward pass which becomes useful when
// recomputation is more expensive than reuse.
a_grad[i] = c_grad[i] * b_plus_1[i];
b_grad[i] = c_grad[i] * a[i];
}
}
ffi::Error FooBwdHost(cudaStream_t stream,
ffi::Buffer<ffi::DataType::F32> c_grad,
ffi::Buffer<ffi::DataType::F32> a,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_plus_1,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> a_grad,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_grad,
size_t n) {
const int block_dim = 128;
const int grid_dim = 1;
FooBwdKernel<<<grid_dim, block_dim, /*shared_mem=*/0, stream>>>(
c_grad.data, a.data, b_plus_1->data, a_grad->data, b_grad->data, n);
cudaError_t last_error = cudaGetLastError();
if (last_error != cudaSuccess) {
return ffi::Error(
XLA_FFI_Error_Code_INTERNAL,
std::string("CUDA error: ") + cudaGetErrorString(last_error));
}
return ffi::Error::Success();
}
// Creates symbol FooBwd with C linkage that can be loaded using Python ctypes
XLA_FFI_DEFINE_HANDLER_SYMBOL(
FooBwd, FooBwdHost,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<cudaStream_t>>() // stream
.Arg<ffi::Buffer<ffi::DataType::F32>>() // c_grad
.Arg<ffi::Buffer<ffi::DataType::F32>>() // a
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
.Ret<ffi::Buffer<ffi::DataType::F32>>() // a_grad
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_grad
.Attr<size_t>("n"));

View File

@ -557,6 +557,18 @@ def pytest_mark_if_available(marker: str):
return wrap
def is_running_under_pytest():
return "pytest" in sys.modules
def skip_under_pytest(reason: str):
"""A decorator for test methods to skip the test when run under pytest."""
reason = "Running under pytest: " + reason
def skip(test_method):
return unittest.skipIf(is_running_under_pytest(), reason)(test_method)
return skip
def format_test_name_suffix(opname, shapes, dtypes):
arg_descriptions = (format_shape_dtype_string(shape, dtype)
for shape, dtype in zip(shapes, dtypes))

View File

@ -13,7 +13,6 @@
# limitations under the License.
import os
import sys
import unittest
from absl.testing import absltest
@ -27,10 +26,7 @@ config.parse_flags_with_absl()
class GpuMemoryAllocationTest(absltest.TestCase):
# This test must be run in its own subprocess.
@unittest.skipIf(
"pytest" in sys.modules,
"Test must run in an isolated process",
)
@jtu.skip_under_pytest("Test must run in an isolated process")
@unittest.skipIf(
"XLA_PYTHON_CLIENT_ALLOCATOR" in os.environ,
"Test does not work if the python client allocator has been overriden",