25710 Commits

Author SHA1 Message Date
Jake VanderPlas
33b989ac9e refactor: import numpy objects directly in jax.numpy 2025-02-14 12:47:58 -08:00
Jake VanderPlas
36d7f8530b Fix the type annotations and don't += a generator (it's confusing)
The code clearly needs those variables to be lists (it mutates, through
`.append` and such).

PiperOrigin-RevId: 727029815
2025-02-14 12:46:01 -08:00
jax authors
4b94665f4f Merge pull request #26528 from jakevdp:lax-docs
PiperOrigin-RevId: 726971041
2025-02-14 10:11:54 -08:00
jax authors
ca87f5f3bf Merge pull request #26542 from jakevdp:fix-breakage
PiperOrigin-RevId: 726944183
2025-02-14 08:50:43 -08:00
Jake VanderPlas
531443c434 jax.lax: improve docs for pow & related functions 2025-02-14 08:40:19 -08:00
Jake VanderPlas
b93934c7fb Fix breakage in indexing refactor 2025-02-14 08:20:56 -08:00
Ayaka
6addf02add Add JAX error checking support
In this PR, only jit and control flows are supported. Support for vmap and multi-device environments will be added in subsequent PRs.

PiperOrigin-RevId: 726920440
2025-02-14 07:28:21 -08:00
Dan Foreman-Mackey
902ebe1bfe Fix segfault when old GPU plugins are installed.
PiperOrigin-RevId: 726919772
2025-02-14 07:26:45 -08:00
Adam Paszke
b287c3924a Ignore ImportError for Triton on Windows
We don't support Windows GPU builds right now and skip all the tests,
but at the moment they can't even skip because of the import failure.

PiperOrigin-RevId: 726917651
2025-02-14 07:17:49 -08:00
jax authors
794ae0f7b7 Merge pull request #26498 from jakevdp:jnp-indexing
PiperOrigin-RevId: 726917490
2025-02-14 07:16:00 -08:00
Adam Paszke
cdcf35fd70 Remove an unused import
PiperOrigin-RevId: 726910300
2025-02-14 06:49:34 -08:00
jax authors
4df596165d Merge pull request #26523 from andportnoy:aportnoy/mosaic-gpu-dialect-hasattr-import
PiperOrigin-RevId: 726900934
2025-02-14 06:15:34 -08:00
jax authors
12d533f635 Merge pull request #26522 from andportnoy:aportnoy/mosaic-gpu-test-sm90a
PiperOrigin-RevId: 726899717
2025-02-14 06:13:53 -08:00
Adam Paszke
5ab8c5a8a4 Make sure that tests don't change the state of the compilation cache
If it was initialized before the test, it should stay so after. And the other
way around too.

PiperOrigin-RevId: 726899671
2025-02-14 06:12:02 -08:00
Christos Perivolaropoulos
49ad24152c [pallas:mgpu] Change FA3 kernel bc lax.div doesn't like mixed types anymore.
PiperOrigin-RevId: 726883573
2025-02-14 05:10:49 -08:00
Sergei Lebedev
3162cc4d0d [pallas:triton] Added basic support for lax.concatenate
The corresponding Triton op is restricted to `jnp.stack([x, y], axis=-1)`,
so the lowering only supports that case for now.

See #25321.

PiperOrigin-RevId: 726881284
2025-02-14 05:02:53 -08:00
jax authors
80dcb7b5af Update XLA dependency to use revision
a16d96f0a0.

PiperOrigin-RevId: 726863997
2025-02-14 04:01:23 -08:00
Adam Paszke
4a8023fe1e [Mosaic GPU] Define TMEMLayout without referring to the PTX guide
The PTX guide talks about a few layouts by assigning them different
letters, which do not have an obvious meaning. We redefine the layout
by parameterizing it with a 2D tile size which, as far as I can tell,
is sufficient to represent all layouts we care about.

PiperOrigin-RevId: 726833412
2025-02-14 02:06:17 -08:00
George Necula
a0812cd57e [better_errors] Make it explicit that debug_info is not None.
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.

