25661 Commits

Author SHA1 Message Date
Sergei Lebedev
194884d311 Migrated to mypy 1.14.1 with --allow_redefinition
I initially wanted to upgrade to 1.15, but it seems to have a bug in how
ternary expressions are type checked. For example,

   def f(x: int) -> str: ...
   def g(x: int) -> str: ...

   callback = f if ... else g  # has type object!
2025-02-13 15:38:28 +00:00
Yash Katariya
3ec7a67e51 [sharding_in_types] Make sharding arg to ShapedArray kwarg only
PiperOrigin-RevId: 726272943
2025-02-12 18:22:50 -08:00
Yash Katariya
15cd83ae00 [sharding_in_types] Error out when PartitionSpec is passed to APIs that take out_sharding like einsum when context_mesh is unset.
This change is raising a better error because doing `NamedSharding(empty_mesh, P('x'))` will raise an error on construction but it is uglier than the current error added in this change.

PiperOrigin-RevId: 726253654
2025-02-12 17:13:14 -08:00
Jevin Jiang
876668faa1 [Mosaic TPU] Support bf16 div if HW does not directly support.
PiperOrigin-RevId: 726212286
2025-02-12 15:04:09 -08:00
jax authors
153a7cf913 Merge pull request #26373 from jax-ml:autodidax-stackless
PiperOrigin-RevId: 726211770
2025-02-12 15:02:41 -08:00
jax authors
73c626d95e Merge pull request #26503 from garymm:patch-1
PiperOrigin-RevId: 726197571
2025-02-12 14:24:42 -08:00
Yash Katariya
0944e5202e Create _BaseMesh so that properties can be shared between Mesh and AbstractMesh so that code is not duplicated
PiperOrigin-RevId: 726193613
2025-02-12 14:14:48 -08:00
Yash Katariya
1a62df1ac0 Rename sharding argument to out_sharding for lax.reshape, lax.broadcast_in_dim, lax.broadcast and lax.broadcasted_iota. .bind of these APIs still take sharding as a parameter though (but that's fine since it's internal and not public facing)
PiperOrigin-RevId: 726187934
2025-02-12 13:59:23 -08:00
Yash Katariya
d58c3a4722 [sharding_in_types] Fix some properties that assumed axis_types always existed.
PiperOrigin-RevId: 726187278
2025-02-12 13:57:19 -08:00
Daniel Suo
8c685be688 [xla:cpu] Implement XLA FFI handlers for CPU Jax callbacks.
PiperOrigin-RevId: 726185954
2025-02-12 13:53:36 -08:00
Gary Miguel
e231a35ad3
Fix doc string for PmapSharding
Lack of indent was resulting in extra parameter being shown in the HTML generated docs
2025-02-12 13:39:53 -08:00
Dan Foreman-Mackey
9298018afa Enable shardy batch partitionable FFI test.
PiperOrigin-RevId: 726171678
2025-02-12 13:17:40 -08:00
jax authors
4f1c67e6c0 Merge pull request #26403 from jakevdp:bf16-mean
PiperOrigin-RevId: 726157721
2025-02-12 12:41:20 -08:00
Dougal
9145366f6f Part 1 of a new autodidax based on "stackless" 2025-02-12 15:23:06 -05:00
Jake VanderPlas
b5e7b60d6a jax.numpy reductions: avoid upcast of f16 when dtype is specified by user 2025-02-12 11:49:39 -08:00
jax authors
5b697728c7 Merge pull request #24910 from olupton:expect-pgle
PiperOrigin-RevId: 726106211
2025-02-12 10:26:08 -08:00
jax authors
f7e2901e8b Merge pull request #25955 from tttc3:magma_qr
PiperOrigin-RevId: 726098235
2025-02-12 10:05:02 -08:00
Yash Katariya
2d01df760b [sharding_in_types] Make the typing checks and sharding rule checks a little bit less strict when the current or aval mesh is empty/unset. Also some more changes as listed below:
* get_aval is not context dependent

* canonicalization does not happen for avals on an empty mesh

* jax.jit does not set abstract mesh context anymore before tracing

* sharding checks have been relaxed for all modes (`Auto`, `Explicit` and `Manual`). This means that `f = lambda x, y: x * y; f(explicit_sharded_arr, np_array)` will be allowed without inserting any mesh_casts even in `Explicit` sharding mode

* Even if use_mesh is not used in explicit sharding mode, computation follows data works!

* Higher order primitives skip canonicalization (pjit_p, while_p, cond_p, for_loop_p, scan_p)

