25929 Commits

Author SHA1 Message Date
jax authors
a8738a069e Merge pull request #26804 from hawkinsp:tsan
PiperOrigin-RevId: 731721390
2025-02-27 07:39:35 -08:00
jax authors
07f5d7a475 Reverts f3fade3b70443b6cf87f01f360e6a1cb85d4b1fb
PiperOrigin-RevId: 731658204
2025-02-27 03:26:37 -08:00
Peter Hawkins
6e73637888 Fix a test failure under multi-threading.
Remove a tsan suppression for a CPython race that is fixed.
2025-02-27 06:07:05 -05:00
jax authors
0fbc453d94 Update XLA dependency to use revision
fb6241ad51.

PiperOrigin-RevId: 731643649
2025-02-27 02:27:17 -08:00
Henning Becker
b3f7c93cb2 Fix cudnn version skipping in fused_attention_stablehlo_test
The CUDNN_VERSION is defined as (CUDNN_MAJOR * 10000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL).

Therefore cuDNN 9.1.0 is represented as 90100 - not as 91000.

PiperOrigin-RevId: 731641814
2025-02-27 02:19:43 -08:00
Chris Jones
d6752e9267 [pallas:triton] Generate more efficient code for loading contiguous slices of int4 values.
The existing `int4` loading code is very generic. When reading contiguous data, it will read with offsets like `0, 0, 1, 1, ...`. Triton doesn't consider these to be contiguous in memory and emits much less efficient code than when reading contiguous blocks.

PiperOrigin-RevId: 731635736
2025-02-27 01:57:47 -08:00
Tom Hennigan
1becb57ac9 Add jax.copy_to_host_async(tree).
A relatively common pattern I've observed is the following:

```python
_, metrics = some_jax_function()

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

We are missing an opportunity here to more eagerly begin the h2d copy of
the metrics (e.g. overlap it with closing the "compute_metrics" context
manager etc. The intention of `jax.copy_to_host_async(x)` is to make it
simple to begin h2d transfers as early as possible. Adapting the above code:

```python
_, metrics = some_jax_function()

# Begin D2H copies as early as we can.
jax.copy_to_host_async(metrics)

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

PiperOrigin-RevId: 731626446
2025-02-27 01:22:15 -08:00
Sharad Vikram
2646b8d4ad [Pallas TPU] Add support for GridDimensionSemantics to pallas_call
PiperOrigin-RevId: 731543938
2025-02-26 19:34:36 -08:00
Sharad Vikram
b5fcffadd4 Add swap as method to TransformedRef
PiperOrigin-RevId: 731541165
2025-02-26 19:19:10 -08:00
Sharad Vikram
1ecbac9702 [Pallas] Add name parameter to core_map
PiperOrigin-RevId: 731536152
2025-02-26 18:59:01 -08:00
Sharad Vikram
0f0d5e90ef Add support for TPU v5 2x2 tray configuration
PiperOrigin-RevId: 731529917
2025-02-26 18:33:49 -08:00
Emily Fertig
82124da5cd Redefine is_fully_addressable in shardings to support zero local devices for McJAX.
PiperOrigin-RevId: 731526750
2025-02-26 18:17:35 -08:00
Emily Fertig
7f9e7473cf Rolling back a commit that caused a 50-90% performance regression in most MaxText workloads.
Reverts 9d421c9149a1db006444adeea87464bd6b8c0743

PiperOrigin-RevId: 731506280
2025-02-26 16:57:18 -08:00
jax authors
615219b1f6 Remove tensorstore dependency from //jax/experimental/array_serialization:serialization in OSS (see https://github.com/google/tensorstore/issues/218)
Disable serialization_test in OSS.

PiperOrigin-RevId: 731463136
2025-02-26 14:47:16 -08:00
jax authors
8492897fd3 Merge pull request #26291 from carlosgmartin:simplify_nn_initializers_orthogonal
PiperOrigin-RevId: 731455939
2025-02-26 14:26:15 -08:00
Nitin Srinivasan
a65de52421 Enable resultstore logging
Tests logged with resulstore are much easier to read and debug