For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.

See https://github.com/jax-ml/jax/issues/26480 for more details.

PiperOrigin-RevId: 726770483
2025-02-13 22:07:04 -08:00
jax authors
60dcded2af Merge pull request #26518 from superbobry:maint-2
PiperOrigin-RevId: 726663977
2025-02-13 15:44:19 -08:00
jax authors
f0cd1686ec Merge pull request #26509 from andportnoy:aportnoy/pallas-mosaic-gpu-test-sm90a
PiperOrigin-RevId: 726624339
2025-02-13 13:52:31 -08:00
Dan Foreman-Mackey
14afb73241 Move info!=0 logic into lax.linalg.tridiagonal lowering rule.
PiperOrigin-RevId: 726617102
2025-02-13 13:33:22 -08:00
jax authors
91c6e449ae Merge pull request #26461 from ROCm:run-less-rocm-tests
PiperOrigin-RevId: 726614797
2025-02-13 13:27:47 -08:00
Sergei Lebedev
a73456d54d Removed unused `# type: ignore` comments
For future reference, this can be done via

    python -m mypy jax --warn-unused-ignores > /tmp/unused.txt
    while IFS=: read file line rest; do
      echo "$file:$line";
      gsed -i "${line}s/ *\# type: ignore\(\[[^]]*\]\)*//" "$file"
    done < /tmp/unused.txt
2025-02-13 21:12:27 +00:00
Dan Foreman-Mackey
c6c38fb852 Reorder top-level functions in lax.linalg, and add/expand docstrings.
PiperOrigin-RevId: 726603731
2025-02-13 12:57:55 -08:00
Andrey Portnoy
ea5eb49aa9 [Mosaic GPU] Use gettatr to import version-specific dialect ops 2025-02-13 15:08:42 -05:00
Andrey Portnoy
ae9389dc0f [Mosaic GPU] Factor out Mosaic GPU dialect arch-specific tests 2025-02-13 15:04:34 -05:00
Jake VanderPlas
f750d0b855 refactor: move lax_numpy indexing routines to their own submodule 2025-02-13 12:03:07 -08:00
jax authors
5ebb7eb55d Merge pull request #26472 from jakevdp:jnp-einsum
PiperOrigin-RevId: 726580373
2025-02-13 11:55:07 -08:00
Yash Katariya
229aa65a3e Split NamedSharding into a separate file called named_sharding.py so that we can import it in core.py and break the cyclic dependency.
PiperOrigin-RevId: 726566863
2025-02-13 11:22:54 -08:00
Dan Foreman-Mackey
ea4e324fe4 Fix some busted batching rules in lax.linalg.
PiperOrigin-RevId: 726543703
2025-02-13 10:28:39 -08:00
Dan Foreman-Mackey
7f999298ac Only cache jax.Array._npy_value when a copy is required.
As discovered in https://github.com/jax-ml/jax/issues/26216, for non-standard dtypes, calling `np.array` on a JAX array will unnecessarily cache the constructed `_npy_value` even when a copy isn't required. This change updates the logic to only save the cached value when it is a copy.

This fixes https://github.com/jax-ml/jax/issues/26216 by making the behavior consistent across dtypes, but we probably also want to expose a mechanism for clearing this cached value regardless.

PiperOrigin-RevId: 726522955
2025-02-13 09:36:55 -08:00
Dan Foreman-Mackey
7efbda6244 Rewrite generic LU pivots to permutation implementation using vmap instead of explicit broadcasting.
I'm working on implementing sharding logic across all of `lax.linalg`, and I've found that the previous implementation of this loop using explicit broadcasted iotas was confounding the partitioner, but this version using vmap batch partitions properly and I don't anticipate any performance differences.

PiperOrigin-RevId: 726518677
2025-02-13 09:24:38 -08:00
Adam Paszke
845e3f8fe8 [Mosaic GPU] Add support for Blackwell MMA with n=512
This requires unrolling into two instructions sequences over n, since the largest
tcgen05.mma instructions can only handle n=256.

