25338 Commits

Author SHA1 Message Date
Justin Fu
b01111d96c Add skeleton for a multi-pass source mapper for Jaxprs/HLO to jax.experimental.
PiperOrigin-RevId: 721119935
2025-01-29 15:01:43 -08:00
jax authors
152099ee0e Merge pull request #26188 from dfm:revert-callback-docs
PiperOrigin-RevId: 721077634
2025-01-29 12:53:57 -08:00
jax authors
a1e4121bae Merge pull request #26185 from hawkinsp:tsan
PiperOrigin-RevId: 721063414
2025-01-29 12:09:19 -08:00
Sergei Lebedev
d4ced960ab Pulled DLDeviceType to XLA backend mapping into a global
I also updated `to_dlpack` and `from_dlpack` to handle `KeyError` instead of `TypeError`, because I think `TypeError` was never actually raised.

PiperOrigin-RevId: 721052736
2025-01-29 11:38:50 -08:00
Gleb Pobudzey
8c02731a06 Increasing shard count and removing asan builds to prevent timeouts.
PiperOrigin-RevId: 721038112
2025-01-29 10:58:51 -08:00
jax authors
0a30ef3c67 Merge pull request #25980 from codinglover222:jit-vmap-compile-test
PiperOrigin-RevId: 721035412
2025-01-29 10:52:15 -08:00
Peter Hawkins
4f00d451aa Update the list of tsan suppressions.
Lower --local_test_jobs in the bazel runner, in the hope that this lowers the number of test timeouts. I suspect we are simply oversubscribing the machine with multiple threads in each test shard.
2025-01-29 10:46:36 -08:00
Dan Foreman-Mackey
2ae018ed8e Unconditionally skip async deadlock test for pure_callback.
PiperOrigin-RevId: 721012451
2025-01-29 09:49:01 -08:00
Yash Katariya
dcb28f1218 [sharding_in_types] Add vmap + explicit sharding support. The main changes are:
* Track `explicit_mesh_axis` on `AxisData`.
* Modify `unmapped_aval` to the the above explicit mesh axis and insert it into the right place in the sharding so out_shardings are correct.
* Make `matchaxis` also handle shardings correctly
* All mapped dimensions should be sharded the same way
* spmd_axis_name and explicit sharded arrays cannot be used together
* `out_shardings` parameter on `dot_general`, `broadcast_in_dim`, `reshape`, `reshard` and `mesh_cast` is handled correctly in presence of vmap.

This should eventually help us get rid of `spmd_axis_name` from `vmap`.

PiperOrigin-RevId: 721007659
2025-01-29 09:34:27 -08:00
Bixia Zheng
20843643ab [jax:custom_partitioning] Make propagate_user_sharding default to None.
Revise documentation for sharding_rule and add a link to jax-shardy-guide.

PiperOrigin-RevId: 721001922
2025-01-29 09:14:35 -08:00
Jake VanderPlas
955e7c4793 Internal: avoid adding _DimExpr to dtypes._weak_types
This causes problems because internal code assumes it will not be modified. We replace this with an internal registration mechanism.

PiperOrigin-RevId: 721000907
2025-01-29 09:11:02 -08:00
jax authors
c01273c603 Update XLA dependency to use revision
621aa467ae.

PiperOrigin-RevId: 720992604
2025-01-29 08:45:51 -08:00
Dan Foreman-Mackey
e2eff1f8d5 Revert https://github.com/jax-ml/jax/pull/25982 since callbacks can now use JAX functions. 2025-01-29 11:12:32 -05:00
George Necula
8720e95570 [export] Fixes for export_harnesses_multi_platform_test.
The test was mistakenly skipped on slow tests. This is a highly-parameterized test, and if there are some individual instances that are slow, only those should be skipped. The slowest of all instances takes 3s.

I have also ensured that when running natively, we also use jit, like in export mode, to reduce chances that we see numerical discrepancies between eager and jit mode. This fixes a failure on GPU in Kokoro.

PiperOrigin-RevId: 720946449
2025-01-29 06:12:33 -08:00
Adam Paszke
c9dfdb4e23 Relax static offset restriction on memref_ptr for sub-byte types
It simply assumes that the base offset is a multiple of byte packing

PiperOrigin-RevId: 720919148
2025-01-29 04:27:36 -08:00
Dan Foreman-Mackey
9d39ab305a Disable async dispatch within the body of a host callback.
This is a follow up to https://github.com/jax-ml/jax/pull/26160 and https://github.com/openxla/xla/pull/21980. See those PRs for more discussion of the motivation for this change.

