mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Update is_device_cuda to support testing for GPU plugin.
GPU plugin platform version is "PJRT C API\ncuda ...". PiperOrigin-RevId: 573017348
This commit is contained in:
parent
ab161bbd40
commit
6fb776beac
@ -304,7 +304,7 @@ def is_device_rocm():
|
||||
return xla_bridge.get_backend().platform_version.startswith('rocm')
|
||||
|
||||
def is_device_cuda():
|
||||
return xla_bridge.get_backend().platform_version.startswith('cuda')
|
||||
return 'cuda' in xla_bridge.get_backend().platform_version
|
||||
|
||||
def is_cloud_tpu():
|
||||
return 'libtpu' in xla_bridge.get_backend().platform_version
|
||||
|
Loading…
x
Reference in New Issue
Block a user