From 997b35e1d96b7172f40fe7dc411d1ceab23ea431 Mon Sep 17 00:00:00 2001 From: John QiangZhang Date: Mon, 11 Sep 2023 16:13:09 -0700 Subject: [PATCH] Improve the gpu lowering error message if users forget link the gpu library. PiperOrigin-RevId: 564530960 --- jaxlib/BUILD | 1 + jaxlib/gpu_common_utils.py | 27 +++++++++++++++++++++++++++ jaxlib/gpu_linalg.py | 4 ++++ jaxlib/gpu_prng.py | 4 +++- jaxlib/gpu_rnn.py | 6 ++++-- jaxlib/gpu_solver.py | 8 ++++++++ 6 files changed, 47 insertions(+), 3 deletions(-) create mode 100644 jaxlib/gpu_common_utils.py diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 548f4d420..a98695bdb 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -34,6 +34,7 @@ py_library_providing_imports_info( name = "jaxlib", srcs = [ "ducc_fft.py", + "gpu_common_utils.py", "gpu_linalg.py", "gpu_prng.py", "gpu_rnn.py", diff --git a/jaxlib/gpu_common_utils.py b/jaxlib/gpu_common_utils.py new file mode 100644 index 000000000..3bf6d1544 --- /dev/null +++ b/jaxlib/gpu_common_utils.py @@ -0,0 +1,27 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Those common utility functions for gpu.""" + + +class GpuLibNotLinkedError(Exception): + """Raised when the GPU library is not linked.""" + + error_msg = ( + 'JAX was not built with GPU support. Please use a GPU-enabled JAX to use' + ' this function.' + ) + + def __init__(self): + super().__init__(self.error_msg) diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index 49958d5ea..273a1017f 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -19,6 +19,7 @@ import operator import jaxlib.mlir.ir as ir from .hlo_helpers import custom_call +from .gpu_common_utils import GpuLibNotLinkedError from jaxlib import xla_client @@ -50,6 +51,9 @@ def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_s batch_size = _prod(dims[:-1]) pivot_size = dims[-1] + if not gpu_linalg: + raise GpuLibNotLinkedError() + opaque = gpu_linalg.lu_pivots_to_permutation_descriptor( batch_size, pivot_size, permutation_size) pivots_layout = tuple(range(len(dims) - 1, -1, -1)) diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index 9b637fe46..8a5c28577 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -24,6 +24,7 @@ import jaxlib.mlir.ir as ir from jaxlib import xla_client from .hlo_helpers import custom_call +from .gpu_common_utils import GpuLibNotLinkedError try: from .cuda import _prng as _cuda_prng # pytype: disable=import-error @@ -41,7 +42,6 @@ except ImportError: _prod = lambda xs: functools.reduce(operator.mul, xs, 1) - def _threefry2x32_lowering(prng, platform, keys, data, length: Optional[Union[int, ir.Value]] = None, output_shape: Optional[ir.Value] = None): @@ -50,6 +50,8 @@ def _threefry2x32_lowering(prng, platform, keys, data, In presence of dynamic shapes, `length` is an `ir.Value` and `output_shape` is a 1D tensor describing the shape of the two outputs. """ + if not prng: + raise GpuLibNotLinkedError() assert len(keys) == 2, keys assert len(data) == 2, data assert (ir.RankedTensorType(keys[0].type).element_type == diff --git a/jaxlib/gpu_rnn.py b/jaxlib/gpu_rnn.py index 0393ef37a..85779b40f 100644 --- a/jaxlib/gpu_rnn.py +++ b/jaxlib/gpu_rnn.py @@ -18,6 +18,7 @@ import jaxlib.mlir.dialects.stablehlo as hlo import numpy as np from jaxlib import xla_client +from .gpu_common_utils import GpuLibNotLinkedError try: from .cuda import _rnn # pytype: disable=import-error @@ -58,8 +59,9 @@ def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *, reserve_space_shape = ctx.avals_out[3].shape reserve_space_type = ir.RankedTensorType.get(reserve_space_shape, ir.F32Type.get()) - if _rnn is None: - raise RuntimeError("cuda couldn't be imported") + if not _rnn: + raise GpuLibNotLinkedError() + opaque = _rnn.build_rnn_descriptor(input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout, bidirectional, workspace_shape[0], diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 11f2fa4cf..9650b71b2 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -21,6 +21,8 @@ import jaxlib.mlir.dialects.stablehlo as hlo import numpy as np +from .gpu_common_utils import GpuLibNotLinkedError + from jaxlib import xla_client from .hlo_helpers import ( @@ -72,6 +74,9 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a): num_bd = len(batch_dims) batch = math.prod(batch_dims) + if not gpu_blas: + raise GpuLibNotLinkedError() + if batch > 1 and m == n and m // batch <= 128: lwork, opaque = gpu_blas.build_getrf_batched_descriptor( np.dtype(dtype), batch, m) @@ -157,6 +162,9 @@ def _geqrf_batched_hlo(platform, gpu_blas, dtype, a): num_bd = len(batch_dims) batch = math.prod(batch_dims) + if not gpu_blas: + raise GpuLibNotLinkedError() + lwork, opaque = gpu_blas.build_geqrf_batched_descriptor( np.dtype(dtype), batch, m, n)