1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

Add support for scratch buffers in jax_triton.

This is required to use device-side TMA descriptors.

PiperOrigin-RevId: 735985603
This commit is contained in:
Chris Jones 2025-03-11 20:48:55 -07:00 committed by jax authors
parent ff751ecc7b
commit 74b4d868e3

@ -493,15 +493,7 @@ absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) {
param.value)));
}
}
// Triton's kernel ABI expects an additional scratchpad global memory.
// For now it is only used for on-device creation of TMA descriptors, which
// we do not use yet, so we are just replacing this argument with a null
// pointer.
// TODO: b/381242007 - Allocate a proper buffer if we want to use
// device-side TMA APIs.
void* scratch_ptr = nullptr; // Alive until kernel_.Launch returns.
params.push_back(&scratch_ptr);
params.push_back(buffers++); // Scratch buffer.
return kernel_.Launch(stream, grid_, params.data());
}