This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):
- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.
Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:
```py
import logging
logger = logging.getLogger(__name__)
logger.debug(...)
logger.info(...)
```
The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.
The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
A recent change broke jax.distributed initialization, which was unsurprising because those APIs were not tested. In particular, we need to only initialize the service from the first process.
Fix it and add some tests that use the distributed service from multiple threads within a unit test. Move the state of jax.distributed into an object so it can be instantiated multiple times from a test case in parallel rather than being process-global.
[XLA:Python] Add gil release guards around distributed system init/shutdown. This allows testing using multiple threads.
PiperOrigin-RevId: 453480351
There was a version compatibility problem that meant that the distributed system did not correctly initialize when mixing the current released versions of jax and jaxlib. We don't have tests for multiple PGUs in our CI, so this was not caught.
In particular, separate "cuda" from "rocm" in MHLO lowering rules. This change is in preparation for refactoring how GPU-specific lowering rules are implemented in JAX, allowing both kind of rules to coexist.
[PJRT] [XLA:Python] Allow the user to specify a particular platform (e.g., "cuda" or "rocm") when creating a GPU device.
PiperOrigin-RevId: 446737518