Default jax_jit_pjit_api_merge to True. This means that the implementation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the Mesh context manager.

This changes the internals of JAX without affecting any public API.

Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).

Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.

PiperOrigin-RevId: 508143501
This commit is contained in:
Yash Katariya 2023-02-08 11:55:10 -08:00 committed by jax authors
parent 9a1f9b1ef8
commit 6ec9082cf5
2 changed files with 22 additions and 2 deletions

View File

@ -8,6 +8,22 @@ Remember to align the itemized text with the first line of an item within a list
## jax 0.4.4
* Changes
* The implementation of `jit` and `pjit` has been merged. Merging jit and pjit
changes the internals of JAX without affecting the public API of JAX.
Before, `jit` was a final style primitive. Final style means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see
[this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).
Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.
You can disable it only via the environment variable i.e.
`os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`.
The merge must be disabled via an environment variable since it affects JAX
at import time so it needs to be disabled before jax is imported.
## jaxlib 0.4.4
## jax 0.4.3 (Feb 8, 2023)

View File

@ -780,9 +780,13 @@ jax_array = config.define_bool_state(
jit_pjit_api_merge = config.define_bool_state(
name='jax_jit_pjit_api_merge',
default=False,
default=True,
upgrade=True,
help=('If True, jit and pjit API will be merged.'))
help=('If True, jit and pjit API will be merged. You can only disable it via '
"the environment variable i.e. `os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`. "
"The merge must be disabled via an environment variable since it "
"affects JAX at import time so it needs to be disabled before jax is "
"imported."))
spmd_mode = config.define_enum_state(