24961 Commits

Author SHA1 Message Date
Peter Hawkins
51b9fe3010 [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00
Adam Paszke
f96339be1e [Mosaic TPU] Be much more aggressive in inferring large 2nd minor layouts for 16-bit types on v6
This often lets us avoid ambiguities between selecting the (8, 128) and (16, 128) tiling,
by biasing the layout inference to prefer the latter.

PiperOrigin-RevId: 713270421
2025-01-08 06:30:36 -08:00
Adam Paszke
5fd1b2f825 [Mosaic TPU] Add support for second minor broadcasts with packed types
PiperOrigin-RevId: 713259707
2025-01-08 05:45:02 -08:00
Adam Paszke
e954930eaf [Mosaic TPU] Add support for true divide in bf16 on TPUv6
PiperOrigin-RevId: 713247480
2025-01-08 04:49:22 -08:00
Tzu-Wei Sung
bf94389b08 [Mosaic] Use tpu::CreateMask for getX32VmaskByPaddingEnd.
It was cmp + iota before.

PiperOrigin-RevId: 713240888
2025-01-08 04:18:53 -08:00
jax authors
4718121efe Merge pull request #25754 from andportnoy:patch-4
PiperOrigin-RevId: 713222111
2025-01-08 02:57:20 -08:00
Sergei Lebedev
90201ce2b7 Removed leftover mentions of xmap from the code
PiperOrigin-RevId: 713202387
2025-01-08 01:39:38 -08:00
jax authors
81db3219b7 Merge pull request #25594 from zhenying-liu:activation-offloading-doc
PiperOrigin-RevId: 713170813
2025-01-07 23:26:21 -08:00
jax authors
1bd781d992 Add JAX events that have time spans, not only durations.
Log such events for log_elapsed_time.

The rationale for not replacing durations with it is that it appears that
record_event_duration_secs() is widely used outside of the code of JAX itself.

PiperOrigin-RevId: 713167192
2025-01-07 23:08:30 -08:00
Jane Liu
21fb171ef9
Merge branch 'jax-ml:main' into activation-offloading-doc 2025-01-07 21:20:29 -08:00
jax authors
6d08f36f5b Merge pull request #25761 from jakevdp:array-api-update
PiperOrigin-RevId: 713110147
2025-01-07 18:23:09 -08:00
Yash Katariya
755d6cdad8 [sharding_in_types] Aval sharding under full auto mode should contain None and not UNCONSTRAINED because axis_types + pspec give the full picture.
PiperOrigin-RevId: 713105375
2025-01-07 18:04:20 -08:00
Sharad Vikram
7be127f23c [Pallas] Improvements to core_map
PiperOrigin-RevId: 713075852
2025-01-07 16:18:30 -08:00
Peter Hawkins
392a851769 Increase the minimum SciPy version to 1.11.1.
(1.11.0 was yanked from PyPi because of licensing problems, so 1.11.1 is the oldest 1.11 release.)

PiperOrigin-RevId: 713073731
2025-01-07 16:10:45 -08:00
Jake VanderPlas
f6c9e87d97 [array api] update test suite to latest commit 2025-01-07 13:58:14 -08:00
jax authors
f1777d5b05 Merge pull request #25042 from dfm:ffi-example-input-output-alias
PiperOrigin-RevId: 712979906
2025-01-07 11:20:52 -08:00
Zixuan Jiang
64c0f62ec4 Sort manual axes when lowering jax.shard_map to sdy.manual_computation, which ensures the determinism in the generated sdy.manual_computation.
PiperOrigin-RevId: 712973327
2025-01-07 11:02:55 -08:00
Dan Foreman-Mackey
62656b32db Add an example demonstrating input-output aliasing with the FFI. 2025-01-07 13:21:59 -05:00
jax authors
00c363e15d Update XLA dependency to use revision
9b8f679bd2.

PiperOrigin-RevId: 712940327
2025-01-07 09:31:20 -08:00
Justin Fu
8c9a539405 [Pallas] Fix pallas_call lowering mutating compiler params during Triton lowering.
Addresses: https://github.com/jax-ml/jax/issues/25714
PiperOrigin-RevId: 712930760
2025-01-07 09:01:51 -08:00
jax authors
4023810565 [AutoPGLE] FIx PGLE kokoro test failures.
PiperOrigin-RevId: 712930537
2025-01-07 08:59:59 -08:00
jax authors
57c2afe7a8 Merge pull request #25441 from Exferro:fixed_advanced_autodiff_doc
PiperOrigin-RevId: 712929769
2025-01-07 08:56:05 -08:00
George Necula
fdb6af82d2 Clean up backend_or_name vs. platforms in lowering code.
It turns out that the backend is rarely needed when lowering, e.g.,
for lowering callbacks. Whenever we need the backend for lowering,
we must be in single-platform lowering mode (`len(platforms) == 1`)
and we can look up the backend from `platforms[0]`.

However, in some rare cases we can have a custom `XlaBackend` whose
platform matches `platforms[0]`. We rename `backend_or_name` to just `backend`
and we restrict its type to be an optional `XlaBackend` (not a platform
string).

PiperOrigin-RevId: 712926140
2025-01-07 08:42:57 -08:00
Andrey Portnoy
5b80892169
[Mosaic GPU] Use num_q_heads=2 in flash_attention.py
Previously with 4 heads the reference function `ref` would allocate 32 GiB since it materializes large intermediate tensors. That causes CI on an 80GB H100 to run out of memory when 4 tests run in parallel. `num_q_heads=2` allows us to test multiple heads while cutting memory in half.
2025-01-07 10:31:56 -05:00
Dan Foreman-Mackey
a7f384cc6e Add a register_custom_type_id function to the GPU plugins.
This enables dynamic registration of custom FFI types on the appropriate platform via PJRT.

PiperOrigin-RevId: 712904085
2025-01-07 07:29:38 -08:00
Aleksei Malyshev
f881f507d6 Update the advanced autodiff tutorial and replace some vmap with grad 2025-01-07 15:56:23 +01:00
jax authors
853af56007 Merge pull request #25748 from shoyer:divmod
PiperOrigin-RevId: 712864349
2025-01-07 04:44:23 -08:00
jax authors
56f0f9534d Merge pull request #25633 from dfm:move-ffi
PiperOrigin-RevId: 712863350
2025-01-07 04:40:21 -08:00
jax authors
712bece2c8 Merge pull request #25731 from gnecula:poly_random_even
PiperOrigin-RevId: 712826758
2025-01-07 02:06:40 -08:00
Stephan Hoyer
7fb68cac20 Fix type signature for __divmod__ 2025-01-07 00:24:24 -08:00
George Necula
bc3306c8bc [shape_poly] Improve threefry with symbolic shapes
Previously, we could only handle threefry for the case when
it was possible to tell statically that the size of the `count`
array is even or odd. This meant that often we had to add a constraint
that one of the dimensions is even.

Here we rewrite the handling of threefry to not require a Python-level
conditional about evenness of the size of the count array. We use
a couple of `lax.dynamic_slice` rather than a `lax.split`.

We also generalize the tests to cases where the size if fully symbolic,
and we cannot tell statically that it is even.
2025-01-07 09:10:04 +02:00
jax authors
7997f080f2 Merge pull request #25728 from zhenying-liu:scipy.misc
PiperOrigin-RevId: 712707311
2025-01-06 17:53:10 -08:00
Yash Katariya
23eaf2160a Make inspect_array_sharding work without mesh context manager too.
PiperOrigin-RevId: 712702329
2025-01-06 17:32:15 -08:00
jax authors
b304b9efd5 Merge pull request #25740 from jakevdp:remove-array-api
PiperOrigin-RevId: 712689888
2025-01-06 16:32:54 -08:00
Jake VanderPlas
c7b0d681bd Remove deprecated jax.experimental.array_api 2025-01-06 15:19:02 -08:00
Sharad Vikram
4caa263a94 [Mosaic TPU] Add some elementwise canonicalizations
PiperOrigin-RevId: 712671502
2025-01-06 15:10:02 -08:00
Parker Schuh
b49ba6553c Remove the need for check_rep for with_sharding_constraint.
PiperOrigin-RevId: 712630197
2025-01-06 12:59:22 -08:00
Jane Liu
77c6947a59 fix the doc error: module 'scipy.misc' has no attribute 'face' 2025-01-06 12:00:02 -08:00
Peter Hawkins
90d8f37863 Rename pybind_extension to nanobind_extension.
We have no remaining uses of pybind11 outside a GPU custom call example.

PiperOrigin-RevId: 712608834
2025-01-06 11:53:44 -08:00
Peter Hawkins
61dd041225 Suppress MSAN warnings from SVD that are showing up in CI.
In our MSAN CI, the copy of LAPACK we use is not MSAN-instrumented, leading to false positives. Suppress those false-positives via annotations.

PiperOrigin-RevId: 712607044
2025-01-06 11:49:05 -08:00
jax authors
52cc5c7f05 Merge pull request #25214 from jakevdp:einsum-optimize
PiperOrigin-RevId: 712603103
2025-01-06 11:37:54 -08:00
Jevin Jiang
9f842909ce [Mosaic TPU] Validate inserted layout in relayout-insertion pass.
PiperOrigin-RevId: 712595778
2025-01-06 11:15:47 -08:00
jax authors
634b45bf00 Merge pull request #25699 from yliu120:fix_iota
PiperOrigin-RevId: 712594991
2025-01-06 11:13:28 -08:00
Jake VanderPlas
2f7204fff6 jnp.einsum: default to optimize='auto' 2025-01-06 11:02:31 -08:00
John QiangZhang
c39e38fe5a bazel: export serialization.fbs for downstream usage
PiperOrigin-RevId: 712587802
2025-01-06 10:57:35 -08:00
jax authors
74be8bd99f Merge pull request #25675 from jakevdp:dep-lpmn
PiperOrigin-RevId: 712579230
2025-01-06 10:37:10 -08:00
jax authors
18b193cbbd Update XLA dependency to use revision
1a6361a734.

PiperOrigin-RevId: 712558157
2025-01-06 09:41:02 -08:00
Jake VanderPlas
245a13a329 Deprecate scipy.special.lpmn & lpmn_values 2025-01-06 09:31:15 -08:00
Mark Sandler
6c87bf389f Fixes tril/triu comments (they were flipped)
PiperOrigin-RevId: 712544847
2025-01-06 08:55:11 -08:00
Yunlong Liu
3ff000ee3e fix the degenerated case 2025-01-06 16:08:07 +00:00