25032 Commits

Author SHA1 Message Date
David Boetius
6e9a34f791
Move _reduce_window docstring to public func lax.reduce_window. 2025-01-09 13:31:48 +01:00
Jake VanderPlas
640cb009f1 bazel visibility change
PiperOrigin-RevId: 713488528
2025-01-08 18:34:10 -08:00
Yash Katariya
b2b38679e2 Make sharding_in_types work with Shardy
PiperOrigin-RevId: 713479962
2025-01-08 18:05:43 -08:00
Yash Katariya
fb832afc00 Respect the original memory kind on reshape, transpose and replicate methods of PositionalSharding. Fixes https://github.com/jax-ml/jax/issues/25769
PiperOrigin-RevId: 713446871
2025-01-08 16:03:03 -08:00
Matthew Johnson
f0392a1535 fix grad(logsumexp) to produce 0s where where is False 2025-01-08 23:38:06 +00:00
Bart Chrzaszcz
cbcc883ea3 #sdy add repr for Sdy ArraySharding and DimSharding
PiperOrigin-RevId: 713422071
2025-01-08 14:41:41 -08:00
jax authors
196eec8296 Merge pull request #25786 from vfdev-5:add-313-ft-configs
PiperOrigin-RevId: 713415451
2025-01-08 14:21:40 -08:00
Peter Hawkins
e20523c2e3 Make api_test.py work when test cases are run using multiple threads.
* keep track of all known config.State objects so we can find them by name.
* change `@jtu.with_config` to default to setting thread-local configurations.
* add a `@jtu.with_global_config` for those things that truly need to be set globally.
* add a `@jtu.thread_local_config_context` that overrides thread-local configuration options, just as `jtu.global_config_context` overrides global configuration options.
* change the pretty printer color option to be a State so it can be set locally.
* tag a number of tests as thread-hostile, in particular tests that check counters for numbers of compilations, rely on garbage collection having particular semantics, or look at log output.

PiperOrigin-RevId: 713411171
2025-01-08 14:09:07 -08:00
vfdev-5
00806ddaf5 Added 3.13 ft requirements lock file and updated WORKSPACE 2025-01-08 22:47:29 +01:00
Bixia Zheng
c4ac0dd6bd Implement the extension to the custom_partitioning API.
Add a sharding rule string and trailing factor_sizes to def_partition, to
provide a sharding rule specification when Shardy is used. We use this
information to construct a SdyShardingRule and invoke SdyShardingRule.build
during MLIR lowering.

Extend custom_partitioner tests in  pjit_test.py for Shardy sharding rule.

PiperOrigin-RevId: 713399604
2025-01-08 13:34:47 -08:00
Yash Katariya
b3833dc705 Match the behavior on single host wrt multi-host if tiled=False. Fixes https://github.com/jax-ml/jax/issues/25783
PiperOrigin-RevId: 713398173
2025-01-08 13:31:35 -08:00
jax authors
6e1f060ad3 Merge pull request #25527 from vfdev-5:single-python-version-build-py
PiperOrigin-RevId: 713365267
2025-01-08 11:49:59 -08:00
Justin Fu
d99a637d8b [Mosaic GPU] Allow multiple indexing on refs
PiperOrigin-RevId: 713355813
2025-01-08 11:21:19 -08:00
Yash Katariya
3848f0d2ac [sharding_in_types] Functions like einsum, reshape, broadcast_in_dim, broadcasted_iota, convert_element_type and sharding_cast that take out_sharding as an argument in their signature should also allow PartitionSpec instead of just NamedSharding as an input.
If PartitionSpec is passed, the mesh is read from the context. The primitives though take `NamedSharding` only. The conversion from `PartitionSpec` to `NamedSharding` happens above `.bind`.

We also raise an error if `PartitionSpec` contain mesh axis names that are of type Auto or Collective for the above functions.

PiperOrigin-RevId: 713352542
2025-01-08 11:11:16 -08:00
Sharad Vikram
c1a60c676a [Pallas] Add empty/empty_like helper functions
PiperOrigin-RevId: 713344151
2025-01-08 10:49:11 -08:00
jax authors
55119490eb Update XLA dependency to use revision
1b99693486.

