Fixed the signature of the fallback get_compute_capability

PiperOrigin-RevId: 615338312
This commit is contained in:
Sergei Lebedev 2024-03-13 02:34:00 -07:00 committed by jax authors
parent 187b7aa8e6
commit a7964445e6

View File

@ -22,7 +22,8 @@ try:
get_compute_capability = triton_kernel_call_lib.get_compute_capability
except AttributeError:
def get_compute_capability() -> int:
def get_compute_capability(device) -> int:
del device # Unused.
raise RuntimeError(
"get_compute_capability is not available. Try installing jaxlib with"
" GPU support following instructions in"