Improve the gpu lowering error message if users forget link the gpu library.

PiperOrigin-RevId: 564530960
This commit is contained in:
John QiangZhang 2023-09-11 16:13:09 -07:00 committed by jax authors
parent 6c3b42d33c
commit 997b35e1d9
6 changed files with 47 additions and 3 deletions

View File

@ -34,6 +34,7 @@ py_library_providing_imports_info(
name = "jaxlib",
srcs = [
"ducc_fft.py",
"gpu_common_utils.py",
"gpu_linalg.py",
"gpu_prng.py",
"gpu_rnn.py",

View File

@ -0,0 +1,27 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Those common utility functions for gpu."""
class GpuLibNotLinkedError(Exception):
"""Raised when the GPU library is not linked."""
error_msg = (
'JAX was not built with GPU support. Please use a GPU-enabled JAX to use'
' this function.'
)
def __init__(self):
super().__init__(self.error_msg)

View File

@ -19,6 +19,7 @@ import operator
import jaxlib.mlir.ir as ir
from .hlo_helpers import custom_call
from .gpu_common_utils import GpuLibNotLinkedError
from jaxlib import xla_client
@ -50,6 +51,9 @@ def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_s
batch_size = _prod(dims[:-1])
pivot_size = dims[-1]
if not gpu_linalg:
raise GpuLibNotLinkedError()
opaque = gpu_linalg.lu_pivots_to_permutation_descriptor(
batch_size, pivot_size, permutation_size)
pivots_layout = tuple(range(len(dims) - 1, -1, -1))

View File

@ -24,6 +24,7 @@ import jaxlib.mlir.ir as ir
from jaxlib import xla_client
from .hlo_helpers import custom_call
from .gpu_common_utils import GpuLibNotLinkedError
try:
from .cuda import _prng as _cuda_prng # pytype: disable=import-error
@ -41,7 +42,6 @@ except ImportError:
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
def _threefry2x32_lowering(prng, platform, keys, data,
length: Optional[Union[int, ir.Value]] = None,
output_shape: Optional[ir.Value] = None):
@ -50,6 +50,8 @@ def _threefry2x32_lowering(prng, platform, keys, data,
In presence of dynamic shapes, `length` is an `ir.Value` and `output_shape`
is a 1D tensor describing the shape of the two outputs.
"""
if not prng:
raise GpuLibNotLinkedError()
assert len(keys) == 2, keys
assert len(data) == 2, data
assert (ir.RankedTensorType(keys[0].type).element_type ==

View File

@ -18,6 +18,7 @@ import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np
from jaxlib import xla_client
from .gpu_common_utils import GpuLibNotLinkedError
try:
from .cuda import _rnn # pytype: disable=import-error
@ -58,8 +59,9 @@ def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *,
reserve_space_shape = ctx.avals_out[3].shape
reserve_space_type = ir.RankedTensorType.get(reserve_space_shape,
ir.F32Type.get())
if _rnn is None:
raise RuntimeError("cuda couldn't be imported")
if not _rnn:
raise GpuLibNotLinkedError()
opaque = _rnn.build_rnn_descriptor(input_size, hidden_size, num_layers,
batch_size, max_seq_length, dropout,
bidirectional, workspace_shape[0],

View File

@ -21,6 +21,8 @@ import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np
from .gpu_common_utils import GpuLibNotLinkedError
from jaxlib import xla_client
from .hlo_helpers import (
@ -72,6 +74,9 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a):
num_bd = len(batch_dims)
batch = math.prod(batch_dims)
if not gpu_blas:
raise GpuLibNotLinkedError()
if batch > 1 and m == n and m // batch <= 128:
lwork, opaque = gpu_blas.build_getrf_batched_descriptor(
np.dtype(dtype), batch, m)
@ -157,6 +162,9 @@ def _geqrf_batched_hlo(platform, gpu_blas, dtype, a):
num_bd = len(batch_dims)
batch = math.prod(batch_dims)
if not gpu_blas:
raise GpuLibNotLinkedError()
lwork, opaque = gpu_blas.build_geqrf_batched_descriptor(
np.dtype(dtype), batch, m, n)