Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.

PiperOrigin-RevId: 490387796
This commit is contained in:
Qiao Zhang 2022-11-22 18:52:56 -08:00 committed by jax authors
parent f33d5514c9
commit 78963b6020
15 changed files with 1378 additions and 0 deletions

View File

@ -248,3 +248,10 @@ pytype_library(
":jax",
],
)
pytype_library(
name = "rnn",
srcs = ["experimental/rnn.py"],
visibility = ["//visibility:public"],
deps = [":jax"],
)

View File

@ -105,6 +105,9 @@ import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
# branch on the Jax github.
xla_extension_version = getattr(xla_client, '_version', 0)
if xla_extension_version > 108:
import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error
can_execute_with_token = (
xla_extension_version >= 89 and hasattr(
xla_client.LoadedExecutable # type: ignore

407
jax/experimental/rnn.py Normal file
View File

@ -0,0 +1,407 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""`jax.experimental.rnn`: GPU accelerated RNN
----------------------------------------------
This module provides experimental support to CUDNN-backed LSTM.
Currrently, the only supported RNN flavor is LSTM with double-bias. We use
notations and variable names similar to
https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM
and CUDNN_LSTM entry in
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnRNNMode_t.
Note that a bidirectional LSTM is treated as having twice the number of layers,
where a forward layer i is followed by a reverse layer i. Each direction has
its own associated weights. We use pseudo-layer to denote such layers
following CUDNN documentation
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnGetRNNWeightParams.
CUDNN takes an opaque 1D weight array that densely packs all the weight arrays
in a sparsely documented layout. Through trial-and-error and testing, we believe
the layout is the following. Assume 2-layer bi-LSTM with double-bias, so 4
pseudo-layers in total (forward-0, reverse-0, forward-1, reverse-1).
There are 4 kinds of weights: W_ih, W_hh, b_ih and b_hh, where
W_ih = (W_ii, W_if, W_ig, W_io) concatenated on leading axis,
W_hh = (W_hi, W_hf, W_hg, W_ho) concatenated on leading axis,
b_ih = (b_ii, b_if, b_ig, b_io) concatenated on leading axis,
b_hh = (b_hi, b_hf, b_hg, b_ho) concatenated on leading axis.
Say W_ih^0 denotates W_ih from pseudo-layer 0. The linear weights are packed
together from all pseudo-layers followed by bias weights from all pseudo-layers.
In particular, for each layer, W_ih is followed by W_hh and b_ih by b_hh.
(W_ih^0, W_hh^0, W_ih^1, W_hh^1, W_ih^2, W_hh^2, W_ih^3, W_hh^3,
b_ih^0, b_hh^0, b_ih^1, b_hh^1, b_ih^2, b_hh^2, b_ih^3, b_hh^3)
See `get_params_shapes_in_lstm`.
Example usage:
```
x = jax.random.normal(
k1, (batch_size, seq_len, input_size), dtype=jnp.float32)
h_0 = jax.random.normal(
k2, (num_directions * num_layers, batch_size, hidden_size),
dtype=jnp.float32)
c_0 = jax.random.normal(
k3, (num_directions * num_layers, batch_size, hidden_size),
dtype=jnp.float32)
seq_lengths = jnp.ones((batch_size,), dtype=jnp.int32) * seq_len
weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers,
bidirectional)
y, h_n, c_n = rnn.lstm(
x,
h_0,
c_0,
weights,
seq_lengths=seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=False,
bidirectional=bidirectional)
```
TODO:
- Add support for input and weight dtypes other than float32.
- Support ragged inputs.
- Support RNNs other than LSTM.
"""
from functools import partial
from typing import Any, Dict, List, Tuple
import jax
import numpy as np
from jax import core
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.custom_derivatives import custom_vjp
from jax._src.typing import Array, Shape
from jax._src.util import prod
import jax.numpy as jnp
try:
from jax._src.lib import gpu_rnn
except ImportError:
gpu_rnn = None
PRNGKeyArray = Any
sigmoid = jax.nn.sigmoid
tanh = jax.nn.tanh
def _W_ih_l(layer_i: int, input_size: int, hidden_size: int,
bidirectional: bool) -> Shape:
"""Shape of W_ii|W_if|W_ig|W_io.
Note that layer_i is an index of pseudo-layers.
"""
if layer_i == 0 or (layer_i == 1 and bidirectional):
return (4 * hidden_size, input_size)
else:
num_directions = 2 if bidirectional else 1
return (4 * hidden_size, num_directions * hidden_size)
def _W_hh_l(layer_i: int, input_size: int, hidden_size: int,
bidirectional: bool) -> Shape:
"""Shape of W_hi|W_hf|W_hg|W_ho."""
return (4 * hidden_size, hidden_size)
def _b_ih_l(layer_i: int, input_size: int, hidden_size: int,
bidirectional: bool) -> Shape:
"""Shape of b_ii|b_if|b_ig|b_io."""
return (4 * hidden_size,)
def _b_hh_l(layer_i: int, input_size: int, hidden_size: int,
bidirectional: bool) -> Shape:
"""Shape of b_hi|b_hf|b_hg|b_ho."""
return (4 * hidden_size,)
def _get_params_shapes_in_lstm(input_size: int, hidden_size: int,
num_layers: int,
bidirectional: bool) -> List[Shape]:
"""Get flat param shapes in LSTM. See module docstring for layout."""
layer_shapes = []
num_directions = 2 if bidirectional else 1
num_pseudo_layers = num_layers * num_directions
linear_weights = [_W_ih_l, _W_hh_l]
for i in range(num_pseudo_layers):
for w_kind in linear_weights:
layer_shape = w_kind(i, input_size, hidden_size, bidirectional)
layer_shapes.append(layer_shape)
bias_weights = [_b_ih_l, _b_hh_l]
for i in range(num_pseudo_layers):
for w_kind in bias_weights:
layer_shape = w_kind(i, input_size, hidden_size, bidirectional)
layer_shapes.append(layer_shape)
return layer_shapes
def get_num_params_in_lstm(input_size: int, hidden_size: int, num_layers: int,
bidirectional: bool) -> int:
"""Get param count in LSTM."""
layer_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers,
bidirectional)
param_count = sum([prod(shape) for shape in layer_shapes])
return param_count
def init_lstm_weight(rng: PRNGKeyArray, input_size: int, hidden_size: int,
num_layers: int, bidirectional: bool):
"""Random initialize LSTM weights from U(-k, k), k=sqrt(1/hidden_size)."""
param_count = get_num_params_in_lstm(input_size, hidden_size, num_layers,
bidirectional)
k = np.sqrt(1.0 / hidden_size)
return jax.random.uniform(
rng, shape=(param_count,), dtype=jnp.float32, minval=-k, maxval=k)
def unpack_lstm_weights(
weights: Array, input_size: int, hidden_size: int, num_layers: int,
bidirectional: bool
) -> Tuple[Dict[int, Array], Dict[int, Array], Dict[int, Array], Dict[int,
Array]]:
"""Unpack cudnn LSTM weights into individual weights.
CUDNN LSTM weight layout: (num_layers, num_directions, W_ih, W_hh, b_ih, b_hh)
Returns W_ih, W_hh, b_ih, b_hh. e.g. W_ih[2][1] is the concat weights of
4 weights (W_ii, W_if, W_ig, W_io), each of shape (hidden_size, input_size)
at 2nd layer for the reverse direction. See notations from
https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM.
"""
flat_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers,
bidirectional)
flat_shapes_offset = 0
w_offsets = 0
num_directions = 2 if bidirectional else 1
num_pseudo_layers = num_layers * num_directions
W_ih: Dict[int, Array] = {}
W_hh: Dict[int, Array] = {}
for l in range(num_pseudo_layers):
for w_kind in [W_ih, W_hh]:
shape = flat_shapes[flat_shapes_offset]
flat_shapes_offset += 1
num_elems = prod(shape)
w_kind[l] = weights[w_offsets:w_offsets + num_elems].reshape(shape)
w_offsets += num_elems
b_ih: Dict[int, Array] = {}
b_hh: Dict[int, Array] = {}
for l in range(num_pseudo_layers):
for w_kind in [b_ih, b_hh]:
shape = flat_shapes[flat_shapes_offset]
flat_shapes_offset += 1
num_elems = prod(shape)
w_kind[l] = weights[w_offsets:w_offsets + num_elems].reshape(shape)
w_offsets += num_elems
return W_ih, W_hh, b_ih, b_hh
@partial(custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
def lstm(x: Array, h_0: Array, c_0: Array, weights: Array, seq_lengths: Array,
input_size: int, hidden_size: int, num_layers: int, dropout: float,
bidirectional: bool) -> Tuple[Array, Array, Array]:
"""LSTM via CuDNN or HIPDNN (not-yet-supported).
Assume batch-first inputs.
Arguments:
x: (batch_size, max_seq_length, input_size)
h_0: (num_directions * num_layers, batch_size, hidden_size)
c_0: (num_directions * num_layers, batch_size, hidden_size)
weights: (num_params,) where num_params = get_num_params_in_lstm(...)
seq_lengths: (batch_size,)
Returns: (y, h_n, c_n, workspace, reserve_space).
y: (batch_size, max_seq_length, hidden_size * num_directions)
h_n: (num_directions * num_layers, batch_size, hidden_size)
c_n: (num_directions * num_layers, batch_size, hidden_size)
"""
(y, h_n, c_n), _ = lstm_fwd(
x,
h_0,
c_0,
weights,
seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional)
return y, h_n, c_n
@partial(jax.jit, static_argnums=(8, 9, 10, 11, 12))
def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array],
W_hh: Dict[int, Array], b_ih: Dict[int, Array],
b_hh: Dict[int, Array], seq_lengths: Array, input_size: int,
hidden_size: int, num_layers: int, dropout: float,
bidirectional: bool) -> Tuple[Array, Array, Array]:
"""Reference implementation of LSTM.
See https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#lstm
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnRNNMode_t
"""
if dropout != 0.0:
raise NotImplementedError(
'Dropout not supported in LSTM reference because we cannot determine CUDNN dropout mask.'
)
# TODO(zhangqiaorjc): Handle ragged seq_lengths.
# batch_size, max_seq_length = x.shape[0], x.shape[1]
# assert seq_lengths.shape == (batch_size,)
# for i in range(batch_size):
# if int(seq_lengths[i]) != max_seq_length:
# raise NotImplementedError('Does not yet support ragged sequences.')
def lstm_cell(carry, x, *, W_ih, W_hh, b_ih, b_hh):
h, c = carry
W_ii, W_if, W_ig, W_io = jnp.split(W_ih, 4, axis=0)
W_hi, W_hf, W_hg, W_ho = jnp.split(W_hh, 4, axis=0)
b_ii, b_if, b_ig, b_io = jnp.split(b_ih, 4, axis=0)
b_hi, b_hf, b_hg, b_ho = jnp.split(b_hh, 4, axis=0)
i = sigmoid(x @ W_ii.T + b_ii[None] + h @ W_hi.T + b_hi[None])
f = sigmoid(x @ W_if.T + b_if[None] + h @ W_hf.T + b_hf[None])
g = tanh(x @ W_ig.T + b_ig[None] + h @ W_hg.T + b_hg[None])
o = sigmoid(x @ W_io.T + b_io[None] + h @ W_ho.T + b_ho[None])
c = f * c + i * g
h = o * tanh(c)
return (h, c), h
seq_first_y = x.transpose(1, 0, 2)
if not bidirectional:
final_h = []
final_c = []
for l in range(num_layers):
cell = partial(
lstm_cell, W_ih=W_ih[l], W_hh=W_hh[l], b_ih=b_ih[l], b_hh=b_hh[l])
(h_t, c_t), seq_first_y = jax.lax.scan(cell, (h_0[l], c_0[l]),
seq_first_y)
final_h.append(h_t)
final_c.append(c_t)
h_n = jnp.stack(final_h)
c_n = jnp.stack(final_c)
return seq_first_y.transpose(1, 0, 2), h_n, c_n
# bidirectional
final_h = []
final_c = []
for l in range(num_layers * 2):
cell = partial(
lstm_cell, W_ih=W_ih[l], W_hh=W_hh[l], b_ih=b_ih[l], b_hh=b_hh[l])
if l % 2 == 0:
(h_t, c_t), seq_first_y_fwd = jax.lax.scan(cell, (h_0[l], c_0[l]),
seq_first_y)
else:
(h_t, c_t), seq_first_y_bwd = jax.lax.scan(
cell, (h_0[l], c_0[l]), seq_first_y, reverse=True)
# Inputs to next layer are concat'ed from fwd and bwd.
seq_first_y = jnp.concatenate([seq_first_y_fwd, seq_first_y_bwd], axis=-1) # pytype: disable=name-error
final_h.append(h_t)
final_c.append(c_t)
h_n = jnp.stack(final_h)
c_n = jnp.stack(final_c)
return seq_first_y.transpose(1, 0, 2), h_n, c_n
def lstm_fwd(x: Array, h_0: Array, c_0: Array, w: Array, seq_lengths: Array,
input_size: int, hidden_size: int, num_layers: int, dropout: float,
bidirectional: bool):
y, h_n, c_n, workspace, reserve_space = rnn_fwd_p.bind(
x,
h_0,
c_0,
w,
seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional)
return (y, h_n, c_n), (x, h_0, c_0, w, seq_lengths, y, workspace,
reserve_space)
def rnn_abstract_eval(x_aval, h_0_aval, c_0_aval, w_aval, seq_lengths_aval,
input_size: int, hidden_size: int, num_layers: int,
dropout: float, bidirectional: bool):
batch_size, max_seq_length = x_aval.shape[0], x_aval.shape[1]
num_directions = 2 if bidirectional else 1
output_shape = (batch_size, max_seq_length, num_directions * hidden_size)
output_aval = core.ShapedArray(output_shape, x_aval.dtype)
workspace_size, reserve_space_size = (
gpu_rnn.compute_rnn_workspace_reserve_space_sizes( # pytype: disable=attribute-error
input_size, hidden_size, num_layers, batch_size, max_seq_length,
dropout, bidirectional))
workspace_aval = core.ShapedArray((workspace_size,), jnp.float32)
reserve_space_aval = core.ShapedArray((reserve_space_size,), jnp.float32)
return output_aval, h_0_aval, c_0_aval, workspace_aval, reserve_space_aval
rnn_fwd_p = core.Primitive('rnn_fwd')
rnn_fwd_p.multiple_results = True
rnn_fwd_p.def_impl(partial(xla.apply_primitive, rnn_fwd_p))
rnn_fwd_p.def_abstract_eval(rnn_abstract_eval)
if gpu_rnn:
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda')
def lstm_bwd(input_size: int, hidden_size: int, num_layers: int, dropout: float,
bidirectional, residuals, gradients):
x, h_0, c_0, w, seq_lengths, y, workspace, reserve_space = residuals
dy, dh_n, dc_n = gradients
dx, dh_0, dc_0, dw = rnn_bwd_p.bind(
dy,
dh_n,
dc_n,
x,
h_0,
c_0,
w,
y,
workspace,
reserve_space,
seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=bidirectional)
return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths))
def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval,
w_aval, y_aval, workspace_aval, reserve_space_aval,
seq_lengths_aval, input_size: int, hidden_size: int,
num_layers: int, dropout: float, bidirectional: bool):
return x_aval, h0_aval, c0_aval, w_aval
rnn_bwd_p = core.Primitive('rnn_bwd')
rnn_bwd_p.multiple_results = True
rnn_bwd_p.def_impl(partial(xla.apply_primitive, rnn_bwd_p))
rnn_bwd_p.def_abstract_eval(rnn_bwd_abstract_eval)
if gpu_rnn:
mlir.register_lowering(
rnn_bwd_p, gpu_rnn.cudnn_rnn_bwd_lowering, platform='cuda')
lstm.defvjp(lstm_fwd, lstm_bwd)

View File

@ -31,6 +31,7 @@ py_library(
"ducc_fft.py",
"gpu_linalg.py",
"gpu_prng.py",
"gpu_rnn.py",
"gpu_solver.py",
"gpu_sparse.py",
"init.py",

View File

@ -31,6 +31,7 @@ cc_library(
],
defines = ["JAX_GPU_CUDA=1"],
deps = [
"//third_party/gpus/cudnn:cudnn_header",
"@local_config_cuda//cuda:cuda_headers",
],
)
@ -107,6 +108,43 @@ pybind_extension(
],
)
cc_library(
name = "cudnn_rnn_kernels",
srcs = ["//jaxlib/gpu:rnn_kernels.cc"],
hdrs = ["//jaxlib/gpu:rnn_kernels.h"],
deps = [
":cuda_gpu_kernel_helpers",
":cuda_vendor",
"//jaxlib:handle_pool",
"//jaxlib:kernel_helpers",
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@local_config_cuda//cuda:cuda_headers",
],
)
pybind_extension(
name = "_rnn",
srcs = ["//jaxlib/gpu:rnn.cc"],
copts = [
"-fexceptions",
"-fno-strict-aliasing",
],
features = ["-use_header_modules"],
module_name = "_rnn",
deps = [
":cuda_vendor",
":cudnn_rnn_kernels",
"//jaxlib:kernel_pybind11_helpers",
"//third_party/pybind11_abseil:status_casters",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings:str_format",
"@pybind11",
],
)
cc_library(
name = "cusolver_kernels",
srcs = ["//jaxlib/gpu:solver_kernels.cc"],
@ -321,6 +359,7 @@ py_library(
":_blas",
":_linalg",
":_prng",
":_rnn",
":_solver",
":_sparse",
],

View File

@ -33,6 +33,9 @@ exports_files(srcs = [
"prng_kernels.cc",
"prng_kernels.cu.cc",
"prng_kernels.h",
"rnn.cc",
"rnn_kernels.cc",
"rnn_kernels.h",
"solver.cc",
"solver_kernels.cc",
"solver_kernels.h",

View File

@ -89,6 +89,10 @@ std::string ErrorString(gpublasStatus_t status) {
}
}
std::string ErrorString(gpudnnStatus_t status) {
return cudnnGetErrorString(status);
}
#else
std::string ErrorString(hipsparseStatus_t status) {
@ -219,6 +223,13 @@ absl::Status AsStatus(gpublasStatus_t status, const char* file,
return absl::OkStatus();
}
absl::Status AsStatus(gpudnnStatus_t status, const char* file,
std::int64_t line, const char* expr) {
if (status != GPUDNN_STATUS_SUCCESS)
return absl::InternalError(ErrorString(status, file, line, expr));
return absl::OkStatus();
}
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(
gpuStream_t stream, void* buffer, void* dev_ptrs, int batch,
int batch_elem_size) {

View File

@ -49,6 +49,8 @@ absl::Status AsStatus(gpusparseStatus_t status, const char* file,
std::int64_t line, const char* expr);
absl::Status AsStatus(gpublasStatus_t status, const char* file,
std::int64_t line, const char* expr);
absl::Status AsStatus(gpudnnStatus_t status, const char* file,
std::int64_t line, const char* expr);
// Builds an array of pointers to each array in a batch, in device memory.
// Caution: the return value must be kept alive (e.g., via a stream

54
jaxlib/gpu/rnn.cc Normal file
View File

@ -0,0 +1,54 @@
/* Copyright 2022 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/gpu/rnn_kernels.h"
#include "jaxlib/gpu/vendor.h"
#include "jaxlib/kernel_pybind11_helpers.h"
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl.h"
#include "third_party/pybind11_abseil/status_casters.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
namespace {
namespace py = pybind11;
py::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers,
int batch_size, int max_seq_length, float dropout,
bool bidirectional, int workspace_size,
int reserve_space_size) {
return PackDescriptor(RnnDescriptor{
input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout,
bidirectional, workspace_size, reserve_space_size});
}
py::dict Registrations() {
py::dict dict;
dict[JAX_GPU_PREFIX "dnn_rnn"] = EncapsulateFunction(RNNForward);
dict[JAX_GPU_PREFIX "dnn_rnn_bwd"] = EncapsulateFunction(RNNBackward);
return dict;
}
PYBIND11_MODULE(_rnn, m) {
m.def("registrations", &Registrations);
m.def("build_rnn_descriptor", &BuildRnnDescriptor);
m.def("compute_rnn_workspace_reserve_space_sizes",
&RnnComputeWorkspaceReserveSpaceSizes);
}
} // namespace
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

549
jaxlib/gpu/rnn_kernels.cc Normal file
View File

@ -0,0 +1,549 @@
/* Copyright 2022 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "jaxlib/gpu/rnn_kernels.h"
#include <utility>
#include <vector>
#include "absl/status/status.h"
#include "jaxlib/gpu/gpu_kernel_helpers.h"
#include "jaxlib/handle_pool.h"
#include "jaxlib/kernel_helpers.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
using DnnHandlePool = HandlePool<gpudnnHandle_t, gpuStream_t>;
template <>
/*static*/ absl::StatusOr<DnnHandlePool::Handle> DnnHandlePool::Borrow(
gpuStream_t stream) {
DnnHandlePool* pool = Instance();
absl::MutexLock lock(&pool->mu_);
gpudnnHandle_t handle;
if (pool->handles_[stream].empty()) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnCreate(&handle)));
} else {
handle = pool->handles_[stream].back();
pool->handles_[stream].pop_back();
}
if (stream) {
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetStream(handle, stream)));
}
return Handle(pool, handle, stream);
}
namespace JAX_GPU_NAMESPACE {
// struct RnnDescriptor {
// int input_size;
// int hidden_size;
// int num_layers;
// int batch_size;
// int max_seq_length;
// float dropout;
// bool bidirectional;
// int workspace_size;
// int reserve_space_size;
// };
static absl::StatusOr<std::pair<int, int>>
DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
int num_layers, int batch_size,
int max_seq_length, float dropout,
bool bidirectional) {
auto h = DnnHandlePool::Borrow();
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cudnnRNNDescriptor_t rnn_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateRNNDescriptor(&rnn_desc)));
cudnnDropoutDescriptor_t dropout_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateDropoutDescriptor(&dropout_desc)));
size_t state_size;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDropoutGetStatesSize(handle.get(), &state_size)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetDropoutDescriptor(
dropout_desc, handle.get(), dropout, nullptr, state_size, 123)));
// TODO(zhangqiaorjc): Handle other kinds of RNN.
cudnnRNNMode_t cell_mode = CUDNN_LSTM;
cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
int num_directions = 1;
cudnnDirectionMode_t dir_mode = CUDNN_UNIDIRECTIONAL;
if (bidirectional) {
dir_mode = CUDNN_BIDIRECTIONAL;
num_directions = 2;
}
cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
cudnnDataType_t math_prec = CUDNN_DATA_FLOAT;
cudnnMathType_t math_type = CUDNN_DEFAULT_MATH;
int32_t proj_size = hidden_size;
uint32_t aux_flags = CUDNN_RNN_PADDED_IO_ENABLED;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDescriptor_v8(
rnn_desc, CUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode,
input_mode, data_type, math_prec, math_type, input_size, hidden_size,
proj_size, num_layers, dropout_desc, aux_flags)));
cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_TRAINING;
cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
float padding = 0.0f;
std::vector<int32_t> seq_length_vector(batch_size, max_seq_length);
int32_t* seq_length_array = &seq_length_vector[0];
cudnnRNNDataDescriptor_t input_data_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&input_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor(
input_data_desc, data_type, layout, max_seq_length, batch_size,
input_size, seq_length_array, &padding)));
size_t workSpaceSize;
size_t reserveSpaceSize;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnGetRNNTempSpaceSizes(
handle.get(), rnn_desc, fwdMode, input_data_desc, &workSpaceSize,
&reserveSpaceSize)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyDropoutDescriptor(dropout_desc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(input_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyRNNDescriptor(rnn_desc)));
// Round up to nearest multiples of 4 so we can return them as f32 arrays.
workSpaceSize += (workSpaceSize % 4);
reserveSpaceSize += (reserveSpaceSize % 4);
return std::make_pair(workSpaceSize, reserveSpaceSize);
}
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
int input_size, int hidden_size, int num_layers, int batch_size,
int max_seq_length, float dropout, bool bidirectional) {
return DoRnnComputeWorkspaceReserveSpaceSizes(
input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout,
bidirectional);
}
static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<RnnDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const RnnDescriptor& d = **s;
auto h = DnnHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cudnnRNNDescriptor_t rnn_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateRNNDescriptor(&rnn_desc)));
cudnnDropoutDescriptor_t dropout_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateDropoutDescriptor(&dropout_desc)));
size_t state_size;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDropoutGetStatesSize(handle.get(), &state_size)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetDropoutDescriptor(
dropout_desc, handle.get(), d.dropout, nullptr, state_size, 123)));
// TODO(zhangqiaorjc): Handle other kinds of RNN.
cudnnRNNMode_t cell_mode = CUDNN_LSTM;
cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
int num_directions = 1;
cudnnDirectionMode_t dir_mode = CUDNN_UNIDIRECTIONAL;
if (d.bidirectional) {
dir_mode = CUDNN_BIDIRECTIONAL;
num_directions = 2;
}
cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
cudnnDataType_t math_prec = CUDNN_DATA_FLOAT;
cudnnMathType_t math_type = CUDNN_DEFAULT_MATH;
int32_t proj_size = d.hidden_size;
uint32_t aux_flags = CUDNN_RNN_PADDED_IO_ENABLED;
// cudnnStatus_t cudnnSetRNNDescriptor_v8(
// cudnnRNNDescriptor_t rnn_desc,
// cudnnRNNAlgo_t algo,
// cudnnRNNMode_t cell_mode,
// cudnnRNNBiasMode_t bias_mode,
// cudnnDirectionMode_t dir_mode,
// cudnnRNNInputMode_t input_mode,
// cudnnDataType_t data_type,
// cudnnDataType_t math_prec,
// cudnnMathType_t math_type,
// int32_t inputSize,
// int32_t hiddenSize,
// int32_t projSize,
// int32_t numLayers,
// cudnnDropoutDescriptor_t dropout_desc,
// uint32_t auxFlags);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDescriptor_v8(
rnn_desc, CUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode,
input_mode, data_type, math_prec, math_type, d.input_size, d.hidden_size,
proj_size, d.num_layers, dropout_desc, aux_flags)));
cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_TRAINING;
// cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_INFERENCE;
// cudnnStatus_t cudnnSetRNNDataDescriptor(
// cudnnRNNDataDescriptor_t RNNDataDesc,
// cudnnDataType_t data_type,
// cudnnRNNDataLayout_t layout,
// int maxSeqLength,
// int batchSize,
// int vectorSize,
// const int seq_length_array[],
// void *paddingFill);
cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
float padding = 0.0f;
std::vector<int32_t> seq_length_vector(d.batch_size, d.max_seq_length);
int32_t* seq_length_array = &seq_length_vector[0];
cudnnRNNDataDescriptor_t input_data_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&input_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor(
input_data_desc, data_type, layout, d.max_seq_length, d.batch_size,
d.input_size, seq_length_array, &padding)));
cudnnRNNDataDescriptor_t output_data_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&output_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor(
output_data_desc, data_type, layout, d.max_seq_length, d.batch_size,
d.hidden_size * num_directions, seq_length_array, &padding)));
// cudnnStatus_t cudnnSetTensor4dDescriptor(
// cudnnTensorDescriptor_t tensorDesc,
// cudnnTensorFormat_t format,
// cudnnDataType_t data_type,
// int n,
// int c,
// int h,
// int w)
// Shape is (num_directions * num_layers, batch_size, hidden_size)
int dims[3];
dims[0] = num_directions * d.num_layers;
dims[1] = d.batch_size;
dims[2] = d.hidden_size;
int strides[3];
strides[0] = dims[1] * dims[2];
strides[1] = dims[2];
strides[2] = 1;
cudnnTensorDescriptor_t h_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&h_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudnnSetTensorNdDescriptor(h_desc, data_type, 3, dims, strides)));
cudnnTensorDescriptor_t c_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&c_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudnnSetTensorNdDescriptor(c_desc, data_type, 3, dims, strides)));
// cudnnStatus_t cudnnGetRNNWeightSpaceSize(
// cudnnHandle_t handle,
// cudnnRNNDescriptor_t rnn_desc,
// size_t *weight_space_size);
size_t weight_space_size;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, &weight_space_size)));
// cudnnStatus_t cudnnRNNForward(
// cudnnHandle_t handle,
// cudnnRNNDescriptor_t rnn_desc,
// cudnnForwardMode_t fwdMode,
// const int32_t devSeqLengths[],
// cudnnRNNDataDescriptor_t xDesc,
// const void *x,
// cudnnRNNDataDescriptor_t yDesc,
// void *y,
// cudnnTensorDescriptor_t h_desc,
// const void *hx,
// void *hy,
// cudnnTensorDescriptor_t cD`esc,
// const void *cx,
// void *cy,
// size_t weight_space_size,
// const void *weightSpace,
// size_t workSpaceSize,
// void *workSpace,
// size_t reserveSpaceSize,
// void *reserveSpace);
auto input_buf = buffers[0];
auto h_0_buf = buffers[1];
auto c_0_buf = buffers[2];
auto weights_buf = buffers[3];
auto seq_lengths_buf = buffers[4];
auto output_buf = buffers[5];
auto h_n_buf = buffers[6];
auto c_n_buf = buffers[7];
auto workspace_buf = buffers[8];
auto reserve_space_buf = buffers[9];
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnRNNForward(
handle.get(), rnn_desc, fwdMode, (const int32_t*)seq_lengths_buf,
input_data_desc, input_buf, output_data_desc, output_buf, h_desc, h_0_buf,
h_n_buf, c_desc, c_0_buf, c_n_buf, weight_space_size, weights_buf,
d.workspace_size, /*workSpace=*/workspace_buf, d.reserve_space_size,
/*reserveSpace=*/reserve_space_buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(h_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(c_desc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyDropoutDescriptor(dropout_desc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(input_data_desc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(output_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyRNNDescriptor(rnn_desc)));
return absl::OkStatus();
}
static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
const char* opaque, size_t opaque_len) {
auto s = UnpackDescriptor<RnnDescriptor>(opaque, opaque_len);
JAX_RETURN_IF_ERROR(s.status());
const RnnDescriptor& d = **s;
auto h = DnnHandlePool::Borrow(stream);
JAX_RETURN_IF_ERROR(h.status());
auto& handle = *h;
cudnnRNNDescriptor_t rnn_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateRNNDescriptor(&rnn_desc)));
cudnnDropoutDescriptor_t dropout_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateDropoutDescriptor(&dropout_desc)));
size_t state_size;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDropoutGetStatesSize(handle.get(), &state_size)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetDropoutDescriptor(
dropout_desc, handle.get(), d.dropout, nullptr, state_size, 123)));
// TODO(zhangqiaorjc): Handle other kinds of RNN.
cudnnRNNMode_t cell_mode = CUDNN_LSTM;
cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
int num_directions = 1;
cudnnDirectionMode_t dir_mode = CUDNN_UNIDIRECTIONAL;
if (d.bidirectional) {
dir_mode = CUDNN_BIDIRECTIONAL;
num_directions = 2;
}
cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
cudnnDataType_t math_prec = CUDNN_DATA_FLOAT;
cudnnMathType_t math_type = CUDNN_DEFAULT_MATH;
int32_t proj_size = d.hidden_size;
uint32_t aux_flags = CUDNN_RNN_PADDED_IO_ENABLED;
// cudnnStatus_t cudnnSetRNNDescriptor_v8(
// cudnnRNNDescriptor_t rnn_desc,
// cudnnRNNAlgo_t algo,
// cudnnRNNMode_t cell_mode,
// cudnnRNNBiasMode_t bias_mode,
// cudnnDirectionMode_t dir_mode,
// cudnnRNNInputMode_t input_mode,
// cudnnDataType_t data_type,
// cudnnDataType_t math_prec,
// cudnnMathType_t math_type,
// int32_t inputSize,
// int32_t hiddenSize,
// int32_t projSize,
// int32_t numLayers,
// cudnnDropoutDescriptor_t dropout_desc,
// uint32_t auxFlags);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDescriptor_v8(
rnn_desc, CUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode,
input_mode, data_type, math_prec, math_type, d.input_size, d.hidden_size,
proj_size, d.num_layers, dropout_desc, aux_flags)));
// cudnnStatus_t cudnnSetRNNDataDescriptor(
// cudnnRNNDataDescriptor_t RNNDataDesc,
// cudnnDataType_t data_type,
// cudnnRNNDataLayout_t layout,
// int maxSeqLength,
// int batchSize,
// int vectorSize,
// const int seq_length_array[],
// void *paddingFill);
cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
float padding = 0.0f;
std::vector<int32_t> seq_length_vector(d.batch_size, d.max_seq_length);
int32_t* seq_length_array = &seq_length_vector[0];
cudnnRNNDataDescriptor_t input_data_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&input_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor(
input_data_desc, data_type, layout, d.max_seq_length, d.batch_size,
d.input_size, seq_length_array, &padding)));
cudnnRNNDataDescriptor_t output_data_desc;
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&output_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor(
output_data_desc, data_type, layout, d.max_seq_length, d.batch_size,
d.hidden_size * num_directions, seq_length_array, &padding)));
// cudnnStatus_t cudnnSetTensor4dDescriptor(
// cudnnTensorDescriptor_t tensorDesc,
// cudnnTensorFormat_t format,
// cudnnDataType_t data_type,
// int n,
// int c,
// int h,
// int w)
// Shape is (num_directions * num_layers, batch_size, hidden_size)
int dims[3];
dims[0] = num_directions * d.num_layers;
dims[1] = d.batch_size;
dims[2] = d.hidden_size;
int strides[3];
strides[0] = dims[1] * dims[2];
strides[1] = dims[2];
strides[2] = 1;
cudnnTensorDescriptor_t h_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&h_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudnnSetTensorNdDescriptor(h_desc, data_type, 3, dims, strides)));
cudnnTensorDescriptor_t c_desc;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&c_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudnnSetTensorNdDescriptor(c_desc, data_type, 3, dims, strides)));
// cudnnStatus_t cudnnGetRNNWeightSpaceSize(
// cudnnHandle_t handle,
// cudnnRNNDescriptor_t rnn_desc,
// size_t *weight_space_size);
size_t weight_space_size;
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
cudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, &weight_space_size)));
auto dy_buf = buffers[0];
auto dh_n_buf = buffers[1];
auto dc_n_buf = buffers[2];
auto x_buf = buffers[3];
auto h_0_buf = buffers[4];
auto c_0_buf = buffers[5];
auto w_buf = buffers[6];
auto y_buf = buffers[7];
auto workspace_buf = buffers[8];
auto reserve_space_buf = buffers[9];
auto zeroed_dw_buf = buffers[10];
auto seq_lengths_buf = buffers[11];
auto dx_buf = buffers[12];
auto dh_0_buf = buffers[13];
auto dc_0_buf = buffers[14];
// auto dw_buf = buffers[15];
// cudnnStatus_t cudnnRNNBackwardData_v8(
// cudnnHandle_t handle,
// cudnnRNNDescriptor_t rnn_desc,
// const int32_t devSeqLengths[],
// cudnnRNNDataDescriptor_t yDesc,
// const void *y,
// const void *dy,
// cudnnRNNDataDescriptor_t xDesc,
// void *dx,
// cudnnTensorDescriptor_t h_desc,
// const void *hx,
// const void *dhy,
// void *dhx,
// cudnnTensorDescriptor_t c_desc,
// const void *cx,
// const void *dcy,
// void *dcx,
// size_t weight_space_size,
// const void *weightSpace,
// size_t workSpaceSize,
// void *workSpace,
// size_t reserveSpaceSize,
// void *reserveSpace);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnRNNBackwardData_v8(
handle.get(), rnn_desc, (const int32_t*)seq_lengths_buf, output_data_desc,
y_buf, dy_buf, input_data_desc, dx_buf, h_desc, h_0_buf, dh_n_buf,
dh_0_buf, c_desc, c_0_buf, dc_n_buf, dc_0_buf, weight_space_size, w_buf,
d.workspace_size, workspace_buf, d.reserve_space_size,
reserve_space_buf)));
// cudnnStatus_t cudnnRNNBackwardWeights_v8(
// cudnnHandle_t handle,
// cudnnRNNDescriptor_t rnn_desc,
// cudnnWgradMode_t addGrad,
// const int32_t devSeqLengths[],
// cudnnRNNDataDescriptor_t xDesc,
// const void *x,
// cudnnTensorDescriptor_t h_desc,
// const void *hx,
// cudnnRNNDataDescriptor_t yDesc,
// const void *y,
// size_t weight_space_size,
// void *dweightSpace,
// size_t workSpaceSize,
// void *workSpace,
// size_t reserveSpaceSize,
// void *reserveSpace);
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnRNNBackwardWeights_v8(
handle.get(), rnn_desc, CUDNN_WGRAD_MODE_ADD,
(const int32_t*)seq_lengths_buf, input_data_desc, x_buf, h_desc, h_0_buf,
output_data_desc, y_buf, weight_space_size, zeroed_dw_buf,
d.workspace_size, workspace_buf, d.reserve_space_size,
reserve_space_buf)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(h_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(c_desc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyDropoutDescriptor(dropout_desc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(input_data_desc)));
JAX_RETURN_IF_ERROR(
JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(output_data_desc)));
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyRNNDescriptor(rnn_desc)));
return absl::OkStatus();
}
void RNNForward(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = DnnRNNForward_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
void RNNBackward(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status) {
auto s = DnnRNNBackward_(stream, buffers, opaque, opaque_len);
if (!s.ok()) {
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
s.message().length());
}
}
} // namespace JAX_GPU_NAMESPACE
} // namespace jax

53
jaxlib/gpu/rnn_kernels.h Normal file
View File

@ -0,0 +1,53 @@
/* Copyright 2022 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef JAXLIB_GPU_RNN_KERNELS_H_
#define JAXLIB_GPU_RNN_KERNELS_H_
#include "absl/status/statusor.h"
#include "jaxlib/gpu/vendor.h"
#include "tensorflow/compiler/xla/service/custom_call_status.h"
namespace jax {
namespace JAX_GPU_NAMESPACE {
// Compile-time info passed as `opaque` to custom kernel.
struct RnnDescriptor {
int input_size;
int hidden_size;
int num_layers;
int batch_size;
int max_seq_length;
float dropout;
bool bidirectional;
int workspace_size;
int reserve_space_size;
};
// Return (workspace size, reserve space size).
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
int input_size, int hidden_size, int num_layers, int batch_size,
int max_seq_length, float dropout, bool bidirectional);
void RNNForward(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
void RNNBackward(gpuStream_t stream, void** buffers, const char* opaque,
size_t opaque_len, XlaCustomCallStatus* status);
} // namespace JAX_GPU_NAMESPACE
} // namespace jax
#endif // JAXLIB_GPU_RNN_KERNELS_H_

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
#include "third_party/gpus/cuda/include/cusolverDn.h"
#include "third_party/gpus/cuda/include/cusparse.h"
#include "third_party/gpus/cudnn/cudnn.h"
// Some sparse functionality is only available in CUSPARSE 11.3 or newer.
#define JAX_GPU_HAVE_SPARSE (CUSPARSE_VERSION >= 11300)
@ -64,6 +65,8 @@ typedef cublasHandle_t gpublasHandle_t;
typedef cudaDataType gpuDataType;
typedef cudaStream_t gpuStream_t;
typedef cudaError_t gpuError_t;
typedef cudnnHandle_t gpudnnHandle_t;
typedef cudnnStatus_t gpudnnStatus_t;
typedef cusolverDnHandle_t gpusolverDnHandle_t;
typedef cusolverStatus_t gpusolverStatus_t;
typedef cusolverEigMode_t gpusolverEigMode_t;
@ -97,6 +100,11 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define GPUBLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS
#define gpudnnCreate cudnnCreate
#define gpudnnSetStream cudnnSetStream
#define GPUDNN_STATUS_SUCCESS CUDNN_STATUS_SUCCESS
#define gpusolverDnCreate cusolverDnCreate
#define gpusolverDnSetStream cusolverDnSetStream
#define gpusolverDnCreateSyevjInfo cusolverDnCreateSyevjInfo
@ -233,6 +241,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define gpuGetLastError cudaGetLastError
#define gpuGetErrorString cudaGetErrorString
#define gpuMemcpy cudaMemcpy
#define gpuMemcpyAsync cudaMemcpyAsync
#define gpuMemcpyDeviceToDevice cudaMemcpyDeviceToDevice
#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice

125
jaxlib/gpu_rnn.py Normal file
View File

@ -0,0 +1,125 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as mhlo
import numpy as np
from jaxlib import xla_client
try:
from .cuda import _rnn as _rnn
for _name, _value in _rnn.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform='CUDA')
except ImportError:
_rnn = None
if _rnn:
compute_rnn_workspace_reserve_space_sizes = _rnn.compute_rnn_workspace_reserve_space_sizes
def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *,
input_size: int, hidden_size: int, num_layers: int,
dropout: bool, bidirectional: bool):
"""CuDnn RNN."""
out_dtype = ctx.avals_out[0].dtype
if out_dtype == np.float32:
out_type = ir.F32Type.get()
elif out_dtype == np.float64:
out_type = ir.F64Type.get()
elif out_dtype == np.complex64:
out_type = ir.ComplexType.get(ir.F32Type.get())
elif out_dtype == np.complex128:
out_type = ir.ComplexType.get(ir.F64Type.get())
else:
raise ValueError(f'Unknown output type {out_dtype}')
output_type = ir.RankedTensorType.get(ctx.avals_out[0].shape, out_type)
batch_size = ctx.avals_in[0].shape[0]
max_seq_length = ctx.avals_in[0].shape[1]
workspace_shape = ctx.avals_out[3].shape
reserve_space_shape = ctx.avals_out[4].shape
workspace_type = ir.RankedTensorType.get(workspace_shape, ir.F32Type.get())
reserve_space_type = ir.RankedTensorType.get(reserve_space_shape,
ir.F32Type.get())
opaque = _rnn.build_rnn_descriptor(input_size, hidden_size, num_layers,
batch_size, max_seq_length, dropout,
bidirectional, workspace_shape[0],
reserve_space_shape[0])
i32_type = ir.IntegerType.get_signless(32)
out = mhlo.CustomCallOp(
[
ir.TupleType.get_tuple([
output_type, h_0.type, c_0.type, workspace_type,
reserve_space_type
])
],
[input, h_0, c_0, weights, seq_lengths],
call_target_name=ir.StringAttr.get('cudnn_rnn'),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
)
return [
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(5)
]
def _mhlo_zeros_f32(shape):
return mhlo.ConstantOp(
ir.DenseElementsAttr.get(
np.zeros(shape, dtype=np.float32), type=ir.F32Type.get())).result
def cudnn_rnn_bwd_lowering(ctx, dy, dhn, dcn, x, h0, c0, w, y, workspace,
reserve_space, seq_lengths, *, input_size: int,
hidden_size: int, num_layers: int, dropout: bool,
bidirectional: bool):
"""CuDnn RNN Backward pass."""
batch_size = ctx.avals_in[3].shape[0]
max_seq_length = ctx.avals_in[3].shape[1]
workspace_shape = ctx.avals_in[8].shape
reserve_space_shape = ctx.avals_in[9].shape
opaque = _rnn.build_rnn_descriptor(input_size, hidden_size, num_layers,
batch_size, max_seq_length, dropout,
bidirectional, workspace_shape[0],
reserve_space_shape[0])
i32_type = ir.IntegerType.get_signless(32)
zeroed_dw = _mhlo_zeros_f32(ctx.avals_out[3].shape)
out = mhlo.CustomCallOp(
[ir.TupleType.get_tuple([x.type, h0.type, c0.type, w.type])], [
dy, dhn, dcn, x, h0, c0, w, y, workspace, reserve_space, zeroed_dw,
seq_lengths
],
call_target_name=ir.StringAttr.get('cudnn_rnn_bwd'),
has_side_effect=ir.BoolAttr.get(False),
backend_config=ir.StringAttr.get(opaque),
api_version=ir.IntegerAttr.get(i32_type, 2),
called_computations=ir.ArrayAttr.get([]),
output_operand_aliases=ir.ArrayAttr.get([
mhlo.OutputOperandAlias.get(
output_tuple_indices=[3],
operand_index=10,
operand_tuple_indices=[])
]))
return [
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
for i in range(4)
]

View File

@ -1001,6 +1001,18 @@ jax_test(
srcs = ["clear_backends_test.py"],
)
jax_test(
name = "experimental_rnn_test",
srcs = ["experimental_rnn_test.py"],
disable_backends = [
"tpu",
"cpu",
],
deps = [
"//jax:rnn",
],
)
exports_files(
[
"api_test.py",

View File

@ -0,0 +1,103 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import numpy as np
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
from jax.experimental import rnn
from jax._src.config import config
config.parse_flags_with_absl()
class RnnTest(jtu.JaxTestCase):
@jtu.sample_product(
batch_size=[1, 4],
seq_len=[1, 4],
input_size=[1, 2],
hidden_size=[1, 6],
num_layers=[1, 4],
bidirectional=[True, False],
)
def test_lstm(self, batch_size: int, seq_len: int, input_size: int,
hidden_size: int, num_layers: int, bidirectional: bool):
if xla_extension_version < 109:
self.skipTest('rnn module added at xla_extension_version 109')
batch_size = 6
seq_len = 7
input_size = 8
hidden_size = 12
num_layers = 5
num_directions = 2 if bidirectional else 1
seq_lengths = jnp.ones((batch_size,), dtype=jnp.int32) * seq_len
root_key = jax.random.PRNGKey(1)
k1, k2, k3, k4 = jax.random.split(root_key, 4)
x = jax.random.normal(
k1, (batch_size, seq_len, input_size), dtype=jnp.float32)
h_0 = jax.random.normal(
k2, (num_directions * num_layers, batch_size, hidden_size),
dtype=jnp.float32)
c_0 = jax.random.normal(
k3, (num_directions * num_layers, batch_size, hidden_size),
dtype=jnp.float32)
weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers,
bidirectional)
def f(x, h_0, c_0, weights):
return rnn.lstm(
x,
h_0,
c_0,
weights,
seq_lengths=seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=False,
bidirectional=bidirectional)
y, h_n, c_n = f(x, h_0, c_0, weights)
jtu.check_grads(f, (x, h_0, c_0, weights), modes=['rev'], order=1)
W_ih, W_hh, b_ih, b_hh = rnn.unpack_lstm_weights(weights, input_size,
hidden_size, num_layers,
bidirectional)
y_ref, h_n_ref, c_n_ref = rnn.lstm_ref(
x,
h_0,
c_0,
W_ih,
W_hh,
b_ih,
b_hh,
seq_lengths=seq_lengths,
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=False,
bidirectional=bidirectional)
np.testing.assert_allclose(y_ref, y, rtol=1e-05, atol=1e-5)
np.testing.assert_allclose(h_n_ref, h_n, rtol=1e-05, atol=1e-5)
np.testing.assert_allclose(c_n_ref, c_n, rtol=1e-05, atol=1e-5)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())