175 Commits

Author SHA1 Message Date
Yash Katariya
e904c177f7 Delete _normalized_spec from NamedSharding
PiperOrigin-RevId: 697779844
2024-11-18 15:35:38 -08:00
Dan Foreman-Mackey
4a365670f7 Fix pre-commit to run on all files in CI. 2024-11-08 13:47:27 -05:00
Yash Katariya
0bb30f0777 Propagate CopySemantics from python to C++ transfer APIs so that device_put works correctly in presence of copy/donate options that user specified.
This change only supports pinned_host -> pinned_host copies on the same device. HBM -> HBM copies don't work yet and donation also doesn't work in PJRT.

This CL also sets up the plumbing from JAX to PJRT so that in the future support for missing features can be added easily.

Fixes https://github.com/jax-ml/jax/issues/24521

PiperOrigin-RevId: 694274616
2024-11-07 15:51:54 -08:00
Dougal Maclaurin
48f24b6acb Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00
Yash Katariya
66c6292e6a Make committed a public property of jax.Array.
Why?

Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.

PiperOrigin-RevId: 686329828
2024-10-15 19:46:10 -07:00
Yash Katariya
89fcd9f1f1 Better repr of aval when shardings are present
Example: (for array for shape (8, 2) with dtype float32

```
P('x', 'y') -- float32[8@x,2@y]

P('x', None) -- float32[8@x,2]

P(('x', 'y'), None) -- float32[8@xy,2]

P(None, None) -- float32[8, 2]
```

PiperOrigin-RevId: 684996577
2024-10-11 16:48:13 -07:00
Yash Katariya
8ef41a6e14 [sharding_in_types] Normalize partition specs when creating avals so that P(None, None) and P() are treated as replicated and equivalent. Shardings on avals are always normalized.
PiperOrigin-RevId: 684465123
2024-10-10 09:07:44 -07:00
Yash Katariya
a2b39192d2 Make make_array_from_process_local_data go via device_put if there is only 1 process.
PiperOrigin-RevId: 677232996
2024-09-21 10:23:20 -07:00
Sergei Lebedev
b886bd7300 Removed the named_shape argument from jex.core.ShapedArray and jax.ShapeDtypeStruct
It is unused and was only kept around to avoid breaking internal users.

PiperOrigin-RevId: 674310795
2024-09-13 08:38:15 -07:00
Yash Katariya
bcfe95e98e Initial integration of sharding in types in JAX. Currently we just support nary ops in forward only sharding propagation. Currently this functionality is experimental and hidden behind jax_sharding_in_types config flag.
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
2024-08-29 10:50:04 -07:00
Yash Katariya
ef33cf5ace Standardize default layout to None in internals (dispatch, lowering and compilation) and non-default layouts to concrete layouts.
This massively simplifies the amount of checks we need and improves dispatch time too. It also fixes a donation bug being hit in serving code related to layouts and non-standardization of default layout in JAX.

PiperOrigin-RevId: 668527139
2024-08-28 11:06:37 -07:00
Yash Katariya
6e1c23610d If input layouts are specified via in_shardings to jit and the array that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user.
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.

Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
2024-08-19 15:10:32 -07:00
Yash Katariya
0c543aef1d Match the argument name with the name in Args section in docstring
PiperOrigin-RevId: 663926739
2024-08-16 17:22:02 -07:00
Yash Katariya
229cbae5ea Add num_devices to Sharding interface so that it works with NamedSharding containing AbstractMesh too.
PiperOrigin-RevId: 662938823
2024-08-14 09:03:17 -07:00
Parker Schuh
4863a568f9 Fix array_test.py when jax_pmap_no_rank_reduction is flipped to true.
The problem is that squeezing was happening on noncommitted arrays
so list(x) was moving all the shards to device 0. This will potentially
cause ooms.

PiperOrigin-RevId: 661408226
2024-08-09 14:40:52 -07:00
Sergei Lebedev
8d33a6c9a6 Bumped jaxlib version mypy uses on the CI
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
Jake VanderPlas
613a00044c [array API] add device property & to_device method 2024-07-23 11:12:35 -07:00
Yash Katariya
82c608674a Fix the efficient reshard path in device_put when you want to go from 1 mesh to another with different device assignments.
The old code lead to the wrong answer as shown in the test added in this PR.

PiperOrigin-RevId: 654318251
2024-07-20 09:09:05 -07:00
Yash Katariya
0dfb206088 Reference make_array_from_process_local_data in make_array_from_single_device_arrays docstring.
PiperOrigin-RevId: 651937263
2024-07-12 18:10:15 -07:00
Peter Hawkins
8ab0c07edc Don't wrap singleton ir.Values with tuples during HLO lowering.
In general a JAX value might correspond to multiple HLO values, which is why the HLO lowering represents each value as a tuple of zero or more ir.Values. However, the common case is that there is exactly one value, and almost all such lists are singletons.

To reduce the number of singleton list and tuple objects allocated during MLIR lowering, instead represent singleton values as unwrapped ir.Values, and only use a tuple if there is not exactly one ir.Value backing a JAX value.
2024-07-01 16:11:00 -04:00
Yash Katariya
89c404e703 Improve error message when a global jax.Array is closed over a jitted function in McJAX.
PiperOrigin-RevId: 648010704
2024-06-29 14:36:44 -07:00
Mark Sandler
fdb1c14433 Switches make_array_from_callback to use batched_device_put
PiperOrigin-RevId: 647537267
2024-06-27 21:00:05 -07:00
Yash Katariya
e1a496d3b6 Add concrete layout API to JAX. The API takes major_to_minor: tuple[int, ...] and tiling: tuple[tuple[int, ...], ...] as the arguments. Allows users to pass layouts to with_sharding_constraint to constrain the layout + sharding.
`sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required.

memory space is exposed via JAX memories API so it doesn't have to be in the layout API.

Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts.

Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information.

PiperOrigin-RevId: 647487510
2024-06-27 16:47:31 -07:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Dan Foreman-Mackey
6d35b109fd Rename "Example" to "Examples" in docstrings.
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
Peter Hawkins
07d24e7dcc Bump minimum jaxlib version to v0.4.30.
This corresponds to xla_extension_version 271 and mlir_api_version 57.
2024-06-18 12:35:08 -04:00
Junwhan Ahn
5046cedbfc Make pxla.shard_arg batch calls to xc.copy_array_to_devices_with_sharding
This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.

Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.

PiperOrigin-RevId: 643097852
2024-06-13 13:10:10 -07:00
Yash Katariya
6c34a56b87 Add util.cache to jax.clear_caches and move pjit, sharding, array, etc uses of functools.lru_cache to util.cache so that those caches will be cleared if jax.clear_caches is called.
PiperOrigin-RevId: 642359226
2024-06-11 12:46:47 -07:00
Mark Sandler
2c246df439 Reverts dfe61285093ff826e1ad23bb36b77a42c01040b4
PiperOrigin-RevId: 640987745
2024-06-06 12:41:17 -07:00
Peter Hawkins
dfe6128509 Reverts da816d34eaad6a1c6536959ccb4bfee4466c037d
PiperOrigin-RevId: 640886105
2024-06-06 07:10:09 -07:00
Mark Sandler
da816d34ea Makes global_shape optional for jax.make_array_from_process_local_data.
PiperOrigin-RevId: 640695090
2024-06-05 16:58:08 -07:00
Yash Katariya
1edd649de4 Deprecate XLACompatibleSharding in favor of jax.sharding.Sharding.
PiperOrigin-RevId: 640544939
2024-06-05 09:07:27 -07:00
Yash Katariya
b527e1ec07 Improve error message when trying to fetch value of non-addressable array.
PiperOrigin-RevId: 636642130
2024-05-23 12:41:34 -07:00
jax authors
b5583742b5 Merge pull request #21273 from superbobry:mypy-ruff
PiperOrigin-RevId: 636146344
2024-05-22 06:35:38 -07:00
Sergei Lebedev
f5617d7323 Removed noop # type: ignore comments
mypy should now flag these by default.
2024-05-19 21:01:29 +01:00
Mark Sandler
8f045cafd2 Add jax.make_array_from_process_local_data to create a distributed tensor from host data and supporting scaffolding in sharding to be able to figure out dimensions of host data required.
PiperOrigin-RevId: 634205261
2024-05-15 22:06:45 -07:00
Yash Katariya
bac3a6fa8f Allow tokens being passed to jit and through dispatch and being returned from the jitted function.
Fixes https://github.com/google/jax/issues/21160

PiperOrigin-RevId: 632531105
2024-05-10 10:12:48 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
Jake VanderPlas
cbe48cad1e Finalize deprecation of arr.device_buffer and arr.device_buffers
PiperOrigin-RevId: 627899901
2024-04-24 17:27:25 -07:00
Junwhan Ahn
4be25d7151 Optimize jax.device_put() dispatch for 1:1 device-to-device transfers
* Cache the sharding index comparison in addition to sharding index calculation. This helps when the list of indices is expensive to compare.
* Remove caching from `pxla.get_addressable_devices_for_shard_arg()` since `sharding._addressable_device_assignment` is already a cached property.
* Use `a is b` instead of `id(a) == id(b)` since the former is more concise.

PiperOrigin-RevId: 627080325
2024-04-22 10:24:35 -07:00
Yue Sheng
c2d4373535 Make core.Token a non-trivial class which wraps a jax.Array. Currently, we use a singleton and empty core.token object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.
Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).

PiperOrigin-RevId: 626091210
2024-04-18 11:09:55 -07:00
Yash Katariya
eb92a5c711 Add layout support to make_array_from_callback.
PiperOrigin-RevId: 625048520
2024-04-15 12:38:34 -07:00
Junwhan Ahn
ac1a53d8e4 Optimize _create_copy_plan by iterating over only the shards that are needed for materialization
For arrays that are fully or partially replicated, it is more efficient to (pre-)construct a list of addressable array shards that participate in materialization rather than going over all array shards. This is particularly useful for single-controller JAX.

The implementation assumes that addressable arrays appear in the same order as the corresponding addressable devices in `sharding.addressable_devices_indices_map()`.

PiperOrigin-RevId: 624969222
2024-04-15 08:29:47 -07:00
Yash Katariya
2c85ca6fec If callback returns a fully replicated global array, return it as is.
Also take the batched_device_put fast path for non-jax.Array's since slicing can return arrays on multiple devices which batched_device_put doesn't support.

PiperOrigin-RevId: 624763603
2024-04-14 14:35:57 -07:00
Yash Katariya
70dca30395 Remove the dead code now that jax.Array is the only array we have
PiperOrigin-RevId: 624390245
2024-04-12 21:41:42 -07:00
Yash Katariya
001732086b Use _internal_device_list in _get_device so that all places accessing _get_device get a speedup.
PiperOrigin-RevId: 624320655
2024-04-12 16:17:34 -07:00
Junwhan Ahn
3245455900 Optimize _create_copy_plan in array.py
* `_get_device` is called from many tight loops, so it's worth avoiding unnecessary work as much as possible.
* `_create_copy_plan` now uses sharding's `_internal_device_list` instead of querying the device of every shard in a loop.

PiperOrigin-RevId: 624288637
2024-04-12 14:12:31 -07:00
Meekail Zain
a2feff2e54 Add support for max_version, dl_device, copy kwargs in __dlpack__ 2024-04-11 16:44:19 +00:00