Update is_device_cuda to support testing for GPU plugin.

GPU plugin platform version is "PJRT C API\ncuda ...".

PiperOrigin-RevId: 573017348
This commit is contained in:
Jieying Luo 2023-10-12 14:44:35 -07:00 committed by jax authors
parent ab161bbd40
commit 6fb776beac

View File

@ -304,7 +304,7 @@ def is_device_rocm():
return xla_bridge.get_backend().platform_version.startswith('rocm')
def is_device_cuda():
return xla_bridge.get_backend().platform_version.startswith('cuda')
return 'cuda' in xla_bridge.get_backend().platform_version
def is_cloud_tpu():
return 'libtpu' in xla_bridge.get_backend().platform_version