mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[ROCm]: Add get_arch_details for triton kernel call
This commit is contained in:
parent
f17d0f382a
commit
4b7c198a1c
@ -132,6 +132,18 @@ NB_MODULE(_triton, m) {
|
||||
return major * 10 + minor;
|
||||
}));
|
||||
|
||||
m.def(
|
||||
"get_arch_details",
|
||||
ValueOrThrowWrapper([](int device) -> absl::StatusOr<absl::string_view> {
|
||||
#ifdef JAX_GPU_HIP
|
||||
hipDeviceProp_t prop;
|
||||
hipGetDeviceProperties(&prop, 0);
|
||||
return prop.gcnArchName;
|
||||
#else
|
||||
return absl::UnimplementedError("Not a HIP GPU");
|
||||
#endif
|
||||
}));
|
||||
|
||||
m.def("get_serialized_metadata",
|
||||
ValueOrThrowWrapper(
|
||||
[](nb::bytes opaque) -> absl::StatusOr<nb::bytes> {
|
||||
|
@ -35,6 +35,7 @@ if _cuda_triton:
|
||||
create_array_parameter = _cuda_triton.create_array_parameter
|
||||
create_scalar_parameter = _cuda_triton.create_scalar_parameter
|
||||
get_compute_capability = _cuda_triton.get_compute_capability
|
||||
get_arch_details = _cuda_triton.get_arch_details
|
||||
get_custom_call = _cuda_triton.get_custom_call
|
||||
get_serialized_metadata = _cuda_triton.get_serialized_metadata
|
||||
|
||||
@ -58,5 +59,6 @@ if _hip_triton:
|
||||
create_array_parameter = _hip_triton.create_array_parameter
|
||||
create_scalar_parameter = _hip_triton.create_scalar_parameter
|
||||
get_compute_capability = _hip_triton.get_compute_capability
|
||||
get_arch_details = _hip_triton.get_arch_details
|
||||
get_custom_call = _hip_triton.get_custom_call
|
||||
get_serialized_metadata = _hip_triton.get_serialized_metadata
|
||||
|
Loading…
x
Reference in New Issue
Block a user