26079 Commits

Author SHA1 Message Date
jax authors
8492897fd3 Merge pull request #26291 from carlosgmartin:simplify_nn_initializers_orthogonal
PiperOrigin-RevId: 731455939
2025-02-26 14:26:15 -08:00
Nitin Srinivasan
a65de52421 Enable resultstore logging
Tests logged with resulstore are much easier to read and debug

PiperOrigin-RevId: 731448196
2025-02-26 14:04:58 -08:00
carlosgmartin
ba428d8cda Extend random.orthogonal to semi-orthogonal matrices. Simplify initializers.orthogonal by using it. 2025-02-26 16:39:45 -05:00
Shu Wang
7f0a5bc83e
Add apache header. 2025-02-26 15:26:56 -06:00
jax authors
f3fade3b70 Merge pull request #26779 from jakevdp:array-contains
PiperOrigin-RevId: 731430821
2025-02-26 13:17:04 -08:00
Jake VanderPlas
7be7c48985 Implement jnp.ndarray.__contains__
Currently this falls back to a linear scan via __iter__, which is slow
and raises unclear error messages in unsupported cases.
2025-02-26 11:13:45 -08:00
jax authors
8b99ddc022 Merge pull request #26740 from dfm:fix-upstream-nightly-uv
PiperOrigin-RevId: 731379980
2025-02-26 10:56:59 -08:00
Dan Foreman-Mackey
b8f236e64d Add --system to uv commands in upstream-nightly workflow. 2025-02-26 13:21:41 -05:00
William S. Moses
8262987a1c Fix build dependencies
PiperOrigin-RevId: 731330542
2025-02-26 08:38:31 -08:00
jax authors
d7849d5dd6 Merge pull request #26712 from hawkinsp:ph3
PiperOrigin-RevId: 731302211
2025-02-26 07:02:46 -08:00
jax authors
eb55aef5d3 Merge pull request #26762 from hawkinsp:tsan
PiperOrigin-RevId: 731300991
2025-02-26 06:58:49 -08:00
jax authors
c9c7250dd4 Upgrade to Bazel 7.4.1
PiperOrigin-RevId: 731278247
2025-02-26 05:33:24 -08:00
Klaus Greff
5acfc88a00
fix Initializer protocol 2025-02-26 14:25:15 +01:00
Peter Hawkins
66293d8897 Remove code present to support jaxlib < 0.5.1.
The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
2025-02-26 07:40:40 -05:00
Adam Paszke
99a12ef9ea [Mosaic GPU] Add support for warpgroup lowering of loops with vector carries
PiperOrigin-RevId: 731260912
2025-02-26 04:29:36 -08:00
Adam Paszke
1de2f839d5 [Mosaic GPU] Make sure to relayout FAs when their layouts mismatch in MGPU lowering
PiperOrigin-RevId: 731253431
2025-02-26 04:03:57 -08:00
Adam Paszke
3251b55ef2 [Pallas:MGPU] Don't recreate single_thread_predicate at every rule
While the predicate helps us avoid branching, it can be created once per
block. Its creation uses `*.sync` instructions, which are not DCEd by
LLVM and end up polluting the final code.

PiperOrigin-RevId: 731253109
2025-02-26 04:02:21 -08:00
Benjamin Chetioui
7a34f1cedc [Pallas/Mosaic GPU][NFC] Move thread_semantics to ModuleContext.
This simplifies the propagation of the argument, and is the proper place to
put it.

PiperOrigin-RevId: 731239831
2025-02-26 03:08:42 -08:00
Peter Hawkins
33bbd5f119 Fix failures in TSAN free threading CI. 2025-02-26 06:04:26 -05:00
jax authors
f21eefe112 Update XLA dependency to use revision
41c2b0eda0.

