diff --git a/docs/cuda_custom_call/BUILD b/docs/cuda_custom_call/BUILD new file mode 100644 index 000000000..0c19dccda --- /dev/null +++ b/docs/cuda_custom_call/BUILD @@ -0,0 +1,66 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +load("@rules_python//python:defs.bzl", "py_test") +load( + "//jaxlib:jax.bzl", + "cuda_library", + "jax_test", +) + +licenses(["notice"]) + +package( + default_applicable_licenses = [], + default_visibility = ["//visibility:private"], +) + +jax_test( + name = "cuda_custom_call_test", + srcs = ["cuda_custom_call_test.py"], + disable_backends = [ + "cpu", + "tpu", + ], + # libfoo.so is a runtime dependency for this test + data = [":foo"], + tags = [ + "notap", + "manual", + ], + deps = [ + "//jax:extend", + ] +) + +# this second target is needed to properly link in CUDA runtime symbols +# such as cudaLaunchKernel, even though we are only building one library. +cc_shared_library( + name = "foo", + deps = [ + ":foo_", + "@xla//xla/tsl/cuda:cudart", + ], +) + +cuda_library( + name = "foo_", + srcs = ["foo.cu.cc"], + deps = [ + "@xla//xla/ffi/api:ffi", + "@xla//xla/ffi/api:api", + "@xla//xla/ffi/api:c_api", + "@local_config_cuda//cuda:cuda_headers", + ], +) diff --git a/docs/cuda_custom_call/Makefile b/docs/cuda_custom_call/Makefile new file mode 100644 index 000000000..ca51b63b5 --- /dev/null +++ b/docs/cuda_custom_call/Makefile @@ -0,0 +1,35 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This Makefile is not used by Bazel for this test, it is intended to serve as +# documentation of build instructions for JAX users that are not using Bazel to +# build their custom call code. For that reason, this Makefile is likely subject +# to bitrot over time. Please file a JAX issue on GitHub if typing "make" in +# this directory no longer runs the test to completion. +NVCC = nvcc +NVCCFLAGS += -I$(shell python -c 'from jax.extend import ffi; print(ffi.include_dir())') +NVCCFLAGS += -arch native +# since the file extension is .cu.cc, tell NVCC explicitly to treat it as .cu +NVCCFLAGS += -x cu + +# depends on libfoo.so being in the same directory as cuda_custom_call_test.py +check: libfoo.so + python cuda_custom_call_test.py + +lib%.so: %.cu.cc + $(NVCC) $(NVCCFLAGS) --compiler-options=-shared,-fPIC -o $@ $< + +clean: + rm -rf *.so diff --git a/docs/cuda_custom_call/cuda_custom_call_test.py b/docs/cuda_custom_call/cuda_custom_call_test.py new file mode 100644 index 000000000..563462feb --- /dev/null +++ b/docs/cuda_custom_call/cuda_custom_call_test.py @@ -0,0 +1,216 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# This test is intentionally structured to stay close to what a standalone JAX +# custom call integration might look like. JAX test harness is in a separate +# section towards the end of this file. The test can be run standalone by typing +# "make" in the directory containing this file. + +import os +import ctypes +import unittest + +import numpy as np + +import jax +import jax.numpy as jnp +from jax.extend import ffi +from jax.lib import xla_client +from jax.interpreters import mlir + +# start test boilerplate +from absl.testing import absltest +from jax._src import config +from jax._src import test_util as jtu + +config.parse_flags_with_absl() +# end test boilerplate + +# XLA needs uppercase, "cuda" isn't recognized +XLA_PLATFORM = "CUDA" + +# JAX needs lowercase, "CUDA" isn't recognized +JAX_PLATFORM = "cuda" + +# 0 = original ("opaque"), 1 = FFI +XLA_CUSTOM_CALL_API_VERSION = 1 + +# these strings are how we identify kernels to XLA: +# - first we register a pointer to the kernel with XLA under this name +# - then we "tell" JAX to emit StableHLO specifying this name to XLA +XLA_CUSTOM_CALL_TARGET_FWD = "foo-fwd" +XLA_CUSTOM_CALL_TARGET_BWD = "foo-bwd" + +# independently, corresponding JAX primitives must also be named, +# names can be different from XLA targets, here they are the same +JAX_PRIMITIVE_FWD = "foo-fwd" +JAX_PRIMITIVE_BWD = "foo-bwd" + +if jtu.is_running_under_pytest(): + raise unittest.SkipTest("libfoo.so hasn't been built") +SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "libfoo.so") + +library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY) + +#-----------------------------------------------------------------------------# +# Forward pass # +#-----------------------------------------------------------------------------# + +# register the XLA FFI binding pointer with XLA +xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_FWD, + fn=ffi.pycapsule(library.FooFwd), + platform=XLA_PLATFORM, + api_version=XLA_CUSTOM_CALL_API_VERSION) + + +# our forward primitive will also return the intermediate output b+1 +# so it can be reused in the backward pass computation +def _foo_fwd_abstract_eval(a, b): + assert a.shape == b.shape + assert a.dtype == b.dtype + shaped_array = jax.core.ShapedArray(a.shape, a.dtype) + return ( + shaped_array, # output c + shaped_array, # intermediate output b+1 + ) + + +def _foo_fwd_lowering(ctx, a, b): + # ffi.ffi_lowering does most of the heavy lifting building a lowering. + # Keyword arguments passed to the lowering constructed by ffi_lowering are + # turned into custom call backend_config entries, which we take advantage of + # here for the dynamically computed n. + n = np.prod(a.type.shape).astype(np.uint64) + return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_FWD)(ctx, a, b, n=n) + + +# construct a new JAX primitive +foo_fwd_p = jax.core.Primitive(JAX_PRIMITIVE_FWD) +# register the abstract evaluation rule for the forward primitive +foo_fwd_p.def_abstract_eval(_foo_fwd_abstract_eval) +foo_fwd_p.multiple_results = True +mlir.register_lowering(foo_fwd_p, _foo_fwd_lowering, platform=JAX_PLATFORM) + +#-----------------------------------------------------------------------------# +# Backward pass # +#-----------------------------------------------------------------------------# + +# register the XLA FFI binding pointer with XLA +xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_BWD, + fn=ffi.pycapsule(library.FooBwd), + platform=XLA_PLATFORM, + api_version=XLA_CUSTOM_CALL_API_VERSION) + + +def _foo_bwd_abstract_eval(c_grad, a, b_plus_1): + assert c_grad.shape == a.shape + assert a.shape == b_plus_1.shape + assert c_grad.dtype == a.dtype + assert a.dtype == b_plus_1.dtype + + shaped_array = jax.core.ShapedArray(a.shape, a.dtype) + return ( + shaped_array, # a_grad + shaped_array, # b_grad + ) + + +def _foo_bwd_lowering(ctx, c_grad, a, b_plus_1): + n = np.prod(a.type.shape).astype(np.uint64) + return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_BWD)(ctx, + c_grad, + a, + b_plus_1, + n=n) + + +# construct a new JAX primitive +foo_bwd_p = jax.core.Primitive(JAX_PRIMITIVE_BWD) +# register the abstract evaluation rule for the backward primitive +foo_bwd_p.def_abstract_eval(_foo_bwd_abstract_eval) +foo_bwd_p.multiple_results = True +mlir.register_lowering(foo_bwd_p, _foo_bwd_lowering, platform=JAX_PLATFORM) + +#-----------------------------------------------------------------------------# +# User facing API # +#-----------------------------------------------------------------------------# + + +def foo_fwd(a, b): + c, b_plus_1 = foo_fwd_p.bind(a, b) + return c, (a, b_plus_1) + + +def foo_bwd(res, c_grad): + a, b_plus_1 = res + return foo_bwd_p.bind(c_grad, a, b_plus_1) + + +@jax.custom_vjp +def foo(a, b): + c, _ = foo_fwd(a, b) + return c + + +foo.defvjp(foo_fwd, foo_bwd) + +#-----------------------------------------------------------------------------# +# Test # +#-----------------------------------------------------------------------------# + + +class CustomCallTest(jtu.JaxTestCase): + + def test_fwd_interpretable(self): + shape = (2, 3) + a = 2. * jnp.ones(shape) + b = 3. * jnp.ones(shape) + observed = jax.jit(foo)(a, b) + expected = (2. * (3. + 1.)) + self.assertArraysEqual(observed, expected) + + def test_bwd_interpretable(self): + shape = (2, 3) + a = 2. * jnp.ones(shape) + b = 3. * jnp.ones(shape) + + def loss(a, b): + return jnp.sum(foo(a, b)) + + da_observed, db_observed = jax.jit(jax.grad(loss, argnums=(0, 1)))(a, b) + da_expected = b + 1 + db_expected = a + self.assertArraysEqual(da_observed, da_expected) + self.assertArraysEqual(db_observed, db_expected) + + def test_fwd_random(self): + shape = (2, 3) + akey, bkey = jax.random.split(jax.random.key(0)) + a = jax.random.normal(key=akey, shape=shape) + b = jax.random.normal(key=bkey, shape=shape) + observed = jax.jit(foo)(a, b) + expected = a * (b + 1) + self.assertAllClose(observed, expected) + + def test_bwd_random(self): + shape = (2, 3) + akey, bkey = jax.random.split(jax.random.key(0)) + a = jax.random.normal(key=akey, shape=shape) + b = jax.random.normal(key=bkey, shape=shape) + jtu.check_grads(f=jax.jit(foo), args=(a, b), order=1, modes=("rev",)) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/docs/cuda_custom_call/foo.cu.cc b/docs/cuda_custom_call/foo.cu.cc new file mode 100644 index 000000000..c154f52fb --- /dev/null +++ b/docs/cuda_custom_call/foo.cu.cc @@ -0,0 +1,137 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/ffi/api/api.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace ffi = xla::ffi; + +//----------------------------------------------------------------------------// +// Forward pass // +//----------------------------------------------------------------------------// + +// c = a * (b+1) +// This strawman operation works well for demo purposes because: +// 1. it's simple enough to be quickly understood, +// 2. it's complex enough to require intermediate outputs in grad computation, +// like many operations in practice do, and +// 3. it does not have a built-in implementation in JAX. +__global__ void FooFwdKernel(const float *a, const float *b, float *c, + float *b_plus_1, // intermediate output b+1 + size_t n) { + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t grid_stride = blockDim.x * gridDim.x; + for (size_t i = tid; i < n; i += grid_stride) { + b_plus_1[i] = b[i] + 1.0f; + c[i] = a[i] * b_plus_1[i]; + } +} + +// Host function wrapper that launches the kernel with hardcoded grid/block +// size. Note, it uses types from XLA FFI. The return type must be ffi::Error. +// Buffer type provides buffer dimensions, so the "n" argument here is not +// strictly necessary, but it allows us to demonstrate the use of attributes +// (.Attr in the FFI handler definition above). +ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer<ffi::DataType::F32> a, + ffi::Buffer<ffi::DataType::F32> b, + ffi::Result<ffi::Buffer<ffi::DataType::F32>> c, + ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_plus_1, + size_t n) { + const int block_dim = 128; + const int grid_dim = 1; + // Note how we access regular Buffer data vs Result Buffer data: + FooFwdKernel<<<grid_dim, block_dim, /*shared_mem=*/0, stream>>>( + a.data, b.data, c->data, b_plus_1->data, n); + // Check for launch time errors. Note that this function may also + // return error codes from previous, asynchronous launches. This + // means that an error status returned here could have been caused + // by a different kernel previously launched by XLA. + cudaError_t last_error = cudaGetLastError(); + if (last_error != cudaSuccess) { + return ffi::Error( + XLA_FFI_Error_Code_INTERNAL, + std::string("CUDA error: ") + cudaGetErrorString(last_error)); + } + return ffi::Error::Success(); +} + +// Creates symbol FooFwd with C linkage that can be loaded using Python ctypes +XLA_FFI_DEFINE_HANDLER_SYMBOL( + FooFwd, FooFwdHost, + ffi::Ffi::Bind() + .Ctx<ffi::PlatformStream<cudaStream_t>>() // stream + .Arg<ffi::Buffer<ffi::DataType::F32>>() // a + .Arg<ffi::Buffer<ffi::DataType::F32>>() // b + .Ret<ffi::Buffer<ffi::DataType::F32>>() // c + .Ret<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1 + .Attr<size_t>("n")); + +//----------------------------------------------------------------------------// +// Backward pass // +//----------------------------------------------------------------------------// + +// compute da = dc * (b+1), and +// db = dc * a +__global__ void FooBwdKernel(const float *c_grad, // incoming gradient wrt c + const float *a, // original input a + const float *b_plus_1, // intermediate output b+1 + float *a_grad, // outgoing gradient wrt a + float *b_grad, // outgoing gradient wrt b + size_t n) { + size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t grid_stride = blockDim.x * gridDim.x; + for (size_t i = tid; i < n; i += grid_stride) { + // In practice on GPUs b_plus_1 can be recomputed for practically free + // instead of storing it out and reusing, so the reuse here is a bit + // contrived. We do it to demonstrate residual/intermediate output passing + // between the forward and the backward pass which becomes useful when + // recomputation is more expensive than reuse. + a_grad[i] = c_grad[i] * b_plus_1[i]; + b_grad[i] = c_grad[i] * a[i]; + } +} + +ffi::Error FooBwdHost(cudaStream_t stream, + ffi::Buffer<ffi::DataType::F32> c_grad, + ffi::Buffer<ffi::DataType::F32> a, + ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_plus_1, + ffi::Result<ffi::Buffer<ffi::DataType::F32>> a_grad, + ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_grad, + size_t n) { + const int block_dim = 128; + const int grid_dim = 1; + FooBwdKernel<<<grid_dim, block_dim, /*shared_mem=*/0, stream>>>( + c_grad.data, a.data, b_plus_1->data, a_grad->data, b_grad->data, n); + cudaError_t last_error = cudaGetLastError(); + if (last_error != cudaSuccess) { + return ffi::Error( + XLA_FFI_Error_Code_INTERNAL, + std::string("CUDA error: ") + cudaGetErrorString(last_error)); + } + return ffi::Error::Success(); +} + +// Creates symbol FooBwd with C linkage that can be loaded using Python ctypes +XLA_FFI_DEFINE_HANDLER_SYMBOL( + FooBwd, FooBwdHost, + ffi::Ffi::Bind() + .Ctx<ffi::PlatformStream<cudaStream_t>>() // stream + .Arg<ffi::Buffer<ffi::DataType::F32>>() // c_grad + .Arg<ffi::Buffer<ffi::DataType::F32>>() // a + .Arg<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1 + .Ret<ffi::Buffer<ffi::DataType::F32>>() // a_grad + .Ret<ffi::Buffer<ffi::DataType::F32>>() // b_grad + .Attr<size_t>("n")); diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 7e20e5911..e2952bca6 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -557,6 +557,18 @@ def pytest_mark_if_available(marker: str): return wrap +def is_running_under_pytest(): + return "pytest" in sys.modules + + +def skip_under_pytest(reason: str): + """A decorator for test methods to skip the test when run under pytest.""" + reason = "Running under pytest: " + reason + def skip(test_method): + return unittest.skipIf(is_running_under_pytest(), reason)(test_method) + return skip + + def format_test_name_suffix(opname, shapes, dtypes): arg_descriptions = (format_shape_dtype_string(shape, dtype) for shape, dtype in zip(shapes, dtypes)) diff --git a/tests/gpu_memory_flags_test.py b/tests/gpu_memory_flags_test.py index d788d881f..308fff257 100644 --- a/tests/gpu_memory_flags_test.py +++ b/tests/gpu_memory_flags_test.py @@ -13,7 +13,6 @@ # limitations under the License. import os -import sys import unittest from absl.testing import absltest @@ -27,10 +26,7 @@ config.parse_flags_with_absl() class GpuMemoryAllocationTest(absltest.TestCase): # This test must be run in its own subprocess. - @unittest.skipIf( - "pytest" in sys.modules, - "Test must run in an isolated process", - ) + @jtu.skip_under_pytest("Test must run in an isolated process") @unittest.skipIf( "XLA_PYTHON_CLIENT_ALLOCATOR" in os.environ, "Test does not work if the python client allocator has been overriden",