25274 Commits

Author SHA1 Message Date
Peter Hawkins
42fd586e79 Disable pytorch_interoperability_test under asan.
PiperOrigin-RevId: 720189636
2025-01-27 09:00:40 -08:00
George Necula
c61401ab6f Rename debug_info_tests.py copied from api_test.py
PiperOrigin-RevId: 720165679
2025-01-27 07:41:14 -08:00
Dimitar (Mitko) Asenov
a0db6c5cf6 [Mosaic GPU] Use a single instance of the single_thread_predicate in the MLIR dialect lowering.
PiperOrigin-RevId: 720155654
2025-01-27 07:04:06 -08:00
jax authors
9b5cb45bc3 Merge pull request #26099 from gnecula:debug_info_no_pe_debug_info_2
PiperOrigin-RevId: 720153601
2025-01-27 06:55:59 -08:00
George Necula
878272ee3c [better_errors] Refactor more uses of pe.tracing_debug_info (part 2)
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 2 of a series (following #26097) for Pallas.
2025-01-27 16:10:56 +02:00
Dimitar (Mitko) Asenov
a3a285dddc [Mosaic GPU] Handle the swizzle attribute in the lowering of async_store and async_load
PiperOrigin-RevId: 720129408
2025-01-27 05:18:16 -08:00
Dimitar (Mitko) Asenov
101f18d4e3 [Mosaic GPU] Fix error message to make it clearer.
PiperOrigin-RevId: 720111248
2025-01-27 04:07:00 -08:00
jax authors
2a6accd63f Merge pull request #26097 from gnecula:debug_info_no_pe_debug_info
PiperOrigin-RevId: 720106054
2025-01-27 03:46:09 -08:00
George Necula
7361d173a9 [better_errors] Refactor more uses of partial_eval.tracing_debug_info (part 1)
We replace those uses with api_util.tracing_debug_info, which means we
have to move the call further upstream. But this is better because we
have the actual args and kwargs, and we can do a better job, especially
for `arg_names`.

This is part 1 of a series, for: cond, switch, while, scan, composite,
custom_dce, custom_root, custom_linear_solve, saved_residuals.
2025-01-27 12:59:34 +02:00
Peter Hawkins
95cb0eb1c9 Optimize JaxprEqnContext context manager.
* Implement the context manager as a context manager class, rather than using @contextlib.contextmanager. It turns out the contextlib contextmanagers are rather slow.
* Fuse the four child context managers into a single context manager. This saves us a bunch of allocations.
* While we are here, also simplify the xla_metadata context manager to avoid its dual representation of the current metadata.

PiperOrigin-RevId: 719918121
2025-01-26 12:08:44 -08:00
jax authors
84921792e4 Update XLA dependency to use revision
1cbcb65ca0.

PiperOrigin-RevId: 719898295
2025-01-26 10:02:29 -08:00
Peter Hawkins
776327919f Optimize implementation of the compute_on context manager.
* We don't need to keep a separate thread-local stack of objects: the config state already has a thread local.
* We don't need to keep an explicit stack of contexts at all: we can maintain it in the context manager frames.
* When checking for incompatible nested compute_ons, we can just check the current state: no need to look higher in the stack!

PiperOrigin-RevId: 719892989
2025-01-26 09:24:33 -08:00
jax authors
381da3cf6f Merge pull request #26093 from gnecula:debug_info_tests1
PiperOrigin-RevId: 719790208
2025-01-25 22:38:27 -08:00
George Necula
e4d5427d13 [better_errors] Add more debug info test coverage
Try to cover the tracing of almost all JAX higher-order
primitives. Some of the tests added show missing debug info,
marked with TODO. Fixes will come separately.

Had to expand the helper functions _check_tracers_and_jaxprs to
use regular expressions for matching because some debug info
still contains non-deterministic elements.
2025-01-26 08:12:29 +02:00
jax authors
55efd4b225 Update XLA dependency to use revision
4ec7e2a772.

PiperOrigin-RevId: 719672370
2025-01-25 10:31:35 -08:00
jax authors
a8adf75295 Merge pull request #26092 from mattjj:make-array-error
PiperOrigin-RevId: 719667108
2025-01-25 09:56:04 -08:00
Matthew Johnson
1fb4b93d41 improve make_array_from_single_device_arrays error 2025-01-25 17:41:01 +00:00
Peter Hawkins
184aefa493 Optimize the set_xla_metadata context manager.
Key idea: if the argument to the context manager is None, then we don't need to touch any context state.

Also clean up the API by separating the "set a dict" from the "set kwargs" use cases.

PiperOrigin-RevId: 719628089
2025-01-25 05:40:45 -08:00
Nitin Srinivasan
89a9c6c244 Add new Bazel remote cache configs
An example run where the cache configs are used: https://github.com/jax-ml/jax/actions/runs/12940123731

PiperOrigin-RevId: 719627011
2025-01-25 05:32:42 -08:00
Yash Katariya
d28c3fa409 Replace Hidden/Visible/Collective AxisTypes names with Auto/Explicit/Manual.
PiperOrigin-RevId: 719561729
2025-01-24 23:21:13 -08:00
Yang Chen
08d81e45d4 Use backend._get_all_devices() to validate devices.
PiperOrigin-RevId: 719367913
2025-01-24 11:09:16 -08:00
Peter Hawkins
cbc2d623fb Don't computing forwarding information if we're going to inline.
Computing forwarding information is pointless because inlining does everything forwarding would do.

PiperOrigin-RevId: 719367022
2025-01-24 11:06:58 -08:00
jax authors
7a23d1d666 Merge pull request #25963 from dfm:dce-custom-star
PiperOrigin-RevId: 719362579
2025-01-24 10:55:26 -08:00
jax authors
407c5b1c25 Merge pull request #25839 from Rifur13:paged
PiperOrigin-RevId: 719357832
2025-01-24 10:42:19 -08:00
jax authors
726abc9c31 Merge pull request #26082 from Saransh-cpp:index-update-syntax-link
PiperOrigin-RevId: 719356167
2025-01-24 10:37:07 -08:00
jax authors
9fb423ed13 Update XLA dependency to use revision
894a70ef68.

PiperOrigin-RevId: 719346771
2025-01-24 10:12:09 -08:00
Saransh Chopra
9e89ae7a19 docs: outdated link for index update syntax 2025-01-24 17:36:35 +00:00
Gleb Pobudzey
e0b38f4e56 Adding GPU paged attention kernel 2025-01-24 17:13:02 +00:00
Justin Fu
617e79f8b6 [Mosaic GPU] Add implementation of FA3 with pipeline emitter.
PiperOrigin-RevId: 719312197
2025-01-24 08:26:07 -08:00
Adam Paszke
c10b9b88f2 [Pallas:MGPU] Add helpers to make writing core_map kernels less verbose
Also add small "getting started" examples that use the helpers in tests.

PiperOrigin-RevId: 719303512
2025-01-24 07:59:26 -08:00
jax authors
33ec6294b8 Merge pull request #26072 from shoyer:test_util_doc
PiperOrigin-RevId: 719287829
2025-01-24 07:02:08 -08:00
jax authors
1f23253bff Merge pull request #26046 from Cjkkkk:fix_cudnn_sdpa_dbias_error_tolerance
PiperOrigin-RevId: 719285005
2025-01-24 06:52:23 -08:00
jax authors
20f02c4973 Merge pull request #26076 from gnecula:fix_ps
PiperOrigin-RevId: 719284560
2025-01-24 06:50:45 -08:00
Sergei Lebedev
9ee7123c39 [mosaic_gpu] Fixed mosaic_gpu-serde pass registration
We previously registered the pass in the :_mosaic_gpu_ext which didn't work
because the extension has its own pass registry. The fix instead is to move
the registration to :register_jax_dialects in jaxlib.

PiperOrigin-RevId: 719280601
2025-01-24 06:35:54 -08:00
Yash Katariya
46d8cd2a71 Don't pass dtype to lax_internal._zero
PiperOrigin-RevId: 719273092
2025-01-24 06:06:38 -08:00
Adam Paszke
7043b852ec [Mosaic GPU] Add basic support for TMA with sub-byte types
PiperOrigin-RevId: 719240287
2025-01-24 03:54:12 -08:00
George Necula
6dd1234707 [export] Fix mis-used of NamedSharding in export tests 2025-01-24 09:18:02 +02:00
Abhinav Gunjal
313e35a214 Remove all MHLO uses, replace it with StableHLO
PiperOrigin-RevId: 719140874
2025-01-23 21:35:36 -08:00
Jevin Jiang
8e1f956804 [Mosaic TPU] Use vmask pack if possible for mask's bitwidth change and introduce relayout op.
PiperOrigin-RevId: 719089676
2025-01-23 18:15:08 -08:00
Parker Schuh
3864512b72 Move transfer python bindings into jax.
PiperOrigin-RevId: 719082208
2025-01-23 17:52:57 -08:00
Stephan Hoyer
b2afb5bf4f add docstrings for check_vjp and check_jvp 2025-01-23 17:31:13 -08:00
Stephan Hoyer
458f6a6efe Add jax.test_util to public API docs 2025-01-23 16:04:35 -08:00
Hyeontaek Lim
1d016962c5 [JAX] Optimize array shard reordering
This change adds a C++ implementation that uses `xla::ifrt::RemapArrays` to
reorder shards of an array. This avoids creating intermediate single-device
arrays and accelerates reordering shards within `jax.device_put()`
implementation.

PiperOrigin-RevId: 718998621
2025-01-23 13:45:59 -08:00
jax authors
8442d64a02 Merge pull request #25116 from wenscarl:fp8_e8m0fnu
PiperOrigin-RevId: 718996844
2025-01-23 13:41:35 -08:00
Bixia Zheng
e33cc428a3 [jax:custom_partitioning] Extend the custom partitioning API to accept a
Callable object that produces a sharding_rule through inspecting the operands
and results.

PiperOrigin-RevId: 718992627
2025-01-23 13:29:12 -08:00
Dan Foreman-Mackey
28d573354b Add DCE rules for custom_jvp and custom_vjp. 2025-01-23 15:22:43 -05:00
Dan Foreman-Mackey
e3b3b913f7 Add an experimental interface for customizing DCE behavior.
We use dead code elimination (DCE) throughout JAX core to remove unused computations from Jaxprs. This typically works transparently when we're just using `lax` primitives, but opaque calls to `pallas_call` or `ffi_call` can't be cleaned up this way. For many kernels however, the author will know how to generate a more efficient call for specific patterns of used outputs, so it is useful to provide a mechanism for customizing this behavior.

In https://github.com/jax-ml/jax/pull/22735, I attempted to automatically tackle one specific example of this that comes up frequently, but there have been feature requests for a more general API. This version is bare bones and probably rough around the edges, but it could be a useful starting point for iteration.

PiperOrigin-RevId: 718950828
2025-01-23 11:38:47 -08:00
jax authors
f4313bb64d Merge pull request #26063 from jakevdp:ensure-arraylike
PiperOrigin-RevId: 718932322
2025-01-23 10:53:03 -08:00
jax authors
281ce64678 Update XLA dependency to use revision
e795171be9.

PiperOrigin-RevId: 718913088
2025-01-23 10:05:26 -08:00
Bart Chrzaszcz
db8c8fc37c #sdy unskip JAX Shardy tests that are already passing
PiperOrigin-RevId: 718898708
2025-01-23 09:26:38 -08:00