25689 Commits

Author SHA1 Message Date
George Necula
a0812cd57e [better_errors] Make it explicit that debug_info is not None.
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.

For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.

See https://github.com/jax-ml/jax/issues/26480 for more details.

PiperOrigin-RevId: 726770483
2025-02-13 22:07:04 -08:00
jax authors
60dcded2af Merge pull request #26518 from superbobry:maint-2
PiperOrigin-RevId: 726663977
2025-02-13 15:44:19 -08:00
jax authors
f0cd1686ec Merge pull request #26509 from andportnoy:aportnoy/pallas-mosaic-gpu-test-sm90a
PiperOrigin-RevId: 726624339
2025-02-13 13:52:31 -08:00
Dan Foreman-Mackey
14afb73241 Move info!=0 logic into lax.linalg.tridiagonal lowering rule.
PiperOrigin-RevId: 726617102
2025-02-13 13:33:22 -08:00
jax authors
91c6e449ae Merge pull request #26461 from ROCm:run-less-rocm-tests
PiperOrigin-RevId: 726614797
2025-02-13 13:27:47 -08:00
Sergei Lebedev
a73456d54d Removed unused `# type: ignore` comments
For future reference, this can be done via

    python -m mypy jax --warn-unused-ignores > /tmp/unused.txt
    while IFS=: read file line rest; do
      echo "$file:$line";
      gsed -i "${line}s/ *\# type: ignore\(\[[^]]*\]\)*//" "$file"
    done < /tmp/unused.txt
2025-02-13 21:12:27 +00:00
Dan Foreman-Mackey
c6c38fb852 Reorder top-level functions in lax.linalg, and add/expand docstrings.
PiperOrigin-RevId: 726603731
2025-02-13 12:57:55 -08:00
jax authors
5ebb7eb55d Merge pull request #26472 from jakevdp:jnp-einsum
PiperOrigin-RevId: 726580373
2025-02-13 11:55:07 -08:00
Yash Katariya
229aa65a3e Split NamedSharding into a separate file called named_sharding.py so that we can import it in core.py and break the cyclic dependency.
PiperOrigin-RevId: 726566863
2025-02-13 11:22:54 -08:00
Dan Foreman-Mackey
ea4e324fe4 Fix some busted batching rules in lax.linalg.
PiperOrigin-RevId: 726543703
2025-02-13 10:28:39 -08:00
Dan Foreman-Mackey
7f999298ac Only cache jax.Array._npy_value when a copy is required.
As discovered in https://github.com/jax-ml/jax/issues/26216, for non-standard dtypes, calling `np.array` on a JAX array will unnecessarily cache the constructed `_npy_value` even when a copy isn't required. This change updates the logic to only save the cached value when it is a copy.

This fixes https://github.com/jax-ml/jax/issues/26216 by making the behavior consistent across dtypes, but we probably also want to expose a mechanism for clearing this cached value regardless.

PiperOrigin-RevId: 726522955
2025-02-13 09:36:55 -08:00
Dan Foreman-Mackey
7efbda6244 Rewrite generic LU pivots to permutation implementation using vmap instead of explicit broadcasting.
I'm working on implementing sharding logic across all of `lax.linalg`, and I've found that the previous implementation of this loop using explicit broadcasted iotas was confounding the partitioner, but this version using vmap batch partitions properly and I don't anticipate any performance differences.

PiperOrigin-RevId: 726518677
2025-02-13 09:24:38 -08:00
Adam Paszke
845e3f8fe8 [Mosaic GPU] Add support for Blackwell MMA with n=512
This requires unrolling into two instructions sequences over n, since the largest
tcgen05.mma instructions can only handle n=256.

PiperOrigin-RevId: 726496900
2025-02-13 08:24:21 -08:00
jax authors
5889fd0d22 Merge pull request #26486 from superbobry:maint-2
PiperOrigin-RevId: 726490849
2025-02-13 08:07:30 -08:00
Adam Paszke
5c7caa3126 [Mosaic GPU] Simplify the collective barrier test to avoid GMEM atomics
It happens rarely, but this test seems to be flaky, probably because we don't
properly synchronize the memory accesses somehow. It's not important, so we
just avoid the data races now.

