9119 Commits

Author SHA1 Message Date
Adam Paszke
676acebafa [Pallas:MGPU] Enable lowering for .astype and scalar broadcasts
PiperOrigin-RevId: 730805326
2025-02-25 03:01:11 -08:00
Adam Paszke
71c7622037 [Pallas:MGPU] Change WG semantics convention to represent scalar arrays using scalars
Previously every ShapedArray got converted to an MLIR vector which was more annoying
than helpful.

PiperOrigin-RevId: 730795455
2025-02-25 02:24:06 -08:00
Adam Paszke
d0d5bba645 [Pallas:MGPU] Avoid SMEM->GMEM wait if no outputs are transferred in the pipeline loop
The TMA wait does not add much overhead, but it lets us save on an unnecessary warpgroup barrier.

PiperOrigin-RevId: 730795234
2025-02-25 02:22:25 -08:00
Sergei Lebedev
7eadc64b5a [pallas:mosaic_gpu] Added WG lowering rules for TMA primitives and run_scoped_p
PiperOrigin-RevId: 730780335
2025-02-25 01:32:43 -08:00
Adam Paszke
80848ad859 [Pallas:MGPU] Consistently use i32 as the grid index type in emit_pipeline
That's more consistent with Pallas semantics and avoids generating a slightly
different kernel depending on x32/x64 mode.

PiperOrigin-RevId: 730778314
2025-02-25 01:24:50 -08:00
Sebastian Kehl
6abc76c874 Check "jax_rocm_visible_devices" at client creation.
This aligns rocm with cuda when using jax.distributed in combination
with one of the mechanisms for cluster-autodetection that set visible
devices in the "jax_rocm_visible_devices" flag.

Fixes #26298
2025-02-25 09:59:59 +01:00
Yash Katariya
b707f0bdbb [sharding_in_types] Error out when using auto_axes or explicit_axes API when there is no context mesh.
Those APIs don't support that right now anyways and they raise an ugly KeyError. Instead we raise a better error here.

I have added a TODO to get the mesh from args so that computation follows data works but we can decide to do that in the future if a lot of users request that and don't want to use `use_mesh`.

PiperOrigin-RevId: 730687231
2025-02-24 19:19:49 -08:00
Yash Katariya
6f8bab3c92 Add sharding mismatch to explain_tracing_cache_miss
PiperOrigin-RevId: 730645598
2025-02-24 16:49:49 -08:00
Yash Katariya
8739232fe5 Share the logic of allowing propagation to input/output between JAX and australis
PiperOrigin-RevId: 730535564
2025-02-24 11:42:24 -08:00
Yash Katariya
6d8be966a0 Fix shard_map debug_nan leakage of manual out_avals in the impl rules of jit i.e. impl rule of jit saw a manual out_aval which is not expected. This is a band-aid for now with a TODO to do a proper fix
PiperOrigin-RevId: 730499532
2025-02-24 10:15:21 -08:00
Dan Foreman-Mackey
62530d5922 Update JVP rule for lax.linalg.lu to use vmap instead of broadcasted_iotas.
PiperOrigin-RevId: 730497540
2025-02-24 10:09:41 -08:00
Dan Foreman-Mackey
6bd99207d5 Fix rank promotion error in JVP of batched eigh.
PiperOrigin-RevId: 730475017
2025-02-24 09:08:55 -08:00
Emily Fertig
9d421c9149 Plumb layout through the creation of PjRtArrays.
This is in preparation to support arrays with no local shards, so that layout may not be accessible from a buffer.

PiperOrigin-RevId: 730469597
2025-02-24 08:53:43 -08:00
Dan Foreman-Mackey
ae656e1574 Update lax.linalg.svd primitive to use registration helper functions.
PiperOrigin-RevId: 730466560
2025-02-24 08:44:06 -08:00
jax authors
c74f497eaf Merge pull request #25053 from JanLuca:gesvd
PiperOrigin-RevId: 730445233
2025-02-24 07:38:15 -08:00
jax authors
c17ea805f3 Merge pull request #26569 from gnecula:debug_info_arg_names
PiperOrigin-RevId: 730432019
2025-02-24 06:48:41 -08:00
Yash Katariya
7d3c63eded [sharding_in_types] Add more reshape sharding support
* Allow merging and splitting only if major most dim is sharded since that involves no data movement. This only happens if `dimensions` is None i.e. if the input array is in **row-major order**.

  * Merging: If **only** the major most dim is sharded of the merge block then that sharding is propagated to the merge block output

  * Splitting: If the dimension being split is sharded, then the sharding is propagated to the major most dimension post split only if the spec divides the new shape exactly.

