Fix jax 0.3.11 GPU breakge when used with jaxlib 0.3.10.

This commit is contained in:
Peter Hawkins 2022-05-16 00:24:04 +00:00
parent 1381afc37f
commit 337ec47d13

View File

@ -183,7 +183,7 @@ except ImportError:
hip_linalg = None
try:
import jaxlib.cuda_linalg as gpu_linalg # pytype: disable=import-error
import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
except ImportError:
gpu_linalg = None