mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
[ROCm] Add dlpack backend support
Depends on the Tensorflow commit included in this PR https://github.com/tensorflow/tensorflow/pull/57640
This commit is contained in:
parent
f9e7629c3f
commit
4370b3385f
@ -64,6 +64,14 @@ def from_dlpack(dlpack):
|
||||
gpu_backend = xla_bridge.get_backend("cuda")
|
||||
except RuntimeError:
|
||||
gpu_backend = None
|
||||
|
||||
# Try ROCm if CUDA backend not found
|
||||
if gpu_backend is None:
|
||||
try:
|
||||
gpu_backend = xla_bridge.get_backend("rocm")
|
||||
except RuntimeError:
|
||||
gpu_backend = None
|
||||
|
||||
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
|
||||
dlpack, cpu_backend, gpu_backend)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user