mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add CUDNN custom call for LSTM. Exposed as jax.experimental.rnn module.
PiperOrigin-RevId: 490387796
This commit is contained in:
parent
f33d5514c9
commit
78963b6020
@ -248,3 +248,10 @@ pytype_library(
|
||||
":jax",
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "rnn",
|
||||
srcs = ["experimental/rnn.py"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
@ -105,6 +105,9 @@ import jaxlib.gpu_linalg as gpu_linalg # pytype: disable=import-error
|
||||
# branch on the Jax github.
|
||||
xla_extension_version = getattr(xla_client, '_version', 0)
|
||||
|
||||
if xla_extension_version > 108:
|
||||
import jaxlib.gpu_rnn as gpu_rnn # pytype: disable=import-error
|
||||
|
||||
can_execute_with_token = (
|
||||
xla_extension_version >= 89 and hasattr(
|
||||
xla_client.LoadedExecutable # type: ignore
|
||||
|
407
jax/experimental/rnn.py
Normal file
407
jax/experimental/rnn.py
Normal file
@ -0,0 +1,407 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""`jax.experimental.rnn`: GPU accelerated RNN
|
||||
|
||||
----------------------------------------------
|
||||
|
||||
This module provides experimental support to CUDNN-backed LSTM.
|
||||
|
||||
Currrently, the only supported RNN flavor is LSTM with double-bias. We use
|
||||
notations and variable names similar to
|
||||
https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM
|
||||
|
||||
and CUDNN_LSTM entry in
|
||||
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnRNNMode_t.
|
||||
|
||||
Note that a bidirectional LSTM is treated as having twice the number of layers,
|
||||
where a forward layer i is followed by a reverse layer i. Each direction has
|
||||
its own associated weights. We use pseudo-layer to denote such layers
|
||||
following CUDNN documentation
|
||||
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnGetRNNWeightParams.
|
||||
|
||||
CUDNN takes an opaque 1D weight array that densely packs all the weight arrays
|
||||
in a sparsely documented layout. Through trial-and-error and testing, we believe
|
||||
the layout is the following. Assume 2-layer bi-LSTM with double-bias, so 4
|
||||
pseudo-layers in total (forward-0, reverse-0, forward-1, reverse-1).
|
||||
|
||||
There are 4 kinds of weights: W_ih, W_hh, b_ih and b_hh, where
|
||||
|
||||
W_ih = (W_ii, W_if, W_ig, W_io) concatenated on leading axis,
|
||||
W_hh = (W_hi, W_hf, W_hg, W_ho) concatenated on leading axis,
|
||||
b_ih = (b_ii, b_if, b_ig, b_io) concatenated on leading axis,
|
||||
b_hh = (b_hi, b_hf, b_hg, b_ho) concatenated on leading axis.
|
||||
|
||||
Say W_ih^0 denotates W_ih from pseudo-layer 0. The linear weights are packed
|
||||
together from all pseudo-layers followed by bias weights from all pseudo-layers.
|
||||
In particular, for each layer, W_ih is followed by W_hh and b_ih by b_hh.
|
||||
|
||||
(W_ih^0, W_hh^0, W_ih^1, W_hh^1, W_ih^2, W_hh^2, W_ih^3, W_hh^3,
|
||||
b_ih^0, b_hh^0, b_ih^1, b_hh^1, b_ih^2, b_hh^2, b_ih^3, b_hh^3)
|
||||
|
||||
See `get_params_shapes_in_lstm`.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
x = jax.random.normal(
|
||||
k1, (batch_size, seq_len, input_size), dtype=jnp.float32)
|
||||
h_0 = jax.random.normal(
|
||||
k2, (num_directions * num_layers, batch_size, hidden_size),
|
||||
dtype=jnp.float32)
|
||||
c_0 = jax.random.normal(
|
||||
k3, (num_directions * num_layers, batch_size, hidden_size),
|
||||
dtype=jnp.float32)
|
||||
seq_lengths = jnp.ones((batch_size,), dtype=jnp.int32) * seq_len
|
||||
weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers,
|
||||
bidirectional)
|
||||
y, h_n, c_n = rnn.lstm(
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
weights,
|
||||
seq_lengths=seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=False,
|
||||
bidirectional=bidirectional)
|
||||
```
|
||||
|
||||
TODO:
|
||||
- Add support for input and weight dtypes other than float32.
|
||||
- Support ragged inputs.
|
||||
- Support RNNs other than LSTM.
|
||||
"""
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import jax
|
||||
import numpy as np
|
||||
from jax import core
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src.custom_derivatives import custom_vjp
|
||||
from jax._src.typing import Array, Shape
|
||||
from jax._src.util import prod
|
||||
import jax.numpy as jnp
|
||||
try:
|
||||
from jax._src.lib import gpu_rnn
|
||||
except ImportError:
|
||||
gpu_rnn = None
|
||||
|
||||
PRNGKeyArray = Any
|
||||
sigmoid = jax.nn.sigmoid
|
||||
tanh = jax.nn.tanh
|
||||
|
||||
|
||||
def _W_ih_l(layer_i: int, input_size: int, hidden_size: int,
|
||||
bidirectional: bool) -> Shape:
|
||||
"""Shape of W_ii|W_if|W_ig|W_io.
|
||||
|
||||
Note that layer_i is an index of pseudo-layers.
|
||||
"""
|
||||
if layer_i == 0 or (layer_i == 1 and bidirectional):
|
||||
return (4 * hidden_size, input_size)
|
||||
else:
|
||||
num_directions = 2 if bidirectional else 1
|
||||
return (4 * hidden_size, num_directions * hidden_size)
|
||||
|
||||
|
||||
def _W_hh_l(layer_i: int, input_size: int, hidden_size: int,
|
||||
bidirectional: bool) -> Shape:
|
||||
"""Shape of W_hi|W_hf|W_hg|W_ho."""
|
||||
return (4 * hidden_size, hidden_size)
|
||||
|
||||
|
||||
def _b_ih_l(layer_i: int, input_size: int, hidden_size: int,
|
||||
bidirectional: bool) -> Shape:
|
||||
"""Shape of b_ii|b_if|b_ig|b_io."""
|
||||
return (4 * hidden_size,)
|
||||
|
||||
|
||||
def _b_hh_l(layer_i: int, input_size: int, hidden_size: int,
|
||||
bidirectional: bool) -> Shape:
|
||||
"""Shape of b_hi|b_hf|b_hg|b_ho."""
|
||||
return (4 * hidden_size,)
|
||||
|
||||
|
||||
def _get_params_shapes_in_lstm(input_size: int, hidden_size: int,
|
||||
num_layers: int,
|
||||
bidirectional: bool) -> List[Shape]:
|
||||
"""Get flat param shapes in LSTM. See module docstring for layout."""
|
||||
layer_shapes = []
|
||||
num_directions = 2 if bidirectional else 1
|
||||
num_pseudo_layers = num_layers * num_directions
|
||||
linear_weights = [_W_ih_l, _W_hh_l]
|
||||
for i in range(num_pseudo_layers):
|
||||
for w_kind in linear_weights:
|
||||
layer_shape = w_kind(i, input_size, hidden_size, bidirectional)
|
||||
layer_shapes.append(layer_shape)
|
||||
|
||||
bias_weights = [_b_ih_l, _b_hh_l]
|
||||
for i in range(num_pseudo_layers):
|
||||
for w_kind in bias_weights:
|
||||
layer_shape = w_kind(i, input_size, hidden_size, bidirectional)
|
||||
layer_shapes.append(layer_shape)
|
||||
return layer_shapes
|
||||
|
||||
|
||||
def get_num_params_in_lstm(input_size: int, hidden_size: int, num_layers: int,
|
||||
bidirectional: bool) -> int:
|
||||
"""Get param count in LSTM."""
|
||||
layer_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers,
|
||||
bidirectional)
|
||||
param_count = sum([prod(shape) for shape in layer_shapes])
|
||||
return param_count
|
||||
|
||||
|
||||
def init_lstm_weight(rng: PRNGKeyArray, input_size: int, hidden_size: int,
|
||||
num_layers: int, bidirectional: bool):
|
||||
"""Random initialize LSTM weights from U(-k, k), k=sqrt(1/hidden_size)."""
|
||||
param_count = get_num_params_in_lstm(input_size, hidden_size, num_layers,
|
||||
bidirectional)
|
||||
k = np.sqrt(1.0 / hidden_size)
|
||||
return jax.random.uniform(
|
||||
rng, shape=(param_count,), dtype=jnp.float32, minval=-k, maxval=k)
|
||||
|
||||
|
||||
def unpack_lstm_weights(
|
||||
weights: Array, input_size: int, hidden_size: int, num_layers: int,
|
||||
bidirectional: bool
|
||||
) -> Tuple[Dict[int, Array], Dict[int, Array], Dict[int, Array], Dict[int,
|
||||
Array]]:
|
||||
"""Unpack cudnn LSTM weights into individual weights.
|
||||
|
||||
CUDNN LSTM weight layout: (num_layers, num_directions, W_ih, W_hh, b_ih, b_hh)
|
||||
Returns W_ih, W_hh, b_ih, b_hh. e.g. W_ih[2][1] is the concat weights of
|
||||
4 weights (W_ii, W_if, W_ig, W_io), each of shape (hidden_size, input_size)
|
||||
at 2nd layer for the reverse direction. See notations from
|
||||
https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM.
|
||||
"""
|
||||
flat_shapes = _get_params_shapes_in_lstm(input_size, hidden_size, num_layers,
|
||||
bidirectional)
|
||||
flat_shapes_offset = 0
|
||||
w_offsets = 0
|
||||
num_directions = 2 if bidirectional else 1
|
||||
num_pseudo_layers = num_layers * num_directions
|
||||
|
||||
W_ih: Dict[int, Array] = {}
|
||||
W_hh: Dict[int, Array] = {}
|
||||
for l in range(num_pseudo_layers):
|
||||
for w_kind in [W_ih, W_hh]:
|
||||
shape = flat_shapes[flat_shapes_offset]
|
||||
flat_shapes_offset += 1
|
||||
num_elems = prod(shape)
|
||||
w_kind[l] = weights[w_offsets:w_offsets + num_elems].reshape(shape)
|
||||
w_offsets += num_elems
|
||||
|
||||
b_ih: Dict[int, Array] = {}
|
||||
b_hh: Dict[int, Array] = {}
|
||||
for l in range(num_pseudo_layers):
|
||||
for w_kind in [b_ih, b_hh]:
|
||||
shape = flat_shapes[flat_shapes_offset]
|
||||
flat_shapes_offset += 1
|
||||
num_elems = prod(shape)
|
||||
w_kind[l] = weights[w_offsets:w_offsets + num_elems].reshape(shape)
|
||||
w_offsets += num_elems
|
||||
return W_ih, W_hh, b_ih, b_hh
|
||||
|
||||
|
||||
@partial(custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
|
||||
def lstm(x: Array, h_0: Array, c_0: Array, weights: Array, seq_lengths: Array,
|
||||
input_size: int, hidden_size: int, num_layers: int, dropout: float,
|
||||
bidirectional: bool) -> Tuple[Array, Array, Array]:
|
||||
"""LSTM via CuDNN or HIPDNN (not-yet-supported).
|
||||
|
||||
Assume batch-first inputs.
|
||||
|
||||
Arguments:
|
||||
x: (batch_size, max_seq_length, input_size)
|
||||
h_0: (num_directions * num_layers, batch_size, hidden_size)
|
||||
c_0: (num_directions * num_layers, batch_size, hidden_size)
|
||||
weights: (num_params,) where num_params = get_num_params_in_lstm(...)
|
||||
seq_lengths: (batch_size,)
|
||||
Returns: (y, h_n, c_n, workspace, reserve_space).
|
||||
y: (batch_size, max_seq_length, hidden_size * num_directions)
|
||||
h_n: (num_directions * num_layers, batch_size, hidden_size)
|
||||
c_n: (num_directions * num_layers, batch_size, hidden_size)
|
||||
"""
|
||||
(y, h_n, c_n), _ = lstm_fwd(
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
weights,
|
||||
seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
bidirectional=bidirectional)
|
||||
return y, h_n, c_n
|
||||
|
||||
|
||||
@partial(jax.jit, static_argnums=(8, 9, 10, 11, 12))
|
||||
def lstm_ref(x: Array, h_0: Array, c_0: Array, W_ih: Dict[int, Array],
|
||||
W_hh: Dict[int, Array], b_ih: Dict[int, Array],
|
||||
b_hh: Dict[int, Array], seq_lengths: Array, input_size: int,
|
||||
hidden_size: int, num_layers: int, dropout: float,
|
||||
bidirectional: bool) -> Tuple[Array, Array, Array]:
|
||||
"""Reference implementation of LSTM.
|
||||
|
||||
See https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#lstm
|
||||
https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnRNNMode_t
|
||||
"""
|
||||
if dropout != 0.0:
|
||||
raise NotImplementedError(
|
||||
'Dropout not supported in LSTM reference because we cannot determine CUDNN dropout mask.'
|
||||
)
|
||||
|
||||
# TODO(zhangqiaorjc): Handle ragged seq_lengths.
|
||||
# batch_size, max_seq_length = x.shape[0], x.shape[1]
|
||||
# assert seq_lengths.shape == (batch_size,)
|
||||
# for i in range(batch_size):
|
||||
# if int(seq_lengths[i]) != max_seq_length:
|
||||
# raise NotImplementedError('Does not yet support ragged sequences.')
|
||||
|
||||
def lstm_cell(carry, x, *, W_ih, W_hh, b_ih, b_hh):
|
||||
h, c = carry
|
||||
W_ii, W_if, W_ig, W_io = jnp.split(W_ih, 4, axis=0)
|
||||
W_hi, W_hf, W_hg, W_ho = jnp.split(W_hh, 4, axis=0)
|
||||
b_ii, b_if, b_ig, b_io = jnp.split(b_ih, 4, axis=0)
|
||||
b_hi, b_hf, b_hg, b_ho = jnp.split(b_hh, 4, axis=0)
|
||||
i = sigmoid(x @ W_ii.T + b_ii[None] + h @ W_hi.T + b_hi[None])
|
||||
f = sigmoid(x @ W_if.T + b_if[None] + h @ W_hf.T + b_hf[None])
|
||||
g = tanh(x @ W_ig.T + b_ig[None] + h @ W_hg.T + b_hg[None])
|
||||
o = sigmoid(x @ W_io.T + b_io[None] + h @ W_ho.T + b_ho[None])
|
||||
c = f * c + i * g
|
||||
h = o * tanh(c)
|
||||
return (h, c), h
|
||||
|
||||
seq_first_y = x.transpose(1, 0, 2)
|
||||
if not bidirectional:
|
||||
final_h = []
|
||||
final_c = []
|
||||
for l in range(num_layers):
|
||||
cell = partial(
|
||||
lstm_cell, W_ih=W_ih[l], W_hh=W_hh[l], b_ih=b_ih[l], b_hh=b_hh[l])
|
||||
(h_t, c_t), seq_first_y = jax.lax.scan(cell, (h_0[l], c_0[l]),
|
||||
seq_first_y)
|
||||
final_h.append(h_t)
|
||||
final_c.append(c_t)
|
||||
h_n = jnp.stack(final_h)
|
||||
c_n = jnp.stack(final_c)
|
||||
return seq_first_y.transpose(1, 0, 2), h_n, c_n
|
||||
|
||||
# bidirectional
|
||||
final_h = []
|
||||
final_c = []
|
||||
for l in range(num_layers * 2):
|
||||
cell = partial(
|
||||
lstm_cell, W_ih=W_ih[l], W_hh=W_hh[l], b_ih=b_ih[l], b_hh=b_hh[l])
|
||||
if l % 2 == 0:
|
||||
(h_t, c_t), seq_first_y_fwd = jax.lax.scan(cell, (h_0[l], c_0[l]),
|
||||
seq_first_y)
|
||||
else:
|
||||
(h_t, c_t), seq_first_y_bwd = jax.lax.scan(
|
||||
cell, (h_0[l], c_0[l]), seq_first_y, reverse=True)
|
||||
# Inputs to next layer are concat'ed from fwd and bwd.
|
||||
seq_first_y = jnp.concatenate([seq_first_y_fwd, seq_first_y_bwd], axis=-1) # pytype: disable=name-error
|
||||
final_h.append(h_t)
|
||||
final_c.append(c_t)
|
||||
h_n = jnp.stack(final_h)
|
||||
c_n = jnp.stack(final_c)
|
||||
return seq_first_y.transpose(1, 0, 2), h_n, c_n
|
||||
|
||||
|
||||
def lstm_fwd(x: Array, h_0: Array, c_0: Array, w: Array, seq_lengths: Array,
|
||||
input_size: int, hidden_size: int, num_layers: int, dropout: float,
|
||||
bidirectional: bool):
|
||||
y, h_n, c_n, workspace, reserve_space = rnn_fwd_p.bind(
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
w,
|
||||
seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
bidirectional=bidirectional)
|
||||
return (y, h_n, c_n), (x, h_0, c_0, w, seq_lengths, y, workspace,
|
||||
reserve_space)
|
||||
|
||||
|
||||
def rnn_abstract_eval(x_aval, h_0_aval, c_0_aval, w_aval, seq_lengths_aval,
|
||||
input_size: int, hidden_size: int, num_layers: int,
|
||||
dropout: float, bidirectional: bool):
|
||||
batch_size, max_seq_length = x_aval.shape[0], x_aval.shape[1]
|
||||
num_directions = 2 if bidirectional else 1
|
||||
output_shape = (batch_size, max_seq_length, num_directions * hidden_size)
|
||||
output_aval = core.ShapedArray(output_shape, x_aval.dtype)
|
||||
workspace_size, reserve_space_size = (
|
||||
gpu_rnn.compute_rnn_workspace_reserve_space_sizes( # pytype: disable=attribute-error
|
||||
input_size, hidden_size, num_layers, batch_size, max_seq_length,
|
||||
dropout, bidirectional))
|
||||
workspace_aval = core.ShapedArray((workspace_size,), jnp.float32)
|
||||
reserve_space_aval = core.ShapedArray((reserve_space_size,), jnp.float32)
|
||||
return output_aval, h_0_aval, c_0_aval, workspace_aval, reserve_space_aval
|
||||
|
||||
|
||||
rnn_fwd_p = core.Primitive('rnn_fwd')
|
||||
rnn_fwd_p.multiple_results = True
|
||||
rnn_fwd_p.def_impl(partial(xla.apply_primitive, rnn_fwd_p))
|
||||
rnn_fwd_p.def_abstract_eval(rnn_abstract_eval)
|
||||
if gpu_rnn:
|
||||
mlir.register_lowering(rnn_fwd_p, gpu_rnn.cudnn_rnn_lowering, platform='cuda')
|
||||
|
||||
|
||||
def lstm_bwd(input_size: int, hidden_size: int, num_layers: int, dropout: float,
|
||||
bidirectional, residuals, gradients):
|
||||
x, h_0, c_0, w, seq_lengths, y, workspace, reserve_space = residuals
|
||||
dy, dh_n, dc_n = gradients
|
||||
dx, dh_0, dc_0, dw = rnn_bwd_p.bind(
|
||||
dy,
|
||||
dh_n,
|
||||
dc_n,
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
w,
|
||||
y,
|
||||
workspace,
|
||||
reserve_space,
|
||||
seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=dropout,
|
||||
bidirectional=bidirectional)
|
||||
return (dx, dh_0, dc_0, dw, jnp.zeros_like(seq_lengths))
|
||||
|
||||
|
||||
def rnn_bwd_abstract_eval(dy_aval, dhn_aval, dcn_aval, x_aval, h0_aval, c0_aval,
|
||||
w_aval, y_aval, workspace_aval, reserve_space_aval,
|
||||
seq_lengths_aval, input_size: int, hidden_size: int,
|
||||
num_layers: int, dropout: float, bidirectional: bool):
|
||||
return x_aval, h0_aval, c0_aval, w_aval
|
||||
|
||||
|
||||
rnn_bwd_p = core.Primitive('rnn_bwd')
|
||||
rnn_bwd_p.multiple_results = True
|
||||
rnn_bwd_p.def_impl(partial(xla.apply_primitive, rnn_bwd_p))
|
||||
rnn_bwd_p.def_abstract_eval(rnn_bwd_abstract_eval)
|
||||
if gpu_rnn:
|
||||
mlir.register_lowering(
|
||||
rnn_bwd_p, gpu_rnn.cudnn_rnn_bwd_lowering, platform='cuda')
|
||||
|
||||
lstm.defvjp(lstm_fwd, lstm_bwd)
|
@ -31,6 +31,7 @@ py_library(
|
||||
"ducc_fft.py",
|
||||
"gpu_linalg.py",
|
||||
"gpu_prng.py",
|
||||
"gpu_rnn.py",
|
||||
"gpu_solver.py",
|
||||
"gpu_sparse.py",
|
||||
"init.py",
|
||||
|
@ -31,6 +31,7 @@ cc_library(
|
||||
],
|
||||
defines = ["JAX_GPU_CUDA=1"],
|
||||
deps = [
|
||||
"//third_party/gpus/cudnn:cudnn_header",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
@ -107,6 +108,43 @@ pybind_extension(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cudnn_rnn_kernels",
|
||||
srcs = ["//jaxlib/gpu:rnn_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:rnn_kernels.h"],
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/service:custom_call_status",
|
||||
"@org_tensorflow//tensorflow/compiler/xla/stream_executor/cuda:cudart_stub",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_rnn",
|
||||
srcs = ["//jaxlib/gpu:rnn.cc"],
|
||||
copts = [
|
||||
"-fexceptions",
|
||||
"-fno-strict-aliasing",
|
||||
],
|
||||
features = ["-use_header_modules"],
|
||||
module_name = "_rnn",
|
||||
deps = [
|
||||
":cuda_vendor",
|
||||
":cudnn_rnn_kernels",
|
||||
"//jaxlib:kernel_pybind11_helpers",
|
||||
"//third_party/pybind11_abseil:status_casters",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cusolver_kernels",
|
||||
srcs = ["//jaxlib/gpu:solver_kernels.cc"],
|
||||
@ -321,6 +359,7 @@ py_library(
|
||||
":_blas",
|
||||
":_linalg",
|
||||
":_prng",
|
||||
":_rnn",
|
||||
":_solver",
|
||||
":_sparse",
|
||||
],
|
||||
|
@ -33,6 +33,9 @@ exports_files(srcs = [
|
||||
"prng_kernels.cc",
|
||||
"prng_kernels.cu.cc",
|
||||
"prng_kernels.h",
|
||||
"rnn.cc",
|
||||
"rnn_kernels.cc",
|
||||
"rnn_kernels.h",
|
||||
"solver.cc",
|
||||
"solver_kernels.cc",
|
||||
"solver_kernels.h",
|
||||
|
@ -89,6 +89,10 @@ std::string ErrorString(gpublasStatus_t status) {
|
||||
}
|
||||
}
|
||||
|
||||
std::string ErrorString(gpudnnStatus_t status) {
|
||||
return cudnnGetErrorString(status);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
std::string ErrorString(hipsparseStatus_t status) {
|
||||
@ -219,6 +223,13 @@ absl::Status AsStatus(gpublasStatus_t status, const char* file,
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status AsStatus(gpudnnStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr) {
|
||||
if (status != GPUDNN_STATUS_SUCCESS)
|
||||
return absl::InternalError(ErrorString(status, file, line, expr));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::StatusOr<std::unique_ptr<void*[]>> MakeBatchPointers(
|
||||
gpuStream_t stream, void* buffer, void* dev_ptrs, int batch,
|
||||
int batch_elem_size) {
|
||||
|
@ -49,6 +49,8 @@ absl::Status AsStatus(gpusparseStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr);
|
||||
absl::Status AsStatus(gpublasStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr);
|
||||
absl::Status AsStatus(gpudnnStatus_t status, const char* file,
|
||||
std::int64_t line, const char* expr);
|
||||
|
||||
// Builds an array of pointers to each array in a batch, in device memory.
|
||||
// Caution: the return value must be kept alive (e.g., via a stream
|
||||
|
54
jaxlib/gpu/rnn.cc
Normal file
54
jaxlib/gpu/rnn.cc
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2022 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/gpu/rnn_kernels.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "jaxlib/kernel_pybind11_helpers.h"
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl.h"
|
||||
#include "third_party/pybind11_abseil/status_casters.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
namespace {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
py::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers,
|
||||
int batch_size, int max_seq_length, float dropout,
|
||||
bool bidirectional, int workspace_size,
|
||||
int reserve_space_size) {
|
||||
return PackDescriptor(RnnDescriptor{
|
||||
input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout,
|
||||
bidirectional, workspace_size, reserve_space_size});
|
||||
}
|
||||
|
||||
py::dict Registrations() {
|
||||
py::dict dict;
|
||||
dict[JAX_GPU_PREFIX "dnn_rnn"] = EncapsulateFunction(RNNForward);
|
||||
dict[JAX_GPU_PREFIX "dnn_rnn_bwd"] = EncapsulateFunction(RNNBackward);
|
||||
return dict;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(_rnn, m) {
|
||||
m.def("registrations", &Registrations);
|
||||
m.def("build_rnn_descriptor", &BuildRnnDescriptor);
|
||||
m.def("compute_rnn_workspace_reserve_space_sizes",
|
||||
&RnnComputeWorkspaceReserveSpaceSizes);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
549
jaxlib/gpu/rnn_kernels.cc
Normal file
549
jaxlib/gpu/rnn_kernels.cc
Normal file
@ -0,0 +1,549 @@
|
||||
/* Copyright 2022 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "jaxlib/gpu/rnn_kernels.h"
|
||||
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
|
||||
using DnnHandlePool = HandlePool<gpudnnHandle_t, gpuStream_t>;
|
||||
|
||||
template <>
|
||||
/*static*/ absl::StatusOr<DnnHandlePool::Handle> DnnHandlePool::Borrow(
|
||||
gpuStream_t stream) {
|
||||
DnnHandlePool* pool = Instance();
|
||||
absl::MutexLock lock(&pool->mu_);
|
||||
gpudnnHandle_t handle;
|
||||
if (pool->handles_[stream].empty()) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnCreate(&handle)));
|
||||
} else {
|
||||
handle = pool->handles_[stream].back();
|
||||
pool->handles_[stream].pop_back();
|
||||
}
|
||||
if (stream) {
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetStream(handle, stream)));
|
||||
}
|
||||
return Handle(pool, handle, stream);
|
||||
}
|
||||
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
// struct RnnDescriptor {
|
||||
// int input_size;
|
||||
// int hidden_size;
|
||||
// int num_layers;
|
||||
// int batch_size;
|
||||
// int max_seq_length;
|
||||
// float dropout;
|
||||
// bool bidirectional;
|
||||
// int workspace_size;
|
||||
// int reserve_space_size;
|
||||
// };
|
||||
|
||||
static absl::StatusOr<std::pair<int, int>>
|
||||
DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
|
||||
int num_layers, int batch_size,
|
||||
int max_seq_length, float dropout,
|
||||
bool bidirectional) {
|
||||
auto h = DnnHandlePool::Borrow();
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
cudnnRNNDescriptor_t rnn_desc;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateRNNDescriptor(&rnn_desc)));
|
||||
|
||||
cudnnDropoutDescriptor_t dropout_desc;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnCreateDropoutDescriptor(&dropout_desc)));
|
||||
size_t state_size;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnDropoutGetStatesSize(handle.get(), &state_size)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetDropoutDescriptor(
|
||||
dropout_desc, handle.get(), dropout, nullptr, state_size, 123)));
|
||||
|
||||
// TODO(zhangqiaorjc): Handle other kinds of RNN.
|
||||
cudnnRNNMode_t cell_mode = CUDNN_LSTM;
|
||||
cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
|
||||
int num_directions = 1;
|
||||
cudnnDirectionMode_t dir_mode = CUDNN_UNIDIRECTIONAL;
|
||||
if (bidirectional) {
|
||||
dir_mode = CUDNN_BIDIRECTIONAL;
|
||||
num_directions = 2;
|
||||
}
|
||||
cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
|
||||
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
|
||||
cudnnDataType_t math_prec = CUDNN_DATA_FLOAT;
|
||||
cudnnMathType_t math_type = CUDNN_DEFAULT_MATH;
|
||||
int32_t proj_size = hidden_size;
|
||||
uint32_t aux_flags = CUDNN_RNN_PADDED_IO_ENABLED;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDescriptor_v8(
|
||||
rnn_desc, CUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode,
|
||||
input_mode, data_type, math_prec, math_type, input_size, hidden_size,
|
||||
proj_size, num_layers, dropout_desc, aux_flags)));
|
||||
|
||||
cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_TRAINING;
|
||||
cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
|
||||
float padding = 0.0f;
|
||||
|
||||
std::vector<int32_t> seq_length_vector(batch_size, max_seq_length);
|
||||
int32_t* seq_length_array = &seq_length_vector[0];
|
||||
|
||||
cudnnRNNDataDescriptor_t input_data_desc;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&input_data_desc)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor(
|
||||
input_data_desc, data_type, layout, max_seq_length, batch_size,
|
||||
input_size, seq_length_array, &padding)));
|
||||
|
||||
size_t workSpaceSize;
|
||||
size_t reserveSpaceSize;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnGetRNNTempSpaceSizes(
|
||||
handle.get(), rnn_desc, fwdMode, input_data_desc, &workSpaceSize,
|
||||
&reserveSpaceSize)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnDestroyDropoutDescriptor(dropout_desc)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(input_data_desc)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyRNNDescriptor(rnn_desc)));
|
||||
|
||||
// Round up to nearest multiples of 4 so we can return them as f32 arrays.
|
||||
workSpaceSize += (workSpaceSize % 4);
|
||||
reserveSpaceSize += (reserveSpaceSize % 4);
|
||||
return std::make_pair(workSpaceSize, reserveSpaceSize);
|
||||
}
|
||||
|
||||
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
|
||||
int input_size, int hidden_size, int num_layers, int batch_size,
|
||||
int max_seq_length, float dropout, bool bidirectional) {
|
||||
return DoRnnComputeWorkspaceReserveSpaceSizes(
|
||||
input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout,
|
||||
bidirectional);
|
||||
}
|
||||
|
||||
static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<RnnDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const RnnDescriptor& d = **s;
|
||||
auto h = DnnHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
cudnnRNNDescriptor_t rnn_desc;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateRNNDescriptor(&rnn_desc)));
|
||||
|
||||
cudnnDropoutDescriptor_t dropout_desc;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnCreateDropoutDescriptor(&dropout_desc)));
|
||||
size_t state_size;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnDropoutGetStatesSize(handle.get(), &state_size)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetDropoutDescriptor(
|
||||
dropout_desc, handle.get(), d.dropout, nullptr, state_size, 123)));
|
||||
|
||||
// TODO(zhangqiaorjc): Handle other kinds of RNN.
|
||||
cudnnRNNMode_t cell_mode = CUDNN_LSTM;
|
||||
cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
|
||||
int num_directions = 1;
|
||||
cudnnDirectionMode_t dir_mode = CUDNN_UNIDIRECTIONAL;
|
||||
if (d.bidirectional) {
|
||||
dir_mode = CUDNN_BIDIRECTIONAL;
|
||||
num_directions = 2;
|
||||
}
|
||||
cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
|
||||
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
|
||||
cudnnDataType_t math_prec = CUDNN_DATA_FLOAT;
|
||||
cudnnMathType_t math_type = CUDNN_DEFAULT_MATH;
|
||||
int32_t proj_size = d.hidden_size;
|
||||
uint32_t aux_flags = CUDNN_RNN_PADDED_IO_ENABLED;
|
||||
|
||||
// cudnnStatus_t cudnnSetRNNDescriptor_v8(
|
||||
// cudnnRNNDescriptor_t rnn_desc,
|
||||
// cudnnRNNAlgo_t algo,
|
||||
// cudnnRNNMode_t cell_mode,
|
||||
// cudnnRNNBiasMode_t bias_mode,
|
||||
// cudnnDirectionMode_t dir_mode,
|
||||
// cudnnRNNInputMode_t input_mode,
|
||||
// cudnnDataType_t data_type,
|
||||
// cudnnDataType_t math_prec,
|
||||
// cudnnMathType_t math_type,
|
||||
// int32_t inputSize,
|
||||
// int32_t hiddenSize,
|
||||
// int32_t projSize,
|
||||
// int32_t numLayers,
|
||||
// cudnnDropoutDescriptor_t dropout_desc,
|
||||
// uint32_t auxFlags);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDescriptor_v8(
|
||||
rnn_desc, CUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode,
|
||||
input_mode, data_type, math_prec, math_type, d.input_size, d.hidden_size,
|
||||
proj_size, d.num_layers, dropout_desc, aux_flags)));
|
||||
|
||||
cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_TRAINING;
|
||||
// cudnnForwardMode_t fwdMode = CUDNN_FWD_MODE_INFERENCE;
|
||||
|
||||
// cudnnStatus_t cudnnSetRNNDataDescriptor(
|
||||
// cudnnRNNDataDescriptor_t RNNDataDesc,
|
||||
// cudnnDataType_t data_type,
|
||||
// cudnnRNNDataLayout_t layout,
|
||||
// int maxSeqLength,
|
||||
// int batchSize,
|
||||
// int vectorSize,
|
||||
// const int seq_length_array[],
|
||||
// void *paddingFill);
|
||||
|
||||
cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
|
||||
float padding = 0.0f;
|
||||
|
||||
std::vector<int32_t> seq_length_vector(d.batch_size, d.max_seq_length);
|
||||
int32_t* seq_length_array = &seq_length_vector[0];
|
||||
|
||||
cudnnRNNDataDescriptor_t input_data_desc;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&input_data_desc)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor(
|
||||
input_data_desc, data_type, layout, d.max_seq_length, d.batch_size,
|
||||
d.input_size, seq_length_array, &padding)));
|
||||
|
||||
cudnnRNNDataDescriptor_t output_data_desc;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&output_data_desc)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor(
|
||||
output_data_desc, data_type, layout, d.max_seq_length, d.batch_size,
|
||||
d.hidden_size * num_directions, seq_length_array, &padding)));
|
||||
|
||||
// cudnnStatus_t cudnnSetTensor4dDescriptor(
|
||||
// cudnnTensorDescriptor_t tensorDesc,
|
||||
// cudnnTensorFormat_t format,
|
||||
// cudnnDataType_t data_type,
|
||||
// int n,
|
||||
// int c,
|
||||
// int h,
|
||||
// int w)
|
||||
|
||||
// Shape is (num_directions * num_layers, batch_size, hidden_size)
|
||||
int dims[3];
|
||||
dims[0] = num_directions * d.num_layers;
|
||||
dims[1] = d.batch_size;
|
||||
dims[2] = d.hidden_size;
|
||||
int strides[3];
|
||||
strides[0] = dims[1] * dims[2];
|
||||
strides[1] = dims[2];
|
||||
strides[2] = 1;
|
||||
cudnnTensorDescriptor_t h_desc;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&h_desc)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cudnnSetTensorNdDescriptor(h_desc, data_type, 3, dims, strides)));
|
||||
|
||||
cudnnTensorDescriptor_t c_desc;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&c_desc)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cudnnSetTensorNdDescriptor(c_desc, data_type, 3, dims, strides)));
|
||||
|
||||
// cudnnStatus_t cudnnGetRNNWeightSpaceSize(
|
||||
// cudnnHandle_t handle,
|
||||
// cudnnRNNDescriptor_t rnn_desc,
|
||||
// size_t *weight_space_size);
|
||||
size_t weight_space_size;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, &weight_space_size)));
|
||||
|
||||
// cudnnStatus_t cudnnRNNForward(
|
||||
// cudnnHandle_t handle,
|
||||
// cudnnRNNDescriptor_t rnn_desc,
|
||||
// cudnnForwardMode_t fwdMode,
|
||||
// const int32_t devSeqLengths[],
|
||||
// cudnnRNNDataDescriptor_t xDesc,
|
||||
// const void *x,
|
||||
// cudnnRNNDataDescriptor_t yDesc,
|
||||
// void *y,
|
||||
// cudnnTensorDescriptor_t h_desc,
|
||||
// const void *hx,
|
||||
// void *hy,
|
||||
// cudnnTensorDescriptor_t cD`esc,
|
||||
// const void *cx,
|
||||
// void *cy,
|
||||
// size_t weight_space_size,
|
||||
// const void *weightSpace,
|
||||
// size_t workSpaceSize,
|
||||
// void *workSpace,
|
||||
// size_t reserveSpaceSize,
|
||||
// void *reserveSpace);
|
||||
auto input_buf = buffers[0];
|
||||
auto h_0_buf = buffers[1];
|
||||
auto c_0_buf = buffers[2];
|
||||
auto weights_buf = buffers[3];
|
||||
auto seq_lengths_buf = buffers[4];
|
||||
auto output_buf = buffers[5];
|
||||
auto h_n_buf = buffers[6];
|
||||
auto c_n_buf = buffers[7];
|
||||
auto workspace_buf = buffers[8];
|
||||
auto reserve_space_buf = buffers[9];
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnRNNForward(
|
||||
handle.get(), rnn_desc, fwdMode, (const int32_t*)seq_lengths_buf,
|
||||
input_data_desc, input_buf, output_data_desc, output_buf, h_desc, h_0_buf,
|
||||
h_n_buf, c_desc, c_0_buf, c_n_buf, weight_space_size, weights_buf,
|
||||
d.workspace_size, /*workSpace=*/workspace_buf, d.reserve_space_size,
|
||||
/*reserveSpace=*/reserve_space_buf)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(h_desc)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(c_desc)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnDestroyDropoutDescriptor(dropout_desc)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(input_data_desc)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(output_data_desc)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyRNNDescriptor(rnn_desc)));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
|
||||
const char* opaque, size_t opaque_len) {
|
||||
auto s = UnpackDescriptor<RnnDescriptor>(opaque, opaque_len);
|
||||
JAX_RETURN_IF_ERROR(s.status());
|
||||
const RnnDescriptor& d = **s;
|
||||
auto h = DnnHandlePool::Borrow(stream);
|
||||
JAX_RETURN_IF_ERROR(h.status());
|
||||
auto& handle = *h;
|
||||
|
||||
cudnnRNNDescriptor_t rnn_desc;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateRNNDescriptor(&rnn_desc)));
|
||||
|
||||
cudnnDropoutDescriptor_t dropout_desc;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnCreateDropoutDescriptor(&dropout_desc)));
|
||||
size_t state_size;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnDropoutGetStatesSize(handle.get(), &state_size)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetDropoutDescriptor(
|
||||
dropout_desc, handle.get(), d.dropout, nullptr, state_size, 123)));
|
||||
|
||||
// TODO(zhangqiaorjc): Handle other kinds of RNN.
|
||||
cudnnRNNMode_t cell_mode = CUDNN_LSTM;
|
||||
cudnnRNNBiasMode_t bias_mode = CUDNN_RNN_DOUBLE_BIAS;
|
||||
int num_directions = 1;
|
||||
cudnnDirectionMode_t dir_mode = CUDNN_UNIDIRECTIONAL;
|
||||
if (d.bidirectional) {
|
||||
dir_mode = CUDNN_BIDIRECTIONAL;
|
||||
num_directions = 2;
|
||||
}
|
||||
cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT;
|
||||
cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
|
||||
cudnnDataType_t math_prec = CUDNN_DATA_FLOAT;
|
||||
cudnnMathType_t math_type = CUDNN_DEFAULT_MATH;
|
||||
int32_t proj_size = d.hidden_size;
|
||||
uint32_t aux_flags = CUDNN_RNN_PADDED_IO_ENABLED;
|
||||
|
||||
// cudnnStatus_t cudnnSetRNNDescriptor_v8(
|
||||
// cudnnRNNDescriptor_t rnn_desc,
|
||||
// cudnnRNNAlgo_t algo,
|
||||
// cudnnRNNMode_t cell_mode,
|
||||
// cudnnRNNBiasMode_t bias_mode,
|
||||
// cudnnDirectionMode_t dir_mode,
|
||||
// cudnnRNNInputMode_t input_mode,
|
||||
// cudnnDataType_t data_type,
|
||||
// cudnnDataType_t math_prec,
|
||||
// cudnnMathType_t math_type,
|
||||
// int32_t inputSize,
|
||||
// int32_t hiddenSize,
|
||||
// int32_t projSize,
|
||||
// int32_t numLayers,
|
||||
// cudnnDropoutDescriptor_t dropout_desc,
|
||||
// uint32_t auxFlags);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDescriptor_v8(
|
||||
rnn_desc, CUDNN_RNN_ALGO_STANDARD, cell_mode, bias_mode, dir_mode,
|
||||
input_mode, data_type, math_prec, math_type, d.input_size, d.hidden_size,
|
||||
proj_size, d.num_layers, dropout_desc, aux_flags)));
|
||||
|
||||
// cudnnStatus_t cudnnSetRNNDataDescriptor(
|
||||
// cudnnRNNDataDescriptor_t RNNDataDesc,
|
||||
// cudnnDataType_t data_type,
|
||||
// cudnnRNNDataLayout_t layout,
|
||||
// int maxSeqLength,
|
||||
// int batchSize,
|
||||
// int vectorSize,
|
||||
// const int seq_length_array[],
|
||||
// void *paddingFill);
|
||||
|
||||
cudnnRNNDataLayout_t layout = CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED;
|
||||
float padding = 0.0f;
|
||||
|
||||
std::vector<int32_t> seq_length_vector(d.batch_size, d.max_seq_length);
|
||||
int32_t* seq_length_array = &seq_length_vector[0];
|
||||
|
||||
cudnnRNNDataDescriptor_t input_data_desc;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&input_data_desc)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor(
|
||||
input_data_desc, data_type, layout, d.max_seq_length, d.batch_size,
|
||||
d.input_size, seq_length_array, &padding)));
|
||||
|
||||
cudnnRNNDataDescriptor_t output_data_desc;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnCreateRNNDataDescriptor(&output_data_desc)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnSetRNNDataDescriptor(
|
||||
output_data_desc, data_type, layout, d.max_seq_length, d.batch_size,
|
||||
d.hidden_size * num_directions, seq_length_array, &padding)));
|
||||
|
||||
// cudnnStatus_t cudnnSetTensor4dDescriptor(
|
||||
// cudnnTensorDescriptor_t tensorDesc,
|
||||
// cudnnTensorFormat_t format,
|
||||
// cudnnDataType_t data_type,
|
||||
// int n,
|
||||
// int c,
|
||||
// int h,
|
||||
// int w)
|
||||
|
||||
// Shape is (num_directions * num_layers, batch_size, hidden_size)
|
||||
int dims[3];
|
||||
dims[0] = num_directions * d.num_layers;
|
||||
dims[1] = d.batch_size;
|
||||
dims[2] = d.hidden_size;
|
||||
int strides[3];
|
||||
strides[0] = dims[1] * dims[2];
|
||||
strides[1] = dims[2];
|
||||
strides[2] = 1;
|
||||
cudnnTensorDescriptor_t h_desc;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&h_desc)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cudnnSetTensorNdDescriptor(h_desc, data_type, 3, dims, strides)));
|
||||
|
||||
cudnnTensorDescriptor_t c_desc;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnCreateTensorDescriptor(&c_desc)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cudnnSetTensorNdDescriptor(c_desc, data_type, 3, dims, strides)));
|
||||
|
||||
// cudnnStatus_t cudnnGetRNNWeightSpaceSize(
|
||||
// cudnnHandle_t handle,
|
||||
// cudnnRNNDescriptor_t rnn_desc,
|
||||
// size_t *weight_space_size);
|
||||
size_t weight_space_size;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
cudnnGetRNNWeightSpaceSize(handle.get(), rnn_desc, &weight_space_size)));
|
||||
|
||||
auto dy_buf = buffers[0];
|
||||
auto dh_n_buf = buffers[1];
|
||||
auto dc_n_buf = buffers[2];
|
||||
auto x_buf = buffers[3];
|
||||
auto h_0_buf = buffers[4];
|
||||
auto c_0_buf = buffers[5];
|
||||
auto w_buf = buffers[6];
|
||||
auto y_buf = buffers[7];
|
||||
auto workspace_buf = buffers[8];
|
||||
auto reserve_space_buf = buffers[9];
|
||||
auto zeroed_dw_buf = buffers[10];
|
||||
auto seq_lengths_buf = buffers[11];
|
||||
auto dx_buf = buffers[12];
|
||||
auto dh_0_buf = buffers[13];
|
||||
auto dc_0_buf = buffers[14];
|
||||
// auto dw_buf = buffers[15];
|
||||
|
||||
// cudnnStatus_t cudnnRNNBackwardData_v8(
|
||||
// cudnnHandle_t handle,
|
||||
// cudnnRNNDescriptor_t rnn_desc,
|
||||
// const int32_t devSeqLengths[],
|
||||
// cudnnRNNDataDescriptor_t yDesc,
|
||||
// const void *y,
|
||||
// const void *dy,
|
||||
// cudnnRNNDataDescriptor_t xDesc,
|
||||
// void *dx,
|
||||
// cudnnTensorDescriptor_t h_desc,
|
||||
// const void *hx,
|
||||
// const void *dhy,
|
||||
// void *dhx,
|
||||
// cudnnTensorDescriptor_t c_desc,
|
||||
// const void *cx,
|
||||
// const void *dcy,
|
||||
// void *dcx,
|
||||
// size_t weight_space_size,
|
||||
// const void *weightSpace,
|
||||
// size_t workSpaceSize,
|
||||
// void *workSpace,
|
||||
// size_t reserveSpaceSize,
|
||||
// void *reserveSpace);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnRNNBackwardData_v8(
|
||||
handle.get(), rnn_desc, (const int32_t*)seq_lengths_buf, output_data_desc,
|
||||
y_buf, dy_buf, input_data_desc, dx_buf, h_desc, h_0_buf, dh_n_buf,
|
||||
dh_0_buf, c_desc, c_0_buf, dc_n_buf, dc_0_buf, weight_space_size, w_buf,
|
||||
d.workspace_size, workspace_buf, d.reserve_space_size,
|
||||
reserve_space_buf)));
|
||||
|
||||
// cudnnStatus_t cudnnRNNBackwardWeights_v8(
|
||||
// cudnnHandle_t handle,
|
||||
// cudnnRNNDescriptor_t rnn_desc,
|
||||
// cudnnWgradMode_t addGrad,
|
||||
// const int32_t devSeqLengths[],
|
||||
// cudnnRNNDataDescriptor_t xDesc,
|
||||
// const void *x,
|
||||
// cudnnTensorDescriptor_t h_desc,
|
||||
// const void *hx,
|
||||
// cudnnRNNDataDescriptor_t yDesc,
|
||||
// const void *y,
|
||||
// size_t weight_space_size,
|
||||
// void *dweightSpace,
|
||||
// size_t workSpaceSize,
|
||||
// void *workSpace,
|
||||
// size_t reserveSpaceSize,
|
||||
// void *reserveSpace);
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnRNNBackwardWeights_v8(
|
||||
handle.get(), rnn_desc, CUDNN_WGRAD_MODE_ADD,
|
||||
(const int32_t*)seq_lengths_buf, input_data_desc, x_buf, h_desc, h_0_buf,
|
||||
output_data_desc, y_buf, weight_space_size, zeroed_dw_buf,
|
||||
d.workspace_size, workspace_buf, d.reserve_space_size,
|
||||
reserve_space_buf)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(h_desc)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyTensorDescriptor(c_desc)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnDestroyDropoutDescriptor(dropout_desc)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(input_data_desc)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(cudnnDestroyRNNDataDescriptor(output_data_desc)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(cudnnDestroyRNNDescriptor(rnn_desc)));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
void RNNForward(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = DnnRNNForward_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
void RNNBackward(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status) {
|
||||
auto s = DnnRNNBackward_(stream, buffers, opaque, opaque_len);
|
||||
if (!s.ok()) {
|
||||
XlaCustomCallStatusSetFailure(status, std::string(s.message()).c_str(),
|
||||
s.message().length());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
53
jaxlib/gpu/rnn_kernels.h
Normal file
53
jaxlib/gpu/rnn_kernels.h
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2022 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef JAXLIB_GPU_RNN_KERNELS_H_
|
||||
#define JAXLIB_GPU_RNN_KERNELS_H_
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "tensorflow/compiler/xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
namespace JAX_GPU_NAMESPACE {
|
||||
|
||||
// Compile-time info passed as `opaque` to custom kernel.
|
||||
struct RnnDescriptor {
|
||||
int input_size;
|
||||
int hidden_size;
|
||||
int num_layers;
|
||||
int batch_size;
|
||||
int max_seq_length;
|
||||
float dropout;
|
||||
bool bidirectional;
|
||||
int workspace_size;
|
||||
int reserve_space_size;
|
||||
};
|
||||
|
||||
// Return (workspace size, reserve space size).
|
||||
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
|
||||
int input_size, int hidden_size, int num_layers, int batch_size,
|
||||
int max_seq_length, float dropout, bool bidirectional);
|
||||
|
||||
void RNNForward(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
void RNNBackward(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus* status);
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
||||
#endif // JAXLIB_GPU_RNN_KERNELS_H_
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
|
||||
#include "third_party/gpus/cuda/include/cusolverDn.h"
|
||||
#include "third_party/gpus/cuda/include/cusparse.h"
|
||||
#include "third_party/gpus/cudnn/cudnn.h"
|
||||
|
||||
// Some sparse functionality is only available in CUSPARSE 11.3 or newer.
|
||||
#define JAX_GPU_HAVE_SPARSE (CUSPARSE_VERSION >= 11300)
|
||||
@ -64,6 +65,8 @@ typedef cublasHandle_t gpublasHandle_t;
|
||||
typedef cudaDataType gpuDataType;
|
||||
typedef cudaStream_t gpuStream_t;
|
||||
typedef cudaError_t gpuError_t;
|
||||
typedef cudnnHandle_t gpudnnHandle_t;
|
||||
typedef cudnnStatus_t gpudnnStatus_t;
|
||||
typedef cusolverDnHandle_t gpusolverDnHandle_t;
|
||||
typedef cusolverStatus_t gpusolverStatus_t;
|
||||
typedef cusolverEigMode_t gpusolverEigMode_t;
|
||||
@ -97,6 +100,11 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
|
||||
|
||||
#define GPUBLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS
|
||||
|
||||
#define gpudnnCreate cudnnCreate
|
||||
#define gpudnnSetStream cudnnSetStream
|
||||
|
||||
#define GPUDNN_STATUS_SUCCESS CUDNN_STATUS_SUCCESS
|
||||
|
||||
#define gpusolverDnCreate cusolverDnCreate
|
||||
#define gpusolverDnSetStream cusolverDnSetStream
|
||||
#define gpusolverDnCreateSyevjInfo cusolverDnCreateSyevjInfo
|
||||
@ -233,6 +241,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
|
||||
|
||||
#define gpuGetLastError cudaGetLastError
|
||||
#define gpuGetErrorString cudaGetErrorString
|
||||
#define gpuMemcpy cudaMemcpy
|
||||
#define gpuMemcpyAsync cudaMemcpyAsync
|
||||
#define gpuMemcpyDeviceToDevice cudaMemcpyDeviceToDevice
|
||||
#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
|
||||
|
125
jaxlib/gpu_rnn.py
Normal file
125
jaxlib/gpu_rnn.py
Normal file
@ -0,0 +1,125 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.mhlo as mhlo
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jaxlib import xla_client
|
||||
|
||||
try:
|
||||
from .cuda import _rnn as _rnn
|
||||
for _name, _value in _rnn.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform='CUDA')
|
||||
except ImportError:
|
||||
_rnn = None
|
||||
|
||||
if _rnn:
|
||||
compute_rnn_workspace_reserve_space_sizes = _rnn.compute_rnn_workspace_reserve_space_sizes
|
||||
|
||||
|
||||
def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *,
|
||||
input_size: int, hidden_size: int, num_layers: int,
|
||||
dropout: bool, bidirectional: bool):
|
||||
"""CuDnn RNN."""
|
||||
out_dtype = ctx.avals_out[0].dtype
|
||||
if out_dtype == np.float32:
|
||||
out_type = ir.F32Type.get()
|
||||
elif out_dtype == np.float64:
|
||||
out_type = ir.F64Type.get()
|
||||
elif out_dtype == np.complex64:
|
||||
out_type = ir.ComplexType.get(ir.F32Type.get())
|
||||
elif out_dtype == np.complex128:
|
||||
out_type = ir.ComplexType.get(ir.F64Type.get())
|
||||
else:
|
||||
raise ValueError(f'Unknown output type {out_dtype}')
|
||||
|
||||
output_type = ir.RankedTensorType.get(ctx.avals_out[0].shape, out_type)
|
||||
batch_size = ctx.avals_in[0].shape[0]
|
||||
max_seq_length = ctx.avals_in[0].shape[1]
|
||||
workspace_shape = ctx.avals_out[3].shape
|
||||
reserve_space_shape = ctx.avals_out[4].shape
|
||||
workspace_type = ir.RankedTensorType.get(workspace_shape, ir.F32Type.get())
|
||||
reserve_space_type = ir.RankedTensorType.get(reserve_space_shape,
|
||||
ir.F32Type.get())
|
||||
opaque = _rnn.build_rnn_descriptor(input_size, hidden_size, num_layers,
|
||||
batch_size, max_seq_length, dropout,
|
||||
bidirectional, workspace_shape[0],
|
||||
reserve_space_shape[0])
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
|
||||
out = mhlo.CustomCallOp(
|
||||
[
|
||||
ir.TupleType.get_tuple([
|
||||
output_type, h_0.type, c_0.type, workspace_type,
|
||||
reserve_space_type
|
||||
])
|
||||
],
|
||||
[input, h_0, c_0, weights, seq_lengths],
|
||||
call_target_name=ir.StringAttr.get('cudnn_rnn'),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
)
|
||||
return [
|
||||
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
|
||||
def _mhlo_zeros_f32(shape):
|
||||
return mhlo.ConstantOp(
|
||||
ir.DenseElementsAttr.get(
|
||||
np.zeros(shape, dtype=np.float32), type=ir.F32Type.get())).result
|
||||
|
||||
|
||||
def cudnn_rnn_bwd_lowering(ctx, dy, dhn, dcn, x, h0, c0, w, y, workspace,
|
||||
reserve_space, seq_lengths, *, input_size: int,
|
||||
hidden_size: int, num_layers: int, dropout: bool,
|
||||
bidirectional: bool):
|
||||
"""CuDnn RNN Backward pass."""
|
||||
batch_size = ctx.avals_in[3].shape[0]
|
||||
max_seq_length = ctx.avals_in[3].shape[1]
|
||||
workspace_shape = ctx.avals_in[8].shape
|
||||
reserve_space_shape = ctx.avals_in[9].shape
|
||||
opaque = _rnn.build_rnn_descriptor(input_size, hidden_size, num_layers,
|
||||
batch_size, max_seq_length, dropout,
|
||||
bidirectional, workspace_shape[0],
|
||||
reserve_space_shape[0])
|
||||
|
||||
i32_type = ir.IntegerType.get_signless(32)
|
||||
zeroed_dw = _mhlo_zeros_f32(ctx.avals_out[3].shape)
|
||||
out = mhlo.CustomCallOp(
|
||||
[ir.TupleType.get_tuple([x.type, h0.type, c0.type, w.type])], [
|
||||
dy, dhn, dcn, x, h0, c0, w, y, workspace, reserve_space, zeroed_dw,
|
||||
seq_lengths
|
||||
],
|
||||
call_target_name=ir.StringAttr.get('cudnn_rnn_bwd'),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
output_operand_aliases=ir.ArrayAttr.get([
|
||||
mhlo.OutputOperandAlias.get(
|
||||
output_tuple_indices=[3],
|
||||
operand_index=10,
|
||||
operand_tuple_indices=[])
|
||||
]))
|
||||
return [
|
||||
mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result
|
||||
for i in range(4)
|
||||
]
|
12
tests/BUILD
12
tests/BUILD
@ -1001,6 +1001,18 @@ jax_test(
|
||||
srcs = ["clear_backends_test.py"],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "experimental_rnn_test",
|
||||
srcs = ["experimental_rnn_test.py"],
|
||||
disable_backends = [
|
||||
"tpu",
|
||||
"cpu",
|
||||
],
|
||||
deps = [
|
||||
"//jax:rnn",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"api_test.py",
|
||||
|
103
tests/experimental_rnn_test.py
Normal file
103
tests/experimental_rnn_test.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.experimental import rnn
|
||||
|
||||
from jax._src.config import config
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class RnnTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
batch_size=[1, 4],
|
||||
seq_len=[1, 4],
|
||||
input_size=[1, 2],
|
||||
hidden_size=[1, 6],
|
||||
num_layers=[1, 4],
|
||||
bidirectional=[True, False],
|
||||
)
|
||||
def test_lstm(self, batch_size: int, seq_len: int, input_size: int,
|
||||
hidden_size: int, num_layers: int, bidirectional: bool):
|
||||
if xla_extension_version < 109:
|
||||
self.skipTest('rnn module added at xla_extension_version 109')
|
||||
batch_size = 6
|
||||
seq_len = 7
|
||||
input_size = 8
|
||||
hidden_size = 12
|
||||
num_layers = 5
|
||||
num_directions = 2 if bidirectional else 1
|
||||
|
||||
seq_lengths = jnp.ones((batch_size,), dtype=jnp.int32) * seq_len
|
||||
|
||||
root_key = jax.random.PRNGKey(1)
|
||||
k1, k2, k3, k4 = jax.random.split(root_key, 4)
|
||||
x = jax.random.normal(
|
||||
k1, (batch_size, seq_len, input_size), dtype=jnp.float32)
|
||||
h_0 = jax.random.normal(
|
||||
k2, (num_directions * num_layers, batch_size, hidden_size),
|
||||
dtype=jnp.float32)
|
||||
c_0 = jax.random.normal(
|
||||
k3, (num_directions * num_layers, batch_size, hidden_size),
|
||||
dtype=jnp.float32)
|
||||
weights = rnn.init_lstm_weight(k4, input_size, hidden_size, num_layers,
|
||||
bidirectional)
|
||||
|
||||
def f(x, h_0, c_0, weights):
|
||||
return rnn.lstm(
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
weights,
|
||||
seq_lengths=seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=False,
|
||||
bidirectional=bidirectional)
|
||||
|
||||
y, h_n, c_n = f(x, h_0, c_0, weights)
|
||||
jtu.check_grads(f, (x, h_0, c_0, weights), modes=['rev'], order=1)
|
||||
|
||||
W_ih, W_hh, b_ih, b_hh = rnn.unpack_lstm_weights(weights, input_size,
|
||||
hidden_size, num_layers,
|
||||
bidirectional)
|
||||
y_ref, h_n_ref, c_n_ref = rnn.lstm_ref(
|
||||
x,
|
||||
h_0,
|
||||
c_0,
|
||||
W_ih,
|
||||
W_hh,
|
||||
b_ih,
|
||||
b_hh,
|
||||
seq_lengths=seq_lengths,
|
||||
input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
dropout=False,
|
||||
bidirectional=bidirectional)
|
||||
|
||||
np.testing.assert_allclose(y_ref, y, rtol=1e-05, atol=1e-5)
|
||||
np.testing.assert_allclose(h_n_ref, h_n, rtol=1e-05, atol=1e-5)
|
||||
np.testing.assert_allclose(c_n_ref, c_n, rtol=1e-05, atol=1e-5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user