`jax.make_mesh` is the stable API endpoint of `mesh_utils` but without all the extra options. If you want those, you can still use the experimental endpoint in `mesh_utils`.
PiperOrigin-RevId: 670707995
This is required to allow the use of subslices: e.g., the two halves
of a TPU slice. One of them will not include the device at
coordinates (0, 0, 0).
E.g., assume we have a TPU v4 1x2x1 slice.
BEFORE THIS CL, if we call _get_physical_tpu_mesh() (an auxiliary for
the public create_device_mesh()) with
jax_devices=[device(0,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]
we get the expected result
[[[device(0,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]]]
However, if we call it with
jax_devices=[device(1,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]
we get the wrong mesh
[[[None]
[device(1,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]]]
That's because the code before this CL assumed the the incoming
jax_devices are arranged in a cuboid that starts at (0, 0, 0). When
working with subslices (e.g., half of a TPU slice) that is not always
the case.
AFTER THIS CL, the second case will return
[[[device(1,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]]]
For each dimension from the TPU coordinates, this CL computes the min
/ max; we expect the provided devices to fill the [min, max] interval
(in that dimension). By requesting this for each dimension, we
request that the set of provided devices constitute a cuboid, but,
unlike before this CL, that cuboid does not need to include (0, 0, 0):
it can be "translated", which allows e.g., both half-slices of a big
slice.
PiperOrigin-RevId: 657902201
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
When sorting by granule key is disabled, the granules are used to create the mesh in the order in which they appear in the sequence of devices.
PiperOrigin-RevId: 590228169
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
It turns out np.array(...) has a bad interaction with certain pybind11-wrapped objects, in which it repeatedly calls getattr() and that fails in an expensive way in pybind11 involving C++ exceptions.
PiperOrigin-RevId: 522607230
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):
- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.
Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:
```py
import logging
logger = logging.getLogger(__name__)
logger.debug(...)
logger.info(...)
```
The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.
The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
Unless you're using GlobalDeviceArrays, the device mesh passed to pjit
must be composed of contiguous submeshes for each process (i.e. each
process's local devices must all be next to each other in the full
mesh and form a rectangular submesh). This change teaches
`create_device_mesh` how to output meshes that satisfy this
constraint in some common cases.
This isn't the default behavior because the resulting meshes are a
little awkward and magical, and eventually we'd like using
GlobalDeviceArrays to be the common use case.