We're switching to the new terminology to avoid confusion in cases
where multiple jax processes are running on a single host, and each
process has a unique process_index/host_id.
This keeps aliases for the old `host_id` APIs for now, but these will
eventually be removed.
We make sure that both the inputs and the outputs of
callbacks can contain empty arrays.
Most platforms do not support empty infeed, so we ensure
we do not send those.
* Moved CHANGELOG to docs
This puts the documentation also on RTD, with TOC.
Also changed its format to .rst, for consistency.
Added GitHub links to the change log.
* Actually add the CHANGELOG.rst
* Added reminder comments to the CHANGELOG.rst
* Adding `static_argnums` to `pmap` for similar behaviour to `static_argnums` of `jit`.
* Removed check for ShardedDeviceArray
* Final clean up and rename.
The goal is to make the Jaxpr language more uniform: all higher-order
primitives carry sub-Jaxprs that are part of the parameters, and they
are all called xxx_jaxpr. As a side-effect, some code is simplified
(e.g., the code that searches for sub-jaxprs).
For now the code assumes that all the `call` (final-style) primitives
carry exactly one subjaxpr with the parameter name `call_jaxpr`. These
primitives are still processed differently in the internal code, but
there is no reason any external consumer of a Jaxpr needs to know this.