Merge pull request #24341 from phu0ngng:cuda_graph_ex

PiperOrigin-RevId: 686577115
This commit is contained in:
jax authors 2024-10-16 11:23:28 -07:00
commit 089e4aa904

View File

@ -27,7 +27,6 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax.extend import ffi
from jax.lib import xla_client
# start test boilerplate
from absl.testing import absltest
@ -58,16 +57,13 @@ if jtu.is_running_under_pytest():
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "libfoo.so")
library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
# register the custom calls targets with XLA
xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_FWD,
fn=ffi.pycapsule(library.FooFwd),
platform=XLA_PLATFORM,
api_version=XLA_CUSTOM_CALL_API_VERSION)
xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_BWD,
fn=ffi.pycapsule(library.FooBwd),
platform=XLA_PLATFORM,
api_version=XLA_CUSTOM_CALL_API_VERSION)
# register the custom calls targets with XLA, api_version=1 by default
ffi.register_ffi_target(name=XLA_CUSTOM_CALL_TARGET_FWD,
fn=ffi.pycapsule(library.FooFwd),
platform=XLA_PLATFORM)
ffi.register_ffi_target(name=XLA_CUSTOM_CALL_TARGET_BWD,
fn=ffi.pycapsule(library.FooBwd),
platform=XLA_PLATFORM)
def foo_fwd(a, b):
assert a.dtype == jnp.float32