rocm_jax/jaxlib/gpu/triton.proto
2024-01-19 05:55:41 -08:00

68 lines
1.4 KiB
Protocol Buffer

syntax = "proto3";
package jax_triton;
message TritonKernel {
string kernel_name = 1; // Kernel function name within module.
uint32 num_warps = 2;
uint32 shared_mem_bytes = 3;
string ptx = 4;
string ttir = 5;
uint32 compute_capability = 6;
uint32 cluster_dim_0 = 7;
uint32 cluster_dim_1 = 8;
uint32 cluster_dim_2 = 9;
}
message TritonKernelCall {
message Parameter {
message Array {
uint64 bytes_to_zero = 1;
uint64 ptr_divisibility = 2;
}
oneof value {
Array array = 1;
bool bool_ = 2;
int32 i32 = 3;
uint32 u32 = 4;
int64 i64 = 5;
uint64 u64 = 6;
float f32 = 7;
double f64 = 8;
}
}
TritonKernel kernel = 1;
uint32 grid_0 = 2;
uint32 grid_1 = 3;
uint32 grid_2 = 4;
repeated Parameter parameters = 5;
}
message TritonAutotunedKernelCall {
message Config {
TritonKernelCall kernel_call = 1;
string description = 2;
}
message InputOutputAlias {
uint32 input_buffer_idx = 1;
uint32 output_buffer_idx = 2;
uint64 buffer_size_bytes = 3;
}
string name = 1; // Name used in auto-tuning log messages.
repeated Config configs = 2;
repeated InputOutputAlias input_output_aliases = 3;
}
message TritonAnyKernelCall {
oneof value {
TritonKernelCall kernel_call = 1;
TritonAutotunedKernelCall autotuned_kernel_call = 2;
}
bytes metadata = 3;
string name = 4; // User assigned name.
}