[ROCm] Add dlpack backend support

Depends on the Tensorflow commit included in this
	PR https://github.com/tensorflow/tensorflow/pull/57640
This commit is contained in:
Rahul Batra 2022-08-16 18:44:36 +00:00
parent f9e7629c3f
commit 4370b3385f

View File

@ -64,6 +64,14 @@ def from_dlpack(dlpack):
gpu_backend = xla_bridge.get_backend("cuda")
except RuntimeError:
gpu_backend = None
# Try ROCm if CUDA backend not found
if gpu_backend is None:
try:
gpu_backend = xla_bridge.get_backend("rocm")
except RuntimeError:
gpu_backend = None
buf = xla_client._xla.dlpack_managed_tensor_to_buffer(
dlpack, cpu_backend, gpu_backend)