211 Commits

Author SHA1 Message Date
Peter Hawkins
474dcd409d Remove code to support jaxlib < v0.6.
New minimum jaxlib_extension_version is 330.

PiperOrigin-RevId: 748853497
2025-04-17 16:44:41 -07:00
Sergei Lebedev
f2f9152d57 Moved the jax.Array baseclass to C++
This allows `ArrayImpl` to directly subclass `jax.Array` without relying on
the expensive virtual subclasses machinery from `abc`.

PiperOrigin-RevId: 743573028
2025-04-03 08:28:02 -07:00
Yash Katariya
ea7fa29be7 Allow tuple(arrays) as an input to make_array_from_single_device_arrays. Fixes https://github.com/jax-ml/jax/issues/27303
PiperOrigin-RevId: 738917340
2025-03-20 13:23:44 -07:00
Yash Katariya
133a885e3b use_mesh and use_concrete_mesh should error when used under jit
PiperOrigin-RevId: 738376533
2025-03-19 06:45:18 -07:00
Yash Katariya
663ef7ae01 Check the type of mesh in use_abstract_mesh and use_concrete_mesh
PiperOrigin-RevId: 738190879
2025-03-18 16:57:40 -07:00
Emily Fertig
bdb6d03322 Allow make_array_from_callback to construct nonaddressable arrays.
PiperOrigin-RevId: 736922870
2025-03-14 11:10:32 -07:00
Emily Fertig
d79472101d Plumb layout through the creation of IFRT Arrays (roll-forward with fix).
Reverts 7f9e7473cfe7e2b3c4eb43ce6df916b3159c1cff

PiperOrigin-RevId: 736739556
2025-03-13 21:32:52 -07:00
Yash Katariya
2d01226b3b Rename some internal APIs (set_abstract_mesh -> use_abstract_mesh and set_concrete_mesh -> use_concrete_mesh)
PiperOrigin-RevId: 736382641
2025-03-12 22:30:05 -07:00
Yash Katariya
c94ec0eb0d Use batched_device_put for token shard_arg handler
PiperOrigin-RevId: 731800613
2025-02-27 11:30:22 -08:00
Emily Fertig
7f9e7473cf Rolling back a commit that caused a 50-90% performance regression in most MaxText workloads.
Reverts 9d421c9149a1db006444adeea87464bd6b8c0743

PiperOrigin-RevId: 731506280
2025-02-26 16:57:18 -08:00
Peter Hawkins
66293d8897 Remove code present to support jaxlib < 0.5.1.
The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
2025-02-26 07:40:40 -05:00
Emily Fertig
9d421c9149 Plumb layout through the creation of PjRtArrays.
This is in preparation to support arrays with no local shards, so that layout may not be accessible from a buffer.

PiperOrigin-RevId: 730469597
2025-02-24 08:53:43 -08:00
Yash Katariya
8bcbf585df Make device_put resharding on single device array input work under use_mesh. Fixes https://github.com/jax-ml/jax/issues/26552
PiperOrigin-RevId: 728382461
2025-02-18 15:22:39 -08:00
jax authors
794ae0f7b7 Merge pull request #26498 from jakevdp:jnp-indexing
PiperOrigin-RevId: 726917490
2025-02-14 07:16:00 -08:00
jax authors
60dcded2af Merge pull request #26518 from superbobry:maint-2
PiperOrigin-RevId: 726663977
2025-02-13 15:44:19 -08:00
Sergei Lebedev
a73456d54d Removed unused `# type: ignore` comments
For future reference, this can be done via

    python -m mypy jax --warn-unused-ignores > /tmp/unused.txt
    while IFS=: read file line rest; do
      echo "$file:$line";
      gsed -i "${line}s/ *\# type: ignore\(\[[^]]*\]\)*//" "$file"
    done < /tmp/unused.txt
2025-02-13 21:12:27 +00:00
Jake VanderPlas
f750d0b855 refactor: move lax_numpy indexing routines to their own submodule 2025-02-13 12:03:07 -08:00
Dan Foreman-Mackey
7f999298ac Only cache jax.Array._npy_value when a copy is required.
As discovered in https://github.com/jax-ml/jax/issues/26216, for non-standard dtypes, calling `np.array` on a JAX array will unnecessarily cache the constructed `_npy_value` even when a copy isn't required. This change updates the logic to only save the cached value when it is a copy.

