22276 Commits

Author SHA1 Message Date
George Necula
252032a368 [pallas] Improve error and debugging messages with source locations
Document the `name` argument to `pallas_call` and supplement it with source location information for the kernel function.
Pass all this as the `name_and_src_info` parameter to the `pallas_call_p` primitive.

Added some more information to the `if debug` prints.

Set the MLIR module names so that the debug dumps are named properly.

I changed `import pallas.core as pl_core` to `... as pallas_core` for consistency, in a couple of modules.

PiperOrigin-RevId: 659506675
2024-08-05 04:23:55 -07:00
Paweł Paruzel
b2a469b361 Port Eigenvalue Decompositions to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 659492696
2024-08-05 03:18:13 -07:00
George Necula
9b35b760ce [pallas] Enable check for GPU lowering that tensor sizes are power of 2
Triton has a restriction that all operations have arguments and results
that are tensor whose size is a power of 2. Added a lowering check
for this. Without this, when we violate the condition we get an
unfriendly crash.

PiperOrigin-RevId: 659483450
2024-08-05 02:34:21 -07:00
jax authors
0b87bf48f9 Update XLA dependency to use revision
b94c84fa54.

PiperOrigin-RevId: 659354942
2024-08-04 14:53:18 -07:00
John Ryan
56ff247c2e Reverts 80560663d3fab4c0c3f87d7c8e52fb9931526dbb
PiperOrigin-RevId: 659334027
2024-08-04 12:11:30 -07:00
jax authors
83b5c7a0dd Merge pull request #22857 from mattjj:improve-while-loop-error
PiperOrigin-RevId: 659160966
2024-08-03 15:32:31 -07:00
jax authors
06c6a73236 Update XLA dependency to use revision
83ef35ce9e.

PiperOrigin-RevId: 659159129
2024-08-03 15:20:24 -07:00
Matthew Johnson
bdcd358b65 improve while_loop carry pytree/type mismatch errors
Now we call into the same error utility as we use in scan.
2024-08-03 21:57:29 +00:00
Yue Sheng
09beb33226 Don't call api.clean_up when there is no default backend.
PiperOrigin-RevId: 658936536
2024-08-02 16:14:29 -07:00
Yue Sheng
eb571c984a Fix lint in run_single_gpu.py
PiperOrigin-RevId: 658933291
2024-08-02 16:03:03 -07:00
jax authors
51abbf9041 Update XLA dependency to use revision
8a54f481f0.

PiperOrigin-RevId: 658923035
2024-08-02 15:25:51 -07:00
jax authors
261cf52020 Merge pull request #22828 from ROCm:ci_add_packages_dockerfile
PiperOrigin-RevId: 658901700
2024-08-02 14:16:05 -07:00
jax authors
e7fd424d9f Merge pull request #22784 from pearu:pearu/accuracy-tests-update
PiperOrigin-RevId: 658901408
2024-08-02 14:12:05 -07:00
Pearu Peterson
780b10b4c4 Update complex functions accuracy tests 2024-08-02 23:31:51 +03:00
Yue Sheng
88c8bacdca Add util.clear_all_caches to api.clear_backends and let api.clear_backends be called before process terminates on JAX CPU. This could make the PjRt CPU client object to be successfully destroyed during Python garbage collection.
PiperOrigin-RevId: 658843789
2024-08-02 11:08:48 -07:00
Yash Katariya
958234a9c1 Thread the mesh context manager to the place where we recover out_shardings back from GSPMDShardings. Before if you had a program like this:
```
with mesh:
  out = pjit(lambda: 1)()
```

The sharding of `out` was a `GSPMDSharding` which is not ideal. This change fixes that and returns a `NamedSharding` instead.

This is also required for `Shardy` integration.

