Change the representation of both units and tokens at the runtime level to be a single buffer with shape pred[0]. While the MLIR lowering is happy to have a non 1:1 mapping between avals and IR values, the XLA lowering is not, so until we remove the XLA lowering it's easiest just to keep the mapping 1:1.
PiperOrigin-RevId: 412957231
* jax._src.device_array, which contains the definition of DeviceArray.
* jax.interpreters.xla, which contains code for lowering jaxprs into XLA computations.
* jax._src.dispatch, which contains code for executing primitives and jit-compiled functions (xla_call_p's impl logic).
The purpose of splitting up this file is that I would like to treat jax.interpreters.mlir lowering as an alternative to jax.interpreters.xla, but we wish to share the device_array and computation dispatch pieces. Currently jax.interpreters.mlir duplicates most of the dispatch logic. (That refactoring is for a future change; this change just moves the existing code around.)
PiperOrigin-RevId: 411565432