Yash Katariya
958234a9c1
Thread the mesh context manager to the place where we recover out_shardings back from GSPMDShardings. Before if you had a program like this:
...
```
with mesh:
out = pjit(lambda: 1)()
```
The sharding of `out` was a `GSPMDSharding` which is not ideal. This change fixes that and returns a `NamedSharding` instead.
This is also required for `Shardy` integration.
PiperOrigin-RevId: 658842350
2024-08-02 11:04:48 -07:00
Eugene Zhulenev
ac52890e3d
[jax] Shard pallas_vmap_test
...
PiperOrigin-RevId: 658834942
2024-08-02 10:41:22 -07:00
jax authors
6c79c10446
Merge pull request #22840 from jakevdp:fix-backends
...
PiperOrigin-RevId: 658817878
2024-08-02 09:47:25 -07:00
Adam Paszke
86c9903067
[Pallas TPU] Make sure that the bug repros actually fail
...
One of them was fixed in the meantime but we didn't realize it.
PiperOrigin-RevId: 658799901
2024-08-02 08:37:22 -07:00
Yash Katariya
e6851e6b22
Fix the AOT check for sharding consistency which skipped checking the devices of the sharding.
...
So before for TPU compiled computation, a user could have passed in a committed array on CPU and JAX wouldn't have errored which is wrong.
This change fixes that. Also `is_equivalent_to` should check for devices, HloSharding and memory_kind (so removing the redundant `memory_kind` check too).
PiperOrigin-RevId: 658794885
2024-08-02 08:15:32 -07:00
Adam Paszke
f85b8e677b
[Mosaic TPU] Add support for bf16 reductions
...
PiperOrigin-RevId: 658787017
2024-08-02 07:42:27 -07:00
Jake VanderPlas
3fa86a9b32
remove jax.extend.backend.default_backend in favor of jax.backend
...
I added this two days ago before realizing there is already a canonical API
for this in the top-level namespace, so it should be safe to remove.
2024-08-02 07:07:29 -07:00
Adam Paszke
e88887eda5
[Mosaic TPU] Add a missing reshape in relayout
...
The fact that src generalizes dst does not mean that they have the same implicit
tile shape (if one has an implicit dim and the other one doesn't, then they will
differ by a singleton dimension).
PiperOrigin-RevId: 658775019
2024-08-02 06:44:31 -07:00
Sergei Lebedev
02d836d990
Updated Mosaic GPU lowering registration in Pallas
...
The lowering rule for mosaic_gpu_p now expects a serialized module.
PiperOrigin-RevId: 658772330
2024-08-02 06:30:04 -07:00
Adam Paszke
959657a489
[Mosaic TPU] Remove special handling of implicit dim in relayout
...
Now all changes happen inside the dedicated functions.
PiperOrigin-RevId: 658763465
2024-08-02 05:46:26 -07:00
Dan Foreman-Mackey
80560663d3
Enable FFI implementation of GPU Getrf FFI handler.
...
PiperOrigin-RevId: 658755392
2024-08-02 05:07:02 -07:00
Adam Paszke
99625ff577
[Mosaic TPU] Break out implicit dim changes from relayout
...
PiperOrigin-RevId: 658752228
2024-08-02 04:50:40 -07:00
jax authors
efba5f61b5
Merge pull request #22812 from superbobry:maint
...
PiperOrigin-RevId: 658751187
2024-08-02 04:43:33 -07:00
Paweł Paruzel
6b0b222a38
Activate LU Decomposition to XLA's FFI
...
PiperOrigin-RevId: 658721697
2024-08-02 02:22:53 -07:00
George Necula
20e9c15ff5
[pallas] Small cleanup in the Mosaic lowering
...
Uses the helper functions for the calling convention from #22552 and #22593 .
PiperOrigin-RevId: 658692284
2024-08-02 00:16:35 -07:00
Enrique Piqueras
d57447a2d3
Double buffer pipeline semaphores so we can hide DMA latency under compute and not just BW. Also enable disabling automatic accumulation across pipelines.
...
PiperOrigin-RevId: 658585671
2024-08-01 16:58:19 -07:00
Christos Perivolaropoulos
28b86604b3
[pallas:mosaic_gpu] Make the linter happy.
...
PiperOrigin-RevId: 658580241
2024-08-01 16:37:12 -07:00
jax authors
2241dadab6
Merge pull request #22814 from superbobry:maint-2
...
PiperOrigin-RevId: 658560253
2024-08-01 15:31:40 -07:00
Jieying Luo
bc0229a61f
Rollback as it broke some tests.
...
Reverts ff17b76e3eec3e573788f64fafe23fabcfc09ce2
PiperOrigin-RevId: 658557091
2024-08-01 15:21:42 -07:00
jax authors
16c868af82
Merge pull request #22825 from jakevdp:fix-old-array-api
...
PiperOrigin-RevId: 658552229
2024-08-01 15:07:57 -07:00
Dan Foreman-Mackey
8df0c3a9cc
Port Getrf GPU kernel from custom call to FFI.
...
PiperOrigin-RevId: 658550170
2024-08-01 15:02:25 -07:00
Jake VanderPlas
48c5fab023
[array api] fix deprecation to support old import pattern
2024-08-01 14:38:59 -07:00
Sergei Lebedev
fb1dbf15df
Bumped mypy to 1.11.0 and jaxlib to 0.4.31 on the CI
2024-08-01 22:30:24 +01:00
jax authors
b64b19f1e3
Update XLA dependency to use revision
...
7f8cc3357f
.
PiperOrigin-RevId: 658538981
2024-08-01 14:27:24 -07:00
jax authors
aa9e1e42a1
Merge pull request #22095 from dfm:ffi-call-tutorial
...
PiperOrigin-RevId: 658523039
2024-08-01 13:43:34 -07:00
Abhinav Gunjal
dfe8d94170
Integrate StableHLO at openxla/stablehlo@fb18ee25
...
PiperOrigin-RevId: 658515936
2024-08-01 13:23:01 -07:00
Jieying Luo
ff17b76e3e
Cleanup. Remove build:cuda_plugin and set enable_gpu and xla_python_enable_gpu to false in build:cuda.
...
JAX already migrated from jaxlib[cuda] to cuda plugin.
PiperOrigin-RevId: 658508037
2024-08-01 12:59:12 -07:00
jax authors
8eaa6bf661
Merge pull request #22818 from jakevdp:finalize-array-api
...
PiperOrigin-RevId: 658500983
2024-08-01 12:37:56 -07:00
Dan Foreman-Mackey
0b4800a193
Add ffi_call tutorial
...
Building on #21925 , this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.
As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
2024-08-01 15:36:32 -04:00
Dan Foreman-Mackey
f20efc630f
Move jaxlib GPU handlers to separate build target.
...
In anticipation of refactoring the jaxlib GPU custom calls into FFI calls, this change moves the implementation of `BlasHandlePool`, `SolverHandlePool`, and `SpSolverHandlePool` into new target.
PiperOrigin-RevId: 658497960
2024-08-01 12:30:04 -07:00
Jake VanderPlas
14fa06298e
[array api] Finalize array API in jax.numpy & deprecate jax.experimental.array_api
2024-08-01 11:19:17 -07:00
jax authors
b3924da2a1
Merge pull request #22794 from jakevdp:array-api-astype
...
PiperOrigin-RevId: 658456484
2024-08-01 10:47:45 -07:00
jax authors
06ffa70cf3
Merge pull request #22810 from nstarman:add-annot
...
PiperOrigin-RevId: 658456382
2024-08-01 10:44:01 -07:00
George Necula
43163ff2e3
[pallas] Add error message for block_shapes of rank less than 1.
...
PiperOrigin-RevId: 658424421
2024-08-01 09:15:07 -07:00
jax authors
6bddaad6c7
Merge pull request #22816 from dfm:fix-remat-opt
...
PiperOrigin-RevId: 658378186
2024-08-01 06:16:57 -07:00
Dan Foreman-Mackey
80cfe83ddc
Fix issue with multiple arguments when using custom_vjp with optimize_remat.
...
This was a silly bug in how we were handling the fact that the `fwd`
function expects `bool` entries for symbolic zeros. At least now I've
added a test!
2024-08-01 08:44:35 -04:00
Sergei Lebedev
92b1f71314
Removed various ununsed functions
...
To rerun the analysis do
python -m vulture jax/_src --ignore-names "[A-Za-z]*" --ignore-decorators "*"
2024-08-01 11:18:19 +01:00
Adam Paszke
0307438c3d
[NFC][Mosaic TPU] Separate out retiling from relayout
...
PiperOrigin-RevId: 658335679
2024-08-01 03:09:15 -07:00
Adam Paszke
0734345279
[NFC][Mosaic TPU] Start breaking up relayout into smaller pieces
...
We're constantly hitting unimpelmented relayouts, but it's hard to even know what's
in there given the way the code is written. This is the first of a few clean-up CLs
that aims to partition the process into steps with clear responsibilities. It should
help us better understand what's missing.
PiperOrigin-RevId: 658318811
2024-08-01 02:02:09 -07:00
Jake VanderPlas
3551fcc077
Deprecate several APIs in jax.lib.xla_bridge
...
PiperOrigin-RevId: 658274719
2024-07-31 23:00:35 -07:00
jax authors
6c083d78e6
Merge pull request #22779 from gnecula:tril
...
PiperOrigin-RevId: 658271885
2024-07-31 22:47:25 -07:00
George Necula
ffd2b00516
Add concretization error check in core.min_dim and core.max_dim
...
Fixes : #22751
2024-08-01 07:27:35 +02:00
nstarman
9f344863d4
type: add annot
...
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
2024-07-31 22:24:21 -04:00
jax authors
f49f2da8fe
Merge pull request #22802 from selamw1:diag_indices_docstring
...
PiperOrigin-RevId: 658196557
2024-07-31 17:21:28 -07:00
Yash Katariya
ec6514cc08
Add donation test for a pure compute offloaded computation
...
PiperOrigin-RevId: 658187714
2024-07-31 16:49:27 -07:00
jax authors
a911d76982
Rollback due to internal test failure
...
PiperOrigin-RevId: 658185213
2024-07-31 16:40:03 -07:00
Jake VanderPlas
19d185ac8d
jax.extend.backend: add semi-private backend utils from xla_bridge
...
For context, see https://jax.readthedocs.io/en/latest/jep/15856-jex.html .
PiperOrigin-RevId: 658179318
2024-07-31 16:19:30 -07:00
selamw1
a11ddfd4bc
diag_indices_docstring_added
...
see_also_diagonal_added
2024-07-31 16:14:00 -07:00
Kanglan Tang
a7e071ec42
Skip flaky memories tests on GPU backend.
...
PiperOrigin-RevId: 658177202
2024-07-31 16:12:52 -07:00
jax authors
b677a712ab
Merge pull request #22800 from jakevdp:where-mode
...
PiperOrigin-RevId: 658165407
2024-07-31 15:37:30 -07:00