PiperOrigin-RevId: 730291595
2025-02-23 21:39:23 -08:00
Sergei Lebedev
74b2e0203f [pallas:mosaic_gpu] Use {min,max}imumf instead of {min,max}numf
PiperOrigin-RevId: 730154865
2025-02-23 09:52:48 -08:00
George Necula
1be801bac8 [better_errors] Cleanup use of DebugInfo.arg_names and result_paths
Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.

I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.

Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
2025-02-23 08:27:56 +02:00
Yash Katariya
d695aa4c63 [sharding_in_types] Add sharding rules for the following primitives:
* `bitcast_convert_element_type`
  * `cumsum`
  * `cumlogsumexp`
  * `cumprod`
  * `cummax`
  * `cummin`
  * `reduce_window`
  * `reduce_window_sum`
  * `reduce_window_max`
  * `reduce_window_min`
  * `select_and_gather_add`

For `reduce_window_...` primitives only trivial windowing is supported along non-replicated dimensions. We can relax the other NotImplemented case in the future.

PiperOrigin-RevId: 729910108
2025-02-22 10:45:58 -08:00
Jan Naumann
e03fe3a06d Implement SVD algorithm based on QR for CPU targets
In a recent jax release the SvdAlgorithm parameter has been added
to the jax.lax.linalg.svd function. Currently, for CPU targets
still only the divide and conquer algorithm from LAPACK is
supported (gesdd).

This commits adds the functionality to select the QR based
algorithm on CPU as well. Mainly it addes the wrapper code
to call the gesvd function of LAPACK using the FFI interface.

Signed-off-by: Jan Naumann <j.naumann@fu-berlin.de>
2025-02-22 15:24:57 +01:00
Yash Katariya
7c4fe2a7cc [sharding_in_types] Allow auto_axes and explicit_axes to take numpy arrays, python scalars.
PiperOrigin-RevId: 729729215
2025-02-21 18:49:02 -08:00
Yash Katariya
34077851d8 Reverts 6e83de5909b5dad6827c1682726c840c9688b32d
PiperOrigin-RevId: 729725907
2025-02-21 18:34:16 -08:00
Yash Katariya
80f18ded23 [sharding_in_types] Make slice and ellipsis work with .at[...].get(out_sharding=P(...))
PiperOrigin-RevId: 729723470
2025-02-21 18:25:11 -08:00
Yash Katariya
629426f89c Allow casting to the same axis type
PiperOrigin-RevId: 729667271
2025-02-21 14:53:28 -08:00
Daniel Suo
6e83de5909 Implement Jax CPU/GPU callbacks with XLA's FFI.
- Change 4 of 4 addressing #3 in https://github.com/jax-ml/jax/issues/25842.

PiperOrigin-RevId: 729641154
2025-02-21 13:34:48 -08:00
Dan Foreman-Mackey
c4418c1010 Update several remaining lax.linalg primitives to use registration helper functions.
In this change, we update schur, triangular_solve, tridiagonal, and tridiagonal_solve. I batched these ones since they're all pretty straightforward.

PiperOrigin-RevId: 729572705
2025-02-21 10:18:30 -08:00
Daniel Suo
2d1bc5c2a0 Refactor Jax FFI lowering to prepare for implementing CPU/GPU callbacks using XLA's FFI.
- This refactor just moves code around and should have no impact on tests or public-facing APIs.
- `mlir.emit_python_callback` would eventually depend on `ffi.ffi_lowering`, which in turn depends on definitions in `mlir.py`. We break this circular dependency.

PiperOrigin-RevId: 729561359
2025-02-21 09:45:59 -08:00
shuw
bfb9d3ca4b Improve based on comment # 1 2025-02-21 17:32:57 +00:00
Dan Foreman-Mackey
ed10003adc Update lax.linalg.qr primitives to use registration helper functions.
PiperOrigin-RevId: 729551997
2025-02-21 09:15:01 -08:00
Yash Katariya
66037d10e7 Set the mesh of the sharding during broadcast in vmap so that we don't hit an error during canonicalization. This is similar to bcd4048dd5
PiperOrigin-RevId: 729532213
2025-02-21 08:05:42 -08:00
jax authors
237b7941a8 Merge pull request #26649 from jakevdp:wrapped-fun
PiperOrigin-RevId: 729500034
2025-02-21 05:57:53 -08:00
Dan Foreman-Mackey
09325d925f Update internal unop primitive helper to pass kwargs to dtype rule.
To be consistent with other rule registration helpers, `unop_dtype_rule` should pass through its kwargs to the `result_dtype` callable.

