24934 Commits

Author SHA1 Message Date
George Necula
fdb6af82d2 Clean up backend_or_name vs. platforms in lowering code.
It turns out that the backend is rarely needed when lowering, e.g.,
for lowering callbacks. Whenever we need the backend for lowering,
we must be in single-platform lowering mode (`len(platforms) == 1`)
and we can look up the backend from `platforms[0]`.

However, in some rare cases we can have a custom `XlaBackend` whose
platform matches `platforms[0]`. We rename `backend_or_name` to just `backend`
and we restrict its type to be an optional `XlaBackend` (not a platform
string).

PiperOrigin-RevId: 712926140
2025-01-07 08:42:57 -08:00
Dan Foreman-Mackey
a7f384cc6e Add a register_custom_type_id function to the GPU plugins.
This enables dynamic registration of custom FFI types on the appropriate platform via PJRT.

PiperOrigin-RevId: 712904085
2025-01-07 07:29:38 -08:00
jax authors
853af56007 Merge pull request #25748 from shoyer:divmod
PiperOrigin-RevId: 712864349
2025-01-07 04:44:23 -08:00
jax authors
56f0f9534d Merge pull request #25633 from dfm:move-ffi
PiperOrigin-RevId: 712863350
2025-01-07 04:40:21 -08:00
jax authors
712bece2c8 Merge pull request #25731 from gnecula:poly_random_even
PiperOrigin-RevId: 712826758
2025-01-07 02:06:40 -08:00
Stephan Hoyer
7fb68cac20 Fix type signature for __divmod__ 2025-01-07 00:24:24 -08:00
George Necula
bc3306c8bc [shape_poly] Improve threefry with symbolic shapes
Previously, we could only handle threefry for the case when
it was possible to tell statically that the size of the `count`
array is even or odd. This meant that often we had to add a constraint
that one of the dimensions is even.

Here we rewrite the handling of threefry to not require a Python-level
conditional about evenness of the size of the count array. We use
a couple of `lax.dynamic_slice` rather than a `lax.split`.

We also generalize the tests to cases where the size if fully symbolic,
and we cannot tell statically that it is even.
2025-01-07 09:10:04 +02:00
jax authors
7997f080f2 Merge pull request #25728 from zhenying-liu:scipy.misc
PiperOrigin-RevId: 712707311
2025-01-06 17:53:10 -08:00
Yash Katariya
23eaf2160a Make inspect_array_sharding work without mesh context manager too.
PiperOrigin-RevId: 712702329
2025-01-06 17:32:15 -08:00
jax authors
b304b9efd5 Merge pull request #25740 from jakevdp:remove-array-api
PiperOrigin-RevId: 712689888
2025-01-06 16:32:54 -08:00
Jake VanderPlas
c7b0d681bd Remove deprecated jax.experimental.array_api 2025-01-06 15:19:02 -08:00
Sharad Vikram
4caa263a94 [Mosaic TPU] Add some elementwise canonicalizations
PiperOrigin-RevId: 712671502
2025-01-06 15:10:02 -08:00
Parker Schuh
b49ba6553c Remove the need for check_rep for with_sharding_constraint.
PiperOrigin-RevId: 712630197
2025-01-06 12:59:22 -08:00
Jane Liu
77c6947a59 fix the doc error: module 'scipy.misc' has no attribute 'face' 2025-01-06 12:00:02 -08:00
Peter Hawkins
90d8f37863 Rename pybind_extension to nanobind_extension.
We have no remaining uses of pybind11 outside a GPU custom call example.

PiperOrigin-RevId: 712608834
2025-01-06 11:53:44 -08:00
Peter Hawkins
61dd041225 Suppress MSAN warnings from SVD that are showing up in CI.
In our MSAN CI, the copy of LAPACK we use is not MSAN-instrumented, leading to false positives. Suppress those false-positives via annotations.

PiperOrigin-RevId: 712607044
2025-01-06 11:49:05 -08:00
jax authors
52cc5c7f05 Merge pull request #25214 from jakevdp:einsum-optimize
PiperOrigin-RevId: 712603103
2025-01-06 11:37:54 -08:00
Jevin Jiang
9f842909ce [Mosaic TPU] Validate inserted layout in relayout-insertion pass.
PiperOrigin-RevId: 712595778
2025-01-06 11:15:47 -08:00
jax authors
634b45bf00 Merge pull request #25699 from yliu120:fix_iota
PiperOrigin-RevId: 712594991
2025-01-06 11:13:28 -08:00
Jake VanderPlas
2f7204fff6 jnp.einsum: default to optimize='auto' 2025-01-06 11:02:31 -08:00
John QiangZhang
c39e38fe5a bazel: export serialization.fbs for downstream usage
PiperOrigin-RevId: 712587802
2025-01-06 10:57:35 -08:00
jax authors
74be8bd99f Merge pull request #25675 from jakevdp:dep-lpmn
PiperOrigin-RevId: 712579230
2025-01-06 10:37:10 -08:00
jax authors
18b193cbbd Update XLA dependency to use revision
1a6361a734.