PiperOrigin-RevId: 726489606
2025-02-13 08:04:21 -08:00
Yash Katariya
2062e986a6 Fix the error message to say out_sharding instead of sharding in lax.reshape sharding rule
PiperOrigin-RevId: 726484167
2025-02-13 07:49:54 -08:00
Adam Paszke
b0b1fa7dad Skip pipeline mode args in tests with older libTPU
PiperOrigin-RevId: 726480896
2025-02-13 07:39:16 -08:00
Sergei Lebedev
194884d311 Migrated to mypy 1.14.1 with --allow_redefinition
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,

   def f(x: int) -> str: ...
   def g(x: int) -> str: ...

   callback = f if ... else g  # has type object!
2025-02-13 15:38:28 +00:00
Andrey Portnoy
54fa1b9aa5 [Mosaic GPU] Factor out arch specific Pallas Mosaic GPU tests 2025-02-13 10:29:22 -05:00
jax authors
af2fa9bcde Merge pull request #25056 from chaserileyroberts:chase/compute_on_stream
PiperOrigin-RevId: 726477912
2025-02-13 07:28:46 -08:00
Adam Paszke
f3f54dee52 [Mosaic GPU] Put all inline asm output constraints before input constraints in optimization barrier
Apparently it's a requirement that LLVM has, but it only complains about it when compiled with
assertions enabled, so it went unnoticed for a while.

PiperOrigin-RevId: 726468259
2025-02-13 06:59:49 -08:00
Adam Paszke
a493df4dd8 Fix Windows build for Mosaic GPU extension
We only export symbols that being with `mlir` and a few other prefixes, so this renames our C API functions for consistency with that.

PiperOrigin-RevId: 726468092
2025-02-13 06:58:17 -08:00
Christos Perivolaropoulos
305e55f323 [pallas:mgpu] Fix and test multiple indexers where one is a dynamic selection index.
PiperOrigin-RevId: 726447413
2025-02-13 05:53:08 -08:00
George Necula
7161cad6a7 Update oryx for JAX debug_info
See https://github.com/jax-ml/jax/issues/26480.

PiperOrigin-RevId: 726437113
2025-02-13 05:12:38 -08:00
Adam Paszke
cbe102df85 [Mosaic GPU] Reorganize tests to make sure WGMMA tests are skipped on Blackwell
WGMMA only works on Hopper.

PiperOrigin-RevId: 726432108
2025-02-13 04:56:00 -08:00
jax authors
0941ec9f14 Update XLA dependency to use revision
6b470af698.

PiperOrigin-RevId: 726426766
2025-02-13 04:38:14 -08:00
chaserileyroberts
60f0184637 Added stream annotation support via @compute_on('gpu_stream:#') 2025-02-13 07:15:18 +00:00
Yash Katariya
3ec7a67e51 [sharding_in_types] Make sharding arg to ShapedArray kwarg only
PiperOrigin-RevId: 726272943
2025-02-12 18:22:50 -08:00
Yash Katariya
15cd83ae00 [sharding_in_types] Error out when PartitionSpec is passed to APIs that take out_sharding like einsum when context_mesh is unset.
This change is raising a better error because doing `NamedSharding(empty_mesh, P('x'))` will raise an error on construction but it is uglier than the current error added in this change.