PiperOrigin-RevId: 726496900
2025-02-13 08:24:21 -08:00
jax authors
5889fd0d22 Merge pull request #26486 from superbobry:maint-2
PiperOrigin-RevId: 726490849
2025-02-13 08:07:30 -08:00
Adam Paszke
5c7caa3126 [Mosaic GPU] Simplify the collective barrier test to avoid GMEM atomics
It happens rarely, but this test seems to be flaky, probably because we don't
properly synchronize the memory accesses somehow. It's not important, so we
just avoid the data races now.

PiperOrigin-RevId: 726489606
2025-02-13 08:04:21 -08:00
Yash Katariya
2062e986a6 Fix the error message to say out_sharding instead of sharding in lax.reshape sharding rule
PiperOrigin-RevId: 726484167
2025-02-13 07:49:54 -08:00
Adam Paszke
b0b1fa7dad Skip pipeline mode args in tests with older libTPU
PiperOrigin-RevId: 726480896
2025-02-13 07:39:16 -08:00
Sergei Lebedev
194884d311 Migrated to mypy 1.14.1 with --allow_redefinition
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,

   def f(x: int) -> str: ...
   def g(x: int) -> str: ...

   callback = f if ... else g  # has type object!
2025-02-13 15:38:28 +00:00
Andrey Portnoy
54fa1b9aa5 [Mosaic GPU] Factor out arch specific Pallas Mosaic GPU tests 2025-02-13 10:29:22 -05:00
jax authors
af2fa9bcde Merge pull request #25056 from chaserileyroberts:chase/compute_on_stream
PiperOrigin-RevId: 726477912
2025-02-13 07:28:46 -08:00
Adam Paszke
f3f54dee52 [Mosaic GPU] Put all inline asm output constraints before input constraints in optimization barrier
Apparently it's a requirement that LLVM has, but it only complains about it when compiled with
assertions enabled, so it went unnoticed for a while.

PiperOrigin-RevId: 726468259
2025-02-13 06:59:49 -08:00
Adam Paszke
a493df4dd8 Fix Windows build for Mosaic GPU extension
We only export symbols that being with `mlir` and a few other prefixes, so this renames our C API functions for consistency with that.

PiperOrigin-RevId: 726468092
2025-02-13 06:58:17 -08:00
Christos Perivolaropoulos
305e55f323 [pallas:mgpu] Fix and test multiple indexers where one is a dynamic selection index.
PiperOrigin-RevId: 726447413
2025-02-13 05:53:08 -08:00
George Necula
7161cad6a7 Update oryx for JAX debug_info
See https://github.com/jax-ml/jax/issues/26480.

PiperOrigin-RevId: 726437113
2025-02-13 05:12:38 -08:00
Adam Paszke
cbe102df85 [Mosaic GPU] Reorganize tests to make sure WGMMA tests are skipped on Blackwell
WGMMA only works on Hopper.

PiperOrigin-RevId: 726432108
2025-02-13 04:56:00 -08:00
jax authors
0941ec9f14 Update XLA dependency to use revision
6b470af698.

PiperOrigin-RevId: 726426766
2025-02-13 04:38:14 -08:00
chaserileyroberts
60f0184637 Added stream annotation support via @compute_on('gpu_stream:#') 2025-02-13 07:15:18 +00:00
Yash Katariya
3ec7a67e51 [sharding_in_types] Make sharding arg to ShapedArray kwarg only
PiperOrigin-RevId: 726272943
2025-02-12 18:22:50 -08:00
Yash Katariya
15cd83ae00 [sharding_in_types] Error out when PartitionSpec is passed to APIs that take out_sharding like einsum when context_mesh is unset.
This change is raising a better error because doing `NamedSharding(empty_mesh, P('x'))` will raise an error on construction but it is uglier than the current error added in this change.

PiperOrigin-RevId: 726253654
2025-02-12 17:13:14 -08:00