PiperOrigin-RevId: 713320368
2025-01-08 09:37:59 -08:00
Bart Chrzaszcz
5c097c8f62 #sdy Move Shardy mesh lift inlining pass after verification.
Before if something went wrong during JAX lowering, then instead of verification catching this, the pass would making the error message difficult to read and incorrectly pointing to the pass as the source of the error. For example
```
File "jax/_src/interpreters/mlir.py", line 1211, in lower_jaxpr_to_module
    pipeline.run(ctx.module.operation)
MLIRError: Failure while executing pass pipeline:
error:
...
'sdy.sharding_constraint' op sharding doesn't match tensor rank: 0 != 2
...
see current operation: %2 = "sdy.sharding_constraint"(%1) <{sharding = #sdy.sharding<@mesh, []>}> : (tensor<8x2xf64>) -> tensor<8x2xf64>
```
PiperOrigin-RevId: 713314555
2025-01-08 09:17:54 -08:00
Peter Hawkins
0389d617c8 Add a unittest test extension that runs test cases in parallel using threads.
This change does not yet do the work necessary to make any tests pass with threading enabled, which will come in future changes.

This approach is broadly inspired by a6d205dd4c/testtools/testsuite.py (L113) and by unittest-ft.

We add a custom TestResult class that batches up any test result actions and applies them under a lock. We also add a custom TestSuite class that runs individual test cases in parallel using a thread-pool.

We need a reader-writer lock to implement a `@jtu.thread_hostile_test` decorator, which we do by adding bindings around absl::Mutex to jaxlib.

PiperOrigin-RevId: 713312937
2025-01-08 09:11:47 -08:00
Peter Hawkins
3fa557289a Port tests away from setUpClass and setUpModule to setUp alone.
This change prepares for upcoming changes in which we run tests in parallel using threads, which we are doing partially to test free threading but also partially to speed up TPU tests via thread-parallelism.

If independent tests run in parallel in no particular order, there's no natural scope around which to call setUpClass or SetUpModule. But for JAX tests this never seems necessary: we can just do the same work in setUp() or do it globally.

