25370 Commits

Author SHA1 Message Date
jax authors
e6e7621f0b Update XLA dependency to use revision
c8ebe14129.

PiperOrigin-RevId: 721764550
2025-01-31 07:19:07 -08:00
Michael Hudgins
873db96f9f [CI] Update naming comments in workflow files.
Comments have been added to indicate fields that if changed require a blocking presubmit check to be updated.

PiperOrigin-RevId: 721753931
2025-01-31 06:36:16 -08:00
Adam Paszke
cadfcc7a1b [Mosaic GPU] Allow uneven partitioning of dimensions into tiles in TileTransform
PiperOrigin-RevId: 721705218
2025-01-31 03:05:44 -08:00
Adam Paszke
10ac6b7e12 [Mosaic GPU] Add support for tiled swizzle=16 (i.e. no swizzle) loads and stores
The tiling still makes it possible to do it without bank conflicts.

PiperOrigin-RevId: 721701635
2025-01-31 02:49:59 -08:00
jax authors
efacec4cfb Merge pull request #26218 from jakevdp:bump-array-api
PiperOrigin-RevId: 721572376
2025-01-30 17:30:42 -08:00
Peter Hawkins
60fad99c9c Fix CI failures in xla_metadata_test.
PiperOrigin-RevId: 721572019
2025-01-30 17:28:54 -08:00
Peter Hawkins
0705ec2ca4 Pass filter=data to tar extractall to avoid a warning under Python 3.12+
PiperOrigin-RevId: 721571944
2025-01-30 17:27:24 -08:00
Peter Hawkins
a2f7824c98 Disable a debug_info_test test that fails in CI.
This test is sometimes reporting 4 warnings, probably because of tracing cache hits. To be correct, this test probably needs to use its own unique functions that are not shared with other test cases.

PiperOrigin-RevId: 721571459
2025-01-30 17:25:18 -08:00
Jevin Jiang
785a63ad0f [Mosaic TPU] Support non-32 bit mask relayout
PiperOrigin-RevId: 721552594
2025-01-30 16:13:23 -08:00
jax authors
9dfe03c5ea Merge pull request #26221 from justinjfu:remove_smap_test
PiperOrigin-RevId: 721541175
2025-01-30 15:34:33 -08:00
Yash Katariya
1f33cad321 remove checks since they are redundant and we can change out_aval because of various reasons
PiperOrigin-RevId: 721535417
2025-01-30 15:14:34 -08:00
Justin Fu
834e0d7c87 Disable source mapper test for optimized hlo 2025-01-30 14:07:54 -08:00
Jake VanderPlas
7f4796a2fc [array api] bump test suite to latest commit 2025-01-30 13:36:29 -08:00
Yash Katariya
9107ee4a22 Do automatic casting from auto -> manual when the context mesh is manual and avals are in auto mode. This happens when values are being closed over in a shard_map. The casting is happening at lax level but we can move this to a different place later on.
PiperOrigin-RevId: 721495804
2025-01-30 13:14:04 -08:00
Gunhyun Park
a8df383ccf Fix lax.ragged_all_to_all degenerate case
In a singleton group case, unlike regular all_to_all, the ragged op becomes a generic equivalent of DynamicUpdateSlice, except update size is not statically known. This operation can't be expressed with standard HLO instructions -- the backend will handle this case separately.

Added small improvement to error messages.

PiperOrigin-RevId: 721473063
2025-01-30 12:05:02 -08:00
Yash Katariya
f4e2c6c34c Try to match out_spec with in_spec if both shardings are full auto and they are equivalent to each other. This is because of backwards compatibility reasons where tests expect the in and out shardings to match.
PiperOrigin-RevId: 721470917
2025-01-30 11:59:57 -08:00
jax authors
2e40549c38 Merge pull request #26208 from dfm:disable-ragged-test
PiperOrigin-RevId: 721433612
2025-01-30 10:16:15 -08:00
Tzu-Wei Sung
d4758b6d5e [Mosaic][NFC] Factor out xla-array related utils in a separate file.
Also added tests.

PiperOrigin-RevId: 721424194
2025-01-30 09:49:41 -08:00
Emily Fertig
bb951136e9 Return arrays from ArrayImpl._check_and_rearrange.
This is in preparation for a larger change, so that input buffers can be checked before Array creation in XLA and the user gets more helpful JAX error messages instead of XLA errors.

PiperOrigin-RevId: 721412760
2025-01-30 09:10:50 -08:00
Benjamin Chetioui
d8f3b33ae4 [Mosaic GPU] Eliminate the arrive attribute from mosaic_gpu.async_load.
We plan to explicitly issue an `expect_tx` operation all the time when using
the dialect.

