Follow-up on #15677, basically undoing it. Some training runs experienced
mysterious failures after many steps. We may leave this disabled until we
diagnose the cause of the failures.
Implicit jit and apply_primitive will still raise an error though (which is recognized via inline parameter). Majority of jnp operations in JAX should be inlined.
PiperOrigin-RevId: 527398394
After the changes in shard_map, there are 75 failures left to be resolved (not counting the EagerPmap tests).
TODO:
* Move shard_map to _src so that the circular import can be removed from api.py
PiperOrigin-RevId: 525930416
Metadata, in particular code location information is present in the HLO generated by JAX. The compilation cache uses the serialized HLO as a cache key, which begs the question: should code location information be part of that key? Simply changing the line number on which a function appears shouldn't necessarily cause a cache miss.
There are pros and cons: the main advantage of excluding metadata is that we will get more cache hits, and the main disadvantage is that debug information and profiling data in the HLO might become confusing, since it may refer to a different program entirely, or to a version of a program that does not correspond to the current state of the source tree. We argue that saving compilation time is the more important concern.
This change adds a tiny MLIR pass that strips Locations from a StableHLO module, and applies it in the compilation cache if metadata stripping is enabled.
PiperOrigin-RevId: 525534901
The main idea here is to improve tooling for knowing what residuals are being
saved and why. There's a lot more that can be done here (e.g. naming the
arguments, explaining what JVP rule produced these residuals, explaining what
consumed them, etc) but this is a start.
Co-authored-by: Qiao Zhang <zhangqiaorjc@google.com>
We refer to the feature as serialization rather than just lowering,
because the former is both more widely understood and is actually
more accurate because jax2tf will both lower to StableHLO and then
serialize to StableHLO with compatibility guarantees.
This is part of launching the new version of jax2tf with native
serialization.
For now we keep also the parameter `experimental_native_lowering` and
the flag `jax2tf_default_experimental_native_lowering`, until we transition
projects using these flags to the new ones (separate change).
PiperOrigin-RevId: 516864636
Add HLO source path canonicalization regex to trace state key because otherwise MetadataTest.test_source_file_prefix_removal fails due to caching of lowerings with different canonicalization regexs.
PiperOrigin-RevId: 509975754
This changes the internals of JAX without affecting any public API.
Before, `jit` was a final style primitive. This means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see [this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).
Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.
PiperOrigin-RevId: 508143501
Consider the following code where static type checkers can report an
error:
```python
CPU = jax.devices('cpu')[0]
with jax.default_device(CPU):
... # ^^^
```
Error message:
```
Pyright: Argument of type "Device" cannot be assigned to parameter "new_val" of type "NoDefault"
"Device" is incompatible with "NoDefault" (reportGeneralTypeIssues)
```
This is because `_StateContextManager.__call__` does not have a proper
type annotation on the parameter, unlike the attribute `_default_value`
which has a type annotation. Adding a `Any` to the parameter would
make the error disappear.
`getattr` turns out to be a tiny bit slower than `__get__()` on `__dict__` in the case that the attribute is absent. `getattr` appears to form an error message that is thrown away if a default is present.
Improves the device_put benchmark:
```
name old cpu/op new cpu/op delta
device_put 51.4µs ± 1% 48.9µs ± 3% -4.87% (p=0.000 n=8+9)
name old time/op new time/op delta
device_put 51.4µs ± 1% 48.9µs ± 3% -4.87% (p=0.000 n=8+9)
```
PiperOrigin-RevId: 493108288
Currently when JAX config values are configured via ABSL, we use the ABSL flags as a source of truth: if we read or write the JAX config option, we read or write the corresponding ABSL flag. This works but has the unfortunate downside that ABSL flags are relatively slow to read, which slows down JAX every time we read a configuration option.
However, there's fundamentally no reason we are mirroring the JAX configuration options back to ABSL in the first place. We can use ABSL flag parsing as a way only to populate the JAX configuration values. The downside is that if someone changes the ABSL flag values after parsing, that change will not be reflected in JAX's config values. JAX config changes after ABSL flags have been parsed must be made via the `jax.config.update()` API.
This gives a decent improvement on the device_put benchmark:
```
name old cpu/op new cpu/op delta
device_put 79.5µs ± 6% 69.4µs ± 7% -12.73% (p=0.000 n=10+9)
name old time/op new time/op delta
device_put 79.5µs ± 6% 69.4µs ± 7% -12.73% (p=0.000 n=10+9)
```
PiperOrigin-RevId: 492519085
Implicitly jitted functions will **always** require a `jax.spmd_mode` context manager for operating on non-fully addressable jax.Array.
Explicitly jitted functions will require the `jax.spmd_mode` config to begin with as we roll out jax.Array since its a new behavior for `jit` (previously jit only worked on single device arrays).
* Overtime (via docs) and as users become more familiar with the new parallelism APIs, we can relax this restriction and allow explicit `jit` to work without needing the config. This can happen when we merge the frontend of `jit` and `pjit`.
PiperOrigin-RevId: 485075693
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):
- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.
Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:
```py
import logging
logger = logging.getLogger(__name__)
logger.debug(...)
logger.info(...)
```
The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.
The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
Fixes#12582. Setting the env var `JAX_RAISE_PERSISTENT_CACHE_ERRORS=true` will revert to the original behavior of raising exception instead of warning.
Also makes JAX_DUMP_IR_TO work when the persistent cache is enabled.