9731 Commits

Author SHA1 Message Date
Adam Paszke
bb96226dd8 [Mosaic GPU] Add support for small RHS tile sizes in WGMMA
This is useful for more fine-grained autotuning and can help avoid
wave quantization effects.

PiperOrigin-RevId: 732105219
2025-02-28 05:41:30 -08:00
Benjamin Chetioui
a9ab614123 [Pallas/Mosaic GPU] Add an abstraction to obtain a slice of dynamic shared memory when using waprgroup semantics.
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.

This is in preparation of changes implementing transform inference.

PiperOrigin-RevId: 732091266
2025-02-28 04:38:25 -08:00
Adam Paszke
832f5a3aff [Mosaic GPU] Remove TMA inputs
This test configuration dates back to the time when we were very unsure
about how to use TMA. At this point we have plenty of experience and it
makes more sense to focus the test in question on verifying WGMMA. This
also simplifies adding support for smaller RHS tiling.

PiperOrigin-RevId: 732040900
2025-02-28 01:19:28 -08:00
jax authors
d8953e5311 Remove spurious zero_to_zero conversion used exclusively for backcompat types as a way of supporting (best effort) unsupported types on hardware to make it easier to debug.
In general, this is a good feature, but it assumed that the packing type utilized here was exclusively for backcompat, and so always applied the adjustment.

PiperOrigin-RevId: 731954456
2025-02-27 19:21:37 -08:00
Dan Foreman-Mackey
810445b10d Temporarily skip linalg sharding tests on GPU.
PiperOrigin-RevId: 731877477
2025-02-27 14:56:16 -08:00
Yash Katariya
dda62f576f Make sure default layout is None for input and output layout in all codepaths
PiperOrigin-RevId: 731865511
2025-02-27 14:26:25 -08:00
jax authors
c7ca35fe32 Merge pull request #26345 from wenscarl:scaled_matmul
PiperOrigin-RevId: 731865430
2025-02-27 14:24:48 -08:00
jax authors
3450e2cee0 Disable certain tests on V4 and below.
PiperOrigin-RevId: 731812726
2025-02-27 12:02:52 -08:00
Yash Katariya
d69da3b012 More cleanups around ParsedPartitionSpec. In a follow up CL, I can remove it from NamedSharding constructor. Deleting ParsedPartitionSpec is remaining but that's after 0.5.2 release.
PiperOrigin-RevId: 731785005
2025-02-27 10:51:04 -08:00
Yash Katariya
034a827a4d Remove _parsed_pspec from everywhere in JAX except for NamedSharding constructor. I'll do that in the next CL since that has a dependency on C++ so needs guards.
PiperOrigin-RevId: 731772222
2025-02-27 10:17:06 -08:00
Yash Katariya
177e1f6ed9 Canonicalize PartitionSpec so that we can delete ParsedPartitionSpec. We need to do this after sharding-in-types to speed up NamedSharding construction and remove a lot of tech debt and unnecessary complexity.
* `_partitions` is now canonicalized and only contains `tuples`, `singular strings`, `None` or `UNCONSTRAINED`. No more empty tuples (`P((), 'x')`) and singleton tuples.

* Cache the creating of sharding on ShapedArray since it's expensive to do it a lot of times

* Change the `__hash__` and `__eq__` of `NamedSharding` to depend on `self.spec` instead of `self._parsed_pspec`.

PiperOrigin-RevId: 731745062
2025-02-27 08:59:25 -08:00
Dan Foreman-Mackey
f93c2a1aa5 Add and test support for partitioning of batch dimensions in lax.linalg.
On CPU and GPU, almost all of the primitives in lax.linalg are backed by custom calls that support simple semantics when batch dimensions are sharded. Before this change, all linalg operations on CPU and GPU will insert an `all-gather` before being executed when called on sharded inputs, even when that shouldn't be necessary. This change adds support for this type of partitioning, to cover a wide range of use cases.

There are a few remaining GPU ops that don't support partitioning either because they are backed by HLO ops that don't partition properly (Cholesky factorization and triangular solves), or because they're still using descriptors with problem dimensions in kernel. I'm going to fix these in follow up changes.