PiperOrigin-RevId: 731448196
2025-02-26 14:04:58 -08:00
carlosgmartin
ba428d8cda Extend random.orthogonal to semi-orthogonal matrices. Simplify initializers.orthogonal by using it. 2025-02-26 16:39:45 -05:00
jax authors
f3fade3b70 Merge pull request #26779 from jakevdp:array-contains
PiperOrigin-RevId: 731430821
2025-02-26 13:17:04 -08:00
Jake VanderPlas
7be7c48985 Implement jnp.ndarray.__contains__
Currently this falls back to a linear scan via __iter__, which is slow
and raises unclear error messages in unsupported cases.
2025-02-26 11:13:45 -08:00
jax authors
8b99ddc022 Merge pull request #26740 from dfm:fix-upstream-nightly-uv
PiperOrigin-RevId: 731379980
2025-02-26 10:56:59 -08:00
Dan Foreman-Mackey
b8f236e64d Add --system to uv commands in upstream-nightly workflow. 2025-02-26 13:21:41 -05:00
William S. Moses
8262987a1c Fix build dependencies
PiperOrigin-RevId: 731330542
2025-02-26 08:38:31 -08:00
jax authors
d7849d5dd6 Merge pull request #26712 from hawkinsp:ph3
PiperOrigin-RevId: 731302211
2025-02-26 07:02:46 -08:00
jax authors
eb55aef5d3 Merge pull request #26762 from hawkinsp:tsan
PiperOrigin-RevId: 731300991
2025-02-26 06:58:49 -08:00
jax authors
c9c7250dd4 Upgrade to Bazel 7.4.1
PiperOrigin-RevId: 731278247
2025-02-26 05:33:24 -08:00
Peter Hawkins
66293d8897 Remove code present to support jaxlib < 0.5.1.
The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
2025-02-26 07:40:40 -05:00
Adam Paszke
99a12ef9ea [Mosaic GPU] Add support for warpgroup lowering of loops with vector carries
PiperOrigin-RevId: 731260912
2025-02-26 04:29:36 -08:00
Adam Paszke
1de2f839d5 [Mosaic GPU] Make sure to relayout FAs when their layouts mismatch in MGPU lowering
PiperOrigin-RevId: 731253431
2025-02-26 04:03:57 -08:00
Adam Paszke
3251b55ef2 [Pallas:MGPU] Don't recreate single_thread_predicate at every rule
While the predicate helps us avoid branching, it can be created once per
block. Its creation uses `*.sync` instructions, which are not DCEd by
LLVM and end up polluting the final code.

PiperOrigin-RevId: 731253109
2025-02-26 04:02:21 -08:00
Benjamin Chetioui
7a34f1cedc [Pallas/Mosaic GPU][NFC] Move thread_semantics to ModuleContext.
This simplifies the propagation of the argument, and is the proper place to
put it.

PiperOrigin-RevId: 731239831
2025-02-26 03:08:42 -08:00
Peter Hawkins
33bbd5f119 Fix failures in TSAN free threading CI. 2025-02-26 06:04:26 -05:00
jax authors
f21eefe112 Update XLA dependency to use revision
41c2b0eda0.

PiperOrigin-RevId: 731216015
2025-02-26 01:42:49 -08:00
Jacob Burnim
4c7140fa03 [Pallas] Add option for async DMAs in the new TPU interpret mode
When dma_execution_mode='on_wait', we wait to execute DMAs until we are interpreting a `dma_wait` instruction.  In particular, while a device is waiting on a DMA semaphore, we will (partially) execute DMAs that signal that semaphore until the wait operation can succeed.

PiperOrigin-RevId: 731103569
2025-02-25 18:19:20 -08:00
Nitin Srinivasan
7566daba68 Use uv instead of pip for installing Python packages
Missed including these in 4b4f2f9cb9

PiperOrigin-RevId: 731095379
2025-02-25 17:48:22 -08:00
Matthias Kramm
e8543024e5 Add unfused_hbm usage to binary ops and dot_general.
PiperOrigin-RevId: 731066135
2025-02-25 16:10:25 -08:00
Nitin Srinivasan
f57c18ad1b Install uv to fix module not found error on Windows
Ideally, this install should be in the Dockerfile but updating the Windows dockerfile is not straightforward so I'm doing the install here for the time being.

