9384 Commits

Author SHA1 Message Date
carlosgmartin
32411a430f Add jax.random.multinomial. 2025-01-31 18:45:55 -05:00
jax authors
a9f4dd7182 Merge pull request #26249 from jakevdp:fix-sterling
PiperOrigin-RevId: 721922732
2025-01-31 15:26:37 -08:00
Emily Fertig
3b2410f77c Reverts bb951136e9b91a584bb422119ada76cc69c86024
PiperOrigin-RevId: 721908669
2025-01-31 14:42:22 -08:00
Jake VanderPlas
216bd9a6cc Fix dtype issue in stirling approximation 2025-01-31 14:13:02 -08:00
jax authors
0ef2ccfdb4 Merge pull request #26238 from hawkinsp:coretest
PiperOrigin-RevId: 721884935
2025-01-31 13:33:04 -08:00
jax authors
b039f976d5 Merge pull request #26239 from jakevdp:rotation-test
PiperOrigin-RevId: 721884090
2025-01-31 13:30:43 -08:00
Gunhyun Park
20555f63da Lower np.ndarray to DenseElementsAttr instead of ArrayAttr.
PiperOrigin-RevId: 721833949
2025-01-31 11:06:06 -08:00
Jake VanderPlas
4d433d063e test: change random matrix generation for Rotation 2025-01-31 09:21:46 -08:00
Peter Hawkins
2d9cd86ae1 Disable two tests of GC behavior if there are multiple threads per process.
These don't seem to work reliably with multiple threads per process, even though the test is marked thread unsafe.
2025-01-31 09:14:49 -08:00
Christos Perivolaropoulos
bf9671731c [mgpu] Correct instruction for conversion of unsigned int types.
PiperOrigin-RevId: 721793849
2025-01-31 09:06:40 -08:00
Adam Paszke
c1e136058c Re-enable Pallas bf16 exp2 tests on TPU
PiperOrigin-RevId: 721784841
2025-01-31 08:36:57 -08:00
Vladimir Belitskiy
1bfdd504ed Reverts 86643a1b3e0516e1a2ddbdabbb714cf8c0301f18
PiperOrigin-RevId: 721776251
2025-01-31 08:05:46 -08:00
Adam Paszke
cadfcc7a1b [Mosaic GPU] Allow uneven partitioning of dimensions into tiles in TileTransform
PiperOrigin-RevId: 721705218
2025-01-31 03:05:44 -08:00
Adam Paszke
10ac6b7e12 [Mosaic GPU] Add support for tiled swizzle=16 (i.e. no swizzle) loads and stores
The tiling still makes it possible to do it without bank conflicts.

PiperOrigin-RevId: 721701635
2025-01-31 02:49:59 -08:00
Peter Hawkins
a2f7824c98 Disable a debug_info_test test that fails in CI.
This test is sometimes reporting 4 warnings, probably because of tracing cache hits. To be correct, this test probably needs to use its own unique functions that are not shared with other test cases.

PiperOrigin-RevId: 721571459
2025-01-30 17:25:18 -08:00
Justin Fu
834e0d7c87 Disable source mapper test for optimized hlo 2025-01-30 14:07:54 -08:00
Yash Katariya
9107ee4a22 Do automatic casting from auto -> manual when the context mesh is manual and avals are in auto mode. This happens when values are being closed over in a shard_map. The casting is happening at lax level but we can move this to a different place later on.
PiperOrigin-RevId: 721495804
2025-01-30 13:14:04 -08:00
Gunhyun Park
a8df383ccf Fix lax.ragged_all_to_all degenerate case
In a singleton group case, unlike regular all_to_all, the ragged op becomes a generic equivalent of DynamicUpdateSlice, except update size is not statically known. This operation can't be expressed with standard HLO instructions -- the backend will handle this case separately.

Added small improvement to error messages.

PiperOrigin-RevId: 721473063
2025-01-30 12:05:02 -08:00
Yash Katariya
f4e2c6c34c Try to match out_spec with in_spec if both shardings are full auto and they are equivalent to each other. This is because of backwards compatibility reasons where tests expect the in and out shardings to match.
PiperOrigin-RevId: 721470917
2025-01-30 11:59:57 -08:00
jax authors
2e40549c38 Merge pull request #26208 from dfm:disable-ragged-test
PiperOrigin-RevId: 721433612
2025-01-30 10:16:15 -08:00
Emily Fertig
bb951136e9 Return arrays from ArrayImpl._check_and_rearrange.
This is in preparation for a larger change, so that input buffers can be checked before Array creation in XLA and the user gets more helpful JAX error messages instead of XLA errors.

PiperOrigin-RevId: 721412760
2025-01-30 09:10:50 -08:00
Benjamin Chetioui
d8f3b33ae4 [Mosaic GPU] Eliminate the arrive attribute from mosaic_gpu.async_load.
We plan to explicitly issue an `expect_tx` operation all the time when using
the dialect.