PiperOrigin-RevId: 713296722
2025-01-08 08:14:50 -08:00
Sergei Lebedev
f1f98afee8 [pallas:mosaic_gpu] Fix the tests following the changes to pl.core_map
PiperOrigin-RevId: 713283207
2025-01-08 07:24:08 -08:00
Peter Hawkins
51b9fe3010 [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00
Adam Paszke
f96339be1e [Mosaic TPU] Be much more aggressive in inferring large 2nd minor layouts for 16-bit types on v6
This often lets us avoid ambiguities between selecting the (8, 128) and (16, 128) tiling,
by biasing the layout inference to prefer the latter.

PiperOrigin-RevId: 713270421
2025-01-08 06:30:36 -08:00
Adam Paszke
5fd1b2f825 [Mosaic TPU] Add support for second minor broadcasts with packed types
PiperOrigin-RevId: 713259707
2025-01-08 05:45:02 -08:00
Adam Paszke
e954930eaf [Mosaic TPU] Add support for true divide in bf16 on TPUv6
PiperOrigin-RevId: 713247480
2025-01-08 04:49:22 -08:00
Tzu-Wei Sung
bf94389b08 [Mosaic] Use tpu::CreateMask for getX32VmaskByPaddingEnd.
It was cmp + iota before.

PiperOrigin-RevId: 713240888
2025-01-08 04:18:53 -08:00
jax authors
4718121efe Merge pull request #25754 from andportnoy:patch-4
PiperOrigin-RevId: 713222111
2025-01-08 02:57:20 -08:00
Sergei Lebedev
90201ce2b7 Removed leftover mentions of xmap from the code
PiperOrigin-RevId: 713202387
2025-01-08 01:39:38 -08:00
jax authors
81db3219b7 Merge pull request #25594 from zhenying-liu:activation-offloading-doc
PiperOrigin-RevId: 713170813
2025-01-07 23:26:21 -08:00
jax authors
1bd781d992 Add JAX events that have time spans, not only durations.
Log such events for log_elapsed_time.

The rationale for not replacing durations with it is that it appears that
record_event_duration_secs() is widely used outside of the code of JAX itself.

PiperOrigin-RevId: 713167192
2025-01-07 23:08:30 -08:00
Jane Liu
21fb171ef9
Merge branch 'jax-ml:main' into activation-offloading-doc 2025-01-07 21:20:29 -08:00
jax authors
6d08f36f5b Merge pull request #25761 from jakevdp:array-api-update
PiperOrigin-RevId: 713110147
2025-01-07 18:23:09 -08:00
Yash Katariya
755d6cdad8 [sharding_in_types] Aval sharding under full auto mode should contain None and not UNCONSTRAINED because axis_types + pspec give the full picture.
PiperOrigin-RevId: 713105375
2025-01-07 18:04:20 -08:00
Sharad Vikram
7be127f23c [Pallas] Improvements to core_map
PiperOrigin-RevId: 713075852
2025-01-07 16:18:30 -08:00
Peter Hawkins
392a851769 Increase the minimum SciPy version to 1.11.1.
(1.11.0 was yanked from PyPi because of licensing problems, so 1.11.1 is the oldest 1.11 release.)

PiperOrigin-RevId: 713073731
2025-01-07 16:10:45 -08:00
Jake VanderPlas
f6c9e87d97 [array api] update test suite to latest commit 2025-01-07 13:58:14 -08:00
jax authors
f1777d5b05 Merge pull request #25042 from dfm:ffi-example-input-output-alias
PiperOrigin-RevId: 712979906
2025-01-07 11:20:52 -08:00
Zixuan Jiang
64c0f62ec4 Sort manual axes when lowering jax.shard_map to sdy.manual_computation, which ensures the determinism in the generated sdy.manual_computation.
PiperOrigin-RevId: 712973327
2025-01-07 11:02:55 -08:00
Dan Foreman-Mackey
62656b32db Add an example demonstrating input-output aliasing with the FFI. 2025-01-07 13:21:59 -05:00
jax authors
00c363e15d Update XLA dependency to use revision
9b8f679bd2.

PiperOrigin-RevId: 712940327
2025-01-07 09:31:20 -08:00
Justin Fu
8c9a539405 [Pallas] Fix pallas_call lowering mutating compiler params during Triton lowering.
Addresses: https://github.com/jax-ml/jax/issues/25714
PiperOrigin-RevId: 712930760
2025-01-07 09:01:51 -08:00
jax authors
4023810565 [AutoPGLE] FIx PGLE kokoro test failures.
PiperOrigin-RevId: 712930537
2025-01-07 08:59:59 -08:00
jax authors
57c2afe7a8 Merge pull request #25441 from Exferro:fixed_advanced_autodiff_doc
PiperOrigin-RevId: 712929769
2025-01-07 08:56:05 -08:00
George Necula
fdb6af82d2 Clean up backend_or_name vs. platforms in lowering code.
It turns out that the backend is rarely needed when lowering, e.g.,
for lowering callbacks. Whenever we need the backend for lowering,
we must be in single-platform lowering mode (`len(platforms) == 1`)
and we can look up the backend from `platforms[0]`.

However, in some rare cases we can have a custom `XlaBackend` whose
platform matches `platforms[0]`. We rename `backend_or_name` to just `backend`
and we restrict its type to be an optional `XlaBackend` (not a platform
string).

PiperOrigin-RevId: 712926140
2025-01-07 08:42:57 -08:00
Andrey Portnoy
5b80892169
[Mosaic GPU] Use num_q_heads=2 in flash_attention.py
Previously with 4 heads the reference function `ref` would allocate 32 GiB since it materializes large intermediate tensors. That causes CI on an 80GB H100 to run out of memory when 4 tests run in parallel. `num_q_heads=2` allows us to test multiple heads while cutting memory in half.
2025-01-07 10:31:56 -05:00
Dan Foreman-Mackey
a7f384cc6e Add a register_custom_type_id function to the GPU plugins.
This enables dynamic registration of custom FFI types on the appropriate platform via PJRT.

PiperOrigin-RevId: 712904085
2025-01-07 07:29:38 -08:00
Aleksei Malyshev
f881f507d6 Update the advanced autodiff tutorial and replace some vmap with grad 2025-01-07 15:56:23 +01:00
jax authors
853af56007 Merge pull request #25748 from shoyer:divmod
PiperOrigin-RevId: 712864349
2025-01-07 04:44:23 -08:00
jax authors
56f0f9534d Merge pull request #25633 from dfm:move-ffi
PiperOrigin-RevId: 712863350
2025-01-07 04:40:21 -08:00
jax authors
712bece2c8 Merge pull request #25731 from gnecula:poly_random_even
PiperOrigin-RevId: 712826758
2025-01-07 02:06:40 -08:00
Stephan Hoyer
7fb68cac20 Fix type signature for __divmod__ 2025-01-07 00:24:24 -08:00