45 Commits

Author SHA1 Message Date
Yash Katariya
252caebce3 Create jax.make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], devices: Sequence[jax.Device] | None = None) API to make it easier to create a mesh and reduce a ton of boilerplate.
`jax.make_mesh` is the stable API endpoint of `mesh_utils` but without all the extra options. If you want those, you can still use the experimental endpoint in `mesh_utils`.

PiperOrigin-RevId: 670707995
2024-09-03 14:32:03 -07:00
Colin Gaffney
276c87eba0 Add a more helpful error message in create_hybrid_device_mesh for missing attribute process_index or `slice_index.
PiperOrigin-RevId: 666928476
2024-08-23 14:42:48 -07:00
jax authors
5c9bb612a7 mesh_utils: allow meshes that do not include device at (0, 0, 0).
This is required to allow the use of subslices: e.g., the two halves
of a TPU slice.  One of them will not include the device at
coordinates (0, 0, 0).

E.g., assume we have a TPU v4 1x2x1 slice.

BEFORE THIS CL, if we call _get_physical_tpu_mesh() (an auxiliary for
the public create_device_mesh()) with

jax_devices=[device(0,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]

we get the expected result

[[[device(0,TPU_DEVICE,coords=[0,0,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]]]

However, if we call it with

jax_devices=[device(1,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]

we get the wrong mesh

[[[None]
  [device(1,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]]]

That's because the code before this CL assumed the the incoming
jax_devices are arranged in a cuboid that starts at (0, 0, 0).  When
working with subslices (e.g., half of a TPU slice) that is not always
the case.

AFTER THIS CL, the second case will return
[[[device(1,TPU_DEVICE,coords=[0,1,0,0],vtask=0,slice=0,default_mem=device,mem_spaces=3)]]]

For each dimension from the TPU coordinates, this CL computes the min
/ max; we expect the provided devices to fill the [min, max] interval
(in that dimension).  By requesting this for each dimension, we
request that the set of provided devices constitute a cuboid, but,
unlike before this CL, that cuboid does not need to include (0, 0, 0):
it can be "translated", which allows e.g., both half-slices of a big
slice.

PiperOrigin-RevId: 657902201
2024-07-31 01:11:31 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Yash Katariya
023bc7856b Add registration handler for TPU v5e in mesh_utils.
PiperOrigin-RevId: 643092629
2024-06-13 12:52:33 -07:00
Changhui Lin
b709925e97 Update the arg description.
`slice_index` attribute has been added for GPU.

PiperOrigin-RevId: 618354455
2024-03-22 20:10:18 -07:00
jax authors
8a2ba76b66 Enable Creating Device Mesh with Physical Axes Splits.
PiperOrigin-RevId: 617994892
2024-03-21 16:21:25 -07:00
jax authors
77d41289aa Remove an excessive log statement from the create_device_mesh function.
PiperOrigin-RevId: 604270125
2024-02-05 03:49:50 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
jax authors
94d58b7270 mesh_utils.create_hybrid_device_mesh: make sorting granules by key user configurable.
When sorting by granule key is disabled, the granules are used to create the mesh in the order in which they appear in the sequence of devices.

PiperOrigin-RevId: 590228169
2023-12-12 09:16:41 -08:00
Utku Evci
83b6c3f450 Removing unused description
PiperOrigin-RevId: 584395995
2023-11-21 12:17:20 -08:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
Skye Wanderman-Milne
01372fedca Clarify that mesh_utils.create_device_mesh's contiguous_submeshes arg isn't necessary with jax.Array
PiperOrigin-RevId: 570751299
2023-10-04 11:24:23 -07:00
jax authors
a23bc36d9a Document ValueError raised from mesh util functions
PiperOrigin-RevId: 563158659
2023-09-06 11:10:34 -07:00
jax authors
079ecfbf20 Allow device mesh handler to return None and use the default logic
PiperOrigin-RevId: 554563151
2023-08-07 12:50:43 -07:00
Yash Katariya
109ed5023d Don't depend on jax in mesh_utils to remove circular dependency.
PiperOrigin-RevId: 552799864
2023-08-01 07:45:42 -07:00
jax authors
99f1fcb0d0 Move device_kind-specific logic to create device mesh into a handler function dict. Doesn't change any existing behavior.
PiperOrigin-RevId: 551674832
2023-07-27 16:35:38 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
jax authors
4c800f5a8a Improve error message to point the way to Megacore.
PiperOrigin-RevId: 547194562
2023-07-11 08:16:33 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Jake VanderPlas
592833e02a Change uses of np.product to np.prod
product is deprecated as of numpy 1.25.0.
2023-06-02 03:57:30 -07:00
jax authors
2a1de46527 update error message
PiperOrigin-RevId: 532809831
2023-05-17 09:18:53 -07:00
jax authors
818805f6f9 Improve error message in _create_device_mesh_for_nd_torus
PiperOrigin-RevId: 527834640
2023-04-28 03:15:06 -07:00
Jean-Baptiste Lespiau
6ca249da78 Improve the error message.
PiperOrigin-RevId: 525138471
2023-04-18 07:16:00 -07:00
Peter Hawkins
c1c8257285 Speed up TPU physical mesh construction in mesh_utils.
It turns out np.array(...) has a bad interaction with certain pybind11-wrapped objects, in which it repeatedly calls getattr() and that fails in an expensive way in pybind11 involving C++ exceptions.

PiperOrigin-RevId: 522607230
2023-04-07 08:52:18 -07:00
Peter Hawkins
87a1fea1c7 Improve algorithmic complexity of hybrid mesh construction.
PiperOrigin-RevId: 522583802
2023-04-07 06:11:35 -07:00
Zafarali Ahmed
6e00ba8bad Enable more mesh shape assignment
We now sort the mesh dims by size first. Smaller dims have fewer choices so
they should be assigned first.

PiperOrigin-RevId: 520942700
2023-03-31 09:36:16 -07:00
Yuanzhong Xu
2002d49230 Enable more mesh shape assignment
We now sort the mesh dims by size first. Smaller dims have fewer choices so
they should be assigned first.

PiperOrigin-RevId: 518093398
2023-03-20 15:26:55 -07:00
Peter Hawkins
428189f8fb Replace uses of deprecated JAX sharding APIs with their new names in jax.sharding.
This change updates:
* {jax.experimental.maps.Mesh, jax.interpreters.pxla.Mesh} to jax.sharding.Mesh
* {jax.experimental.PartitionSpec, jax.experimental.pjit.PartitionSpec, jax.interpreters.pxla.PartitionSpec, jax.pxla.PartitionSpec} to jax.sharding.PartitionSpec
* jax.experimental.maps.NamedSharding to jax.sharding.NamedSharding.

PiperOrigin-RevId: 506994892
2023-02-03 14:28:45 -08:00
Yash Katariya
12f7cdeeae Delete the pjit 101 tutorial which is subsumed by https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html
PiperOrigin-RevId: 503581557
2023-01-20 20:35:47 -08:00
Yash Katariya
68c43e6c99 Update the non-contiguous error message to not say GDA anymore
PiperOrigin-RevId: 501396344
2023-01-11 15:35:15 -08:00
Yuanzhong Xu
2cf0791635 Allow more cases in _TRANSPOSE_TRICKS by ignoring leading 1s in the mesh shape.
PiperOrigin-RevId: 498201820
2022-12-28 09:52:24 -08:00
jax authors
7890ec8164 Generalize TPU mesh computations.
PiperOrigin-RevId: 489936718
2022-11-21 03:22:24 -08:00
James Bradbury
bdde0f0cc2 [mesh_utils] Support single-core 2D meshes
PiperOrigin-RevId: 484026013
2022-10-26 11:32:50 -07:00
Nicholas Junge
efd61b73f6 Migrate JAX internals to builtin Python logging
This commit changes the JAX codebase to use Python's builtin logging instead of ABSL logging. With the latter being used in JAX code as of now, the change to Python builtin logging is advised for the following reasons (among others):

- absl-py can be removed as an external dependency of JAX.
- Builtin logging brings the option of adding more log handlers, for example file handlers for log dumps or writers to different IO streams.

Logging in JAX is ported over to take place at the module level. While previously, some Python namespaces within JAX already used module-scoped logging via absl.vlog, the following idiom was adopted to provide the same functionality in Python builtin logging:

```py
import logging
logger = logging.getLogger(__name__)