This fixes https://github.com/jax-ml/jax/issues/26216 by making the behavior consistent across dtypes, but we probably also want to expose a mechanism for clearing this cached value regardless.

PiperOrigin-RevId: 726522955
2025-02-13 09:36:55 -08:00
Emily Fertig
4ae7fcf376 Return arrays from ArrayImpl._check_and_rearrange.
This is in preparation for a larger change, so that input buffers can be checked before Array creation in XLA and the user gets more helpful JAX error messages instead of XLA errors.

Reverts 3b2410f77cdb0acc6951e1770c1229e6689b7409

PiperOrigin-RevId: 723539592
2025-02-05 09:24:22 -08:00
Yash Katariya
bc1a706688 [sharding_in_types] Add a canonicalize_value step before dispatching bind so that we can insert mesh_casts under the following conditions:
* When current_mesh is Manual and aval mesh is Auto

* When current mesh is set and aval mesh is unset

* Final style primitives skip this canonicalization and they are free to add it in their own `bind` method.

* `mesh_cast` is skipped from this canonicalization to avoid recursion errors.

This is required to make sure that after we hit abstract_eval rule and check_jaxpr, everything is properly typed in JAX's type system.

`Auto` right now is a bit more permissive because we need to keep the current code at HEAD working but `Explicit` and `Manual` are very strict.

PiperOrigin-RevId: 722868091
2025-02-03 18:00:19 -08:00
Emily Fertig
3b2410f77c Reverts bb951136e9b91a584bb422119ada76cc69c86024
PiperOrigin-RevId: 721908669
2025-01-31 14:42:22 -08:00
Emily Fertig
bb951136e9 Return arrays from ArrayImpl._check_and_rearrange.
This is in preparation for a larger change, so that input buffers can be checked before Array creation in XLA and the user gets more helpful JAX error messages instead of XLA errors.

PiperOrigin-RevId: 721412760
2025-01-30 09:10:50 -08:00
Matthew Johnson
1fb4b93d41 improve make_array_from_single_device_arrays error 2025-01-25 17:41:01 +00:00
Yash Katariya
d50d1e2c40 Don't allow users to query tracer.sharding even under sharding in types mode.
Instead, users should do `tracer.aval.sharding` so that code behaves the same under jit and eager mode.

PiperOrigin-RevId: 717638986
2025-01-20 15:12:47 -08:00
Parker Schuh
f2f552c108 Allow resharding between tokens on a single device
and multiple devices.

Whenever this happens we can essentially introduce an effects barrier
instead of doing the normal device -> host -> device transfer.

Fixes https://github.com/jax-ml/jax/issues/25671.

PiperOrigin-RevId: 716309978
2025-01-16 11:24:22 -08:00
Jake VanderPlas
ccc3a29537 Internal: use a single registry for abstractify APIs 2024-12-23 08:44:35 -08:00
Jake VanderPlas
5dc37d3f70 Remove internal uses of api_util.shaped_abstractify 2024-12-19 07:06:36 -08:00
Jake VanderPlas
676070f4cd Refactor: move shaped_abstractify to core 2024-12-18 19:14:46 -08:00
Jake VanderPlas
89a54a9e85 Re-land changes from https://github.com/jax-ml/jax/pull/25555
Reverts 25524abc67d82281e8a4093480637785c03a0150

PiperOrigin-RevId: 707679094
2024-12-18 15:02:54 -08:00
jax authors
25524abc67 Reverts b56dc63160eaccd7df05d03b1c38f804ff85f564
PiperOrigin-RevId: 707501925
2024-12-18 04:43:57 -08:00
Jake VanderPlas
3cecbf34f2 Remove core.concrete_aval and replace with abstractify 2024-12-17 18:18:25 -08:00
Jake VanderPlas
c9afc89c23 Always use the same code for array avals 2024-12-17 13:47:58 -08:00
Jake VanderPlas
2c722d9b13 Cleanup: toward merging core.concrete_aval & xla.abstractify 2024-12-17 09:27:00 -08:00
Peter Hawkins
62e66b684b Don't monkey-patch functions in test_utils to count events for tests.
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Peter Hawkins
79318a08cf Remove dead code after minimum jaxlib version bump to v0.4.36.
New minimum xla_extension_version is 299, and the new mlir_api_version is 57.