PiperOrigin-RevId: 731216015
2025-02-26 01:42:49 -08:00
shuw
17088e9025 Improve after review # 2 2025-02-26 04:48:25 +00:00
Jacob Burnim
4c7140fa03 [Pallas] Add option for async DMAs in the new TPU interpret mode
When dma_execution_mode='on_wait', we wait to execute DMAs until we are interpreting a `dma_wait` instruction.  In particular, while a device is waiting on a DMA semaphore, we will (partially) execute DMAs that signal that semaphore until the wait operation can succeed.

PiperOrigin-RevId: 731103569
2025-02-25 18:19:20 -08:00
Nitin Srinivasan
7566daba68 Use uv instead of pip for installing Python packages
Missed including these in 4b4f2f9cb9

PiperOrigin-RevId: 731095379
2025-02-25 17:48:22 -08:00
Matthias Kramm
e8543024e5 Add unfused_hbm usage to binary ops and dot_general.
PiperOrigin-RevId: 731066135
2025-02-25 16:10:25 -08:00
Nitin Srinivasan
f57c18ad1b Install uv to fix module not found error on Windows
Ideally, this install should be in the Dockerfile but updating the Windows dockerfile is not straightforward so I'm doing the install here for the time being.

PiperOrigin-RevId: 731055684
2025-02-25 15:39:07 -08:00
Nitin Srinivasan
771306bab3 Use ${{ !cancelled() }} instead of ${{ always() }}
`${{ always() }}` makes it difficult to cancel a workflow. See https://github.com/orgs/community/discussions/26303

PiperOrigin-RevId: 731044750
2025-02-25 15:06:38 -08:00
jax authors
dc1c3f9abd Disable //tests:serialization_test_cpu from TSAN job and remove tensorstore dependency from //jax/experimental/array_serialization:serialization.
`TSAN CPython` is unable to find a compatible version of `tensorstore` wheel, hence the test can not be executed.

PiperOrigin-RevId: 731027518
2025-02-25 14:19:02 -08:00
jax authors
467e0bddb4 Merge pull request #26676 from Rifur13:padding
PiperOrigin-RevId: 731024640
2025-02-25 14:12:14 -08:00
Matthias Kramm
aad178a6f8 roofline: Add support for min_p, max_p, reduce_sum_p.
PiperOrigin-RevId: 731024098
2025-02-25 14:10:15 -08:00
Matthias Kramm
08081c4db6 roofline: Support broadcasting, for binary ops.
PiperOrigin-RevId: 731014250
2025-02-25 13:46:00 -08:00
Nitin Srinivasan
cf01fdfe6a Use the 64 core Windows runner to build artifacts
Now that we have disabled RBE on Windows, we need to use the bigger machine to build fast.

PiperOrigin-RevId: 731012952
2025-02-25 13:42:16 -08:00
jax authors
7c26ab53f6 Use jax.Array as type annotation for pallas random keys
jax_prng.PRNGKeyArray is not exposed to the public jax API, resulting in type check errors when sampling outside of tests.

PiperOrigin-RevId: 731008883
2025-02-25 13:30:58 -08:00
Adam Paszke
cb7402f6de Remove MemoryEffects annotations from async_{load/store} ops
The annotation on async_load didn't indicate its write to SMEM, allowing it
to be DCEd by MLIR canonicalization. We don't get much mileage out of those
annotations, so let's just delete them for simplicity.

PiperOrigin-RevId: 731003033
2025-02-25 13:15:00 -08:00
jax authors
03e2c888e2 Merge pull request #26327 from ksebaz:fix-rocm-with-distributed
PiperOrigin-RevId: 730999556
2025-02-25 13:05:16 -08:00
Nitin Srinivasan
2f6f722150 Disable RBE on Windows
We no longer have a RBE pool with ltsc2019 image and are blocked on upgrading GKE to ltsc2022.

