mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

Note that dynamic grid dimensions with 'parallel' semantics are disallowed. This enables the computation of grid points, with randomized coordinates along 'parallel' dimensions, in Jax/on device. If randomization of grid dimensions with dynamic sizes (i.e. sizes not known at Jax trace time) were allowed, this would require computing these randomizations on the host/on CPU (where one can have arrays of dynamic shape). PiperOrigin-RevId: 746365669