diff --git a/docs/cuda_custom_call/BUILD b/docs/cuda_custom_call/BUILD
new file mode 100644
index 000000000..0c19dccda
--- /dev/null
+++ b/docs/cuda_custom_call/BUILD
@@ -0,0 +1,66 @@
+# Copyright 2024 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+load("@rules_python//python:defs.bzl", "py_test")
+load(
+    "//jaxlib:jax.bzl",
+    "cuda_library",
+    "jax_test",
+)
+
+licenses(["notice"])
+
+package(
+    default_applicable_licenses = [],
+    default_visibility = ["//visibility:private"],
+)
+
+jax_test(
+    name = "cuda_custom_call_test",
+    srcs = ["cuda_custom_call_test.py"],
+    disable_backends = [
+        "cpu",
+        "tpu",
+    ],
+    # libfoo.so is a runtime dependency for this test
+    data = [":foo"],
+    tags = [
+        "notap",
+        "manual",
+    ],
+    deps = [
+        "//jax:extend",
+    ]
+)
+
+# this second target is needed to properly link in CUDA runtime symbols
+# such as cudaLaunchKernel, even though we are only building one library.
+cc_shared_library(
+    name = "foo",
+    deps = [
+        ":foo_",
+        "@xla//xla/tsl/cuda:cudart",
+    ],
+)
+
+cuda_library(
+    name = "foo_",
+    srcs = ["foo.cu.cc"],
+    deps = [
+        "@xla//xla/ffi/api:ffi",
+        "@xla//xla/ffi/api:api",
+        "@xla//xla/ffi/api:c_api",
+        "@local_config_cuda//cuda:cuda_headers",
+    ],
+)
diff --git a/docs/cuda_custom_call/Makefile b/docs/cuda_custom_call/Makefile
new file mode 100644
index 000000000..ca51b63b5
--- /dev/null
+++ b/docs/cuda_custom_call/Makefile
@@ -0,0 +1,35 @@
+# Copyright 2024 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# This Makefile is not used by Bazel for this test, it is intended to serve as
+# documentation of build instructions for JAX users that are not using Bazel to
+# build their custom call code. For that reason, this Makefile is likely subject
+# to bitrot over time. Please file a JAX issue on GitHub if typing "make" in
+# this directory no longer runs the test to completion.
+NVCC = nvcc
+NVCCFLAGS += -I$(shell python -c 'from jax.extend import ffi; print(ffi.include_dir())')
+NVCCFLAGS += -arch native
+# since the file extension is .cu.cc, tell NVCC explicitly to treat it as .cu
+NVCCFLAGS += -x cu
+
+# depends on libfoo.so being in the same directory as cuda_custom_call_test.py
+check: libfoo.so
+	python cuda_custom_call_test.py
+
+lib%.so: %.cu.cc
+	$(NVCC) $(NVCCFLAGS) --compiler-options=-shared,-fPIC -o $@ $<
+
+clean:
+	rm -rf *.so
diff --git a/docs/cuda_custom_call/cuda_custom_call_test.py b/docs/cuda_custom_call/cuda_custom_call_test.py
new file mode 100644
index 000000000..563462feb
--- /dev/null
+++ b/docs/cuda_custom_call/cuda_custom_call_test.py
@@ -0,0 +1,216 @@
+# Copyright 2024 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# This test is intentionally structured to stay close to what a standalone JAX
+# custom call integration might look like. JAX test harness is in a separate
+# section towards the end of this file. The test can be run standalone by typing
+# "make" in the directory containing this file.
+
+import os
+import ctypes
+import unittest
+
+import numpy as np
+
+import jax
+import jax.numpy as jnp
+from jax.extend import ffi
+from jax.lib import xla_client
+from jax.interpreters import mlir
+
+# start test boilerplate
+from absl.testing import absltest
+from jax._src import config
+from jax._src import test_util as jtu
+
+config.parse_flags_with_absl()
+# end test boilerplate
+
+# XLA needs uppercase, "cuda" isn't recognized
+XLA_PLATFORM = "CUDA"
+
+# JAX needs lowercase, "CUDA" isn't recognized
+JAX_PLATFORM = "cuda"
+
+# 0 = original ("opaque"), 1 = FFI
+XLA_CUSTOM_CALL_API_VERSION = 1
+
+# these strings are how we identify kernels to XLA:
+# - first we register a pointer to the kernel with XLA under this name
+# - then we "tell" JAX to emit StableHLO specifying this name to XLA
+XLA_CUSTOM_CALL_TARGET_FWD = "foo-fwd"
+XLA_CUSTOM_CALL_TARGET_BWD = "foo-bwd"
+
+# independently, corresponding JAX primitives must also be named,
+# names can be different from XLA targets, here they are the same
+JAX_PRIMITIVE_FWD = "foo-fwd"
+JAX_PRIMITIVE_BWD = "foo-bwd"
+
+if jtu.is_running_under_pytest():
+  raise unittest.SkipTest("libfoo.so hasn't been built")
+SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "libfoo.so")
+
+library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
+
+#-----------------------------------------------------------------------------#
+#                              Forward pass                                   #
+#-----------------------------------------------------------------------------#
+
+# register the XLA FFI binding pointer with XLA
+xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_FWD,
+                                       fn=ffi.pycapsule(library.FooFwd),
+                                       platform=XLA_PLATFORM,
+                                       api_version=XLA_CUSTOM_CALL_API_VERSION)
+
+
+# our forward primitive will also return the intermediate output b+1
+# so it can be reused in the backward pass computation
+def _foo_fwd_abstract_eval(a, b):
+  assert a.shape == b.shape
+  assert a.dtype == b.dtype
+  shaped_array = jax.core.ShapedArray(a.shape, a.dtype)
+  return (
+      shaped_array,  # output c
+      shaped_array,  # intermediate output b+1
+  )
+
+
+def _foo_fwd_lowering(ctx, a, b):
+  # ffi.ffi_lowering does most of the heavy lifting building a lowering.
+  # Keyword arguments passed to the lowering constructed by ffi_lowering are
+  # turned into custom call backend_config entries, which we take advantage of
+  # here for the dynamically computed n.
+  n = np.prod(a.type.shape).astype(np.uint64)
+  return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_FWD)(ctx, a, b, n=n)
+
+
+# construct a new JAX primitive
+foo_fwd_p = jax.core.Primitive(JAX_PRIMITIVE_FWD)
+# register the abstract evaluation rule for the forward primitive
+foo_fwd_p.def_abstract_eval(_foo_fwd_abstract_eval)
+foo_fwd_p.multiple_results = True
+mlir.register_lowering(foo_fwd_p, _foo_fwd_lowering, platform=JAX_PLATFORM)
+
+#-----------------------------------------------------------------------------#
+#                              Backward pass                                  #
+#-----------------------------------------------------------------------------#
+
+# register the XLA FFI binding pointer with XLA
+xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_BWD,
+                                       fn=ffi.pycapsule(library.FooBwd),
+                                       platform=XLA_PLATFORM,
+                                       api_version=XLA_CUSTOM_CALL_API_VERSION)
+
+
+def _foo_bwd_abstract_eval(c_grad, a, b_plus_1):
+  assert c_grad.shape == a.shape
+  assert a.shape == b_plus_1.shape
+  assert c_grad.dtype == a.dtype
+  assert a.dtype == b_plus_1.dtype
+
+  shaped_array = jax.core.ShapedArray(a.shape, a.dtype)
+  return (
+      shaped_array,  # a_grad
+      shaped_array,  # b_grad
+  )
+
+
+def _foo_bwd_lowering(ctx, c_grad, a, b_plus_1):
+  n = np.prod(a.type.shape).astype(np.uint64)
+  return ffi.ffi_lowering(XLA_CUSTOM_CALL_TARGET_BWD)(ctx,
+                                                      c_grad,
+                                                      a,
+                                                      b_plus_1,
+                                                      n=n)
+
+
+# construct a new JAX primitive
+foo_bwd_p = jax.core.Primitive(JAX_PRIMITIVE_BWD)
+# register the abstract evaluation rule for the backward primitive
+foo_bwd_p.def_abstract_eval(_foo_bwd_abstract_eval)
+foo_bwd_p.multiple_results = True
+mlir.register_lowering(foo_bwd_p, _foo_bwd_lowering, platform=JAX_PLATFORM)
+
+#-----------------------------------------------------------------------------#
+#                              User facing API                                #
+#-----------------------------------------------------------------------------#
+
+
+def foo_fwd(a, b):
+  c, b_plus_1 = foo_fwd_p.bind(a, b)
+  return c, (a, b_plus_1)
+
+
+def foo_bwd(res, c_grad):
+  a, b_plus_1 = res
+  return foo_bwd_p.bind(c_grad, a, b_plus_1)
+
+
+@jax.custom_vjp
+def foo(a, b):
+  c, _ = foo_fwd(a, b)
+  return c
+
+
+foo.defvjp(foo_fwd, foo_bwd)
+
+#-----------------------------------------------------------------------------#
+#                                  Test                                       #
+#-----------------------------------------------------------------------------#
+
+
+class CustomCallTest(jtu.JaxTestCase):
+
+  def test_fwd_interpretable(self):
+    shape = (2, 3)
+    a = 2. * jnp.ones(shape)
+    b = 3. * jnp.ones(shape)
+    observed = jax.jit(foo)(a, b)
+    expected = (2. * (3. + 1.))
+    self.assertArraysEqual(observed, expected)
+
+  def test_bwd_interpretable(self):
+    shape = (2, 3)
+    a = 2. * jnp.ones(shape)
+    b = 3. * jnp.ones(shape)
+
+    def loss(a, b):
+      return jnp.sum(foo(a, b))
+
+    da_observed, db_observed = jax.jit(jax.grad(loss, argnums=(0, 1)))(a, b)
+    da_expected = b + 1
+    db_expected = a
+    self.assertArraysEqual(da_observed, da_expected)
+    self.assertArraysEqual(db_observed, db_expected)
+
+  def test_fwd_random(self):
+    shape = (2, 3)
+    akey, bkey = jax.random.split(jax.random.key(0))
+    a = jax.random.normal(key=akey, shape=shape)
+    b = jax.random.normal(key=bkey, shape=shape)
+    observed = jax.jit(foo)(a, b)
+    expected = a * (b + 1)
+    self.assertAllClose(observed, expected)
+
+  def test_bwd_random(self):
+    shape = (2, 3)
+    akey, bkey = jax.random.split(jax.random.key(0))
+    a = jax.random.normal(key=akey, shape=shape)
+    b = jax.random.normal(key=bkey, shape=shape)
+    jtu.check_grads(f=jax.jit(foo), args=(a, b), order=1, modes=("rev",))
+
+
+if __name__ == "__main__":
+  absltest.main(testLoader=jtu.JaxTestLoader())
diff --git a/docs/cuda_custom_call/foo.cu.cc b/docs/cuda_custom_call/foo.cu.cc
new file mode 100644
index 000000000..c154f52fb
--- /dev/null
+++ b/docs/cuda_custom_call/foo.cu.cc
@@ -0,0 +1,137 @@
+/* Copyright 2024 The JAX Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "xla/ffi/api/api.h"
+#include "xla/ffi/api/c_api.h"
+#include "xla/ffi/api/ffi.h"
+
+namespace ffi = xla::ffi;
+
+//----------------------------------------------------------------------------//
+//                            Forward pass                                    //
+//----------------------------------------------------------------------------//
+
+// c = a * (b+1)
+// This strawman operation works well for demo purposes because:
+// 1. it's simple enough to be quickly understood,
+// 2. it's complex enough to require intermediate outputs in grad computation,
+//    like many operations in practice do, and
+// 3. it does not have a built-in implementation in JAX.
+__global__ void FooFwdKernel(const float *a, const float *b, float *c,
+                             float *b_plus_1,  // intermediate output b+1
+                             size_t n) {
+  size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
+  const size_t grid_stride = blockDim.x * gridDim.x;
+  for (size_t i = tid; i < n; i += grid_stride) {
+    b_plus_1[i] = b[i] + 1.0f;
+    c[i] = a[i] * b_plus_1[i];
+  }
+}
+
+// Host function wrapper that launches the kernel with hardcoded grid/block
+// size. Note, it uses types from XLA FFI. The return type must be ffi::Error.
+// Buffer type provides buffer dimensions, so the "n" argument here is not
+// strictly necessary, but it allows us to demonstrate the use of attributes
+// (.Attr in the FFI handler definition above).
+ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer<ffi::DataType::F32> a,
+                      ffi::Buffer<ffi::DataType::F32> b,
+                      ffi::Result<ffi::Buffer<ffi::DataType::F32>> c,
+                      ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_plus_1,
+                      size_t n) {
+  const int block_dim = 128;
+  const int grid_dim = 1;
+  // Note how we access regular Buffer data vs Result Buffer data:
+  FooFwdKernel<<<grid_dim, block_dim, /*shared_mem=*/0, stream>>>(
+      a.data, b.data, c->data, b_plus_1->data, n);
+  // Check for launch time errors. Note that this function may also
+  // return error codes from previous, asynchronous launches. This
+  // means that an error status returned here could have been caused
+  // by a different kernel previously launched by XLA.
+  cudaError_t last_error = cudaGetLastError();
+  if (last_error != cudaSuccess) {
+    return ffi::Error(
+        XLA_FFI_Error_Code_INTERNAL,
+        std::string("CUDA error: ") + cudaGetErrorString(last_error));
+  }
+  return ffi::Error::Success();
+}
+
+// Creates symbol FooFwd with C linkage that can be loaded using Python ctypes
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+    FooFwd, FooFwdHost,
+    ffi::Ffi::Bind()
+        .Ctx<ffi::PlatformStream<cudaStream_t>>()  // stream
+        .Arg<ffi::Buffer<ffi::DataType::F32>>()    // a
+        .Arg<ffi::Buffer<ffi::DataType::F32>>()    // b
+        .Ret<ffi::Buffer<ffi::DataType::F32>>()    // c
+        .Ret<ffi::Buffer<ffi::DataType::F32>>()    // b_plus_1
+        .Attr<size_t>("n"));
+
+//----------------------------------------------------------------------------//
+//                            Backward pass                                   //
+//----------------------------------------------------------------------------//
+
+// compute da = dc * (b+1), and
+//         db = dc * a
+__global__ void FooBwdKernel(const float *c_grad,    // incoming gradient wrt c
+                             const float *a,         // original input a
+                             const float *b_plus_1,  // intermediate output b+1
+                             float *a_grad,          // outgoing gradient wrt a
+                             float *b_grad,          // outgoing gradient wrt b
+                             size_t n) {
+  size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
+  const size_t grid_stride = blockDim.x * gridDim.x;
+  for (size_t i = tid; i < n; i += grid_stride) {
+    // In practice on GPUs b_plus_1 can be recomputed for practically free
+    // instead of storing it out and reusing, so the reuse here is a bit
+    // contrived. We do it to demonstrate residual/intermediate output passing
+    // between the forward and the backward pass which becomes useful when
+    // recomputation is more expensive than reuse.
+    a_grad[i] = c_grad[i] * b_plus_1[i];
+    b_grad[i] = c_grad[i] * a[i];
+  }
+}
+
+ffi::Error FooBwdHost(cudaStream_t stream,
+                      ffi::Buffer<ffi::DataType::F32> c_grad,
+                      ffi::Buffer<ffi::DataType::F32> a,
+                      ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_plus_1,
+                      ffi::Result<ffi::Buffer<ffi::DataType::F32>> a_grad,
+                      ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_grad,
+                      size_t n) {
+  const int block_dim = 128;
+  const int grid_dim = 1;
+  FooBwdKernel<<<grid_dim, block_dim, /*shared_mem=*/0, stream>>>(
+      c_grad.data, a.data, b_plus_1->data, a_grad->data, b_grad->data, n);
+  cudaError_t last_error = cudaGetLastError();
+  if (last_error != cudaSuccess) {
+    return ffi::Error(
+        XLA_FFI_Error_Code_INTERNAL,
+        std::string("CUDA error: ") + cudaGetErrorString(last_error));
+  }
+  return ffi::Error::Success();
+}
+
+// Creates symbol FooBwd with C linkage that can be loaded using Python ctypes
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+    FooBwd, FooBwdHost,
+    ffi::Ffi::Bind()
+        .Ctx<ffi::PlatformStream<cudaStream_t>>()  // stream
+        .Arg<ffi::Buffer<ffi::DataType::F32>>()    // c_grad
+        .Arg<ffi::Buffer<ffi::DataType::F32>>()    // a
+        .Arg<ffi::Buffer<ffi::DataType::F32>>()    // b_plus_1
+        .Ret<ffi::Buffer<ffi::DataType::F32>>()    // a_grad
+        .Ret<ffi::Buffer<ffi::DataType::F32>>()    // b_grad
+        .Attr<size_t>("n"));
diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py
index 7e20e5911..e2952bca6 100644
--- a/jax/_src/test_util.py
+++ b/jax/_src/test_util.py
@@ -557,6 +557,18 @@ def pytest_mark_if_available(marker: str):
   return wrap
 
 
