18 Commits

Author SHA1 Message Date
Sergei Lebedev
36f6b52e42 Upgrade most .py sources to 3.9
This commit was generated by running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-08 12:23:15 +00:00
Matthew Johnson
64cb53f624 improve an error message during Mesh creation 2023-12-06 16:43:36 -08:00
Tom Hennigan
1b504bb68e Allow threads to race setting attributes on Mesh.
PiperOrigin-RevId: 584602313
2023-11-22 05:47:56 -08:00
Peter Hawkins
30a0136813 Increase minimum jaxlib version to 0.4.19.
0.4.19 has xla_extension version 207 and mlir_api_version 54.

PiperOrigin-RevId: 583412447
2023-11-17 09:38:31 -08:00
Yash Katariya
8ee58117e2 Don't print all the devices in the mesh during ResourceEnv's repr. Just print the mesh shape.
PiperOrigin-RevId: 577305337
2023-10-27 14:25:34 -07:00
Jake VanderPlas
f1fc2adfbd Fix mypy error 2023-08-29 13:25:12 -07:00
Yash Katariya
6072d5993e Any devices passed to jax.sharding.Mesh are required to be hashable.
This is true for mock devices or user specific devices and jax.devices() too.

Fix the tests so that the mock devices are hashable.

PiperOrigin-RevId: 561103167
2023-08-29 12:20:54 -07:00
Yash Katariya
a37e2159b3 Don't drop out of C++ fast path if mesh pointers are not equal.
This is done by returning the same object when constructing mesh if devices.shape, axis_names and flat device list matches.

PiperOrigin-RevId: 560828993
2023-08-28 15:04:05 -07:00
Yash Katariya
242c2c1b52 Use _internal_device_list in __hash__ and __eq__ of Shardings and Mesh to speed them up.
PiperOrigin-RevId: 557665385
2023-08-16 18:41:58 -07:00
Hyeontaek Lim
423c8d8d4f [JAX] Use DeviceList in JAX Sharding implementations
XLA-compatible `Sharding` implementations keep a `DeviceList` object as
`_internal_device_list`. This is used for finding the default memory kind more
quickly in C++, and enables caching of the default memory kind between multiple
`NamedSharding` objects that shares the same `Mesh`. Also it uses an
addressable device within `DeviceList`, which will be required for supporting
multiple device types with different default memory kinds.

PiperOrigin-RevId: 556969789
2023-08-14 18:11:23 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Yash Katariya
fdbad53b15 Make _device_assignment a Tuple[Device] so that we don't convert a list to a tuple and vice-versa everywhere
PiperOrigin-RevId: 524002310
2023-04-13 08:03:27 -07:00
Yash Katariya
a3ce08cf1d Override addressable_devices for NamedSharding since the mesh can be the same throughout the program.
PiperOrigin-RevId: 522677209
2023-04-07 13:54:37 -07:00
Peter Hawkins
0f368e4428 Cache __repr__ and device_ids properties on Mesh.
PiperOrigin-RevId: 522653188
2023-04-07 12:12:14 -07:00
Yash Katariya
8838039287 Override is_fully_addressable() for NamedSharding.
The intent of this change is to speed up is_fully_addressable() when computing it repeatedly over the same mesh.

PiperOrigin-RevId: 522500766
2023-04-06 19:46:29 -07:00
Peter Hawkins
623282715d Split Mesh and ResourceEnv into a new module jax._src.mesh.
This work is an effort to reduce cyclic dependencies in JAX internals.

Move the _global_to_local and _local_to_global methods out of Mesh and into pxla as free functions. This removes the need for jax._src.mesh to depend on things like avals.

PiperOrigin-RevId: 515667671
2023-03-10 10:08:21 -08:00