16405 Commits

Author SHA1 Message Date
jax authors
7c26ab53f6 Use jax.Array as type annotation for pallas random keys
jax_prng.PRNGKeyArray is not exposed to the public jax API, resulting in type check errors when sampling outside of tests.

PiperOrigin-RevId: 731008883
2025-02-25 13:30:58 -08:00
jax authors
03e2c888e2 Merge pull request #26327 from ksebaz:fix-rocm-with-distributed
PiperOrigin-RevId: 730999556
2025-02-25 13:05:16 -08:00
Dan Foreman-Mackey
553b441fef Use LAPACK trsm kernel even for batched solves.
Depending on the platform and linked LAPACK library, this change seems to improve (or at least not degrade) performance across a wide range of problem and batch sizes. On colab, the performance is not dramatically improved for most input shapes, but on my Mac, this improves the performance of batched triangular solves by a factor of a few up to an order of magnitude across all the problems that I tried.

PiperOrigin-RevId: 730971127
2025-02-25 11:49:01 -08:00
Gleb Pobudzey
a35494e020 Allow query and keys that aren’t multiples of 128 2025-02-25 19:13:24 +00:00
Peter Hawkins
256e37af5f Port many uses of contextlib.contextdecorator to explicit context manager classes.
contextdecorator turns out to be slower than just writing a decorator class explicitly. Since we use many decorators per-equation, this causes a measurable speed difference in certain benchmarks.

PiperOrigin-RevId: 730939406
2025-02-25 10:31:05 -08:00
Dan Foreman-Mackey
2ce88c950a Deprecate alpha argument to trsm LAPACK kernel.
(Part of general cleanups of the lax.linalg submodule.)

This is always set to 1 and I don't see any benefit to keeping this argument around. This can be done in a forward and backward compatible way following these docs: https://docs.jax.dev/en/latest/export/export.html#ensuring-forward-and-backward-compatibility

We start by updating the FFI handler to remove the explicit alpha argument, but allow it to accept (but ignore) extra input arguments. Then we only pass alpha when lowering in forward compatibility mode, or when the jaxlib version is old (I'm using >0.5.1 as the cutoff assuming that this change doesn't make it into the upcoming release).

Then, the forward compatibility lowering can be removed after at least 21 days, and the kernel can be updated at least 180 days after 0.5.2 is released.

PiperOrigin-RevId: 730928808
2025-02-25 10:04:29 -08:00
jax authors
05614edc7d Expose pallas.mosaic.random.sample_block to pltpu interface
PiperOrigin-RevId: 730923727
2025-02-25 09:49:34 -08:00
Adam Paszke
ced28167e8 [Mosaic GPU] Use explicit recursion in rules instead of doing it automatically
Control-flow ops that have vector inputs or outputs will need to be specially adjusted.