PiperOrigin-RevId: 731732301
2025-02-27 08:16:16 -08:00
Bart Chrzaszcz
4997e45743 #sdy close any partially sharded dimensions if using auto axes in a shard_map.
PiperOrigin-RevId: 731724837
2025-02-27 07:53:18 -08:00
jax authors
a8738a069e Merge pull request #26804 from hawkinsp:tsan
PiperOrigin-RevId: 731721390
2025-02-27 07:39:35 -08:00
jax authors
07f5d7a475 Reverts f3fade3b70443b6cf87f01f360e6a1cb85d4b1fb
PiperOrigin-RevId: 731658204
2025-02-27 03:26:37 -08:00
Peter Hawkins
6e73637888 Fix a test failure under multi-threading.
Remove a tsan suppression for a CPython race that is fixed.
2025-02-27 06:07:05 -05:00
Henning Becker
b3f7c93cb2 Fix cudnn version skipping in fused_attention_stablehlo_test
The CUDNN_VERSION is defined as (CUDNN_MAJOR * 10000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL).

Therefore cuDNN 9.1.0 is represented as 90100 - not as 91000.

PiperOrigin-RevId: 731641814
2025-02-27 02:19:43 -08:00
Chris Jones
d6752e9267 [pallas:triton] Generate more efficient code for loading contiguous slices of int4 values.
The existing `int4` loading code is very generic. When reading contiguous data, it will read with offsets like `0, 0, 1, 1, ...`. Triton doesn't consider these to be contiguous in memory and emits much less efficient code than when reading contiguous blocks.

PiperOrigin-RevId: 731635736
2025-02-27 01:57:47 -08:00
Tom Hennigan
1becb57ac9 Add jax.copy_to_host_async(tree).
A relatively common pattern I've observed is the following:

```python
_, metrics = some_jax_function()

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

We are missing an opportunity here to more eagerly begin the h2d copy of
the metrics (e.g. overlap it with closing the "compute_metrics" context
manager etc. The intention of `jax.copy_to_host_async(x)` is to make it
simple to begin h2d transfers as early as possible. Adapting the above code:

```python
_, metrics = some_jax_function()

# Begin D2H copies as early as we can.
jax.copy_to_host_async(metrics)

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

PiperOrigin-RevId: 731626446
2025-02-27 01:22:15 -08:00
Emily Fertig
7f9e7473cf Rolling back a commit that caused a 50-90% performance regression in most MaxText workloads.
Reverts 9d421c9149a1db006444adeea87464bd6b8c0743

PiperOrigin-RevId: 731506280
2025-02-26 16:57:18 -08:00
jax authors
615219b1f6 Remove tensorstore dependency from //jax/experimental/array_serialization:serialization in OSS (see https://github.com/google/tensorstore/issues/218)
Disable serialization_test in OSS.

PiperOrigin-RevId: 731463136
2025-02-26 14:47:16 -08:00
carlosgmartin
ba428d8cda Extend random.orthogonal to semi-orthogonal matrices. Simplify initializers.orthogonal by using it. 2025-02-26 16:39:45 -05:00
Shu Wang
7f0a5bc83e
Add apache header. 2025-02-26 15:26:56 -06:00
Jake VanderPlas
7be7c48985 Implement jnp.ndarray.__contains__
Currently this falls back to a linear scan via __iter__, which is slow
and raises unclear error messages in unsupported cases.
2025-02-26 11:13:45 -08:00
jax authors
d7849d5dd6 Merge pull request #26712 from hawkinsp:ph3
PiperOrigin-RevId: 731302211
2025-02-26 07:02:46 -08:00
jax authors
eb55aef5d3 Merge pull request #26762 from hawkinsp:tsan
PiperOrigin-RevId: 731300991
2025-02-26 06:58:49 -08:00
Peter Hawkins
66293d8897 Remove code present to support jaxlib < 0.5.1.
The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
2025-02-26 07:40:40 -05:00
Adam Paszke
99a12ef9ea [Mosaic GPU] Add support for warpgroup lowering of loops with vector carries
PiperOrigin-RevId: 731260912
2025-02-26 04:29:36 -08:00
Adam Paszke
1de2f839d5 [Mosaic GPU] Make sure to relayout FAs when their layouts mismatch in MGPU lowering
PiperOrigin-RevId: 731253431
2025-02-26 04:03:57 -08:00
Peter Hawkins
33bbd5f119 Fix failures in TSAN free threading CI. 2025-02-26 06:04:26 -05:00
shuw
17088e9025 Improve after review # 2 2025-02-26 04:48:25 +00:00
Jacob Burnim
4c7140fa03 [Pallas] Add option for async DMAs in the new TPU interpret mode
When dma_execution_mode='on_wait', we wait to execute DMAs until we are interpreting a `dma_wait` instruction.  In particular, while a device is waiting on a DMA semaphore, we will (partially) execute DMAs that signal that semaphore until the wait operation can succeed.

