This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.
Here I changed the `custom_jvp_call` to replace the parameter
`jvp_jaxpr_thunk` (a callable) with `jvp_jaxpr_fun` (a `lu.WrappedFun`
that can carry debug info).
Also fixed uses in shard_map, checkify, sparse, attrs, and jax2tf.
This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg:
```
ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]...
```
Jaxpr Preview (apply multiple transforms to same ref):
```
{ lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
* DMA with bitcasted ref
* Load from bitcasted ref
* Store to bitcasted ref
* Multiple transforms
* Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 674961388
There was an attempt to handle consts captured by the kernel,
but it was incomplete and with errors: the calling convention was
wrong, and the support for handling consts along with scalar
prefetch and scratch values was incomplete.
I expanded the tests: one in pallas_tests.py and two tests
in tpu_pallas_test.py (to handle scalar prefetch, with and
without scratch inputs).
The calling convention now: `*scalar_refs, *consts, *ins, *outs, *scratch`.
This is different from before (`*consts, *scalar_refs, *ins, ...`) so that
it keeps the block arguments (consts, ins, outs) together and makes it
easier to write the lowering.
I will follow up with a cleanup PR for the handling of grid_mapping.
Here I attempted to minimize the changes.