PiperOrigin-RevId: 731055684
2025-02-25 15:39:07 -08:00
Nitin Srinivasan
771306bab3 Use ${{ !cancelled() }} instead of ${{ always() }}
`${{ always() }}` makes it difficult to cancel a workflow. See https://github.com/orgs/community/discussions/26303

PiperOrigin-RevId: 731044750
2025-02-25 15:06:38 -08:00
jax authors
dc1c3f9abd Disable //tests:serialization_test_cpu from TSAN job and remove tensorstore dependency from //jax/experimental/array_serialization:serialization.
`TSAN CPython` is unable to find a compatible version of `tensorstore` wheel, hence the test can not be executed.

PiperOrigin-RevId: 731027518
2025-02-25 14:19:02 -08:00
jax authors
467e0bddb4 Merge pull request #26676 from Rifur13:padding
PiperOrigin-RevId: 731024640
2025-02-25 14:12:14 -08:00
Matthias Kramm
aad178a6f8 roofline: Add support for min_p, max_p, reduce_sum_p.
PiperOrigin-RevId: 731024098
2025-02-25 14:10:15 -08:00
Matthias Kramm
08081c4db6 roofline: Support broadcasting, for binary ops.
PiperOrigin-RevId: 731014250
2025-02-25 13:46:00 -08:00
Nitin Srinivasan
cf01fdfe6a Use the 64 core Windows runner to build artifacts
Now that we have disabled RBE on Windows, we need to use the bigger machine to build fast.

PiperOrigin-RevId: 731012952
2025-02-25 13:42:16 -08:00
jax authors
7c26ab53f6 Use jax.Array as type annotation for pallas random keys
jax_prng.PRNGKeyArray is not exposed to the public jax API, resulting in type check errors when sampling outside of tests.

PiperOrigin-RevId: 731008883
2025-02-25 13:30:58 -08:00
Adam Paszke
cb7402f6de Remove MemoryEffects annotations from async_{load/store} ops
The annotation on async_load didn't indicate its write to SMEM, allowing it
to be DCEd by MLIR canonicalization. We don't get much mileage out of those
annotations, so let's just delete them for simplicity.

PiperOrigin-RevId: 731003033
2025-02-25 13:15:00 -08:00
jax authors
03e2c888e2 Merge pull request #26327 from ksebaz:fix-rocm-with-distributed
PiperOrigin-RevId: 730999556
2025-02-25 13:05:16 -08:00
Nitin Srinivasan
2f6f722150 Disable RBE on Windows
We no longer have a RBE pool with ltsc2019 image and are blocked on upgrading GKE to ltsc2022.

PiperOrigin-RevId: 730997201
2025-02-25 12:58:45 -08:00
Dan Foreman-Mackey
553b441fef Use LAPACK trsm kernel even for batched solves.
Depending on the platform and linked LAPACK library, this change seems to improve (or at least not degrade) performance across a wide range of problem and batch sizes. On colab, the performance is not dramatically improved for most input shapes, but on my Mac, this improves the performance of batched triangular solves by a factor of a few up to an order of magnitude across all the problems that I tried.

PiperOrigin-RevId: 730971127
2025-02-25 11:49:01 -08:00
Gleb Pobudzey
a35494e020 Allow query and keys that aren’t multiples of 128 2025-02-25 19:13:24 +00:00
Dan Foreman-Mackey
525cb4bde4 Rename top level build file to BUILD.bazel.
PiperOrigin-RevId: 730957694
2025-02-25 11:13:17 -08:00
Peter Hawkins
256e37af5f Port many uses of contextlib.contextdecorator to explicit context manager classes.
contextdecorator turns out to be slower than just writing a decorator class explicitly. Since we use many decorators per-equation, this causes a measurable speed difference in certain benchmarks.

PiperOrigin-RevId: 730939406
2025-02-25 10:31:05 -08:00