mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00
68 lines
1.4 KiB
Protocol Buffer
68 lines
1.4 KiB
Protocol Buffer
syntax = "proto3";
|
|
|
|
package jax_triton;
|
|
|
|
message TritonKernel {
|
|
string kernel_name = 1; // Kernel function name within module.
|
|
uint32 num_warps = 2;
|
|
uint32 shared_mem_bytes = 3;
|
|
string ptx = 4;
|
|
string ttir = 5;
|
|
uint32 compute_capability = 6;
|
|
uint32 cluster_dim_0 = 7;
|
|
uint32 cluster_dim_1 = 8;
|
|
uint32 cluster_dim_2 = 9;
|
|
}
|
|
|
|
message TritonKernelCall {
|
|
message Parameter {
|
|
message Array {
|
|
uint64 bytes_to_zero = 1;
|
|
uint64 ptr_divisibility = 2;
|
|
}
|
|
|
|
oneof value {
|
|
Array array = 1;
|
|
bool bool_ = 2;
|
|
int32 i32 = 3;
|
|
uint32 u32 = 4;
|
|
int64 i64 = 5;
|
|
uint64 u64 = 6;
|
|
float f32 = 7;
|
|
double f64 = 8;
|
|
}
|
|
}
|
|
|
|
TritonKernel kernel = 1;
|
|
uint32 grid_0 = 2;
|
|
uint32 grid_1 = 3;
|
|
uint32 grid_2 = 4;
|
|
repeated Parameter parameters = 5;
|
|
}
|
|
|
|
message TritonAutotunedKernelCall {
|
|
message Config {
|
|
TritonKernelCall kernel_call = 1;
|
|
string description = 2;
|
|
}
|
|
|
|
message InputOutputAlias {
|
|
uint32 input_buffer_idx = 1;
|
|
uint32 output_buffer_idx = 2;
|
|
uint64 buffer_size_bytes = 3;
|
|
}
|
|
|
|
string name = 1; // Name used in auto-tuning log messages.
|
|
repeated Config configs = 2;
|
|
repeated InputOutputAlias input_output_aliases = 3;
|
|
}
|
|
|
|
message TritonAnyKernelCall {
|
|
oneof value {
|
|
TritonKernelCall kernel_call = 1;
|
|
TritonAutotunedKernelCall autotuned_kernel_call = 2;
|
|
}
|
|
bytes metadata = 3;
|
|
string name = 4; // User assigned name.
|
|
}
|