mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00

When it is possible to annotate an operation using both a `strided` and a `splat` layout, we pick the `strided` layout. This is the correct choice when propagating layouts down from parameters to the root; e.g. ``` ? = add(strided, splat) ``` becomes ``` strided = add(strided, strided) ``` and requires a re-layout for the right-hand side argument. The logic needs to be reversed to handle propagation in the opposite direction. For example, code like ``` c : ? use(c) : strided use(c) : splat ``` should resolve to ``` c : splat use(c) : strided use(c) : splat ``` and incur a relayout in the `strided` use of `c`. This direction of propagation is left as a `TODO` for now, to limit the amount of changes in a single commit. PiperOrigin-RevId: 714056648