PiperOrigin-RevId: 658842350
2024-08-02 11:04:48 -07:00
Eugene Zhulenev
ac52890e3d [jax] Shard pallas_vmap_test
PiperOrigin-RevId: 658834942
2024-08-02 10:41:22 -07:00
jax authors
6c79c10446 Merge pull request #22840 from jakevdp:fix-backends
PiperOrigin-RevId: 658817878
2024-08-02 09:47:25 -07:00
Adam Paszke
86c9903067 [Pallas TPU] Make sure that the bug repros actually fail
One of them was fixed in the meantime but we didn't realize it.

PiperOrigin-RevId: 658799901
2024-08-02 08:37:22 -07:00
Yash Katariya
e6851e6b22 Fix the AOT check for sharding consistency which skipped checking the devices of the sharding.
So before for TPU compiled computation, a user could have passed in a committed array on CPU and JAX wouldn't have errored which is wrong.

This change fixes that. Also `is_equivalent_to` should check for devices, HloSharding and memory_kind (so removing the redundant `memory_kind` check too).

PiperOrigin-RevId: 658794885
2024-08-02 08:15:32 -07:00
Adam Paszke
f85b8e677b [Mosaic TPU] Add support for bf16 reductions
PiperOrigin-RevId: 658787017
2024-08-02 07:42:27 -07:00
Jake VanderPlas
3fa86a9b32 remove jax.extend.backend.default_backend in favor of jax.backend
I added this two days ago before realizing there is already a canonical API
for this in the top-level namespace, so it should be safe to remove.
2024-08-02 07:07:29 -07:00
Adam Paszke
e88887eda5 [Mosaic TPU] Add a missing reshape in relayout
The fact that src generalizes dst does not mean that they have the same implicit
tile shape (if one has an implicit dim and the other one doesn't, then they will
differ by a singleton dimension).

PiperOrigin-RevId: 658775019
2024-08-02 06:44:31 -07:00
Sergei Lebedev
02d836d990 Updated Mosaic GPU lowering registration in Pallas
The lowering rule for mosaic_gpu_p now expects a serialized module.

PiperOrigin-RevId: 658772330
2024-08-02 06:30:04 -07:00
Adam Paszke
959657a489 [Mosaic TPU] Remove special handling of implicit dim in relayout
Now all changes happen inside the dedicated functions.

PiperOrigin-RevId: 658763465
2024-08-02 05:46:26 -07:00
Dan Foreman-Mackey
80560663d3 Enable FFI implementation of GPU Getrf FFI handler.
PiperOrigin-RevId: 658755392
2024-08-02 05:07:02 -07:00
Adam Paszke
99625ff577 [Mosaic TPU] Break out implicit dim changes from relayout
PiperOrigin-RevId: 658752228
2024-08-02 04:50:40 -07:00
jax authors
efba5f61b5 Merge pull request #22812 from superbobry:maint
PiperOrigin-RevId: 658751187
2024-08-02 04:43:33 -07:00
Paweł Paruzel
6b0b222a38 Activate LU Decomposition to XLA's FFI
PiperOrigin-RevId: 658721697
2024-08-02 02:22:53 -07:00
George Necula
20e9c15ff5 [pallas] Small cleanup in the Mosaic lowering
Uses the helper functions for the calling convention from #22552 and #22593.

PiperOrigin-RevId: 658692284
2024-08-02 00:16:35 -07:00
Enrique Piqueras
d57447a2d3 Double buffer pipeline semaphores so we can hide DMA latency under compute and not just BW. Also enable disabling automatic accumulation across pipelines.
PiperOrigin-RevId: 658585671
2024-08-01 16:58:19 -07:00
Christos Perivolaropoulos
28b86604b3 [pallas:mosaic_gpu] Make the linter happy.
PiperOrigin-RevId: 658580241
2024-08-01 16:37:12 -07:00
Rahul Batra
7d6fa3c05b [ROCm]: Add support to continue on fail, fix script paths and update Dockerfile to add necessary packages 2024-08-01 17:55:15 -05:00
jax authors
2241dadab6 Merge pull request #22814 from superbobry:maint-2
PiperOrigin-RevId: 658560253
2024-08-01 15:31:40 -07:00
Jieying Luo
bc0229a61f Rollback as it broke some tests.
Reverts ff17b76e3eec3e573788f64fafe23fabcfc09ce2

