jax authors d42d2e88b4 [Pallas] Interpret dimensions with parallel semantics by traversing the corresponding grid coordinates in randomized order.
Note that dynamic grid dimensions with 'parallel' semantics are disallowed. This enables the computation of grid points, with randomized coordinates along 'parallel' dimensions, in Jax/on device.
If randomization of grid dimensions with dynamic sizes (i.e. sizes not known at Jax trace time) were allowed, this would require computing these randomizations on the host/on CPU (where one can have arrays of dynamic shape).

PiperOrigin-RevId: 746365669
2025-04-11 01:54:11 -07:00
..
2025-03-27 00:05:28 +00:00
2025-03-26 02:11:03 +00:00
2025-03-27 10:13:14 -07:00
2025-02-18 16:47:19 -08:00