mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Fix jax 0.3.11 GPU breakge when used with jaxlib 0.3.10.
This commit is contained in:
parent
1381afc37f
commit
337ec47d13
@ -183,7 +183,7 @@ except ImportError:
|
||||
hip_linalg = None
|
||||
|
||||
try:
|
||||
import jaxlib.cuda_linalg as gpu_linalg # pytype: disable=import-error
|
||||
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
|
||||
except ImportError:
|
||||
gpu_linalg = None
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user