PiperOrigin-RevId: 658557091
2024-08-01 15:21:42 -07:00
jax authors
16c868af82 Merge pull request #22825 from jakevdp:fix-old-array-api
PiperOrigin-RevId: 658552229
2024-08-01 15:07:57 -07:00
Dan Foreman-Mackey
8df0c3a9cc Port Getrf GPU kernel from custom call to FFI.
PiperOrigin-RevId: 658550170
2024-08-01 15:02:25 -07:00
Jake VanderPlas
48c5fab023 [array api] fix deprecation to support old import pattern 2024-08-01 14:38:59 -07:00
Sergei Lebedev
fb1dbf15df Bumped mypy to 1.11.0 and jaxlib to 0.4.31 on the CI 2024-08-01 22:30:24 +01:00
jax authors
b64b19f1e3 Update XLA dependency to use revision
7f8cc3357f.

PiperOrigin-RevId: 658538981
2024-08-01 14:27:24 -07:00
jax authors
aa9e1e42a1 Merge pull request #22095 from dfm:ffi-call-tutorial
PiperOrigin-RevId: 658523039
2024-08-01 13:43:34 -07:00
Abhinav Gunjal
dfe8d94170 Integrate StableHLO at openxla/stablehlo@fb18ee25
PiperOrigin-RevId: 658515936
2024-08-01 13:23:01 -07:00
Jieying Luo
ff17b76e3e Cleanup. Remove build:cuda_plugin and set enable_gpu and xla_python_enable_gpu to false in build:cuda.
JAX already migrated from jaxlib[cuda] to cuda plugin.

PiperOrigin-RevId: 658508037
2024-08-01 12:59:12 -07:00
jax authors
8eaa6bf661 Merge pull request #22818 from jakevdp:finalize-array-api
PiperOrigin-RevId: 658500983
2024-08-01 12:37:56 -07:00
Dan Foreman-Mackey
0b4800a193 Add ffi_call tutorial
Building on #21925, this tutorial demonstrates the use of the FFI using
`ffi_call` with a simple example. I don't think this should cover all of
the most advanced use cases, but it should be sufficient for the most
common examples. I think it would be useful to eventually replace the
existing CUDA tutorial, but I'm not sure that it'll get there in the
first draft.

As an added benefit, this also runs a simple test (akin to
`docs/cuda_custom_call`) which actually executes using a tool chain that
open source users would use in practice.
2024-08-01 15:36:32 -04:00
Dan Foreman-Mackey
f20efc630f Move jaxlib GPU handlers to separate build target.
In anticipation of refactoring the jaxlib GPU custom calls into FFI calls, this change moves the implementation of `BlasHandlePool`, `SolverHandlePool`, and `SpSolverHandlePool` into new target.

PiperOrigin-RevId: 658497960
2024-08-01 12:30:04 -07:00
Jake VanderPlas
14fa06298e [array api] Finalize array API in jax.numpy & deprecate jax.experimental.array_api 2024-08-01 11:19:17 -07:00
jax authors
b3924da2a1 Merge pull request #22794 from jakevdp:array-api-astype
PiperOrigin-RevId: 658456484
2024-08-01 10:47:45 -07:00
jax authors
06ffa70cf3 Merge pull request #22810 from nstarman:add-annot
PiperOrigin-RevId: 658456382
2024-08-01 10:44:01 -07:00
George Necula
43163ff2e3 [pallas] Add error message for block_shapes of rank less than 1.
PiperOrigin-RevId: 658424421
2024-08-01 09:15:07 -07:00