mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #24341 from phu0ngng:cuda_graph_ex
PiperOrigin-RevId: 686577115
This commit is contained in:
commit
089e4aa904
@ -27,7 +27,6 @@ import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.extend import ffi
|
||||
from jax.lib import xla_client
|
||||
|
||||
# start test boilerplate
|
||||
from absl.testing import absltest
|
||||
@ -58,16 +57,13 @@ if jtu.is_running_under_pytest():
|
||||
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "libfoo.so")
|
||||
library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
|
||||
|
||||
# register the custom calls targets with XLA
|
||||
xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_FWD,
|
||||
fn=ffi.pycapsule(library.FooFwd),
|
||||
platform=XLA_PLATFORM,
|
||||
api_version=XLA_CUSTOM_CALL_API_VERSION)
|
||||
xla_client.register_custom_call_target(name=XLA_CUSTOM_CALL_TARGET_BWD,
|
||||
fn=ffi.pycapsule(library.FooBwd),
|
||||
platform=XLA_PLATFORM,
|
||||
api_version=XLA_CUSTOM_CALL_API_VERSION)
|
||||
|
||||
# register the custom calls targets with XLA, api_version=1 by default
|
||||
ffi.register_ffi_target(name=XLA_CUSTOM_CALL_TARGET_FWD,
|
||||
fn=ffi.pycapsule(library.FooFwd),
|
||||
platform=XLA_PLATFORM)
|
||||
ffi.register_ffi_target(name=XLA_CUSTOM_CALL_TARGET_BWD,
|
||||
fn=ffi.pycapsule(library.FooBwd),
|
||||
platform=XLA_PLATFORM)
|
||||
|
||||
def foo_fwd(a, b):
|
||||
assert a.dtype == jnp.float32
|
||||
|
Loading…
x
Reference in New Issue
Block a user