mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Improve the gpu lowering error message if users forget link the gpu library.
PiperOrigin-RevId: 564530960
This commit is contained in:
parent
6c3b42d33c
commit
997b35e1d9
@ -34,6 +34,7 @@ py_library_providing_imports_info(
|
||||
name = "jaxlib",
|
||||
srcs = [
|
||||
"ducc_fft.py",
|
||||
"gpu_common_utils.py",
|
||||
"gpu_linalg.py",
|
||||
"gpu_prng.py",
|
||||
"gpu_rnn.py",
|
||||
|
27
jaxlib/gpu_common_utils.py
Normal file
27
jaxlib/gpu_common_utils.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Those common utility functions for gpu."""
|
||||
|
||||
|
||||
class GpuLibNotLinkedError(Exception):
|
||||
"""Raised when the GPU library is not linked."""
|
||||
|
||||
error_msg = (
|
||||
'JAX was not built with GPU support. Please use a GPU-enabled JAX to use'
|
||||
' this function.'
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self.error_msg)
|
@ -19,6 +19,7 @@ import operator
|
||||
import jaxlib.mlir.ir as ir
|
||||
|
||||
from .hlo_helpers import custom_call
|
||||
from .gpu_common_utils import GpuLibNotLinkedError
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
@ -50,6 +51,9 @@ def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_s
|
||||
batch_size = _prod(dims[:-1])
|
||||
pivot_size = dims[-1]
|
||||
|
||||
if not gpu_linalg:
|
||||
raise GpuLibNotLinkedError()
|
||||
|
||||
opaque = gpu_linalg.lu_pivots_to_permutation_descriptor(
|
||||
batch_size, pivot_size, permutation_size)
|
||||
pivots_layout = tuple(range(len(dims) - 1, -1, -1))
|
||||
|
@ -24,6 +24,7 @@ import jaxlib.mlir.ir as ir
|
||||
from jaxlib import xla_client
|
||||
|
||||
from .hlo_helpers import custom_call
|
||||
from .gpu_common_utils import GpuLibNotLinkedError
|
||||
|
||||
try:
|
||||
from .cuda import _prng as _cuda_prng # pytype: disable=import-error
|
||||
@ -41,7 +42,6 @@ except ImportError:
|
||||
|
||||
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
||||
|
||||
|
||||
def _threefry2x32_lowering(prng, platform, keys, data,
|
||||
length: Optional[Union[int, ir.Value]] = None,
|
||||
output_shape: Optional[ir.Value] = None):
|
||||
@ -50,6 +50,8 @@ def _threefry2x32_lowering(prng, platform, keys, data,
|
||||
In presence of dynamic shapes, `length` is an `ir.Value` and `output_shape`
|
||||
is a 1D tensor describing the shape of the two outputs.
|
||||
"""
|
||||
if not prng:
|
||||
raise GpuLibNotLinkedError()
|
||||
assert len(keys) == 2, keys
|
||||
assert len(data) == 2, data
|
||||
assert (ir.RankedTensorType(keys[0].type).element_type ==
|
||||
|
@ -18,6 +18,7 @@ import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
import numpy as np
|
||||
|
||||
from jaxlib import xla_client
|
||||
from .gpu_common_utils import GpuLibNotLinkedError
|
||||
|
||||
try:
|
||||
from .cuda import _rnn # pytype: disable=import-error
|
||||
@ -58,8 +59,9 @@ def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *,
|
||||
reserve_space_shape = ctx.avals_out[3].shape
|
||||
reserve_space_type = ir.RankedTensorType.get(reserve_space_shape,
|
||||
ir.F32Type.get())
|
||||
if _rnn is None:
|
||||
raise RuntimeError("cuda couldn't be imported")
|
||||
if not _rnn:
|
||||
raise GpuLibNotLinkedError()
|
||||
|
||||
opaque = _rnn.build_rnn_descriptor(input_size, hidden_size, num_layers,
|
||||
batch_size, max_seq_length, dropout,
|
||||
bidirectional, workspace_shape[0],
|
||||
|
@ -21,6 +21,8 @@ import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .gpu_common_utils import GpuLibNotLinkedError
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
from .hlo_helpers import (
|
||||
@ -72,6 +74,9 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a):
|
||||
num_bd = len(batch_dims)
|
||||
batch = math.prod(batch_dims)
|
||||
|
||||
if not gpu_blas:
|
||||
raise GpuLibNotLinkedError()
|
||||
|
||||
if batch > 1 and m == n and m // batch <= 128:
|
||||
lwork, opaque = gpu_blas.build_getrf_batched_descriptor(
|
||||
np.dtype(dtype), batch, m)
|
||||
@ -157,6 +162,9 @@ def _geqrf_batched_hlo(platform, gpu_blas, dtype, a):
|
||||
num_bd = len(batch_dims)
|
||||
batch = math.prod(batch_dims)
|
||||
|
||||
if not gpu_blas:
|
||||
raise GpuLibNotLinkedError()
|
||||
|
||||
lwork, opaque = gpu_blas.build_geqrf_batched_descriptor(
|
||||
np.dtype(dtype), batch, m, n)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user