PiperOrigin-RevId: 726253654
2025-02-12 17:13:14 -08:00
Jevin Jiang
876668faa1 [Mosaic TPU] Support bf16 div if HW does not directly support.
PiperOrigin-RevId: 726212286
2025-02-12 15:04:09 -08:00
jax authors
153a7cf913 Merge pull request #26373 from jax-ml:autodidax-stackless
PiperOrigin-RevId: 726211770
2025-02-12 15:02:41 -08:00
jax authors
73c626d95e Merge pull request #26503 from garymm:patch-1
PiperOrigin-RevId: 726197571
2025-02-12 14:24:42 -08:00
Yash Katariya
0944e5202e Create _BaseMesh so that properties can be shared between Mesh and AbstractMesh so that code is not duplicated
PiperOrigin-RevId: 726193613
2025-02-12 14:14:48 -08:00
Yash Katariya
1a62df1ac0 Rename sharding argument to out_sharding for lax.reshape, lax.broadcast_in_dim, lax.broadcast and lax.broadcasted_iota. .bind of these APIs still take sharding as a parameter though (but that's fine since it's internal and not public facing)
PiperOrigin-RevId: 726187934
2025-02-12 13:59:23 -08:00
Yash Katariya
d58c3a4722 [sharding_in_types] Fix some properties that assumed axis_types always existed.
PiperOrigin-RevId: 726187278
2025-02-12 13:57:19 -08:00
Daniel Suo
8c685be688 [xla:cpu] Implement XLA FFI handlers for CPU Jax callbacks.
PiperOrigin-RevId: 726185954
2025-02-12 13:53:36 -08:00
Gary Miguel
e231a35ad3
Fix doc string for PmapSharding
Lack of indent was resulting in extra parameter being shown in the HTML generated docs
2025-02-12 13:39:53 -08:00
Dan Foreman-Mackey
9298018afa Enable shardy batch partitionable FFI test.
PiperOrigin-RevId: 726171678
2025-02-12 13:17:40 -08:00
jax authors
4f1c67e6c0 Merge pull request #26403 from jakevdp:bf16-mean
PiperOrigin-RevId: 726157721
2025-02-12 12:41:20 -08:00
Dougal
9145366f6f Part 1 of a new autodidax based on "stackless" 2025-02-12 15:23:06 -05:00
Jake VanderPlas
b5e7b60d6a jax.numpy reductions: avoid upcast of f16 when dtype is specified by user 2025-02-12 11:49:39 -08:00
jax authors
5b697728c7 Merge pull request #24910 from olupton:expect-pgle
PiperOrigin-RevId: 726106211
2025-02-12 10:26:08 -08:00
jax authors
f7e2901e8b Merge pull request #25955 from tttc3:magma_qr
PiperOrigin-RevId: 726098235
2025-02-12 10:05:02 -08:00
Yash Katariya
2d01df760b [sharding_in_types] Make the typing checks and sharding rule checks a little bit less strict when the current or aval mesh is empty/unset. Also some more changes as listed below:
* get_aval is not context dependent

* canonicalization does not happen for avals on an empty mesh

* jax.jit does not set abstract mesh context anymore before tracing

* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode

* Even if use_mesh is not used in explicit sharding mode, computation follows data works!

* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)

* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.

As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.

PiperOrigin-RevId: 726097292
2025-02-12 10:03:01 -08:00
Nitin Srinivasan
93831bdde7 Download and use jax wheels from GCS bucket for nightly/release test workflows
Unlike continuous workflows, when testing nightly/release artifacts, we want to download and install the `jax` wheels found in the GCS bucket instead of installing it from HEAD.

It looks like `env` setting in the calling workflow isn't passed over to the called workflows so we define a new workflow input, `install-jax-current-commit`, to control the `jax` install behavior.

PiperOrigin-RevId: 726086522
2025-02-12 09:32:05 -08:00
Jake VanderPlas
7ab7b214ac refactor: move jnp.einsum impl into its own submodule 2025-02-12 09:05:30 -08:00
Benjamin Chetioui
837418c652 [Mosaic GPU] Remove old jaxlib version guards.
PiperOrigin-RevId: 726071956
2025-02-12 08:49:40 -08:00
Yash Katariya
b4b4a98db7 [sharding_in_types] When caching mesh with axis_types, make sure the data structure is (axis_size, axis_names, tuple(axis_types))
PiperOrigin-RevId: 726064530
2025-02-12 08:23:52 -08:00
Adam Paszke
f1ab7514db Make sure we take libTPU version into account in the Pallas lowering
Also, strengthen the presubmit to make sure we catch more errors.

PiperOrigin-RevId: 726061633
2025-02-12 08:15:57 -08:00
tttc3
b1b56ea0b0 Enable pivoted QR on GPU via MAGMA.
Originally noted in #20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
2025-02-12 16:12:42 +00:00