jax authors 1f0b5728a4 Add a memory saving index rewrite step to vmap with ragged inputs over pallas_call.
The approach here is to add a new notion to jax, for ragged_prop. Ragged prop is useful for computing the dynamism/raggedness of an output, given a set of inputs. In the limit, if we decide that this is a useful property to have in jax as a first class citizen, we could fold the raggedness into the type system. At the moment, however, it is just a small set of rules implemented per op.

PiperOrigin-RevId: 685827096
2024-10-14 14:01:42 -07:00
..
2024-10-10 08:07:35 -07:00
2024-10-03 10:40:39 -07:00
2024-10-07 05:44:00 -07:00
2024-09-11 23:34:03 +10:00
2024-10-04 10:56:18 -04:00