mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Default jax_jit_pjit_api_merge
to True. This means that the implementation of jit and pjit have been merged but they still remain separate APIs due to the semantic difference of how they behave under the Mesh
context manager.
This changes the internals of JAX without affecting any public API. Before, `jit` was a final style primitive. This means that the creation of jaxpr was delayed as much as possible and transformations were stacked on top of each other. With the `jit`-`pjit` implementation merge, `jit` becomes an initial style primitive which means that we trace to jaxpr as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing). Moving to initial style should simplify JAX's internals and make development of features like dynamic shapes, etc easier. PiperOrigin-RevId: 508143501
This commit is contained in:
parent
9a1f9b1ef8
commit
6ec9082cf5
16
CHANGELOG.md
16
CHANGELOG.md
@ -8,6 +8,22 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
|
|
||||||
## jax 0.4.4
|
## jax 0.4.4
|
||||||
|
|
||||||
|
* Changes
|
||||||
|
* The implementation of `jit` and `pjit` has been merged. Merging jit and pjit
|
||||||
|
changes the internals of JAX without affecting the public API of JAX.
|
||||||
|
Before, `jit` was a final style primitive. Final style means that the creation
|
||||||
|
of jaxpr was delayed as much as possible and transformations were stacked
|
||||||
|
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
|
||||||
|
becomes an initial style primitive which means that we trace to jaxpr
|
||||||
|
as early as possible. For more information see
|
||||||
|
[this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).
|
||||||
|
Moving to initial style should simplify JAX's internals and make
|
||||||
|
development of features like dynamic shapes, etc easier.
|
||||||
|
You can disable it only via the environment variable i.e.
|
||||||
|
`os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`.
|
||||||
|
The merge must be disabled via an environment variable since it affects JAX
|
||||||
|
at import time so it needs to be disabled before jax is imported.
|
||||||
|
|
||||||
## jaxlib 0.4.4
|
## jaxlib 0.4.4
|
||||||
|
|
||||||
## jax 0.4.3 (Feb 8, 2023)
|
## jax 0.4.3 (Feb 8, 2023)
|
||||||
|
@ -780,9 +780,13 @@ jax_array = config.define_bool_state(
|
|||||||
|
|
||||||
jit_pjit_api_merge = config.define_bool_state(
|
jit_pjit_api_merge = config.define_bool_state(
|
||||||
name='jax_jit_pjit_api_merge',
|
name='jax_jit_pjit_api_merge',
|
||||||
default=False,
|
default=True,
|
||||||
upgrade=True,
|
upgrade=True,
|
||||||
help=('If True, jit and pjit API will be merged.'))
|
help=('If True, jit and pjit API will be merged. You can only disable it via '
|
||||||
|
"the environment variable i.e. `os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`. "
|
||||||
|
"The merge must be disabled via an environment variable since it "
|
||||||
|
"affects JAX at import time so it needs to be disabled before jax is "
|
||||||
|
"imported."))
|
||||||
|
|
||||||
|
|
||||||
spmd_mode = config.define_enum_state(
|
spmd_mode = config.define_enum_state(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user