mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add CUDA custom call example as a JAX test
This commit is contained in:
parent
a9edaeb38e
commit
ec5c4f5a10
66
docs/cuda_custom_call/BUILD
Normal file
66
docs/cuda_custom_call/BUILD
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
load("@rules_python//python:defs.bzl", "py_test")
|
||||
load(
|
||||
"//jaxlib:jax.bzl",
|
||||
"cuda_library",
|
||||
"jax_test",
|
||||
)
|
||||
|
||||
licenses(["notice"])
|
||||
|
||||
package(
|
||||
default_applicable_licenses = [],
|
||||
default_visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "cuda_custom_call_test",
|
||||
srcs = ["cuda_custom_call_test.py"],
|
||||
disable_backends = [
|
||||
"cpu",
|
||||
"tpu",
|
||||
],
|
||||
# libfoo.so is a runtime dependency for this test
|
||||
data = [":foo"],
|
||||
tags = [
|
||||
"notap",
|
||||
"manual",
|
||||
],
|
||||
deps = [
|
||||
"//jax:extend",
|
||||
]
|
||||
)
|
||||
|
||||
# this second target is needed to properly link in CUDA runtime symbols
|
||||
# such as cudaLaunchKernel, even though we are only building one library.
|
||||
cc_shared_library(
|
||||
name = "foo",
|
||||
deps = [
|
||||
":foo_",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_library(
|
||||
name = "foo_",
|
||||
srcs = ["foo.cu.cc"],
|
||||
deps = [
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/ffi/api:api",
|
||||
"@xla//xla/ffi/api:c_api",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
35
docs/cuda_custom_call/Makefile
Normal file
35
docs/cuda_custom_call/Makefile
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# This Makefile is not used by Bazel for this test, it is intended to serve as
|
||||
# documentation of build instructions for JAX users that are not using Bazel to
|
||||
# build their custom call code. For that reason, this Makefile is likely subject
|
||||
# to bitrot over time. Please file a JAX issue on GitHub if typing "make" in
|
||||
# this directory no longer runs the test to completion.
|
||||
NVCC = nvcc
|
||||
NVCCFLAGS += -I$(shell python -c 'from jax.extend import ffi; print(ffi.include_dir())')
|
||||
NVCCFLAGS += -arch native
|
||||
# since the file extension is .cu.cc, tell NVCC explicitly to treat it as .cu
|
||||
NVCCFLAGS += -x cu
|
||||
|
||||
# depends on libfoo.so being in the same directory as cuda_custom_call_test.py
|
||||
check: libfoo.so
|
||||
python cuda_custom_call_test.py
|
||||
|
||||
lib%.so: %.cu.cc
|
||||
$(NVCC) $(NVCCFLAGS) --compiler-options=-shared,-fPIC -o $@ $<
|
||||
|
||||
clean:
|
||||
rm -rf *.so
|
216
docs/cuda_custom_call/cuda_custom_call_test.py
Normal file
216
docs/cuda_custom_call/cuda_custom_call_test.py
Normal file
@ -0,0 +1,216 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
# This test is intentionally structured to stay close to what a standalone JAX
|
||||
# custom call integration might look like. JAX test harness is in a separate
|
||||
# section towards the end of this file. The test can be run standalone by typing
|
||||
# "make" in the directory containing this file.
|
||||
|
||||
import os
|
||||
import ctypes
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.extend import ffi
|
||||
from jax.lib import xla_client
|
||||
from jax.interpreters import mlir
|
||||
|
||||
# start test boilerplate
|
||||
from absl.testing import absltest
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
# end test boilerplate
|
||||
|
||||
# XLA needs uppercase, "cuda" isn't recognized
|
||||
XLA_PLATFORM = "CUDA"
|
||||
|
||||
# JAX needs lowercase, "CUDA" isn't recognized
|
||||
JAX_PLATFORM = "cuda"
|
||||
|
||||
# 0 = original ("opaque"), 1 = FFI
|
||||
XLA_CUSTOM_CALL_API_VERSION = 1
|
||||
|
||||
# these strings are how we identify kernels to XLA:
|
||||
# - first we register a pointer to the kernel with XLA under this name
|
||||
# - then we "tell" JAX to emit StableHLO specifying this name to XLA
|
||||
XLA_CUSTOM_CALL_TARGET_FWD = "foo-fwd"
|
||||
XLA_CUSTOM_CALL_TARGET_BWD = "foo-bwd"
|
||||
|
||||
# independently, corresponding JAX primitives must also be named,
|
||||
# names can be different from XLA targets, here they are the same
|
||||
JAX_PRIMITIVE_FWD = "foo-fwd"
|
||||
JAX_PRIMITIVE_BWD = "foo-bwd"
|
||||
|
||||
if jtu.is_running_under_pytest():
|
||||
raise unittest.SkipTest("libfoo.so hasn't been built")
|
||||
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "libfoo.so")
|
||||
|
||||
library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
|
||||
|
||||
#-----------------------------------------------------------------------------#
|
||||
# Forward pass #
|
||||
#-----------------------------------------------------------------------------#
|
||||
|
||||
# register the XLA FFI binding pointer with XLA
|
||||
xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_FWD,
|
||||
fn=ffi.pycapsule(library.FooFwd),
|
||||
platform=XLA_PLATFORM,
|
||||
api_version=XLA_CUSTOM_CALL_API_VERSION)
|
||||
|
||||
|
||||
# our forward primitive will also return the intermediate output b+1
|
||||
# so it can be reused in the backward pass computation
|
||||
def _foo_fwd_abstract_eval(a, b):
|
||||
assert a.shape == b.shape
|
||||
assert a.dtype == b.dtype
|
||||
shaped_array = jax.core.ShapedArray(a.shape, a.dtype)
|
||||
return (
|
||||
shaped_array, # output c
|
||||
shaped_array, # intermediate output b+1
|
||||
)
|
||||
|
||||
|
||||
def _foo_fwd_lowering(ctx, a, b):
|
||||
# ffi.ffi_lowering does most of the heavy lifting building a lowering.
|
||||
# Keyword arguments passed to the lowering constructed by ffi_lowering are
|
||||
# turned into custom call backend_config entries, which we take advantage of
|
||||
# here for the dynamically computed n.
|
||||
n = np.prod(a.type.shape).astype(np.uint64)
|
||||
return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_FWD)(ctx, a, b, n=n)
|
||||
|
||||
|
||||
# construct a new JAX primitive
|
||||
foo_fwd_p = jax.core.Primitive(JAX_PRIMITIVE_FWD)
|
||||
# register the abstract evaluation rule for the forward primitive
|
||||
foo_fwd_p.def_abstract_eval(_foo_fwd_abstract_eval)
|
||||
foo_fwd_p.multiple_results = True
|
||||
mlir.register_lowering(foo_fwd_p, _foo_fwd_lowering, platform=JAX_PLATFORM)
|
||||
|
||||
#-----------------------------------------------------------------------------#
|
||||
# Backward pass #
|
||||
#-----------------------------------------------------------------------------#
|
||||
|
||||
# register the XLA FFI binding pointer with XLA
|
||||
xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_BWD,
|
||||
fn=ffi.pycapsule(library.FooBwd),
|
||||
platform=XLA_PLATFORM,
|
||||
api_version=XLA_CUSTOM_CALL_API_VERSION)
|
||||
|
||||
|
||||
def _foo_bwd_abstract_eval(c_grad, a, b_plus_1):
|
||||
assert c_grad.shape == a.shape
|
||||
assert a.shape == b_plus_1.shape
|
||||
assert c_grad.dtype == a.dtype
|
||||
assert a.dtype == b_plus_1.dtype
|
||||
|
||||
shaped_array = jax.core.ShapedArray(a.shape, a.dtype)
|
||||
return (
|
||||
shaped_array, # a_grad
|
||||
shaped_array, # b_grad
|
||||
)
|
||||
|
||||
|
||||
def _foo_bwd_lowering(ctx, c_grad, a, b_plus_1):
|
||||
n = np.prod(a.type.shape).astype(np.uint64)
|
||||
return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_BWD)(ctx,
|
||||
c_grad,
|
||||
a,
|
||||
b_plus_1,
|
||||
n=n)
|
||||
|
||||
|
||||
# construct a new JAX primitive
|
||||
foo_bwd_p = jax.core.Primitive(JAX_PRIMITIVE_BWD)
|
||||
# register the abstract evaluation rule for the backward primitive
|
||||
foo_bwd_p.def_abstract_eval(_foo_bwd_abstract_eval)
|
||||
foo_bwd_p.multiple_results = True
|
||||
mlir.register_lowering(foo_bwd_p, _foo_bwd_lowering, platform=JAX_PLATFORM)
|
||||
|
||||
#-----------------------------------------------------------------------------#
|
||||
# User facing API #
|
||||
#-----------------------------------------------------------------------------#
|
||||
|
||||
|
||||
def foo_fwd(a, b):
|
||||
c, b_plus_1 = foo_fwd_p.bind(a, b)
|
||||
return c, (a, b_plus_1)
|
||||
|
||||
|
||||
def foo_bwd(res, c_grad):
|
||||
a, b_plus_1 = res
|
||||
return foo_bwd_p.bind(c_grad, a, b_plus_1)
|
||||
|
||||
|
||||
@jax.custom_vjp
|
||||
def foo(a, b):
|
||||
c, _ = foo_fwd(a, b)
|
||||
return c
|
||||
|
||||
|
||||
foo.defvjp(foo_fwd, foo_bwd)
|
||||
|
||||
#-----------------------------------------------------------------------------#
|
||||
# Test #
|
||||
#-----------------------------------------------------------------------------#
|
||||
|
||||
|
||||
class CustomCallTest(jtu.JaxTestCase):
|
||||
|
||||
def test_fwd_interpretable(self):
|
||||
shape = (2, 3)
|
||||
a = 2. * jnp.ones(shape)
|
||||
b = 3. * jnp.ones(shape)
|
||||
observed = jax.jit(foo)(a, b)
|
||||
expected = (2. * (3. + 1.))
|
||||
self.assertArraysEqual(observed, expected)
|
||||
|
||||
def test_bwd_interpretable(self):
|
||||
shape = (2, 3)
|
||||
a = 2. * jnp.ones(shape)
|
||||
b = 3. * jnp.ones(shape)
|
||||
|
||||
def loss(a, b):
|
||||
return jnp.sum(foo(a, b))
|
||||
|
||||
da_observed, db_observed = jax.jit(jax.grad(loss, argnums=(0, 1)))(a, b)
|
||||
da_expected = b + 1
|
||||
db_expected = a
|
||||
self.assertArraysEqual(da_observed, da_expected)
|
||||
self.assertArraysEqual(db_observed, db_expected)
|
||||
|
||||
def test_fwd_random(self):
|
||||
shape = (2, 3)
|
||||
akey, bkey = jax.random.split(jax.random.key(0))
|
||||
a = jax.random.normal(key=akey, shape=shape)
|
||||
b = jax.random.normal(key=bkey, shape=shape)
|
||||
observed = jax.jit(foo)(a, b)
|
||||
expected = a * (b + 1)
|
||||
self.assertAllClose(observed, expected)
|
||||
|
||||
def test_bwd_random(self):
|
||||
shape = (2, 3)
|
||||
akey, bkey = jax.random.split(jax.random.key(0))
|
||||
a = jax.random.normal(key=akey, shape=shape)
|
||||
b = jax.random.normal(key=bkey, shape=shape)
|
||||
jtu.check_grads(f=jax.jit(foo), args=(a, b), order=1, modes=("rev",))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
137
docs/cuda_custom_call/foo.cu.cc
Normal file
137
docs/cuda_custom_call/foo.cu.cc
Normal file
@ -0,0 +1,137 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "xla/ffi/api/api.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// Forward pass //
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
// c = a * (b+1)
|
||||
// This strawman operation works well for demo purposes because:
|
||||
// 1. it's simple enough to be quickly understood,
|
||||
// 2. it's complex enough to require intermediate outputs in grad computation,
|
||||
// like many operations in practice do, and
|
||||
// 3. it does not have a built-in implementation in JAX.
|
||||
__global__ void FooFwdKernel(const float *a, const float *b, float *c,
|
||||
float *b_plus_1, // intermediate output b+1
|
||||
size_t n) {
|
||||
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t grid_stride = blockDim.x * gridDim.x;
|
||||
for (size_t i = tid; i < n; i += grid_stride) {
|
||||
b_plus_1[i] = b[i] + 1.0f;
|
||||
c[i] = a[i] * b_plus_1[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Host function wrapper that launches the kernel with hardcoded grid/block
|
||||
// size. Note, it uses types from XLA FFI. The return type must be ffi::Error.
|
||||
// Buffer type provides buffer dimensions, so the "n" argument here is not
|
||||
// strictly necessary, but it allows us to demonstrate the use of attributes
|
||||
// (.Attr in the FFI handler definition above).
|
||||
ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer<ffi::DataType::F32> a,
|
||||
ffi::Buffer<ffi::DataType::F32> b,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> c,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_plus_1,
|
||||
size_t n) {
|
||||
const int block_dim = 128;
|
||||
const int grid_dim = 1;
|
||||
// Note how we access regular Buffer data vs Result Buffer data:
|
||||
FooFwdKernel<<<grid_dim, block_dim, /*shared_mem=*/0, stream>>>(
|
||||
a.data, b.data, c->data, b_plus_1->data, n);
|
||||
// Check for launch time errors. Note that this function may also
|
||||
// return error codes from previous, asynchronous launches. This
|
||||
// means that an error status returned here could have been caused
|
||||
// by a different kernel previously launched by XLA.
|
||||
cudaError_t last_error = cudaGetLastError();
|
||||
if (last_error != cudaSuccess) {
|
||||
return ffi::Error(
|
||||
XLA_FFI_Error_Code_INTERNAL,
|
||||
std::string("CUDA error: ") + cudaGetErrorString(last_error));
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
// Creates symbol FooFwd with C linkage that can be loaded using Python ctypes
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
FooFwd, FooFwdHost,
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<cudaStream_t>>() // stream
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // a
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // c
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
|
||||
.Attr<size_t>("n"));
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// Backward pass //
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
// compute da = dc * (b+1), and
|
||||
// db = dc * a
|
||||
__global__ void FooBwdKernel(const float *c_grad, // incoming gradient wrt c
|
||||
const float *a, // original input a
|
||||
const float *b_plus_1, // intermediate output b+1
|
||||
float *a_grad, // outgoing gradient wrt a
|
||||
float *b_grad, // outgoing gradient wrt b
|
||||
size_t n) {
|
||||
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t grid_stride = blockDim.x * gridDim.x;
|
||||
for (size_t i = tid; i < n; i += grid_stride) {
|
||||
// In practice on GPUs b_plus_1 can be recomputed for practically free
|
||||
// instead of storing it out and reusing, so the reuse here is a bit
|
||||
// contrived. We do it to demonstrate residual/intermediate output passing
|
||||
// between the forward and the backward pass which becomes useful when
|
||||
// recomputation is more expensive than reuse.
|
||||
a_grad[i] = c_grad[i] * b_plus_1[i];
|
||||
b_grad[i] = c_grad[i] * a[i];
|
||||
}
|
||||
}
|
||||
|
||||
ffi::Error FooBwdHost(cudaStream_t stream,
|
||||
ffi::Buffer<ffi::DataType::F32> c_grad,
|
||||
ffi::Buffer<ffi::DataType::F32> a,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_plus_1,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> a_grad,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_grad,
|
||||
size_t n) {
|
||||
const int block_dim = 128;
|
||||
const int grid_dim = 1;
|
||||
FooBwdKernel<<<grid_dim, block_dim, /*shared_mem=*/0, stream>>>(
|
||||
c_grad.data, a.data, b_plus_1->data, a_grad->data, b_grad->data, n);
|
||||
cudaError_t last_error = cudaGetLastError();
|
||||
if (last_error != cudaSuccess) {
|
||||
return ffi::Error(
|
||||
XLA_FFI_Error_Code_INTERNAL,
|
||||
std::string("CUDA error: ") + cudaGetErrorString(last_error));
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
// Creates symbol FooBwd with C linkage that can be loaded using Python ctypes
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
FooBwd, FooBwdHost,
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<cudaStream_t>>() // stream
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // c_grad
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // a
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // a_grad
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_grad
|
||||
.Attr<size_t>("n"));
|
@ -557,6 +557,18 @@ def pytest_mark_if_available(marker: str):
|
||||
return wrap
|
||||
|
||||
|
||||
def is_running_under_pytest():
|
||||
return "pytest" in sys.modules
|
||||
|
||||
|
||||
def skip_under_pytest(reason: str):
|
||||
"""A decorator for test methods to skip the test when run under pytest."""
|
||||
reason = "Running under pytest: " + reason
|
||||
def skip(test_method):
|
||||
return unittest.skipIf(is_running_under_pytest(), reason)(test_method)
|
||||
return skip
|
||||
|
||||
|
||||
def format_test_name_suffix(opname, shapes, dtypes):
|
||||
arg_descriptions = (format_shape_dtype_string(shape, dtype)
|
||||
for shape, dtype in zip(shapes, dtypes))
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -27,10 +26,7 @@ config.parse_flags_with_absl()
|
||||
class GpuMemoryAllocationTest(absltest.TestCase):
|
||||
|
||||
# This test must be run in its own subprocess.
|
||||
@unittest.skipIf(
|
||||
"pytest" in sys.modules,
|
||||
"Test must run in an isolated process",
|
||||
)
|
||||
@jtu.skip_under_pytest("Test must run in an isolated process")
|
||||
@unittest.skipIf(
|
||||
"XLA_PYTHON_CLIENT_ALLOCATOR" in os.environ,
|
||||
"Test does not work if the python client allocator has been overriden",
|
||||
|
Loading…
x
Reference in New Issue
Block a user