* Check in partial_eval which compares jaxpr_known.outvars == jaxpr.out_avals has been relaxed to not check shardings if any one of the aval has an empty mesh.

As mentioned in https://github.com/jax-ml/jax/issues/26474 we need to relax the typing and sharding rule checks because if we insert `mesh_cast`s, those lead to creation of unnecessary residuals (for literals, numpy arrays, basically anything that has an empty mesh) which is not good.

PiperOrigin-RevId: 726097292
2025-02-12 10:03:01 -08:00
Nitin Srinivasan
93831bdde7 Download and use jax wheels from GCS bucket for nightly/release test workflows
Unlike continuous workflows, when testing nightly/release artifacts, we want to download and install the `jax` wheels found in the GCS bucket instead of installing it from HEAD.

It looks like `env` setting in the calling workflow isn't passed over to the called workflows so we define a new workflow input, `install-jax-current-commit`, to control the `jax` install behavior.

PiperOrigin-RevId: 726086522
2025-02-12 09:32:05 -08:00
Benjamin Chetioui
837418c652 [Mosaic GPU] Remove old jaxlib version guards.
PiperOrigin-RevId: 726071956
2025-02-12 08:49:40 -08:00
Yash Katariya
b4b4a98db7 [sharding_in_types] When caching mesh with axis_types, make sure the data structure is (axis_size, axis_names, tuple(axis_types))
PiperOrigin-RevId: 726064530
2025-02-12 08:23:52 -08:00
Adam Paszke
f1ab7514db Make sure we take libTPU version into account in the Pallas lowering
Also, strengthen the presubmit to make sure we catch more errors.

PiperOrigin-RevId: 726061633
2025-02-12 08:15:57 -08:00
tttc3
b1b56ea0b0 Enable pivoted QR on GPU via MAGMA.
Originally noted in #20282, this commit provides a GPU compatible
implementation of `geqp3` via MAGMA.
2025-02-12 16:12:42 +00:00
jax authors
e14466a8fb Merge pull request #26447 from jakevdp:refactor-contractions
PiperOrigin-RevId: 726043463
2025-02-12 07:14:30 -08:00
Adam Paszke
6662ea96ba Fix the string array test to not assume that there will always be exactly 2 CPU devices
PiperOrigin-RevId: 726038266
2025-02-12 06:55:01 -08:00
Benjamin Chetioui
c7199fe8a5 [Pallas/Mosaic GPU] Enable progressive lowering for integer addition.
The helpers `_fragmented_array_to_ir` and `_fragmented_array_from_ir` in
`dialect_lowering.py` have been modified, such that a fragmented array's
signedness no longer appears in its IR representation.

This is because signedness is a reflection of how we make use of the value,
and not an inherent property of it. The appropriate signedness value to use
to reload a fragmented array from IR must be provided by the caller.

PiperOrigin-RevId: 726030853
2025-02-12 06:29:25 -08:00
jax authors
1e2a5770c9 Merge pull request #26455 from gnecula:debug_info_jaxpr_8
PiperOrigin-RevId: 726023315
2025-02-12 06:03:32 -08:00
George Necula
faa0ad6f33 [better_errors] Continue adding debug info to Jaxprs (step 8)
This follows in a series, starting with #26078 and #26313, adding debug_info to more calls to lu.wrap_init.

These are some leftover changes, in particular those needed when
running with `JAX_USE_DIRECT_LINEARIZE=1`.
2025-02-12 14:23:52 +01:00
jax authors
0409d1c38c Update XLA dependency to use revision
0fc06fb596.

PiperOrigin-RevId: 726013938
2025-02-12 05:22:11 -08:00
Benjamin Chetioui
dd4f396d90 [Mosaic GPU][NFC] Add missing wgmma_{commit,wait}_group_sync_aligned to test.
We need to wait for the results to be available before we can copy them to
`smem`, and these instructions are not issued in the lowering of
`mgpu_dialect.wgmma`.

PiperOrigin-RevId: 725989759
2025-02-12 03:45:47 -08:00
Adam Paszke
8eea88626f Skip CPU ASAN in scipy special function tests
They just take too long under ASAN

PiperOrigin-RevId: 725965329
2025-02-12 02:13:25 -08:00
Benjamin Chetioui
5ad89006c3 [Pallas/Mosaic GPU] Add initial support for warpgroup semantics in lowering.
This will allow us to lower Pallas kernels using the Mosaic GPU dialect, and
in turn to perform layout inference and optimization automatically.

The change contains lowering rules for `get` and `swap` (which are necessary
to get a basic example to run), as well as for `add`.