PiperOrigin-RevId: 712558157
2025-01-06 09:41:02 -08:00
Jake VanderPlas
245a13a329 Deprecate scipy.special.lpmn & lpmn_values 2025-01-06 09:31:15 -08:00
Mark Sandler
6c87bf389f Fixes tril/triu comments (they were flipped)
PiperOrigin-RevId: 712544847
2025-01-06 08:55:11 -08:00
Yunlong Liu
3ff000ee3e fix the degenerated case 2025-01-06 16:08:07 +00:00
George Necula
e87a2a5929 [shape_poly] Remove old non_negative support.
This was deprecated in January 2024, replaced by
`core_max_dim(..., 0)`.

PiperOrigin-RevId: 712523579
2025-01-06 07:36:11 -08:00
Vladimir Belitskiy
f2e210b315 Disable avxvnniint8 when building with Clang version < 19, or GCC < 13.
PiperOrigin-RevId: 712516025
2025-01-06 07:06:09 -08:00
Dan Foreman-Mackey
512d5450ae Temporarily allow deprecation warnings for scipy.special.lpmn and scipy.special.sph_harm.
These functions are deprecated in scipy 1.15.0. I'll fix this properly soon, but let's start by getting CI working again!

PiperOrigin-RevId: 712512363
2025-01-06 06:50:51 -08:00
jax authors
4de7794faf Merge pull request #25715 from ROCm:ci_build_code_fixes-upstream
PiperOrigin-RevId: 712463302
2025-01-06 03:05:14 -08:00
jax authors
d0a92c5c7d Update XLA dependency to use revision
ac6e71fe0c.

PiperOrigin-RevId: 712263421
2025-01-05 08:45:44 -08:00
jax authors
54fd738ecb Add SMEM as a supported Pallas output memory space.
PiperOrigin-RevId: 712144883
2025-01-04 19:33:18 -08:00
jax authors
9af2970042 Update XLA dependency to use revision
c12c114858.

PiperOrigin-RevId: 712052097
2025-01-04 08:55:59 -08:00
jax authors
e4278f7866 Update XLA dependency to use revision
a84e3b7f8f.

PiperOrigin-RevId: 711770147
2025-01-03 09:07:45 -08:00
jax authors
0f4677bcf6 Merge pull request #25713 from jakevdp:debug-printoptions
PiperOrigin-RevId: 711671926
2025-01-03 01:33:31 -08:00
Jake VanderPlas
330606320a jax.debug.print: respect local np.printoptions 2025-01-02 16:10:54 -08:00
Tzu-Wei Sung
57b21541a2 [Mosaic] NFC: Pull out vreg related functions to util.
These functions are related to vreg manipulation and are used in different rules.

PiperOrigin-RevId: 711484002
2025-01-02 11:50:19 -08:00
Zac Mustin
df36c29803 Compute cost-analysis on only one HLO module.
There was historically a goal to support multiple HLOs in an executable, but this work was never finished and is no longer planned so we don't need this support.

This will soon enable us to return only a dict, instead of a list of dicts with only one item.

PiperOrigin-RevId: 711477481
2025-01-02 11:24:52 -08:00
jax authors
800f903f9b Merge pull request #25686 from Mikcl:docs/working-with-pytrees-formatting
PiperOrigin-RevId: 711448394
2025-01-02 09:40:17 -08:00
jax authors
726950b885 Update XLA dependency to use revision
06078480db.

PiperOrigin-RevId: 711436464
2025-01-02 08:48:53 -08:00
jax authors
68483b8ed6 Merge pull request #25710 from apaszke:mgpu_dialect_fix
PiperOrigin-RevId: 711430610
2025-01-02 08:23:28 -08:00
Adam Paszke
64433435ff Fix OSS build for the Mosaic GPU dialect 2025-01-02 15:55:03 +00:00
Tomás Longeri
ac817b48ca [Mosaic:TPU][NFC] Clean up unused variable
PiperOrigin-RevId: 711412888
2025-01-02 06:57:38 -08:00
jax authors
82001ed5b3 Merge pull request #25706 from pearu:pearu/log10-large
PiperOrigin-RevId: 711411578
2025-01-02 06:50:54 -08:00
jax authors
04a0fbecc3 Merge pull request #25661 from rdyro:tb-nightly-instructions
PiperOrigin-RevId: 711407514
2025-01-02 06:27:24 -08:00
Adam Paszke
7c984c600b Don't use x32 mode for pallas_test
There's no need to, and it caused our GPU tests for this target to only
run nightly.

PiperOrigin-RevId: 711406571
2025-01-02 06:23:32 -08:00
Adam Paszke
dbe9ccd6dc Reverts 83e60a9697ec20023f4e11169edf64e910b93031
PiperOrigin-RevId: 711403091
2025-01-02 06:04:14 -08:00
Robert Dyro
213e1782ac tbp nightly instructions 2025-01-02 09:49:31 +01:00
Ruturaj4
20b75ab82f Update package indentation fix 2025-01-01 18:50:47 -06:00
jax authors
4a6cfebcea Update XLA dependency to use revision
045356d8c8.

PiperOrigin-RevId: 711195871
2025-01-01 08:23:24 -08:00