19753 Commits

Author SHA1 Message Date
Sergei Lebedev
d0e0ca1e52 Ensured that Pallas GPU tests only run in x32 mode
We do not yet properly handle x64.

PiperOrigin-RevId: 613190060
2024-03-06 06:15:13 -08:00
jax authors
8aefe5e69f Update XLA dependency to use revision
08b6a3deeb.

PiperOrigin-RevId: 613101323
2024-03-05 23:43:13 -08:00
Peter Hawkins
1a193ea189 Fix segfault in cuda_plugin_extension.
The nanobind switch for the GPU callback code means that we are now using the NumPy APIs rather than pybind11's clone of them. It is important to initialize the NumPy APIs before using them in each module.

PiperOrigin-RevId: 613036056
2024-03-05 18:31:50 -08:00
Yash Katariya
c37fd1bd85 Move Device_put tests out of a disabled test class and execute them (so they are tested)
PiperOrigin-RevId: 613035562
2024-03-05 18:22:18 -08:00
Yash Katariya
ca3f3f0f17 Make sure that if gspmd_sharding1 == gspmd_sharding2, then their hash also is equal.
PiperOrigin-RevId: 613009976
2024-03-05 16:36:49 -08:00
jax authors
15da71384d [XLA:SPMD] Do not propagate sharding to parameter/output if it does not evenly partition the parameter/output.
PiperOrigin-RevId: 612998062
2024-03-05 15:57:05 -08:00
jax authors
0c9bbe4364 Merge pull request #20087 from jakevdp:fix-scipy-doc
PiperOrigin-RevId: 612994633
2024-03-05 15:46:01 -08:00
Jake VanderPlas
f35538be6a DOC: fix two minor doc issues 2024-03-05 15:22:42 -08:00
jax authors
20090dd176 Merge pull request #20083 from mattjj:attrs-fix-tracer-lifetime
PiperOrigin-RevId: 612944372
2024-03-05 13:10:25 -08:00
Matthew Johnson
c44dda84d0 [attrs] fix tracer lifetime bug, fixes #20082 2024-03-05 12:08:44 -08:00
jax authors
67e3542d32 Merge pull request #20080 from jakevdp:key-reuse-srcinfo
PiperOrigin-RevId: 612920951
2024-03-05 11:57:58 -08:00
jax authors
baed4eb1d1 Merge pull request #20081 from jakevdp:setup-tpu-err
PiperOrigin-RevId: 612914123
2024-03-05 11:40:05 -08:00
jax authors
288aca1bdf Merge pull request #20078 from mattjj:remat-saving-collectives-fix
PiperOrigin-RevId: 612912239
2024-03-05 11:31:42 -08:00
Benjamin Kramer
5005890546 Enable more tests on H100
20895965b2 fixed these

PiperOrigin-RevId: 612907679
2024-03-05 11:22:56 -08:00
Jake VanderPlas
d8a4ea42cc Use shorter error message for jax.tools.colab_tpu.setup_tpu() 2024-03-05 11:20:38 -08:00
Jevin Jiang
05f54b665c [XLA:Mosaic] Use different MXU shape based on the target
PiperOrigin-RevId: 612906617
2024-03-05 11:14:24 -08:00
Matthew Johnson
3d32262b21 ignore NamedAxisEffect for remat and dce purposes 2024-03-05 11:04:23 -08:00
Jake VanderPlas
735ec63dd1 [key reuse] improve error message using source_info_util 2024-03-05 11:02:39 -08:00
jax authors
32fec820ed Merge pull request #20077 from jakevdp:fix-dunder-array
PiperOrigin-RevId: 612895096
2024-03-05 10:53:45 -08:00
jax authors
430c7edaeb Merge pull request #20070 from jakevdp:key-reuse-errs
PiperOrigin-RevId: 612895094
2024-03-05 10:44:08 -08:00
Jake VanderPlas
851b82b89c Add copy argument to Array.__array__ 2024-03-05 09:31:16 -08:00
Jake VanderPlas
bb91bf2e09 [key reuse] improve some key reuse errors. 2024-03-05 08:14:20 -08:00
jax authors
28fa88681e Update XLA dependency to use revision
9e8b8b45b5.

PiperOrigin-RevId: 612715965
2024-03-04 23:05:31 -08:00
jax authors
843aa21884 Merge pull request #20071 from jakevdp:key-reuse-docs
PiperOrigin-RevId: 612654740
2024-03-04 18:12:40 -08:00
Chris Jones
9996b1f969 [jax_triton] Add parameter allowing user to compile for specific compute capability.
PiperOrigin-RevId: 612647104
2024-03-04 17:37:04 -08:00
jax authors
bc3f123978 Merge pull request #20069 from jakevdp:key-reuse-equality
PiperOrigin-RevId: 612642622
2024-03-04 17:17:49 -08:00
Jake VanderPlas
9a4b0fc1f8 [key reuse] improve module docs 2024-03-04 17:16:55 -08:00
Peter Hawkins
6207977fac Disable some tests that fail on H100 in CI.
PiperOrigin-RevId: 612637375
2024-03-04 16:59:52 -08:00
Philip Pham
3fe65e2005 Pipe tiled through all_to_all primitive
The `_all_to_all_transpose_rule` calls `all_to_all` which can accept a `tiled`
argument. Thus, for the transpose to know the right value of `tiled` to pass, we
need to plumb the `tiled` argument through the primitive and various
interpreters, even though it's a no-op because the `tiled` argument is handled
outside the primitive. It would be cleaner to handle `tiled` inside the
primitive, but I will leave that for followup work.

