mirror of
https://github.com/ROCm/jax.git
synced 2025-04-24 19:56:05 +00:00

There was an attempt to handle consts captured by the kernel, but it was incomplete and with errors: the calling convention was wrong, and the support for handling consts along with scalar prefetch and scratch values was incomplete. I expanded the tests: one in pallas_tests.py and two tests in tpu_pallas_test.py (to handle scalar prefetch, with and without scratch inputs). The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`. This is different from before (`*consts, *scalar_refs, *ins, ...`) so that it keeps the block arguments (consts, ins, outs) together and makes it easier to write the lowering. I will follow up with a cleanup PR for the handling of grid_mapping. Here I attempted to minimize the changes.