mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 18:36:07 +00:00

This allows users to distinguish Mosaic GPU kernels from other kernels when using profiling programs such as Nsight Systems. The new default behavior is to use `mosaic_gpu_<def_name>_kernel` as the kernel name, where `<def_name>` is the name of the Mosaic GPU Python kernel function passed to `as_gpu_kernel` or `as_torch_gpu_kernel`. We also add a new `kernel_name` optional argument to `as_gpu_kernel` and `as_torch_gpu_kernel`. If `kernel_name` is not `None`, the resulting kernel name is `mosaic_gpu_<kernel_name>_kernel`. This is useful when the Mosaic GPU Python kernel function is constructed through metaprogramming so that the final specialized kernel can have different meaningful names depending on the metaparameters. Previously the kernel name was always `main_kernel`.