The new lowering path can be used by specifying the `Warpgroup` thread
semantics as part of `pallas_call`'s compiler params.

PiperOrigin-RevId: 725958027
2025-02-12 01:47:49 -08:00
jax authors
72e7b93b4d Merge pull request #26478 from froystig:aot-doc-traced2
PiperOrigin-RevId: 725917018
2025-02-11 23:03:00 -08:00
Roy Frostig
8720a9c0cd docstrings and API reference doc listing for the traced AOT stage 2025-02-11 22:30:50 -08:00
jax authors
914adaf60c Merge pull request #26476 from froystig:aot-doc-traced
PiperOrigin-RevId: 725902103
2025-02-11 22:01:21 -08:00
Roy Frostig
af381a73a3 update AOT walkthrough to cover explicit tracing stage 2025-02-11 21:26:05 -08:00
Yash Katariya
675be0121b Add a custom __reduce__ for UnconstrainedSingleton because it can be picked and then loaded back and we need the id of P.UNCONSTRAINED to match before and after loading.
PiperOrigin-RevId: 725874879
2025-02-11 20:05:48 -08:00
Dan Foreman-Mackey
bba09137dc Match output container to result_shape_dtypes in ffi_call.
Previously, ffi_call would always return a list for multiple results, but if the input `result_shape_dtypes` is a tuple, we should return a tuple.

PiperOrigin-RevId: 725834048
2025-02-11 17:33:32 -08:00
jax authors
fd12f30011 Merge pull request #26446 from jakevdp:lax-reductions
PiperOrigin-RevId: 725812561
2025-02-11 16:22:03 -08:00
Jake VanderPlas
e389b707ba Add public APIs for jax.lax monoidal reductions 2025-02-11 16:00:03 -08:00
jax authors
d0b6c677b0 Merge pull request #26470 from jakevdp:lax-docs
PiperOrigin-RevId: 725804083
2025-02-11 15:58:52 -08:00
Gunhyun Park
6b19bb2091 Allow composites to provide default kwargs with None value
The current behavior will crash upon trying to convert NoneType to an mlir attribute. This allows a composite to have optional attributes that can be omitted when it's not provided. This behavior is similar to how default values in MLIR is not shown in the IR.

PiperOrigin-RevId: 725786442
2025-02-11 15:05:50 -08:00
Jake VanderPlas
e488956092 jax.lax: improve docs for real, imag, complex, conj, and abs. 2025-02-11 14:12:22 -08:00
Nitin Srinivasan
30acd383fb Run test job irrespective of if the build jobs succeeds or fails
This lets us avoid losing test coverage if a single unrelated build job fails. E.g Windows build job fails but everything else succeeds. In this case, we still want to run the tests for other platforms.

Also, if a build job fails, its corresponding test job will also report a failure as a result of not being able to download the wheel artifact so we should still be able to tell the source of job failure easily.

PiperOrigin-RevId: 725754098
2025-02-11 13:37:30 -08:00
Marcello Maggioni
6c6b5ec582 [JAX/Pallas] Add has_side_effect parameter to CompilerParams to stop CSE of operations.
Some pallas kernels shouldn't be CSEd even if they share the same inputs.
For example in async pallas scenarios like when you have a kernel starting some DMAs
that are waited in the user of the kernel (to perform async copies) we can't CSE or kernels
might wait multiple times on a DMA that happens only one.

PiperOrigin-RevId: 725752913
2025-02-11 13:33:01 -08:00
Adam Paszke
e987ce2b77 [Mosaic GPU] Add autotuning to the Blackwell matmul kernel
PiperOrigin-RevId: 725726073
2025-02-11 12:17:46 -08:00
Gunhyun Park
7994aa82f8 Delete unused code in _dot_batch_rule
PiperOrigin-RevId: 725725676
2025-02-11 12:16:01 -08:00
Dimitar (Mitko) Asenov
6fc1c61520 [Mosaic GPU] Use the memref layout to encode transforms (only swizzle for now).
Tile and Transpose transforms to follow.

PiperOrigin-RevId: 725716812
2025-02-11 11:51:25 -08:00
Adam Paszke
c2bd1576da [Mosaic GPU] Add support for tiling the M grid dimension in Blackwell matmul
This lets us make better use of L2.

PiperOrigin-RevId: 725715193
2025-02-11 11:47:08 -08:00
Yash Katariya
005c14b4da [sharding_in_types] Error out if the sharding's specs passed to with_sharding_constraint don't refer to Auto axes.
PiperOrigin-RevId: 725679220
2025-02-11 10:16:52 -08:00