PiperOrigin-RevId: 730922072
2025-02-25 09:44:57 -08:00
jax authors
eb912ad0d9 Create jax wheel build target.
This change introduces a uniform way of building the artifacts and controlling the filename version suffixes (see the changes for `jaxlib`, `jax-cuda-plugin` and `jax-cuda-pjrt` in https://github.com/jax-ml/jax/pull/25126)

Previously `jax` wheel was built via `python3 -m build` command. The resulting wheel contained the python packages files in `jax` folder (e.g. the files in the subdirs that have `__init__.py` file).

You can still build the `jax` wheel with `python3 -m build` command.

Bazel `jax` wheel target: `//:jax_wheel`

Environment variables combinations for creating wheels with different versions:
  * self-built wheel (default build rule behavior): `--repo_env=ML_WHEEL_TYPE=snapshot`
  * release: `--repo_env=ML_WHEEL_TYPE=release`
  * release candidate: `--repo_env=ML_WHEEL_TYPE=release --repo_env=ML_WHEEL_VERSION_SUFFIX=-rc1`
  * nightly build: `--repo_env=ML_WHEEL_TYPE=custom --repo_env=ML_WHEEL_BUILD_DATE=<YYYYmmdd> --repo_env=ML_WHEEL_GIT_HASH=$(git rev-parse HEAD)`

PiperOrigin-RevId: 730916743
2025-02-25 09:30:08 -08:00
jax authors
a6b8384aed Merge pull request #26564 from gspschmid:gschmid/mini_mpmd
PiperOrigin-RevId: 730912043
2025-02-25 09:15:40 -08:00
shuw
681ee18436 Fix CI 2025-02-25 17:15:31 +00:00
jax authors
0f8e6b996d Typecheck pallas.CostEstimate
Passing a float can lead to miscompilations

PiperOrigin-RevId: 730909635
2025-02-25 09:08:29 -08:00
Georg Stefan Schmid
2b4c455af5 Add jax.experimental._mini_mpmd 2025-02-25 17:02:29 +00:00
Adam Paszke
3d87a01bea [Pallas:MGPU] Adjust warpgroup lowering to the recent emit_pipeline changes
The Pallas-level pipelining generates a number of ops we haven't had to deal with before
like conditionals, scans, etc.

PiperOrigin-RevId: 730899808
2025-02-25 08:39:44 -08:00
Bixia Zheng
30348e90e7 [jax:custom_partitioning] Propagate static arguments to sharding_rule callback.
PiperOrigin-RevId: 730885306
2025-02-25 07:55:00 -08:00
Yash Katariya
9deb7e3d96 [sharding_in_types] physical_aval should set the correct sharding on ShapedArray so that lowering and compilation don't crash
PiperOrigin-RevId: 730885084
2025-02-25 07:53:14 -08:00
Benjamin Chetioui
5024ef213f [Mosaic GPU] Add layout inference for scf.ForOp and scf.YieldOp.
PiperOrigin-RevId: 730873769
2025-02-25 07:13:25 -08:00
jax authors
7acd60c867 Merge pull request #26058 from gnecula:debug_info_pallas
PiperOrigin-RevId: 730866402
2025-02-25 06:47:56 -08:00
Benjamin Chetioui
5312b5e35a [Mosaic GPU] Add layout inference for arith.Ext{F,SI,UI}Op and arith.Trunc{F,I}Op.
PiperOrigin-RevId: 730851596
2025-02-25 05:59:40 -08:00
George Necula
c4e0db6f8a [better_errors] Port the Pallas debug info mechanisms to the new JAX DebugInfo.
Now that we carry debug informatiion in Jaxpr we can remove the Pallas-specific
tracking of the `func_src_info`, e.g., `NameAndSrcInfo`.
2025-02-25 14:43:17 +01:00
Benjamin Chetioui
5b13883f8e [Mosaic GPU] Add dialect lowering logic for splat constants.
PiperOrigin-RevId: 730842871
2025-02-25 05:25:56 -08:00
Sergei Lebedev
c13a2f95d5 [pallas:mosaic_gpu] Use emit_pipeline for pipelining in the lowering
This shaves off a lot of complexity from our lowering code, while retaining
all of the functionality, except the arrive_tx optimization: `emit_pipeline`
arrives once per buffer, whereas the pipelining in the lowering used to
arrive once for all buffers.

PiperOrigin-RevId: 730824239
2025-02-25 04:14:10 -08:00
Adam Paszke
676acebafa [Pallas:MGPU] Enable lowering for .astype and scalar broadcasts
PiperOrigin-RevId: 730805326
2025-02-25 03:01:11 -08:00
Adam Paszke
71c7622037 [Pallas:MGPU] Change WG semantics convention to represent scalar arrays using scalars
Previously every ShapedArray got converted to an MLIR vector which was more annoying
than helpful.

PiperOrigin-RevId: 730795455
2025-02-25 02:24:06 -08:00
Adam Paszke
d0d5bba645 [Pallas:MGPU] Avoid SMEM->GMEM wait if no outputs are transferred in the pipeline loop
The TMA wait does not add much overhead, but it lets us save on an unnecessary warpgroup barrier.

PiperOrigin-RevId: 730795234
2025-02-25 02:22:25 -08:00
Sergei Lebedev
7eadc64b5a [pallas:mosaic_gpu] Added WG lowering rules for TMA primitives and run_scoped_p
PiperOrigin-RevId: 730780335
2025-02-25 01:32:43 -08:00
Adam Paszke
80848ad859 [Pallas:MGPU] Consistently use i32 as the grid index type in emit_pipeline
That's more consistent with Pallas semantics and avoids generating a slightly
different kernel depending on x32/x64 mode.

PiperOrigin-RevId: 730778314
2025-02-25 01:24:50 -08:00
Sebastian Kehl
6abc76c874 Check "jax_rocm_visible_devices" at client creation.
This aligns rocm with cuda when using jax.distributed in combination
with one of the mechanisms for cluster-autodetection that set visible
devices in the "jax_rocm_visible_devices" flag.

Fixes #26298
2025-02-25 09:59:59 +01:00
Yash Katariya
b707f0bdbb [sharding_in_types] Error out when using auto_axes or explicit_axes API when there is no context mesh.
Those APIs don't support that right now anyways and they raise an ugly KeyError. Instead we raise a better error here.

I have added a TODO to get the mesh from args so that computation follows data works but we can decide to do that in the future if a lot of users request that and don't want to use `use_mesh`.

PiperOrigin-RevId: 730687231
2025-02-24 19:19:49 -08:00
jax authors
41faf51a16 Merge pull request #26715 from Rifur13:normalize
PiperOrigin-RevId: 730667196
2025-02-24 17:54:06 -08:00
Gleb Pobudzey
8f1dd02e5f Add a param to not normalize the attention weights 2025-02-25 00:55:58 +00:00
Yash Katariya
6f8bab3c92 Add sharding mismatch to explain_tracing_cache_miss
PiperOrigin-RevId: 730645598
2025-02-24 16:49:49 -08:00
Peter Hawkins
c8c4cfa04e Update version numbers after 0.5.1 release. 2025-02-24 16:18:25 -05:00
Peter Hawkins
54f707240a Merge branch 'release/0.5.1' into main 2025-02-24 16:06:15 -05:00
Yash Katariya
8739232fe5 Share the logic of allowing propagation to input/output between JAX and australis
PiperOrigin-RevId: 730535564
2025-02-24 11:42:24 -08:00
Matthias Kramm
79e1e1fcee Make mesh and *_spec parameters optional.
PiperOrigin-RevId: 730499695
2025-02-24 10:15:38 -08:00
Yash Katariya
6d8be966a0 Fix shard_map debug_nan leakage of manual out_avals in the impl rules of jit i.e. impl rule of jit saw a manual out_aval which is not expected. This is a band-aid for now with a TODO to do a proper fix
PiperOrigin-RevId: 730499532
2025-02-24 10:15:21 -08:00
Dan Foreman-Mackey
62530d5922 Update JVP rule for lax.linalg.lu to use vmap instead of broadcasted_iotas.
PiperOrigin-RevId: 730497540
2025-02-24 10:09:41 -08:00
Dan Foreman-Mackey
6bd99207d5 Fix rank promotion error in JVP of batched eigh.
PiperOrigin-RevId: 730475017
2025-02-24 09:08:55 -08:00
Emily Fertig
9d421c9149 Plumb layout through the creation of PjRtArrays.
This is in preparation to support arrays with no local shards, so that layout may not be accessible from a buffer.

PiperOrigin-RevId: 730469597
2025-02-24 08:53:43 -08:00
Dan Foreman-Mackey
ae656e1574 Update lax.linalg.svd primitive to use registration helper functions.
PiperOrigin-RevId: 730466560
2025-02-24 08:44:06 -08:00
Yash Katariya
07440f4afa Prepare for JAX release 0.5.1 2025-02-24 10:59:04 -05:00
jax authors
c74f497eaf Merge pull request #25053 from JanLuca:gesvd
PiperOrigin-RevId: 730445233
2025-02-24 07:38:15 -08:00
jax authors
c17ea805f3 Merge pull request #26569 from gnecula:debug_info_arg_names
PiperOrigin-RevId: 730432019
2025-02-24 06:48:41 -08:00
Yash Katariya
7d3c63eded [sharding_in_types] Add more reshape sharding support
* Allow merging and splitting only if major most dim is sharded since that involves no data movement. This only happens if `dimensions` is None i.e. if the input array is in **row-major order**.

  * Merging: If **only** the major most dim is sharded of the merge block then that sharding is propagated to the merge block output

  * Splitting: If the dimension being split is sharded, then the sharding is propagated to the major most dimension post split only if the spec divides the new shape exactly.

PiperOrigin-RevId: 730291595
2025-02-23 21:39:23 -08:00
Sergei Lebedev
908ff49e22 [mosaic_gpu] Warmup the kernel when doing CUPTI based profiling
Closes #26144.

PiperOrigin-RevId: 730159757
2025-02-23 10:20:37 -08:00
Sergei Lebedev
74b2e0203f [pallas:mosaic_gpu] Use {min,max}imumf instead of {min,max}numf
PiperOrigin-RevId: 730154865
2025-02-23 09:52:48 -08:00
George Necula
1be801bac8 [better_errors] Cleanup use of DebugInfo.arg_names and result_paths
Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.

I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.

Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
2025-02-23 08:27:56 +02:00
Yash Katariya
d695aa4c63 [sharding_in_types] Add sharding rules for the following primitives:
* `bitcast_convert_element_type`
  * `cumsum`
  * `cumlogsumexp`
  * `cumprod`
  * `cummax`
  * `cummin`
  * `reduce_window`
  * `reduce_window_sum`
  * `reduce_window_max`
  * `reduce_window_min`
  * `select_and_gather_add`

For `reduce_window_...` primitives only trivial windowing is supported along non-replicated dimensions. We can relax the other NotImplemented case in the future.

PiperOrigin-RevId: 729910108
2025-02-22 10:45:58 -08:00
Jan Naumann
e03fe3a06d Implement SVD algorithm based on QR for CPU targets
In a recent jax release the SvdAlgorithm parameter has been added
to the jax.lax.linalg.svd function. Currently, for CPU targets
still only the divide and conquer algorithm from LAPACK is
supported (gesdd).

This commits adds the functionality to select the QR based
algorithm on CPU as well. Mainly it addes the wrapper code
to call the gesvd function of LAPACK using the FFI interface.

Signed-off-by: Jan Naumann <j.naumann@fu-berlin.de>
2025-02-22 15:24:57 +01:00