mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Update RNN kernels to use FFI.
PiperOrigin-RevId: 724151647
This commit is contained in:
parent
cce3df1071
commit
c6e83903de
@ -154,12 +154,14 @@ cc_library(
|
||||
deps = [
|
||||
":cuda_gpu_kernel_helpers",
|
||||
":cuda_vendor",
|
||||
":ffi_wrapper",
|
||||
"//jaxlib:handle_pool",
|
||||
"//jaxlib:kernel_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
"@xla//xla/tsl/cuda:cudart",
|
||||
"@xla//xla/tsl/cuda:cudnn",
|
||||
|
@ -32,14 +32,16 @@ nb::bytes BuildRnnDescriptor(int input_size, int hidden_size, int num_layers,
|
||||
int workspace_size, int reserve_space_size) {
|
||||
return PackDescriptor(RnnDescriptor{
|
||||
input_size, hidden_size, num_layers, batch_size, max_seq_length, dropout,
|
||||
bidirectional, cudnn_allow_tf32, workspace_size, reserve_space_size
|
||||
});
|
||||
bidirectional, cudnn_allow_tf32, workspace_size, reserve_space_size});
|
||||
}
|
||||
|
||||
nb::dict Registrations() {
|
||||
nb::dict dict;
|
||||
dict[JAX_GPU_PREFIX "dnn_rnn"] = EncapsulateFunction(RNNForward);
|
||||
dict[JAX_GPU_PREFIX "dnn_rnn_bwd"] = EncapsulateFunction(RNNBackward);
|
||||
dict[JAX_GPU_PREFIX "dnn_rnn_ffi"] = EncapsulateFfiHandler(RNNForwardFfi);
|
||||
dict[JAX_GPU_PREFIX "dnn_rnn_bwd_ffi"] =
|
||||
EncapsulateFfiHandler(RNNBackwardFfi);
|
||||
return dict;
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/status/status.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "jaxlib/gpu/ffi_wrapper.h"
|
||||
#include "jaxlib/gpu/gpu_kernel_helpers.h"
|
||||
#include "jaxlib/handle_pool.h"
|
||||
#include "jaxlib/kernel_helpers.h"
|
||||
@ -92,14 +93,17 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
|
||||
|
||||
#ifdef JAX_GPU_HIP
|
||||
void* dropout_states_dev = nullptr;
|
||||
// Allocate minimal memory for dropout states (can be very small since it's not used)
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size)));
|
||||
// Allocate minimal memory for dropout states (can be very small since it's
|
||||
// not used)
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size)));
|
||||
if (!dropout_states_dev) {
|
||||
return absl::InternalError("Failed to allocate minimal GPU memory for dropout states.");
|
||||
return absl::InternalError(
|
||||
"Failed to allocate minimal GPU memory for dropout states.");
|
||||
}
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor(
|
||||
dropout_desc, handle.get(), dropout, dropout_states_dev, state_size, 123, false, false,
|
||||
MIOPEN_RNG_PSEUDO_XORWOW)));
|
||||
dropout_desc, handle.get(), dropout, dropout_states_dev, state_size, 123,
|
||||
false, false, MIOPEN_RNG_PSEUDO_XORWOW)));
|
||||
#else // JAX_GPU_CUDA
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor(
|
||||
dropout_desc, handle.get(), dropout, nullptr, state_size, 123)));
|
||||
@ -123,7 +127,8 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
|
||||
cell_mode, bias_mode, GPUDNN_RNN_ALGO_STANDARD, data_type)));
|
||||
#else // JAX_GPU_CUDA
|
||||
gpudnnDataType_t math_prec = GPUDNN_DATA_FLOAT;
|
||||
gpudnnMathType_t math_type = cudnn_allow_tf32? GPUDNN_DEFAULT_MATH: GPUDNN_FMA_MATH;
|
||||
gpudnnMathType_t math_type =
|
||||
cudnn_allow_tf32 ? GPUDNN_DEFAULT_MATH : GPUDNN_FMA_MATH;
|
||||
int32_t proj_size = hidden_size;
|
||||
uint32_t aux_flags = GPUDNN_RNN_PADDED_IO_ENABLED;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor(
|
||||
@ -149,9 +154,9 @@ DoRnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
|
||||
size_t workSpaceSize;
|
||||
size_t reserveSpaceSize;
|
||||
#ifdef JAX_GPU_HIP
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnGetRNNTempSpaceSizes(
|
||||
handle.get(), rnn_desc, input_data_desc, fwdMode, &workSpaceSize,
|
||||
&reserveSpaceSize)));
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(
|
||||
gpudnnGetRNNTempSpaceSizes(handle.get(), rnn_desc, input_data_desc,
|
||||
fwdMode, &workSpaceSize, &reserveSpaceSize)));
|
||||
#else // JAX_GPU_CUDA
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnGetRNNTempSpaceSizes(
|
||||
handle.get(), rnn_desc, fwdMode, input_data_desc, &workSpaceSize,
|
||||
@ -199,14 +204,17 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers,
|
||||
|
||||
#ifdef JAX_GPU_HIP
|
||||
void* dropout_states_dev = nullptr;
|
||||
// Allocate minimal memory for dropout states (can be very small since it's not used).
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size)));
|
||||
// Allocate minimal memory for dropout states (can be very small since it's
|
||||
// not used).
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size)));
|
||||
if (!dropout_states_dev) {
|
||||
return absl::InternalError("Failed to allocate minimal GPU memory for dropout states.");
|
||||
return absl::InternalError(
|
||||
"Failed to allocate minimal GPU memory for dropout states.");
|
||||
}
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor(
|
||||
dropout_desc, handle.get(), d.dropout, dropout_states_dev, state_size, 123, false, false,
|
||||
MIOPEN_RNG_PSEUDO_XORWOW)));
|
||||
dropout_desc, handle.get(), d.dropout, dropout_states_dev, state_size,
|
||||
123, false, false, MIOPEN_RNG_PSEUDO_XORWOW)));
|
||||
#else // JAX_GPU_CUDA
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor(
|
||||
dropout_desc, handle.get(), d.dropout, nullptr, state_size, 123)));
|
||||
@ -230,7 +238,8 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers,
|
||||
cell_mode, bias_mode, GPUDNN_RNN_ALGO_STANDARD, data_type)));
|
||||
#else // JAX_GPU_CUDA
|
||||
gpudnnDataType_t math_prec = GPUDNN_DATA_FLOAT;
|
||||
gpudnnMathType_t math_type = d.cudnn_allow_tf32? GPUDNN_DEFAULT_MATH: GPUDNN_FMA_MATH;
|
||||
gpudnnMathType_t math_type =
|
||||
d.cudnn_allow_tf32 ? GPUDNN_DEFAULT_MATH : GPUDNN_FMA_MATH;
|
||||
int32_t proj_size = d.hidden_size;
|
||||
uint32_t aux_flags = GPUDNN_RNN_PADDED_IO_ENABLED;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor(
|
||||
@ -288,8 +297,9 @@ static absl::Status DnnRNNForward_(gpuStream_t stream, void** buffers,
|
||||
|
||||
size_t weight_space_size;
|
||||
#ifdef JAX_GPU_HIP
|
||||
miopenTensorDescriptor_t input_tensor_desc;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenCreateTensorDescriptor(&input_tensor_desc)));
|
||||
miopenTensorDescriptor_t input_tensor_desc;
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(miopenCreateTensorDescriptor(&input_tensor_desc)));
|
||||
int dimsA[2] = {d.batch_size, d.input_size};
|
||||
int stridesA[2] = {dimsA[1], 1}; // Row-major order, similar to GPUDNN
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenSetTensorDescriptor(
|
||||
@ -314,10 +324,10 @@ miopenTensorDescriptor_t input_tensor_desc;
|
||||
|
||||
#ifdef JAX_GPU_HIP
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNForward(
|
||||
handle.get(), rnn_desc, fwdMode, input_data_desc, input_buf,
|
||||
h_desc, h_0_buf, h_n_buf, c_desc, c_0_buf, c_n_buf,
|
||||
output_data_desc, output_buf, weights_buf, weight_space_size,
|
||||
workspace_buf, d.workspace_size, reserve_space_buf, d.reserve_space_size)));
|
||||
handle.get(), rnn_desc, fwdMode, input_data_desc, input_buf, h_desc,
|
||||
h_0_buf, h_n_buf, c_desc, c_0_buf, c_n_buf, output_data_desc, output_buf,
|
||||
weights_buf, weight_space_size, workspace_buf, d.workspace_size,
|
||||
reserve_space_buf, d.reserve_space_size)));
|
||||
#else // JAX_GPU_CUDA
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNForward(
|
||||
handle.get(), rnn_desc, fwdMode, (const int32_t*)seq_lengths_buf,
|
||||
@ -361,14 +371,17 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
|
||||
|
||||
#ifdef JAX_GPU_HIP
|
||||
void* dropout_states_dev = nullptr;
|
||||
// Allocate minimal memory for dropout states (can be very small since it's not used)
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size)));
|
||||
// Allocate minimal memory for dropout states (can be very small since it's
|
||||
// not used)
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(gpuMalloc(&dropout_states_dev, state_size)));
|
||||
if (!dropout_states_dev) {
|
||||
return absl::InternalError("Failed to allocate minimal GPU memory for dropout states.");
|
||||
return absl::InternalError(
|
||||
"Failed to allocate minimal GPU memory for dropout states.");
|
||||
}
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor(
|
||||
dropout_desc, handle.get(), d.dropout, dropout_states_dev, state_size, 123, false, false,
|
||||
MIOPEN_RNG_PSEUDO_XORWOW)));
|
||||
dropout_desc, handle.get(), d.dropout, dropout_states_dev, state_size,
|
||||
123, false, false, MIOPEN_RNG_PSEUDO_XORWOW)));
|
||||
#else // JAX_GPU_CUDA
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetDropoutDescriptor(
|
||||
dropout_desc, handle.get(), d.dropout, nullptr, state_size, 123)));
|
||||
@ -392,7 +405,8 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
|
||||
cell_mode, bias_mode, GPUDNN_RNN_ALGO_STANDARD, data_type)));
|
||||
#else // JAX_GPU_CUDA
|
||||
gpudnnDataType_t math_prec = GPUDNN_DATA_FLOAT;
|
||||
gpudnnMathType_t math_type = d.cudnn_allow_tf32? GPUDNN_DEFAULT_MATH: GPUDNN_FMA_MATH;
|
||||
gpudnnMathType_t math_type =
|
||||
d.cudnn_allow_tf32 ? GPUDNN_DEFAULT_MATH : GPUDNN_FMA_MATH;
|
||||
int32_t proj_size = d.hidden_size;
|
||||
uint32_t aux_flags = GPUDNN_RNN_PADDED_IO_ENABLED;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnSetRNNDescriptor(
|
||||
@ -449,7 +463,8 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
|
||||
size_t weight_space_size;
|
||||
#ifdef JAX_GPU_HIP
|
||||
miopenTensorDescriptor_t input_tensor_desc;
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenCreateTensorDescriptor(&input_tensor_desc)));
|
||||
JAX_RETURN_IF_ERROR(
|
||||
JAX_AS_STATUS(miopenCreateTensorDescriptor(&input_tensor_desc)));
|
||||
int input_dims[2] = {d.batch_size, d.input_size};
|
||||
int input_strides[2] = {input_dims[1], 1}; // row-major order
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(miopenSetTensorDescriptor(
|
||||
@ -482,16 +497,15 @@ static absl::Status DnnRNNBackward_(gpuStream_t stream, void** buffers,
|
||||
|
||||
#ifdef JAX_GPU_HIP
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNBackwardData(
|
||||
handle.get(), rnn_desc, output_data_desc, y_buf, dy_buf,
|
||||
h_desc, h_0_buf, dh_n_buf, dh_0_buf,
|
||||
c_desc, c_0_buf, dc_n_buf, dc_0_buf,
|
||||
input_data_desc, dx_buf, w_buf, weight_space_size,
|
||||
workspace_buf, d.workspace_size, reserve_space_buf, d.reserve_space_size)));
|
||||
handle.get(), rnn_desc, output_data_desc, y_buf, dy_buf, h_desc, h_0_buf,
|
||||
dh_n_buf, dh_0_buf, c_desc, c_0_buf, dc_n_buf, dc_0_buf, input_data_desc,
|
||||
dx_buf, w_buf, weight_space_size, workspace_buf, d.workspace_size,
|
||||
reserve_space_buf, d.reserve_space_size)));
|
||||
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNBackwardWeights(
|
||||
handle.get(), rnn_desc, input_data_desc, x_buf, h_desc, h_0_buf,
|
||||
output_data_desc, y_buf, zeroed_dw_buf, weight_space_size,
|
||||
workspace_buf, d.workspace_size, reserve_space_buf, d.reserve_space_size)));
|
||||
output_data_desc, y_buf, zeroed_dw_buf, weight_space_size, workspace_buf,
|
||||
d.workspace_size, reserve_space_buf, d.reserve_space_size)));
|
||||
#else // JAX_GPU_CUDA
|
||||
JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpudnnRNNBackwardData(
|
||||
handle.get(), rnn_desc, (const int32_t*)seq_lengths_buf, output_data_desc,
|
||||
@ -539,5 +553,8 @@ void RNNBackward(gpuStream_t stream, void** buffers, const char* opaque,
|
||||
}
|
||||
}
|
||||
|
||||
JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(RNNForwardFfi, DnnRNNForward_);
|
||||
JAX_GPU_REGISTER_WRAPPED_LEGACY_KERNEL(RNNBackwardFfi, DnnRNNBackward_);
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "jaxlib/gpu/vendor.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
#include "xla/service/custom_call_status.h"
|
||||
|
||||
namespace jax {
|
||||
@ -38,11 +39,10 @@ struct RnnDescriptor {
|
||||
};
|
||||
|
||||
// Return (workspace size, reserve space size).
|
||||
absl::StatusOr<std::pair<int, int>>
|
||||
RnnComputeWorkspaceReserveSpaceSizes(int input_size, int hidden_size,
|
||||
int num_layers, int batch_size,
|
||||
int max_seq_length, float dropout,
|
||||
bool bidirectional, bool cudnn_allow_tf32);
|
||||
absl::StatusOr<std::pair<int, int>> RnnComputeWorkspaceReserveSpaceSizes(
|
||||
int input_size, int hidden_size, int num_layers, int batch_size,
|
||||
int max_seq_length, float dropout, bool bidirectional,
|
||||
bool cudnn_allow_tf32);
|
||||
|
||||
void RNNForward(gpuStream_t stream, void **buffers, const char *opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus *status);
|
||||
@ -50,6 +50,9 @@ void RNNForward(gpuStream_t stream, void **buffers, const char *opaque,
|
||||
void RNNBackward(gpuStream_t stream, void **buffers, const char *opaque,
|
||||
size_t opaque_len, XlaCustomCallStatus *status);
|
||||
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(RNNForwardFfi);
|
||||
XLA_FFI_DECLARE_HANDLER_SYMBOL(RNNBackwardFfi);
|
||||
|
||||
} // namespace JAX_GPU_NAMESPACE
|
||||
} // namespace jax
|
||||
|
||||
|
@ -33,7 +33,9 @@ for cuda_module_name in [".cuda", "jax_cuda12_plugin"]:
|
||||
|
||||
if _cuda_rnn:
|
||||
for _name, _value in _cuda_rnn.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform='CUDA')
|
||||
api_version = 1 if _name.endswith("_ffi") else 0
|
||||
xla_client.register_custom_call_target(_name, _value, platform='CUDA',
|
||||
api_version=api_version)
|
||||
compute_rnn_workspace_reserve_space_sizes = _cuda_rnn.compute_rnn_workspace_reserve_space_sizes
|
||||
|
||||
|
||||
@ -47,7 +49,9 @@ for rocm_module_name in [".rocm", "jax_rocm60_plugin"]:
|
||||
|
||||
if _hip_rnn:
|
||||
for _name, _value in _hip_rnn.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform='ROCM')
|
||||
api_version = 1 if _name.endswith("_ffi") else 0
|
||||
xla_client.register_custom_call_target(_name, _value, platform='ROCM',
|
||||
api_version=api_version)
|
||||
compute_rnn_workspace_reserve_space_sizes = _hip_rnn.compute_rnn_workspace_reserve_space_sizes
|
||||
|
||||
|
||||
@ -93,10 +97,10 @@ def _rnn_fwd_lowering(_rnn, platform, ctx, input, h_0, c_0, weights, seq_lengths
|
||||
out = hlo.CustomCallOp(
|
||||
[output_type, h_0.type, c_0.type, workspace_type, reserve_space_type],
|
||||
[input, h_0, c_0, weights, seq_lengths],
|
||||
call_target_name=ir.StringAttr.get(f"{platform}dnn_rnn"),
|
||||
call_target_name=ir.StringAttr.get(f"{platform}dnn_rnn_ffi"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
backend_config=ir.DictAttr.get({"opaque": ir.StringAttr.get(opaque)}),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 4),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
)
|
||||
return out.results[:-2] + out.results[-1:] # drop workspace output
|
||||
@ -140,10 +144,10 @@ def _rnn_bwd_lowering(_rnn, platform, ctx, dy, dhn, dcn, x, h0, c0, w, y,
|
||||
dy, dhn, dcn, x, h0, c0, w, y, reserve_space, zeroed_dw,
|
||||
seq_lengths
|
||||
],
|
||||
call_target_name=ir.StringAttr.get(f"{platform}dnn_rnn_bwd"),
|
||||
call_target_name=ir.StringAttr.get(f"{platform}dnn_rnn_bwd_ffi"),
|
||||
has_side_effect=ir.BoolAttr.get(False),
|
||||
backend_config=ir.StringAttr.get(opaque),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 2),
|
||||
backend_config=ir.DictAttr.get({"opaque": ir.StringAttr.get(opaque)}),
|
||||
api_version=ir.IntegerAttr.get(i32_type, 4),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
output_operand_aliases=ir.ArrayAttr.get([
|
||||
hlo.OutputOperandAlias.get(
|
||||
|
@ -140,6 +140,7 @@ cc_library(
|
||||
srcs = ["//jaxlib/gpu:rnn_kernels.cc"],
|
||||
hdrs = ["//jaxlib/gpu:rnn_kernels.h"],
|
||||
deps = [
|
||||
":ffi_wrapper",
|
||||
":hip_gpu_kernel_helpers",
|
||||
":hip_vendor",
|
||||
"//jaxlib:handle_pool",
|
||||
@ -149,6 +150,7 @@ cc_library(
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@local_config_rocm//rocm:miopen",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@xla//xla/ffi/api:ffi",
|
||||
"@xla//xla/service:custom_call_status",
|
||||
],
|
||||
)
|
||||
|
@ -213,9 +213,7 @@ class RnnTest(jtu.JaxTestCase):
|
||||
|
||||
k = jax.random.split(jax.random.PRNGKey(1), 4)
|
||||
stablehlo = jax.jit(f).lower(*k).as_text("stablehlo")
|
||||
self.assertIn('stablehlo.custom_call @cudnn_rnn(%0, %1, %2, %6, %5) '
|
||||
'{api_version = 2 : i32, backend_config = '
|
||||
'"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"}',
|
||||
self.assertIn('"\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\01\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\00\\01\\00\\00\\00@\\03\\80\\00@\\01\\00\\00"',
|
||||
stablehlo)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
Loading…
x
Reference in New Issue
Block a user