In this PR, we disable CPU asynchronous execution when running within the body of a host callback, because this can cause deadlocks.

PiperOrigin-RevId: 720918318
2025-01-29 04:24:33 -08:00
jax authors
a459e7e4cd Merge pull request #26151 from gnecula:debug_info_collect_lowered_jaxprs
PiperOrigin-RevId: 720911587
2025-01-29 04:00:03 -08:00
Christos Perivolaropoulos
f2f7a150f9 [mosaic_gpu] Allow tiled array instead of wgmma.
PiperOrigin-RevId: 720908864
2025-01-29 03:48:14 -08:00
Dan Foreman-Mackey
83457c115a Always dispatch CPU executables synchronously when they include callbacks.
As discussed in https://github.com/jax-ml/jax/issues/25861 and https://github.com/jax-ml/jax/issues/24255, using host callbacks within an asynchronously-dispatched CPU executable can deadlock when the body of the callback itself asynchronously dispatches JAX CPU code. My rough understanding of the problem is that the XLA intra op thread pool gets filled up with callbacks waiting for their body to execute, but there aren't enough resources to schedule the inner computations.

There's probably a better way to fix this within XLA:CPU, but the temporary fix that I've come up with is to disable asynchronous dispatch on CPU when either:

1. Executing a program that includes any host callbacks, or
2. when running within the body of a callback.

It seems like both of these conditions are needed in general because I was able to find test cases that failed with just one or the other implemented.

This PR includes just the first change, and the second will be implemented in a follow-up.