Fixes #15982.

PiperOrigin-RevId: 612628600
2024-03-04 16:33:56 -08:00
Yash Katariya
40038d65c2 Rename test
PiperOrigin-RevId: 612609237
2024-03-04 15:35:02 -08:00
Peter Hawkins
feda85dff3 Replace references to xla/python/status_casters.h with xla/pjrt/status_casters.h, which its current home.
PiperOrigin-RevId: 612578488
2024-03-04 14:11:01 -08:00
Peter Hawkins
6f7be3cf04 Define lax.Precision directly in Python, rather than inheriting from a C++ type in jaxlib.
Historically, we defined Precision to be an enum exported from jaxlib using pybind11, since that was the type the old XLA ComputationBuilder classes expected as input. But we build IR using StableHLO MLIR builders these days, and there's no reason for the JAX-level Precision type to match the XLA-internal one.

In a future change I plan to change the definition of Precision in jaxlib to be defined using nanobind instead of pybind11. Nanobind defines its enum classes to be final by default, which precludes this inheritance, and that's probably a good design decision by nanobind. But as discussed above, there's no good reason to inherit in the first place.

PiperOrigin-RevId: 612575404
2024-03-04 14:01:31 -08:00
Jake VanderPlas
84d11d7b11 [key reuse] don't consume on equality check 2024-03-04 13:32:35 -08:00
Yash Katariya
67b0eb3af4 Improve pytree mismatch error in AOT
PiperOrigin-RevId: 612560820
2024-03-04 13:15:32 -08:00
Abhinav Goel
2480ca383e
respond to reviewer's comments 2024-03-04 11:42:01 -08:00
jax authors
a745b8e683 Merge pull request #20067 from jakevdp:copy-failure
PiperOrigin-RevId: 612513014
2024-03-04 11:06:25 -08:00
jax authors
ee963a73a5 Merge pull request #20065 from google:dependabot/github_actions/actions/cache-4.0.1
PiperOrigin-RevId: 612503993
2024-03-04 10:42:29 -08:00
Jake VanderPlas
32da56fc95 jnp.array: fix failure under numpy 2.0 copy semantics 2024-03-04 10:39:38 -08:00
dependabot[bot]
8a62918910
Bump actions/cache from 4.0.0 to 4.0.1
Bumps [actions/cache](https://github.com/actions/cache) from 4.0.0 to 4.0.1.
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](13aacd865c...ab5e6d0c87)

---
updated-dependencies:
- dependency-name: actions/cache
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-03-04 17:06:05 +00:00
Yash Katariya
46913d7e2e Reverts 18344e8647f3459ae6a559a1e0a322120ac50782
PiperOrigin-RevId: 612466742
2024-03-04 08:51:54 -08:00
jax authors
81363cefd7 Merge pull request #19808 from Micky774:cc_check
PiperOrigin-RevId: 612463272
2024-03-04 08:40:13 -08:00
Adam Paszke
63cc46df47 Treat both aarch and arm as possible ARM prefixes
On macOS `platform.machine()` returns `arm64` instead of `aarch64`.

PiperOrigin-RevId: 612458486
2024-03-04 08:26:04 -08:00
Adam Paszke
0549f18ead Refine regions_with_inaccuracies to account for ARM numerics differences
cc @pearu

PiperOrigin-RevId: 612424597
2024-03-04 06:17:05 -08:00
jax authors
7514d5c7aa [triton] Add clustering support and test
PiperOrigin-RevId: 612417957
2024-03-04 05:51:10 -08:00
Adam Paszke
18344e8647 Reverts 5c9c57fd6ff747ea37a2b74ff327a48fb72b3e69
PiperOrigin-RevId: 612417903
2024-03-04 05:50:55 -08:00
Sergei Lebedev
5283d4b4a5 Axis names are now tracked via an effect
This allows propagating the names bottom up -- from equations to the jaxpr,
instead of "discovering" them top-down by traversing (and rebuilding) the
jaxpr via core.subst_axis_names.

PiperOrigin-RevId: 612416803
2024-03-04 05:42:03 -08:00
jax authors
2dd5e9e180 Update XLA dependency to use revision
e2af8488f6.

PiperOrigin-RevId: 612318379
2024-03-03 22:24:30 -08:00
Meekail Zain
9fff9aeb69 Update 2024-03-03 19:57:26 +00:00
Yash Katariya
0b70244b1c Thread out_avals to MeshExecutable
PiperOrigin-RevId: 612037684
2024-03-02 13:35:31 -08:00
jax authors
8569b893b1 Update XLA dependency to use revision
bbf2f8bcfa.

PiperOrigin-RevId: 611934527
2024-03-01 23:44:58 -08:00