mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add support for scratch buffers in jax_triton
.
This is required to use device-side TMA descriptors. PiperOrigin-RevId: 735985603
This commit is contained in:
parent
ff751ecc7b
commit
74b4d868e3
@ -493,15 +493,7 @@ absl::Status KernelCall::Launch(gpuStream_t stream, void** buffers) {
|
||||
param.value)));
|
||||
}
|
||||
}
|
||||
// Triton's kernel ABI expects an additional scratchpad global memory.
|
||||
// For now it is only used for on-device creation of TMA descriptors, which
|
||||
// we do not use yet, so we are just replacing this argument with a null
|
||||
// pointer.
|
||||
// TODO: b/381242007 - Allocate a proper buffer if we want to use
|
||||
// device-side TMA APIs.
|
||||
void* scratch_ptr = nullptr; // Alive until kernel_.Launch returns.
|
||||
params.push_back(&scratch_ptr);
|
||||
|
||||
params.push_back(buffers++); // Scratch buffer.
|
||||
return kernel_.Launch(stream, grid_, params.data());
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user