+def is_running_under_pytest():
+  return "pytest" in sys.modules
+
+
+def skip_under_pytest(reason: str):
+  """A decorator for test methods to skip the test when run under pytest."""
+  reason = "Running under pytest: " + reason
+  def skip(test_method):
+    return unittest.skipIf(is_running_under_pytest(), reason)(test_method)
+  return skip
+
+
 def format_test_name_suffix(opname, shapes, dtypes):
   arg_descriptions = (format_shape_dtype_string(shape, dtype)
                       for shape, dtype in zip(shapes, dtypes))
diff --git a/tests/gpu_memory_flags_test.py b/tests/gpu_memory_flags_test.py
index d788d881f..308fff257 100644
--- a/tests/gpu_memory_flags_test.py
+++ b/tests/gpu_memory_flags_test.py
@@ -13,7 +13,6 @@
 # limitations under the License.
 
 import os
-import sys
 import unittest
 
 from absl.testing import absltest
@@ -27,10 +26,7 @@ config.parse_flags_with_absl()
 class GpuMemoryAllocationTest(absltest.TestCase):
 
   # This test must be run in its own subprocess.
-  @unittest.skipIf(
-      "pytest" in sys.modules,
-      "Test must run in an isolated process",
-  )
+  @jtu.skip_under_pytest("Test must run in an isolated process")
   @unittest.skipIf(
       "XLA_PYTHON_CLIENT_ALLOCATOR" in os.environ,
       "Test does not work if the python client allocator has been overriden",