9761 Commits

Author SHA1 Message Date
George Necula
9f797990b5 Remove old backward compatibility mode for old PRGN custom call on GPU
The backend support for the new custom call was added on June 28th, 2024 (#20997).

PiperOrigin-RevId: 723077990
2025-02-04 07:34:52 -08:00
Peter Buchlovsky
c7d535d3c9 [pallas_mgpu] Add a test for emit_pipeline with wgmma.
PiperOrigin-RevId: 723012611
2025-02-04 03:25:02 -08:00
George Necula
d12aead696 [better_errors] Add debug info to more Jaxprs and WrappedFun (step 1)
The plan is for all `core.Jaxpr` and `lu.WrappedFun` to carry
non-None debug info.

We change `lu.wrap_init` to construct the result paths thunk
whenever it is passed a `debug_info`. The goal is to make sure that
all `WrappedFun` have a debug info with result paths support.

We change some calling conventions for internal functions to not
pass along a separate debug_info if we have a `WrappedFun` or
a `Jaxpr`.

We obtain several improvements in presence of debug infos
in debug_info_test.py
2025-02-04 10:02:35 +02:00
Jevin Jiang
124e123946 [Pallas] Support promise_in_bounds mode in jnp.take_along_axis.
Change is also applied to jax because we don't need to normalize index if the mode is already "promise_in_bounds".

PiperOrigin-RevId: 722930215
2025-02-03 22:06:19 -08:00
Yash Katariya
bc1a706688 [sharding_in_types] Add a canonicalize_value step before dispatching bind so that we can insert mesh_casts under the following conditions:
* When current_mesh is Manual and aval mesh is Auto

* When current mesh is set and aval mesh is unset

* Final style primitives skip this canonicalization and they are free to add it in their own `bind` method.

* `mesh_cast` is skipped from this canonicalization to avoid recursion errors.

This is required to make sure that after we hit abstract_eval rule and check_jaxpr, everything is properly typed in JAX's type system.

`Auto` right now is a bit more permissive because we need to keep the current code at HEAD working but `Explicit` and `Manual` are very strict.

PiperOrigin-RevId: 722868091
2025-02-03 18:00:19 -08:00
jax authors
363f1e6944 Merge pull request #26290 from mattjj:linearize-name-stack-fixes-2
PiperOrigin-RevId: 722856808
2025-02-03 17:20:19 -08:00
Matthew Johnson
8f967c5171 [direct-linearize] fix name stack tests
Co-authored-by: Sharad Vikram <sharadmv@google.com>
2025-02-04 00:47:10 +00:00
Bill Varcho
0abd9538ce [JAX] disable flaky parameter permutations for sparse_bcoo_bcsr test.
PiperOrigin-RevId: 722832212
2025-02-03 16:02:06 -08:00
Sergei Lebedev
7929cd8410 [pallas:triton] The lowering now uses PTX instead of Triton IR
This change improves the stability and backward compatibility of Pallas Triton
calls, because unlike PTX, the Triton dialect has no stability guarantees
and does change in practice.

See #25196.

A few notes

* Pallas Triton no longer delegates compilation to PTX to XLA:GPU. Instead,
  compilation is done via a new PjRt extension, which uses its own compilation
  pipeline mirrored after the one in the Triton Python bindings.
* The implementation of the old custom call used by Pallas Triton is
  deprecated and will be removed after 6 months as per
  [compatibility guarantees] [*]

[*]: https://jax.readthedocs.io/en/latest/export/export.html#compatibility-guarantees

PiperOrigin-RevId: 722773884
2025-02-03 13:21:40 -08:00
jax authors
17d0b86c7c Merge pull request #26275 from dfm:effects-in-custom-linear-solve
PiperOrigin-RevId: 722751986
2025-02-03 12:18:55 -08:00
jax authors
95535df13b Merge pull request #25688 from carlosgmartin:random_multinomial
PiperOrigin-RevId: 722741835
2025-02-03 11:52:43 -08:00
Sergei Lebedev
2d7e4ab2dc [mosaic_gpu] LayoutTest now correctly resets the value of MOSAIC_GPU_DUMP_SASS
PiperOrigin-RevId: 722711341
2025-02-03 10:32:14 -08:00
jax authors
7e353913f2 Merge pull request #26262 from gnecula:debug_info_one
PiperOrigin-RevId: 722684417
2025-02-03 09:17:13 -08:00
Dan Foreman-Mackey
d42e3650d0 Handle effects in lax.custom_linear_solve. 2025-02-03 11:14:48 -05:00
jax authors
7164c6ba3e Merge pull request #25812 from Cjkkkk:segment_ids
PiperOrigin-RevId: 722650439
2025-02-03 07:28:25 -08:00
Parker Schuh
cb188a0cb1 Reject invalid None in jax.NamedSharding(spec=None).
PiperOrigin-RevId: 722500631
2025-02-02 21:29:33 -08:00
Christos Perivolaropoulos
b48d15d788 [pallas_mgpu] For loops can have **non-ref** accumulators for carries.
The user has access only to accumulator references and they can't pass them as caries to loops. However when they are discharged these accumulators become values and become part of the carry. Before this CL this would surprise the loop lowering code.

This was never a problem for pallas mgpu until we added pipelining loops instead of sequential bloc axes.

PiperOrigin-RevId: 722495749
2025-02-02 21:03:26 -08:00
Parker Schuh
da97ee2591 Stop passing None into jax.NamedSharding in preparation for followup which bans passing None in (in favor of PartitionSpec()
PiperOrigin-RevId: 722477002
2025-02-02 19:35:33 -08:00
George Necula
c70de6deed [better_errors] Merge the JaxprDebugInfo and TracingDebugInfo into core.DebugInfo
Previously, we had two almost identical classes: `TracingDebugInfo` and
`JaxprDebugInfo`. The only difference was that `TracingDebugInfo` had
a thunk to return the result paths, while `JaxprDebugInfo` had the
result paths resolved to a tuple. The separation of these types
provided some clarity, but also led to code duplication and
required conversions as the debugging info goes from `WrappedFun`
to a `Jaxpr` and then to `WrappedFun` again.
2025-02-02 06:23:03 +02:00
Qazalbash
42b64fc06c
feat(gh-13291): Add exponential distribution functions: cdf, logcdf, sf, logsf, and ppf 2025-02-01 12:51:11 +05:00
Jevin Jiang
ed952c8e65 [Pallas TPU] Support jnp.take_along_axis for 32-bit vreg-sized vector.
PiperOrigin-RevId: 722015152
2025-01-31 21:27:08 -08:00
Gunhyun Park
c4e176328f Move ragged_all_to_all test under appropriate test file
PiperOrigin-RevId: 721947980
2025-01-31 16:44:04 -08:00
carlosgmartin
32411a430f Add jax.random.multinomial. 2025-01-31 18:45:55 -05:00
jax authors
872e6c0ec4 Merge pull request #25766 from carlosgmartin:nn_initializers_variance_scaling_mode_fan_geo_avg
PiperOrigin-RevId: 721928532
2025-01-31 15:41:50 -08:00
jax authors
a9f4dd7182 Merge pull request #26249 from jakevdp:fix-sterling
PiperOrigin-RevId: 721922732
2025-01-31 15:26:37 -08:00
carlosgmartin
96d3447e89 Add mode='fan_geo_avg' to nn.initializers.variance_scaling. 2025-01-31 17:52:22 -05:00
Emily Fertig
3b2410f77c Reverts bb951136e9b91a584bb422119ada76cc69c86024
PiperOrigin-RevId: 721908669
2025-01-31 14:42:22 -08:00
Jake VanderPlas
216bd9a6cc Fix dtype issue in stirling approximation 2025-01-31 14:13:02 -08:00
jax authors
0ef2ccfdb4 Merge pull request #26238 from hawkinsp:coretest
PiperOrigin-RevId: 721884935
2025-01-31 13:33:04 -08:00
jax authors
b039f976d5 Merge pull request #26239 from jakevdp:rotation-test
PiperOrigin-RevId: 721884090
2025-01-31 13:30:43 -08:00
Gunhyun Park
20555f63da Lower np.ndarray to DenseElementsAttr instead of ArrayAttr.
PiperOrigin-RevId: 721833949
2025-01-31 11:06:06 -08:00
Jake VanderPlas
4d433d063e test: change random matrix generation for Rotation 2025-01-31 09:21:46 -08:00
Peter Hawkins
2d9cd86ae1 Disable two tests of GC behavior if there are multiple threads per process.
These don't seem to work reliably with multiple threads per process, even though the test is marked thread unsafe.
2025-01-31 09:14:49 -08:00
Christos Perivolaropoulos
bf9671731c [mgpu] Correct instruction for conversion of unsigned int types.
PiperOrigin-RevId: 721793849
2025-01-31 09:06:40 -08:00
Adam Paszke
c1e136058c Re-enable Pallas bf16 exp2 tests on TPU
PiperOrigin-RevId: 721784841
2025-01-31 08:36:57 -08:00
Vladimir Belitskiy
1bfdd504ed Reverts 86643a1b3e0516e1a2ddbdabbb714cf8c0301f18
PiperOrigin-RevId: 721776251
2025-01-31 08:05:46 -08:00
Adam Paszke
cadfcc7a1b [Mosaic GPU] Allow uneven partitioning of dimensions into tiles in TileTransform
PiperOrigin-RevId: 721705218
2025-01-31 03:05:44 -08:00
Adam Paszke
10ac6b7e12 [Mosaic GPU] Add support for tiled swizzle=16 (i.e. no swizzle) loads and stores
The tiling still makes it possible to do it without bank conflicts.

PiperOrigin-RevId: 721701635
2025-01-31 02:49:59 -08:00
Peter Hawkins
a2f7824c98 Disable a debug_info_test test that fails in CI.
This test is sometimes reporting 4 warnings, probably because of tracing cache hits. To be correct, this test probably needs to use its own unique functions that are not shared with other test cases.

PiperOrigin-RevId: 721571459
2025-01-30 17:25:18 -08:00
cjkkkk
ba6b1fdd09 address lint and typecheck 2025-01-30 22:12:26 +00:00
Justin Fu
834e0d7c87 Disable source mapper test for optimized hlo 2025-01-30 14:07:54 -08:00
Yash Katariya
9107ee4a22 Do automatic casting from auto -> manual when the context mesh is manual and avals are in auto mode. This happens when values are being closed over in a shard_map. The casting is happening at lax level but we can move this to a different place later on.
PiperOrigin-RevId: 721495804
2025-01-30 13:14:04 -08:00
Gunhyun Park
a8df383ccf Fix lax.ragged_all_to_all degenerate case
In a singleton group case, unlike regular all_to_all, the ragged op becomes a generic equivalent of DynamicUpdateSlice, except update size is not statically known. This operation can't be expressed with standard HLO instructions -- the backend will handle this case separately.

Added small improvement to error messages.

PiperOrigin-RevId: 721473063
2025-01-30 12:05:02 -08:00
Yash Katariya
f4e2c6c34c Try to match out_spec with in_spec if both shardings are full auto and they are equivalent to each other. This is because of backwards compatibility reasons where tests expect the in and out shardings to match.
PiperOrigin-RevId: 721470917
2025-01-30 11:59:57 -08:00
jax authors
2e40549c38 Merge pull request #26208 from dfm:disable-ragged-test
PiperOrigin-RevId: 721433612
2025-01-30 10:16:15 -08:00
Emily Fertig
bb951136e9 Return arrays from ArrayImpl._check_and_rearrange.
This is in preparation for a larger change, so that input buffers can be checked before Array creation in XLA and the user gets more helpful JAX error messages instead of XLA errors.

PiperOrigin-RevId: 721412760
2025-01-30 09:10:50 -08:00
Benjamin Chetioui
d8f3b33ae4 [Mosaic GPU] Eliminate the arrive attribute from mosaic_gpu.async_load.
We plan to explicitly issue an `expect_tx` operation all the time when using
the dialect.

PiperOrigin-RevId: 721411949
2025-01-30 09:08:45 -08:00
Dan Foreman-Mackey
9442f90cb2 [XLA:CPU] Add CPU client support for layout modes.
The main motivation for this change is to support user-specified input and output layouts for JAX interoperability with other libraries. For example, https://github.com/jax-ml/jax/issues/25066.

The logic is more-or-less a direct copy of the implementation in `PjRtStreamExecutorClient`.

PiperOrigin-RevId: 721382281
2025-01-30 07:27:02 -08:00
Dan Foreman-Mackey
19c17bb28b Skip ragged collective tests on CPU. 2025-01-30 10:03:53 -05:00
Dimitar (Mitko) Asenov
6214c25a6d [Mosaic GPU] Add ArriveExpect and Wait ops on dialect barriers with explicit handling of parities
This makes dialect tests in mgpu_test.py truly express the entire computation at the warpgroup level.

PiperOrigin-RevId: 721371327
2025-01-30 06:44:32 -08:00