mirror of
https://github.com/ROCm/jax.git
synced 2025-04-23 23:06:06 +00:00

1. Add (limited) precision specifier handling to LSTM Enables differentiating between TF32 and FP32 math. TF32 math had insufficient precision to reliably pass LSTM correctness tests on A100 and H100. 2. Run the test using FP32 TF32 precision is not sufficient for the test to pass reliably on Ampere+ GPUs such as A100 and H100.
56 lines
1.9 KiB
C++
56 lines
1.9 KiB
C++
/* Copyright 2022 The JAX Authors.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "nanobind/nanobind.h"
|
|
#include "nanobind/stl/pair.h"
|
|
#include "jaxlib/absl_status_casters.h"
|
|
#include "jaxlib/gpu/rnn_kernels.h"
|
|
#include "jaxlib/gpu/vendor.h"
|
|
#include "jaxlib/kernel_nanobind_helpers.h"
|
|
|
|
namespace jax {
|
|
namespace JAX_GPU_NAMESPACE {
|
|
namespace {
|
|
|
|
namespace nb = nanobind;
|
|
|
|
nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers,
|
|
int batch_size, int max_seq_length, float dropout,
|
|
bool bidirectional, bool cudnn_allow_tf32,
|
|
int workspace_size, int reserve_space_size) {
|
|
return PackDescriptor(RnnDescriptor{
|
|
input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout,
|
|
bidirectional, cudnn_allow_tf32, workspace_size, reserve_space_size
|
|
});
|
|
}
|
|
|
|
nb::dict Registrations() {
|
|
nb::dict dict;
|
|
dict[JAX_GPU_PREFIX "dnn_rnn"] = EncapsulateFunction(RNNForward);
|
|
dict[JAX_GPU_PREFIX "dnn_rnn_bwd"] = EncapsulateFunction(RNNBackward);
|
|
return dict;
|
|
}
|
|
|
|
NB_MODULE(_rnn, m) {
|
|
m.def("registrations", &Registrations);
|
|
m.def("build_rnn_descriptor", &BuildRnnDescriptor);
|
|
m.def("compute_rnn_workspace_reserve_space_sizes",
|
|
ValueOrThrowWrapper(RnnComputeWorkspaceReserveSpaceSizes));
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace JAX_GPU_NAMESPACE
|
|
} // namespace jax
|