logger.debug(...)
logger.info(...)
```

 The builtin root logger is left untouched, which is beneficial for downstream users planning to customize the Python root logger. All JAX internal code promises to log to descendants of the top-level "jax" logger by virtue of log propagation.

The package `absl-py` was removed from JAX's install requirements, and added into its test requirements.
2022-10-13 21:32:44 +02:00
James Bradbury
bc03e29afa [mesh_utils] Avoid relying on process-tiled device order
PiperOrigin-RevId: 453270849
2022-06-06 13:30:56 -07:00
jax authors
c72154474f better error message in mesh_utils.py
PiperOrigin-RevId: 446119369
2022-05-02 23:24:56 -07:00
James Bradbury
38e754585f [mesh_utils] Add device/slice count checks
PiperOrigin-RevId: 443178279
2022-04-20 13:30:01 -07:00
James Bradbury
3f9e45e0c5 [mesh_utils] Support creating device meshes for hybrid networks
Also makes some NFCs to other mesh_utils code.

PiperOrigin-RevId: 442581767
2022-04-18 11:00:30 -07:00
Yilei Yang
7ad1120da0 Remove unused comments related to Python 2 compatibility.
PiperOrigin-RevId: 441831488
2022-04-14 12:52:51 -07:00
Yuanzhong Xu
3a949acccb Allow a single logical mesh dim to take all devices.
PiperOrigin-RevId: 435240241
2022-03-16 20:58:59 -07:00
Skye Wanderman-Milne
bcee442390 Improve TPU v2 and v3 mesh_utils.create_device_mesh logic.
* Fixes a bug when a non-3D mesh was requested
* Adds new logic when requesting a single-host mesh
* Extends logic to v2 as well as v3
2022-03-08 22:47:10 +00:00
Yash Katariya
687a7630ee Deprecate maps.mesh and replace it with maps.Mesh.
PiperOrigin-RevId: 430489855
2022-02-23 10:47:06 -08:00
Skye Wanderman-Milne
17b0866bbe Add contiguous_submeshes option to mesh_utils.create_device_mesh().
Unless you're using GlobalDeviceArrays, the device mesh passed to pjit
must be composed of contiguous submeshes for each process (i.e. each
process's local devices must all be next to each other in the full
mesh and form a rectangular submesh). This change teaches
`create_device_mesh` how to output meshes that satisfy this
constraint in some common cases.

This isn't the default behavior because the resulting meshes are a
little awkward and magical, and eventually we'd like using
GlobalDeviceArrays to be the common use case.
2021-12-10 00:01:12 +00:00
Qiao Zhang
64569abb46 Upstream mesh utils to JAX core.
Co-authored-by: James Bradbury <jekbradbury@google.com>
Co-authored-by: Anselm Levskaya <levskaya@google.com>
PiperOrigin-RevId: 415136597
2021-12-08 17:21:58 -08:00