mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 01:06:05 +00:00

This cl refactors Pallas memref indexers to transforms which can support different ref transforms: indexing, bitcast (added in this cl), reshape (to be added) and others. Like indexer, user can apply multiple transforms to same memref, eg: ``` ref.bitcast(type1).at[slice1].bitcast(type2).bitcast(type3).at[slice2]... ``` Jaxpr Preview (apply multiple transforms to same ref): ``` { lambda ; a:MemRef<None>{int32[16,256]} b:MemRef<None>{int32[8,128]}. let c:i32[8,128] <- a[:8,:][bitcast(int16[16,256])][bitcast(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:] b[:,:] <- c in () } ``` Tested: * DMA with bitcasted ref * Load from bitcasted ref * Store to bitcasted ref * Multiple transforms * Interpret Mode for ref transforms (updated discharge rules) PiperOrigin-RevId: 674961388