PiperOrigin-RevId: 721411949
2025-01-30 09:08:45 -08:00
jax authors
1003ba93c3 Merge pull request #26150 from jreiffers:main
PiperOrigin-RevId: 721400896
2025-01-30 08:32:20 -08:00
jax authors
4af0481b7d Merge pull request #26189 from vfdev-5:add-ci-cache-for-cpython-build
PiperOrigin-RevId: 721398430
2025-01-30 08:23:10 -08:00
jax authors
c3814921e6 Update XLA dependency to use revision
cb96ba024f.

PiperOrigin-RevId: 721395045
2025-01-30 08:12:21 -08:00
vfdev-5
07a7b7debb Added cpython cache step 2025-01-30 17:07:02 +01:00
Dan Foreman-Mackey
9442f90cb2 [XLA:CPU] Add CPU client support for layout modes.
The main motivation for this change is to support user-specified input and output layouts for JAX interoperability with other libraries. For example, https://github.com/jax-ml/jax/issues/25066.

The logic is more-or-less a direct copy of the implementation in `PjRtStreamExecutorClient`.

PiperOrigin-RevId: 721382281
2025-01-30 07:27:02 -08:00
Dan Foreman-Mackey
19c17bb28b Skip ragged collective tests on CPU. 2025-01-30 10:03:53 -05:00
Dimitar (Mitko) Asenov
6214c25a6d [Mosaic GPU] Add ArriveExpect and Wait ops on dialect barriers with explicit handling of parities
This makes dialect tests in mgpu_test.py truly express the entire computation at the warpgroup level.

PiperOrigin-RevId: 721371327
2025-01-30 06:44:32 -08:00
Benjamin Chetioui
46512e684b [Mosaic GPU][NFC] Fix wrong type annotations, and do some NFC cleanups.
PiperOrigin-RevId: 721350296
2025-01-30 05:13:58 -08:00
jax authors
4c8d7379dd Merge pull request #26100 from gnecula:debug_info_no_pe_debug_info_3
PiperOrigin-RevId: 721247153
2025-01-29 22:30:05 -08:00
George Necula
32c98b9a76 [better_errors] Refactor more uses of pe.tracing_debug_info (part 3)
We replace uses of `pe.tracing_debug_info` with with `api_util.tracing_debug_info`,
which uses the actual args and kwargs, instead of `in_tree` to manufacture fake
args and kwargs. This ends up being more accurate, especially for `arg_names`;
see changes in debug_info_tests.py.
This means that we have to construct the debug info further upstream, before
flattening args. This will later help populate debug info in `WrappedFun` and
`Jaxpr`.