PiperOrigin-RevId: 729483613
2025-02-21 04:52:51 -08:00
Dan Foreman-Mackey
126909b62a Update lax.linalg.lu primitive to use registration helper functions.
PiperOrigin-RevId: 729483456
2025-02-21 04:50:46 -08:00
Dan Foreman-Mackey
a981e1c4b9 Start adding primitive registration helper functions to lax.linalg.
As part of my efforts to simplify the primitive implementations in lax.linalg, I've found that all of the primitives share some common logic when it comes to impls, abstract_evals, and batching. This change adds some helper functions and starts the process of abstracting the primitive definitions to simplify and reduce duplication. I will continue with the rest of the primitives in lax.linalg, but I didn't want to overload the first diff.

PiperOrigin-RevId: 729471970
2025-02-21 04:05:34 -08:00
Yash Katariya
bcd4048dd5 Set the mesh of tangent.aval when we are creating zeros_like_aval because when you close over an array which is unused, we error out during canonicalization
PiperOrigin-RevId: 729340808
2025-02-20 19:32:34 -08:00
Yash Katariya
250e2ee7da Use the mesh of out_aval when converting GSPMDSharding to NamedSharding. This makes sure that the axis types of the corresponding output is correct.
Also, if all axes of an out_aval are auto, set the corresponding out_sharding to Unspecified during lowering, otherwise things go horribly wrong. This is actually a XLA bug but we can workaround it in JAX for now.

PiperOrigin-RevId: 729307115
2025-02-20 17:13:24 -08:00
Jake VanderPlas
fe00aa0f65 Internal: improved type annotations for lu.WrappedFun 2025-02-20 15:12:39 -08:00
Sergei Lebedev
7438976e79 [pallas:mosaic_gpu] Added support for binary/comparison ops with WG semantics
PiperOrigin-RevId: 729266484
2025-02-20 15:06:27 -08:00
Robert David
08de0128b6 Fix head comment: was referring to nonexistent parameters.
PiperOrigin-RevId: 729231457
2025-02-20 13:29:40 -08:00
jax authors
b7968474c2 [Pallas][Mosaic] Support float8_e4m3b11fnuz
PiperOrigin-RevId: 729169181
2025-02-20 10:44:33 -08:00
Yash Katariya
262aab74f0 canonicalize closed over values if **atleast** 1 mesh axis is Manual and **all other mesh axes** are Manual or Auto. This would make the canonicalization work properly with shmap partial-auto.
If a mesh axis is Explicit, we don't canonicalize closed over values yet since that make require shape changes. The workaround is for users to pass those arrays as arguments instead of closing over them in a shard_map.

PiperOrigin-RevId: 728956512
2025-02-19 22:18:56 -08:00
Yash Katariya
8305803b76 [sharding_in_types] Initial support for partial-auto/explicit shard_map + sharding-in-types. If the axes in shmap(..., auto=...) is an explicit axes in the outer mesh context, then that axis is treated as Explicit instead of Auto.
PiperOrigin-RevId: 728920514
2025-02-19 20:04:54 -08:00
jax authors
cb0d326e16 Merge pull request #26591 from jakevdp:lax-docs
PiperOrigin-RevId: 728908919
2025-02-19 19:22:48 -08:00
cjkkkk
3a80080392 fix unit tests to not use fmha rewriter 2025-02-20 00:41:04 +00:00
Shu Wang
ae111f7c97
Rename custom-call name. 2025-02-19 16:46:44 -06:00
Yash Katariya
dbb46e9214 Relax one more check in partial_eval_jaxpr_nounits
PiperOrigin-RevId: 728788472
2025-02-19 13:14:08 -08:00
Jacob Burnim
ac74857d27 [Pallas] Support dynamic grids in the new TPU interpret mode
PiperOrigin-RevId: 728786896
2025-02-19 13:09:23 -08:00
jax authors
eef829cbcb Merge pull request #26615 from froystig:scan-docfix
PiperOrigin-RevId: 728753989
2025-02-19 11:41:35 -08:00
Yash Katariya
1081c1f11a Relax the check in _mapped_axis_spec to allow () and None to be treated the same
PiperOrigin-RevId: 728746291
2025-02-19 11:23:17 -08:00