PiperOrigin-RevId: 731103569
2025-02-25 18:19:20 -08:00
Matthias Kramm
e8543024e5 Add unfused_hbm usage to binary ops and dot_general.
PiperOrigin-RevId: 731066135
2025-02-25 16:10:25 -08:00
jax authors
467e0bddb4 Merge pull request #26676 from Rifur13:padding
PiperOrigin-RevId: 731024640
2025-02-25 14:12:14 -08:00
Matthias Kramm
aad178a6f8 roofline: Add support for min_p, max_p, reduce_sum_p.
PiperOrigin-RevId: 731024098
2025-02-25 14:10:15 -08:00
Matthias Kramm
08081c4db6 roofline: Support broadcasting, for binary ops.
PiperOrigin-RevId: 731014250
2025-02-25 13:46:00 -08:00
Gleb Pobudzey
a35494e020 Allow query and keys that aren’t multiples of 128 2025-02-25 19:13:24 +00:00
jax authors
eb912ad0d9 Create jax wheel build target.
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)

Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the `jax` wheel with `python3 -m build` command.

Bazel `jax` wheel target: `//:jax_wheel`

Environment variables combinations for creating wheels with different versions:
  * self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * release: `--repo_env=ML_WHEEL_TYPE=release`
  * release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
  * nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 730916743
2025-02-25 09:30:08 -08:00
shuw
681ee18436 Fix CI 2025-02-25 17:15:31 +00:00
Adam Paszke
3d87a01bea [Pallas:MGPU] Adjust warpgroup lowering to the recent emit_pipeline changes
The Pallas-level pipelining generates a number of ops we haven't had to deal with before
like conditionals, scans, etc.

PiperOrigin-RevId: 730899808
2025-02-25 08:39:44 -08:00
Dan Foreman-Mackey
a3a48af105 Add tests for lax.linalg.svd algorithm specification.
PiperOrigin-RevId: 730890907
2025-02-25 08:11:52 -08:00
Bixia Zheng
30348e90e7 [jax:custom_partitioning] Propagate static arguments to sharding_rule callback.
PiperOrigin-RevId: 730885306
2025-02-25 07:55:00 -08:00
Yash Katariya
9deb7e3d96 [sharding_in_types] physical_aval should set the correct sharding on ShapedArray so that lowering and compilation don't crash
PiperOrigin-RevId: 730885084
2025-02-25 07:53:14 -08:00
Benjamin Chetioui
5024ef213f [Mosaic GPU] Add layout inference for scf.ForOp and scf.YieldOp.
PiperOrigin-RevId: 730873769
2025-02-25 07:13:25 -08:00
jax authors
7acd60c867 Merge pull request #26058 from gnecula:debug_info_pallas
PiperOrigin-RevId: 730866402
2025-02-25 06:47:56 -08:00
Benjamin Chetioui
6f966397e0 [Mosaic GPU][NFC] Remove unnecessary lambda wrappers from test.
PiperOrigin-RevId: 730847574
2025-02-25 05:43:59 -08:00
George Necula
c4e0db6f8a [better_errors] Port the Pallas debug info mechanisms to the new JAX DebugInfo.
Now that we carry debug informatiion in Jaxpr we can remove the Pallas-specific
tracking of the `func_src_info`, e.g., `NameAndSrcInfo`.
2025-02-25 14:43:17 +01:00
Benjamin Chetioui
5b13883f8e [Mosaic GPU] Add dialect lowering logic for splat constants.
PiperOrigin-RevId: 730842871
2025-02-25 05:25:56 -08:00
Sergei Lebedev
c13a2f95d5 [pallas:mosaic_gpu] Use emit_pipeline for pipelining in the lowering
This shaves off a lot of complexity from our lowering code, while retaining
all of the functionality, except the arrive_tx optimization: `emit_pipeline`
arrives once per buffer, whereas the pipelining in the lowering used to
arrive once for all buffers.

PiperOrigin-RevId: 730824239
2025-02-25 04:14:10 -08:00
Adam Paszke
676acebafa [Pallas:MGPU] Enable lowering for .astype and scalar broadcasts
PiperOrigin-RevId: 730805326
2025-02-25 03:01:11 -08:00