PiperOrigin-RevId: 720777713
2025-01-28 18:23:35 -08:00
jax authors
bf22b53cf4 Merge pull request #26154 from jakevdp:pure-callback-doc
PiperOrigin-RevId: 720763192
2025-01-28 17:28:02 -08:00
Gunhyun Park
809e1133c8 Add support for axis_name and axis_index_groups to lax.ragged_all_to_all
PiperOrigin-RevId: 720738861
2025-01-28 16:02:03 -08:00
Bixia Zheng
9cbff64251 #sdy Enable test_partial_auto_of_random_keys under Shardy.
PiperOrigin-RevId: 720731202
2025-01-28 15:36:52 -08:00
Anselm Levskaya
23c8607bab disable kernel test due to races.
PiperOrigin-RevId: 720715578
2025-01-28 14:47:36 -08:00
jax authors
fa9bb231f1 Merge pull request #26157 from jakevdp:callbacks-vectorized
PiperOrigin-RevId: 720701900
2025-01-28 14:06:44 -08:00
George Necula
f8673cde94 [better_errors] Expand debug info testing with eager mode, and MLIR module checking.
Made several improvements to the debug info tests:

 * added support for eager mode, which sometimes uses
   different code paths for the debug info, e.g., for
   `jvp(pmap)`. To check the debugging info in these cases we add
   instrumentation to collect the lowered Jaxprs and MLIR modules right
   after lowering, and we check the debugging information there.
 * added support for checking for the presence of regular expressions
   and strings in the lowered module, to check that the location
   information and arg_names and result_paths is present. This
   is now enabled only for a subset of the tests.
 * simplified the pretty-printing of the arg_names and result_paths
   in the debug info, to remove a layer of parentheses and string,
   so that instead of `arg_names=("x", "y")` we now pretty-print
   just `arg_names=x,y"
 * added support for checking the provenance information in
   leaked tracers
2025-01-28 23:51:06 +02:00
jax authors
95558a9f63 Merge pull request #24898 from vfdev-5:add-tsan-ft-ci-job
PiperOrigin-RevId: 720692325
2025-01-28 13:38:41 -08:00
Dan Foreman-Mackey
09392d8160 Simplify dtype inference in lax.linalg.eig abstract eval rule.
I came across this when working on an unrelated issue, but the explicit use of `finfo` was causing some `UserWarning`s, and it was really unnecessary.

PiperOrigin-RevId: 720691470
2025-01-28 13:35:53 -08:00
Dimitar (Mitko) Asenov
d9f67ffe13 [Mosaic GPU] Implement a lowering for the dialect WGMMA op
PiperOrigin-RevId: 720663200
2025-01-28 12:08:45 -08:00
Jake VanderPlas
25aa5a3008 DOC: avoid deprecated argument in external callbacks 2025-01-28 11:34:56 -08:00
Yash Katariya
8f248fe626 [sharding_in_types] Upstream changes from defaulting sharding_in_types config to True experiment. There aren't a lot of failures in TGP but we can atleast upstream these changes until we work on the failures.
PiperOrigin-RevId: 720639755
2025-01-28 11:04:42 -08:00
Gleb Pobudzey
7a4a53ad9e Add win32 guard to fix imports on Windows
PiperOrigin-RevId: 720625818
2025-01-28 10:32:19 -08:00
jax authors
5625b2fde6 Merge pull request #26131 from jakevdp:fix-nightly-ci
PiperOrigin-RevId: 720625005
2025-01-28 10:29:47 -08:00
Jake VanderPlas
ba2858f834 DOC: add discussion of exceptions in pure_callback 2025-01-28 09:53:47 -08:00
jax authors
ddc8beff5b Update XLA dependency to use revision
bb8a7f3809.

PiperOrigin-RevId: 720604622
2025-01-28 09:33:54 -08:00
Adam Paszke
a4fe5c1ac2 [Mosaic GPU] Add specialized support for some int4 -> bfloat16 casts
PiperOrigin-RevId: 720601356
2025-01-28 09:21:40 -08:00
Adam Paszke
29b658b358 [Mosaic TPU] Optimize clipping impelmentation in arith.fptosi
We can use maxf/minf to avoid extra comparisons

PiperOrigin-RevId: 720601304
2025-01-28 09:20:16 -08:00
Adam Paszke
f504d32492 [Mosaic GPU] Add support for tiled loads/stores with sub-byte types
Apparently MLIR and LLVM love to pad sub-byte types to whole bytes, so only
the code where we do address arithmetic ourselves is easy to adapt.

PiperOrigin-RevId: 720593538
2025-01-28 08:57:21 -08:00
Dmitri Gribenko
e332b94f19 Integrate LLVM at llvm/llvm-project@2e5a5237da
Updates LLVM usage to match
[2e5a5237daf8](https://github.com/llvm/llvm-project/commit/2e5a5237daf8)

PiperOrigin-RevId: 720516860
2025-01-28 04:03:02 -08:00
Yash Katariya
7ed7e0b5b1 [sharding_in_types] Add clamp_p sharding rule.
PiperOrigin-RevId: 720428881
2025-01-27 21:58:08 -08:00
Yash Katariya
ae705fef9c [sharding_in_types] Add support for svd_p
PiperOrigin-RevId: 720409750
2025-01-27 20:31:54 -08:00
jax authors
24987a90dc Merge pull request #26134 from justinjfu:pallas_accum_bugfix
PiperOrigin-RevId: 720374819
2025-01-27 18:05:57 -08:00
Justin Fu
54ac172b4c [Pallas] Refactor Pallas HLO interpret mode to a standalone file.
Also replaces the interpreter context (used only for handling extended dtypes) with a physicalize Jaxpr pass.

PiperOrigin-RevId: 720371033
2025-01-27 17:52:27 -08:00
jax authors
bc130c7ba6 Merge pull request #25213 from Rifur13:dynamic_mask
PiperOrigin-RevId: 720361301
2025-01-27 17:16:12 -08:00
Nitin Srinivasan
36679d89e3 Add a retry logic to avoid flakes when downloading Bazel
PiperOrigin-RevId: 720349541
2025-01-27 16:36:01 -08:00
Gleb Pobudzey
4fe937683e Fix import for Windows platforms
PiperOrigin-RevId: 720348679
2025-01-27 16:33:37 -08:00
Peter Hawkins
faaaf82974 Disable pytorch_interoperability_test under asan on all backends.
It wasn't sufficient to disable this only on GPU.

PiperOrigin-RevId: 720344366
2025-01-27 16:18:28 -08:00
Gleb Pobudzey
c0d23af42c Support dynamic masks in splash attention 2025-01-28 00:14:53 +00:00
jax authors
727d0367a4 Update --config=cuda to add direct dependencies on CUDA libraries both for bazel build and bazel test phases.
With this configuration the same cache is used both for `bazel build` and `bazel test` commands (provided the same target is specified).

Add `--config=no_cuda_libs` for building targets with CUDA libraries from stubs.

PiperOrigin-RevId: 720334587
2025-01-27 15:46:17 -08:00
Justin Fu
7ace72fb3a [Pallas] Be explicit about accumulation dtype in reference implementations 2025-01-27 22:09:29 +00:00
Jake VanderPlas
36c6c74c8a CI: simplify nightly workflow definition 2025-01-27 13:14:41 -08:00