diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 1274eeba4..500034af3 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -132,6 +132,18 @@ NB_MODULE(_triton, m) { return major * 10 + minor; })); + m.def( + "get_arch_details", + ValueOrThrowWrapper([](int device) -> absl::StatusOr { +#ifdef JAX_GPU_HIP + hipDeviceProp_t prop; + hipGetDeviceProperties(&prop, 0); + return prop.gcnArchName; +#else + return absl::UnimplementedError("Not a HIP GPU"); +#endif + })); + m.def("get_serialized_metadata", ValueOrThrowWrapper( [](nb::bytes opaque) -> absl::StatusOr { diff --git a/jaxlib/gpu_triton.py b/jaxlib/gpu_triton.py index f2d37bfec..77f315e5b 100644 --- a/jaxlib/gpu_triton.py +++ b/jaxlib/gpu_triton.py @@ -35,6 +35,7 @@ if _cuda_triton: create_array_parameter = _cuda_triton.create_array_parameter create_scalar_parameter = _cuda_triton.create_scalar_parameter get_compute_capability = _cuda_triton.get_compute_capability + get_arch_details = _cuda_triton.get_arch_details get_custom_call = _cuda_triton.get_custom_call get_serialized_metadata = _cuda_triton.get_serialized_metadata @@ -58,5 +59,6 @@ if _hip_triton: create_array_parameter = _hip_triton.create_array_parameter create_scalar_parameter = _hip_triton.create_scalar_parameter get_compute_capability = _hip_triton.get_compute_capability + get_arch_details = _hip_triton.get_arch_details get_custom_call = _hip_triton.get_custom_call get_serialized_metadata = _hip_triton.get_serialized_metadata