This is part 3 of a series (following #26097, #26099) for jit, pmap, checkify,
and the custom_partitioning (the last few uses).

In order to land this, I had to remove a safety check that the number of
`arg_names` and `result_paths` in a Jaxpr's debug info match the number
of Jaxpr invars and outvars, respectively. Additionally, I added two
accessors `safe_arg_names` and `safe_result_paths` to ensure that
the arg names and result paths match the expected length. These accessors
return no-op results when the lengths are not as expected.
From my testint, this happens only in Jaxprs that
are not used for lowering, hence there is no actual user-visible
change here. Simply, more internal Jaxprs are getting debug_info
and in some cases the `arg_names` and `result_paths` are not correct.
Still, this change is worth it because the `func_src_info` is the most
useful part of the debug info (used for leaked tracers), and that is
accurate. We will fix the `arg_names` and `result_paths` in a future change.

One can see in the changes in debug_info_test.py the improvements in the
user-visible debug info, including for `pjit` and `pmap` cases when
it was wrong.
2025-01-30 07:40:05 +02:00
Yash Katariya
d223dfc3f7 Allow multiple meshes for avals but in that case, just use empty_abstract_mesh instead of enabling computation follows data only for **Auto mode**.
PiperOrigin-RevId: 721224349
2025-01-29 20:47:34 -08:00
Justin Fu
b01111d96c Add skeleton for a multi-pass source mapper for Jaxprs/HLO to jax.experimental.
PiperOrigin-RevId: 721119935
2025-01-29 15:01:43 -08:00
jax authors
152099ee0e Merge pull request #26188 from dfm:revert-callback-docs
PiperOrigin-RevId: 721077634
2025-01-29 12:53:57 -08:00
jax authors
a1e4121bae Merge pull request #26185 from hawkinsp:tsan
PiperOrigin-RevId: 721063414
2025-01-29 12:09:19 -08:00
Sergei Lebedev
d4ced960ab Pulled DLDeviceType to XLA backend mapping into a global
I also updated `to_dlpack` and `from_dlpack` to handle `KeyError` instead of `TypeError`, because I think `TypeError` was never actually raised.

PiperOrigin-RevId: 721052736
2025-01-29 11:38:50 -08:00
Gleb Pobudzey
8c02731a06 Increasing shard count and removing asan builds to prevent timeouts.
PiperOrigin-RevId: 721038112
2025-01-29 10:58:51 -08:00
jax authors
0a30ef3c67 Merge pull request #25980 from codinglover222:jit-vmap-compile-test
PiperOrigin-RevId: 721035412
2025-01-29 10:52:15 -08:00
Peter Hawkins
4f00d451aa Update the list of tsan suppressions.
Lower --local_test_jobs in the bazel runner, in the hope that this lowers the number of test timeouts. I suspect we are simply oversubscribing the machine with multiple threads in each test shard.
2025-01-29 10:46:36 -08:00
Dan Foreman-Mackey
2ae018ed8e Unconditionally skip async deadlock test for pure_callback.
PiperOrigin-RevId: 721012451
2025-01-29 09:49:01 -08:00
Yash Katariya
dcb28f1218 [sharding_in_types] Add vmap + explicit sharding support. The main changes are:
* Track `explicit_mesh_axis` on `AxisData`.
* Modify `unmapped_aval` to the the above explicit mesh axis and insert it into the right place in the sharding so out_shardings are correct.
* Make `matchaxis` also handle shardings correctly
* All mapped dimensions should be sharded the same way
* spmd_axis_name and explicit sharded arrays cannot be used together
* `out_shardings` parameter on `dot_general`, `broadcast_in_dim`, `reshape`, `reshard` and `mesh_cast` is handled correctly in presence of vmap.

This should eventually help us get rid of `spmd_axis_name` from `vmap`.

PiperOrigin-RevId: 721007659
2025-01-29 09:34:27 -08:00
Bixia Zheng
20843643ab [jax:custom_partitioning] Make propagate_user_sharding default to None.
Revise documentation for sharding_rule and add a link to jax-shardy-guide.

PiperOrigin-RevId: 721001922
2025-01-29 09:14:35 -08:00
Jake VanderPlas
955e7c4793 Internal: avoid adding _DimExpr to dtypes._weak_types
This causes problems because internal code assumes it will not be modified. We replace this with an internal registration mechanism.

PiperOrigin-RevId: 721000907
2025-01-29 09:11:02 -08:00
jax authors
c01273c603 Update XLA dependency to use revision
621aa467ae.

PiperOrigin-RevId: 720992604
2025-01-29 08:45:51 -08:00
Dan Foreman-Mackey
e2eff1f8d5 Revert https://github.com/jax-ml/jax/pull/25982 since callbacks can now use JAX functions. 2025-01-29 11:12:32 -05:00
George Necula
8720e95570 [export] Fixes for export_harnesses_multi_platform_test.
The test was mistakenly skipped on slow tests. This is a highly-parameterized test, and if there are some individual instances that are slow, only those should be skipped. The slowest of all instances takes 3s.

I have also ensured that when running natively, we also use jit, like in export mode, to reduce chances that we see numerical discrepancies between eager and jit mode. This fixes a failure on GPU in Kokoro.

PiperOrigin-RevId: 720946449
2025-01-29 06:12:33 -08:00
Adam Paszke
c9dfdb4e23 Relax static offset restriction on memref_ptr for sub-byte types
It simply assumes that the base offset is a multiple of byte packing

PiperOrigin-RevId: 720919148
2025-01-29 04:27:36 -08:00
Dan Foreman-Mackey
9d39ab305a Disable async dispatch within the body of a host callback.
This is a follow up to https://github.com/jax-ml/jax/pull/26160 and https://github.com/openxla/xla/pull/21980. See those PRs for more discussion of the motivation for this change.

In this PR, we disable CPU asynchronous execution when running within the body of a host callback, because this can cause deadlocks.

PiperOrigin-RevId: 720918318
2025-01-29 04:24:33 -08:00
jax authors
a459e7e4cd Merge pull request #26151 from gnecula:debug_info_collect_lowered_jaxprs
PiperOrigin-RevId: 720911587
2025-01-29 04:00:03 -08:00
Christos Perivolaropoulos
f2f7a150f9 [mosaic_gpu] Allow tiled array instead of wgmma.
PiperOrigin-RevId: 720908864
2025-01-29 03:48:14 -08:00
Dan Foreman-Mackey
83457c115a Always dispatch CPU executables synchronously when they include callbacks.
As discussed in https://github.com/jax-ml/jax/issues/25861 and https://github.com/jax-ml/jax/issues/24255, using host callbacks within an asynchronously-dispatched CPU executable can deadlock when the body of the callback itself asynchronously dispatches JAX CPU code. My rough understanding of the problem is that the XLA intra op thread pool gets filled up with callbacks waiting for their body to execute, but there aren't enough resources to schedule the inner computations.

There's probably a better way to fix this within XLA:CPU, but the temporary fix that I've come up with is to disable asynchronous dispatch on CPU when either:

1. Executing a program that includes any host callbacks, or
2. when running within the body of a callback.

It seems like both of these conditions are needed in general because I was able to find test cases that failed with just one or the other implemented.

This PR includes just the first change, and the second will be implemented in a follow-up.

PiperOrigin-RevId: 720777713
2025-01-28 18:23:35 -08:00