PiperOrigin-RevId: 721411949
2025-01-30 09:08:45 -08:00
Dan Foreman-Mackey
9442f90cb2 [XLA:CPU] Add CPU client support for layout modes.
The main motivation for this change is to support user-specified input and output layouts for JAX interoperability with other libraries. For example, https://github.com/jax-ml/jax/issues/25066.

The logic is more-or-less a direct copy of the implementation in `PjRtStreamExecutorClient`.

PiperOrigin-RevId: 721382281
2025-01-30 07:27:02 -08:00
Dan Foreman-Mackey
19c17bb28b Skip ragged collective tests on CPU. 2025-01-30 10:03:53 -05:00
Dimitar (Mitko) Asenov
6214c25a6d [Mosaic GPU] Add ArriveExpect and Wait ops on dialect barriers with explicit handling of parities
This makes dialect tests in mgpu_test.py truly express the entire computation at the warpgroup level.

PiperOrigin-RevId: 721371327
2025-01-30 06:44:32 -08:00
George Necula
32c98b9a76 [better_errors] Refactor more uses of pe.tracing_debug_info (part 3)
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 3 of a series (following #26097, #26099) for jit, pmap, checkify,
and the custom_partitioning (the last few uses).

In order to land this, I had to remove a safety check that the number of
`arg_names` and `result_paths` in a Jaxpr's debug info match the number
of Jaxpr invars and outvars, respectively. Additionally, I added two
accessors `safe_arg_names` and `safe_result_paths` to ensure that
the arg names and result paths match the expected length. These accessors
return no-op results when the lengths are not as expected.
From my testint, this happens only in Jaxprs that
are not used for lowering, hence there is no actual user-visible
change here. Simply, more internal Jaxprs are getting debug_info
and in some cases the `arg_names` and `result_paths` are not correct.
Still, this change is worth it because the `func_src_info` is the most
useful part of the debug info (used for leaked tracers), and that is
accurate. We will fix the `arg_names` and `result_paths` in a future change.

One can see in the changes in debug_info_test.py the improvements in the
user-visible debug info, including for `pjit` and `pmap` cases when
it was wrong.
2025-01-30 07:40:05 +02:00
Justin Fu
b01111d96c Add skeleton for a multi-pass source mapper for Jaxprs/HLO to jax.experimental.
PiperOrigin-RevId: 721119935
2025-01-29 15:01:43 -08:00
Gleb Pobudzey
8c02731a06 Increasing shard count and removing asan builds to prevent timeouts.
PiperOrigin-RevId: 721038112
2025-01-29 10:58:51 -08:00
jax authors
0a30ef3c67 Merge pull request #25980 from codinglover222:jit-vmap-compile-test
PiperOrigin-RevId: 721035412
2025-01-29 10:52:15 -08:00
Dan Foreman-Mackey
2ae018ed8e Unconditionally skip async deadlock test for pure_callback.
PiperOrigin-RevId: 721012451
2025-01-29 09:49:01 -08:00
Yash Katariya
dcb28f1218 [sharding_in_types] Add vmap + explicit sharding support. The main changes are:
* Track `explicit_mesh_axis` on `AxisData`.
* Modify `unmapped_aval` to the the above explicit mesh axis and insert it into the right place in the sharding so out_shardings are correct.
* Make `matchaxis` also handle shardings correctly
* All mapped dimensions should be sharded the same way
* spmd_axis_name and explicit sharded arrays cannot be used together
* `out_shardings` parameter on `dot_general`, `broadcast_in_dim`, `reshape`, `reshard` and `mesh_cast` is handled correctly in presence of vmap.

This should eventually help us get rid of `spmd_axis_name` from `vmap`.

PiperOrigin-RevId: 721007659
2025-01-29 09:34:27 -08:00
Jake VanderPlas
955e7c4793 Internal: avoid adding _DimExpr to dtypes._weak_types
This causes problems because internal code assumes it will not be modified. We replace this with an internal registration mechanism.

PiperOrigin-RevId: 721000907
2025-01-29 09:11:02 -08:00
George Necula
8720e95570 [export] Fixes for export_harnesses_multi_platform_test.
The test was mistakenly skipped on slow tests. This is a highly-parameterized test, and if there are some individual instances that are slow, only those should be skipped. The slowest of all instances takes 3s.

I have also ensured that when running natively, we also use jit, like in export mode, to reduce chances that we see numerical discrepancies between eager and jit mode. This fixes a failure on GPU in Kokoro.

PiperOrigin-RevId: 720946449
2025-01-29 06:12:33 -08:00
Dan Foreman-Mackey
9d39ab305a Disable async dispatch within the body of a host callback.
This is a follow up to https://github.com/jax-ml/jax/pull/26160 and https://github.com/openxla/xla/pull/21980. See those PRs for more discussion of the motivation for this change.

In this PR, we disable CPU asynchronous execution when running within the body of a host callback, because this can cause deadlocks.

PiperOrigin-RevId: 720918318
2025-01-29 04:24:33 -08:00
jax authors
a459e7e4cd Merge pull request #26151 from gnecula:debug_info_collect_lowered_jaxprs
PiperOrigin-RevId: 720911587
2025-01-29 04:00:03 -08:00
Christos Perivolaropoulos
f2f7a150f9 [mosaic_gpu] Allow tiled array instead of wgmma.
PiperOrigin-RevId: 720908864
2025-01-29 03:48:14 -08:00
Dan Foreman-Mackey
83457c115a Always dispatch CPU executables synchronously when they include callbacks.
As discussed in https://github.com/jax-ml/jax/issues/25861 and https://github.com/jax-ml/jax/issues/24255, using host callbacks within an asynchronously-dispatched CPU executable can deadlock when the body of the callback itself asynchronously dispatches JAX CPU code. My rough understanding of the problem is that the XLA intra op thread pool gets filled up with callbacks waiting for their body to execute, but there aren't enough resources to schedule the inner computations.

There's probably a better way to fix this within XLA:CPU, but the temporary fix that I've come up with is to disable asynchronous dispatch on CPU when either:

1. Executing a program that includes any host callbacks, or
2. when running within the body of a callback.

It seems like both of these conditions are needed in general because I was able to find test cases that failed with just one or the other implemented.

This PR includes just the first change, and the second will be implemented in a follow-up.

PiperOrigin-RevId: 720777713
2025-01-28 18:23:35 -08:00
Gunhyun Park
809e1133c8 Add support for axis_name and axis_index_groups to lax.ragged_all_to_all
PiperOrigin-RevId: 720738861
2025-01-28 16:02:03 -08:00
Bixia Zheng
9cbff64251 #sdy Enable test_partial_auto_of_random_keys under Shardy.
PiperOrigin-RevId: 720731202
2025-01-28 15:36:52 -08:00
Anselm Levskaya
23c8607bab disable kernel test due to races.
PiperOrigin-RevId: 720715578
2025-01-28 14:47:36 -08:00
George Necula
f8673cde94 [better_errors] Expand debug info testing with eager mode, and MLIR module checking.
Made several improvements to the debug info tests:

 * added support for eager mode, which sometimes uses
   different code paths for the debug info, e.g., for
   `jvp(pmap)`. To check the debugging info in these cases we add
   instrumentation to collect the lowered Jaxprs and MLIR modules right
   after lowering, and we check the debugging information there.
 * added support for checking for the presence of regular expressions
   and strings in the lowered module, to check that the location
   information and arg_names and result_paths is present. This
   is now enabled only for a subset of the tests.
 * simplified the pretty-printing of the arg_names and result_paths
   in the debug info, to remove a layer of parentheses and string,
   so that instead of `arg_names=("x", "y")` we now pretty-print
   just `arg_names=x,y"
 * added support for checking the provenance information in
   leaked tracers
2025-01-28 23:51:06 +02:00
Dimitar (Mitko) Asenov
d9f67ffe13 [Mosaic GPU] Implement a lowering for the dialect WGMMA op
PiperOrigin-RevId: 720663200
2025-01-28 12:08:45 -08:00
Yash Katariya
8f248fe626 [sharding_in_types] Upstream changes from defaulting sharding_in_types config to True experiment. There aren't a lot of failures in TGP but we can atleast upstream these changes until we work on the failures.
PiperOrigin-RevId: 720639755
2025-01-28 11:04:42 -08:00
Gleb Pobudzey
7a4a53ad9e Add win32 guard to fix imports on Windows
PiperOrigin-RevId: 720625818
2025-01-28 10:32:19 -08:00
Adam Paszke
a4fe5c1ac2 [Mosaic GPU] Add specialized support for some int4 -> bfloat16 casts
PiperOrigin-RevId: 720601356
2025-01-28 09:21:40 -08:00
Adam Paszke
f504d32492 [Mosaic GPU] Add support for tiled loads/stores with sub-byte types
Apparently MLIR and LLVM love to pad sub-byte types to whole bytes, so only
the code where we do address arithmetic ourselves is easy to adapt.

PiperOrigin-RevId: 720593538
2025-01-28 08:57:21 -08:00
Dmitri Gribenko
e332b94f19 Integrate LLVM at llvm/llvm-project@2e5a5237da
Updates LLVM usage to match
[2e5a5237daf8](https://github.com/llvm/llvm-project/commit/2e5a5237daf8)

PiperOrigin-RevId: 720516860
2025-01-28 04:03:02 -08:00
Yash Katariya
ae705fef9c [sharding_in_types] Add support for svd_p
PiperOrigin-RevId: 720409750
2025-01-27 20:31:54 -08:00
jax authors
24987a90dc Merge pull request #26134 from justinjfu:pallas_accum_bugfix
PiperOrigin-RevId: 720374819
2025-01-27 18:05:57 -08:00
Justin Fu
54ac172b4c [Pallas] Refactor Pallas HLO interpret mode to a standalone file.
Also replaces the interpreter context (used only for handling extended dtypes) with a physicalize Jaxpr pass.

PiperOrigin-RevId: 720371033
2025-01-27 17:52:27 -08:00