PiperOrigin-RevId: 730997201
2025-02-25 12:58:45 -08:00
Dan Foreman-Mackey
553b441fef Use LAPACK trsm kernel even for batched solves.
Depending on the platform and linked LAPACK library, this change seems to improve (or at least not degrade) performance across a wide range of problem and batch sizes. On colab, the performance is not dramatically improved for most input shapes, but on my Mac, this improves the performance of batched triangular solves by a factor of a few up to an order of magnitude across all the problems that I tried.

PiperOrigin-RevId: 730971127
2025-02-25 11:49:01 -08:00
Gleb Pobudzey
a35494e020 Allow query and keys that aren’t multiples of 128 2025-02-25 19:13:24 +00:00
Dan Foreman-Mackey
525cb4bde4 Rename top level build file to BUILD.bazel.
PiperOrigin-RevId: 730957694
2025-02-25 11:13:17 -08:00
Peter Hawkins
256e37af5f Port many uses of contextlib.contextdecorator to explicit context manager classes.
contextdecorator turns out to be slower than just writing a decorator class explicitly. Since we use many decorators per-equation, this causes a measurable speed difference in certain benchmarks.

PiperOrigin-RevId: 730939406
2025-02-25 10:31:05 -08:00
Dan Foreman-Mackey
2ce88c950a Deprecate alpha argument to trsm LAPACK kernel.
(Part of general cleanups of the lax.linalg submodule.)

This is always set to 1 and I don't see any benefit to keeping this argument around. This can be done in a forward and backward compatible way following these docs: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility

We start by updating the FFI handler to remove the explicit alpha argument, but allow it to accept (but ignore) extra input arguments. Then we only pass alpha when lowering in forward compatibility mode, or when the jaxlib version is old (I'm using >0.5.1 as the cutoff assuming that this change doesn't make it into the upcoming release).

Then, the forward compatibility lowering can be removed after at least 21 days, and the kernel can be updated at least 180 days after 0.5.2 is released.

PiperOrigin-RevId: 730928808
2025-02-25 10:04:29 -08:00
jax authors
05614edc7d Expose pallas.mosaic.random.sample_block to pltpu interface
PiperOrigin-RevId: 730923727
2025-02-25 09:49:34 -08:00
Nitin Srinivasan
8ac2969759 Pass JAXLIB_* env variables to docker container
PiperOrigin-RevId: 730922129
2025-02-25 09:46:30 -08:00
Adam Paszke
ced28167e8 [Mosaic GPU] Use explicit recursion in rules instead of doing it automatically
Control-flow ops that have vector inputs or outputs will need to be specially adjusted.

PiperOrigin-RevId: 730922072
2025-02-25 09:44:57 -08:00
jax authors
eb912ad0d9 Create jax wheel build target.
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)

Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the `jax` wheel with `python3 -m build` command.

Bazel `jax` wheel target: `//:jax_wheel`

Environment variables combinations for creating wheels with different versions:
  * self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * release: `--repo_env=ML_WHEEL_TYPE=release`
  * release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
  * nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 730916743
2025-02-25 09:30:08 -08:00
Nitin Srinivasan
7a162f2abc Fix incorrect line separator
On Windows, we are supposed to use "`" instead of "\"

PiperOrigin-RevId: 730916160
2025-02-25 09:28:06 -08:00
jax authors
c06c7851d6 Merge pull request #26668 from jakevdp:sharp-bits
PiperOrigin-RevId: 730915911
2025-02-25 09:26:29 -08:00
jax authors
69a6aaa32e Merge pull request #26734 from hawkinsp:tsan
PiperOrigin-RevId: 730912120
2025-02-25 09:17:07 -08:00
jax authors
a6b8384aed Merge pull request #26564 from gspschmid:gschmid/mini_mpmd
PiperOrigin-RevId: 730912043
2025-02-25 09:15:40 -08:00
shuw
681ee18436 Fix CI 2025-02-25 17:15:31 +00:00
jax authors
0f8e6b996d Typecheck pallas.CostEstimate
Passing a float can lead to miscompilations

PiperOrigin-RevId: 730909635
2025-02-25 09:08:29 -08:00