PiperOrigin-RevId: 704280856
2024-12-09 07:35:05 -08:00
Sergei Lebedev
1ac6b762dd Ensured that JAX type checks under pytype on Python 3.12
Some errors uncovered by pytype look genuine and need to be revisited in
the in the future.

PiperOrigin-RevId: 704268742
2024-12-09 06:53:08 -08:00
Yash Katariya
e904c177f7 Delete _normalized_spec from NamedSharding
PiperOrigin-RevId: 697779844
2024-11-18 15:35:38 -08:00
Dan Foreman-Mackey
4a365670f7 Fix pre-commit to run on all files in CI. 2024-11-08 13:47:27 -05:00
Yash Katariya
0bb30f0777 Propagate CopySemantics from python to C++ transfer APIs so that device_put works correctly in presence of copy/donate options that user specified.
This change only supports pinned_host -> pinned_host copies on the same device. HBM -> HBM copies don't work yet and donation also doesn't work in PJRT.

This CL also sets up the plumbing from JAX to PJRT so that in the future support for missing features can be added easily.

Fixes https://github.com/jax-ml/jax/issues/24521

PiperOrigin-RevId: 694274616
2024-11-07 15:51:54 -08:00
Dougal Maclaurin
48f24b6acb Remove ConcreteArray from JAX. It's easy to do trace-time concretization without it.
PiperOrigin-RevId: 691929385
2024-10-31 14:06:54 -07:00
Yash Katariya
66c6292e6a Make committed a public property of jax.Array.
Why?

Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.

PiperOrigin-RevId: 686329828
2024-10-15 19:46:10 -07:00
Yash Katariya
89fcd9f1f1 Better repr of aval when shardings are present
Example: (for array for shape (8, 2) with dtype float32

```
P('x', 'y') -- float32[8@x,2@y]

P('x', None) -- float32[8@x,2]

P(('x', 'y'), None) -- float32[8@xy,2]

P(None, None) -- float32[8, 2]
```

PiperOrigin-RevId: 684996577
2024-10-11 16:48:13 -07:00
Yash Katariya
8ef41a6e14 [sharding_in_types] Normalize partition specs when creating avals so that P(None, None) and P() are treated as replicated and equivalent. Shardings on avals are always normalized.
PiperOrigin-RevId: 684465123
2024-10-10 09:07:44 -07:00
Yash Katariya
a2b39192d2 Make make_array_from_process_local_data go via device_put if there is only 1 process.
PiperOrigin-RevId: 677232996
2024-09-21 10:23:20 -07:00
Sergei Lebedev
b886bd7300 Removed the named_shape argument from jex.core.ShapedArray and jax.ShapeDtypeStruct
It is unused and was only kept around to avoid breaking internal users.

PiperOrigin-RevId: 674310795
2024-09-13 08:38:15 -07:00
Yash Katariya
bcfe95e98e Initial integration of sharding in types in JAX. Currently we just support nary ops in forward only sharding propagation. Currently this functionality is experimental and hidden behind jax_sharding_in_types config flag.
There will be more improvements and semantics clarification coming in the future as we integrate it more into JAX.

Co-authored-by: Dougal Maclaurin <dougalm@google.com>
PiperOrigin-RevId: 668991384
2024-08-29 10:50:04 -07:00
Yash Katariya
ef33cf5ace Standardize default layout to None in internals (dispatch, lowering and compilation) and non-default layouts to concrete layouts.
This massively simplifies the amount of checks we need and improves dispatch time too. It also fixes a donation bug being hit in serving code related to layouts and non-standardization of default layout in JAX.

PiperOrigin-RevId: 668527139
2024-08-28 11:06:37 -07:00
Yash Katariya
6e1c23610d If input layouts are specified via in_shardings to jit and the array that the jitted function is called with is uncommitted, reshard the input array to the layout specified by the user.
Not doing the resharding, leads to incorrect outputs on GPU and a crash on TPU which is not good.

Fixes: https://github.com/google/jax/issues/23100
PiperOrigin-RevId: 665000157
2024-08-19 15:10:32 -07:00
Yash Katariya
0c543aef1d Match the argument name with the name in Args section in docstring
PiperOrigin-RevId: 663926739
2024-08-16 17:22:02 -07:00
Yash Katariya
229cbae5ea Add num_devices to Sharding interface so that it works with NamedSharding containing AbstractMesh too.
PiperOrigin-RevId: 662938823
2024-08-14 09:03:17 -07:00