22164 Commits

Author SHA1 Message Date
George Necula
6d53aaf7d0 [pallas] Improve the error localization
* Add the source location information for the index map function to
    `BlockMapping`.
  * Removed the `compute_index` wrapper around the index_map, so that
    we can get the location information for the index_map, not the wrapper.
  * Added source location to the errors related to index map functions.
  * Added an error if the index map returns something other than integer
    scalars.
  * Construct BlockSpec origins for arguments using JAX helper functions
    to get argument names
  * Removed redundant API error tests from tpu_pallas_test.py
2024-07-30 14:11:57 +02:00
jax authors
cc212457d2 Merge pull request #22481 from zhenying-liu:offloading
PiperOrigin-RevId: 657413977
2024-07-29 19:43:35 -07:00
Yash Katariya
30037547d7 Bump minimum jaxlib version to 0.4.31. The corresponding xla_extension_version is 279 and mlir_api_version is 57
PiperOrigin-RevId: 657400413
2024-07-29 18:44:31 -07:00
Yash Katariya
2106a25977 Finish jax and jaxlib v0.4.31 release
PiperOrigin-RevId: 657388782
2024-07-29 17:57:37 -07:00
Bixia Zheng
c81f5cd2fc [xla] Replace debug option xla_use_shardy with execution option
use_shardy_partitioner.

Replace the use of xla_use_shardy with use_shardy_partitioner and remove
xla_use_shardy.

PiperOrigin-RevId: 657359119
2024-07-29 16:11:36 -07:00
jax authors
291438403d Update XLA dependency to use revision
ffd724e235.

PiperOrigin-RevId: 657341038
2024-07-29 15:14:04 -07:00
jax authors
63303973a2 Merge pull request #22716 from superbobry:pallas
PiperOrigin-RevId: 657333519
2024-07-29 14:50:01 -07:00
Sergei Lebedev
a44265aa73 Added a trivial discharge rule for debug_callback_p
This allows using jax.debug.print with Refs in interpreted Pallas kernels.
2024-07-29 22:26:01 +01:00
jax authors
091eba1955 Merge pull request #22736 from jakevdp:fft-device
PiperOrigin-RevId: 657313812
2024-07-29 13:48:55 -07:00
Jake VanderPlas
6516a079f8 [array API] add device argument to fftfreq & rfftfreq 2024-07-29 13:23:54 -07:00
jax authors
5ce66dc1aa Merge pull request #22734 from jakevdp:array-api-methods
PiperOrigin-RevId: 657299896
2024-07-29 13:05:22 -07:00
Yash Katariya
7fd9302785 Start JAX and jaxlib 0.4.31 release
PiperOrigin-RevId: 657295431
2024-07-29 12:49:58 -07:00
Jake VanderPlas
00ba7a6d25 [array API] move api metadata into jax.numpy namespace 2024-07-29 12:43:11 -07:00
jax authors
f070c0658f Merge pull request #22703 from Rifur13:plugin-fix
PiperOrigin-RevId: 657283607
2024-07-29 12:10:42 -07:00
jax authors
9beb4f1474 Merge pull request #19760 from Blair-Johnson:fix-pytree-grads-sparse
PiperOrigin-RevId: 657258194
2024-07-29 10:59:57 -07:00
Blair-Johnson
802a14cd61 Re-pack gradients of jax.experimental.sparse.grad() to match original pytrees & test cases 2024-07-29 13:04:05 -04:00
jax authors
85e83b508b Merge pull request #22690 from jakevdp:inplace-doc
PiperOrigin-RevId: 657237218
2024-07-29 10:03:43 -07:00
Peter Hawkins
d1c0d993fc Bump the minimum CUDNN version to v9.1.
This actually was already the minimum version since we build with that version, but we needed to tighten the constraints.

Also in passing, drop mentions of CUDA builds from the Windows build instructions. jaxlib hasn't built with CUDA enabled on Windows for a very long time, so it's probably best we just don't mention it.

PiperOrigin-RevId: 657225917
2024-07-29 09:28:47 -07:00
Vladimir Belitskiy
6127baa117 Ignore the Deprecation warning produced about native_serialization=False.
PiperOrigin-RevId: 657221363
2024-07-29 09:11:54 -07:00
Peter Hawkins
fd23b8733d Bump minimum SciPy version to 1.10.
SciPy 1.9.0 was released July 29, 2022, which is 24 months ago

PiperOrigin-RevId: 657215038
2024-07-29 08:50:18 -07:00
Jake VanderPlas
cfa1e78549 Improve documentation for jnp.put, jnp.place, jnp.fill_diagonal
These are all the APIs that have an inplace parameter
2024-07-29 08:31:54 -07:00
Gleb Pobudzey
0224235edf Skip cuda backend initialization if no nvidia GPUs are visible. 2024-07-29 15:12:58 +00:00
jax authors
e78e643b5f Merge pull request #22593 from gnecula:pallas_more_simplification
PiperOrigin-RevId: 657198330
2024-07-29 07:48:52 -07:00
Dan Foreman-Mackey
ff4e0b1214 Rearrange the LAPACK handler definitions in jaxlib to avoid duplicate handler errors.
When linking the jaxlib `cpu_kernels` target and importing JAX, we currently silently fail to instantiate the CPU backend. This refactor means that we only ever define one version of the handlers.

PiperOrigin-RevId: 657186057
2024-07-29 06:59:44 -07:00
Vladimir Belitskiy
fef91fb201 Skip tests/mock_gpu_test.py on pytest.
PiperOrigin-RevId: 657185249
2024-07-29 06:55:43 -07:00
George Necula
70a11acbb1 [pallas] More simplification of grid mapping and calling convention
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.

I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.

I added entries to pallas/CHANGELOG.
2024-07-29 15:53:47 +02:00
George Necula
68972de021 [pallas] Add lowering errors for block shapes that are not supported.
Previously these errors came from Mosaic with less useful stack traces, and in the case of GPU we were getting a crash instead of an exception.

PiperOrigin-RevId: 657184114
2024-07-29 06:49:27 -07:00
Sergei Lebedev
ccc4c42ec9 Reduced the input size in PallasCallInputOutputAliasingTest
This ensures the test doesn't OOM when running on A100 on the CI.

PiperOrigin-RevId: 657165032
2024-07-29 05:29:45 -07:00
Adam Paszke
a00b659b03 [Mosaic GPU] Fix two subtle issues with kernel lowering
1. The MLIR context is created by the user and its lifetime is not
   in our control. To avoid depending on it, we serialize the module.
2. The operand and result layout requirements were missing from the custom call.

PiperOrigin-RevId: 657164985
2024-07-29 05:25:50 -07:00
jax authors
6a7822a73b Update XLA dependency to use revision
95e3eea8d2.

PiperOrigin-RevId: 657003194
2024-07-28 15:32:56 -07:00
jax authors
74649be7ed Update XLA dependency to use revision
89089aa569.

PiperOrigin-RevId: 656797625
2024-07-27 15:23:41 -07:00
Jake VanderPlas
a17c8d945b Finalize deprecation of jax.random.shuffle
This has been raising a DeprecationWarning for longer than anyone can remember.

PiperOrigin-RevId: 656765001
2024-07-27 11:21:49 -07:00
jax authors
dab15d6fdd Merge pull request #22684 from froystig:rngdoc
PiperOrigin-RevId: 656600958
2024-07-26 19:12:36 -07:00
jax authors
40d569b22e Update XLA dependency to use revision
cf139009c9.

PiperOrigin-RevId: 656531286
2024-07-26 14:34:09 -07:00
Roy Frostig
f30ebd8586 document vmap peculiarity of experimental RNG implementations 2024-07-26 13:40:16 -07:00
Roy Frostig
6ddd488df0 improve RNG doc around implementation configuration 2024-07-26 13:40:16 -07:00
jax authors
aeff5b61a9 Merge pull request #22080 from vfdev-5:add-device-kwarg-linspace-array
PiperOrigin-RevId: 656467191
2024-07-26 11:18:24 -07:00
Vladimir Belitskiy
7f96b263d4 Un-skip //third_party/py/jax/tests:pytorch_interoperability_test_cpu on ASAN.
It should have been fixed via
https://github.com/pytorch/pytorch/issues/117058#issuecomment-1973020150

PiperOrigin-RevId: 656464550
2024-07-26 11:10:41 -07:00
Vladimir Belitskiy
282ebf4882 Skip //third_party/py/jax/tests:pytorch_interoperability_test_cpu on MSAN.
MSAN has issues with using `-c opt` in some cases, which prevents this
test from running properly.

PiperOrigin-RevId: 656454585
2024-07-26 10:44:19 -07:00
jax authors
0df074c285 Merge pull request #22680 from superbobry:maint
PiperOrigin-RevId: 656427681
2024-07-26 09:25:47 -07:00
Adam Paszke
d862f78dcc [Mosaic GPU] Skip matmul tests with large clusters
I'm still investigating but they sometimes hang for an unclear reason.

PiperOrigin-RevId: 656426326
2024-07-26 09:21:13 -07:00
Yash Katariya
05677694d8 Document copy_to_host_async method of jax.Array
PiperOrigin-RevId: 656408298
2024-07-26 08:21:01 -07:00
jax authors
694c14bbe6 Merge pull request #22556 from cool-RR:log-cache-key
PiperOrigin-RevId: 656364840
2024-07-26 05:32:11 -07:00
Ayaka
bb160cf54e Move TPU ops test to ops_test.py
Move the TPU ops test from `tpu_ops_test.py` to `ops_test.py`. The functions tested in this file are not TPU-specific operations, so we don't need a separate test file.

PiperOrigin-RevId: 656347969
2024-07-26 04:24:13 -07:00
Sergei Lebedev
8d33a6c9a6 Bumped jaxlib version mypy uses on the CI
I also enabled unnecessary cast checking, because turns out we have quite
a few of those.
2024-07-26 11:22:39 +01:00
jax authors
2db99e03dd Merge pull request #22283 from ayaka14732:ayx/lowering/sign
PiperOrigin-RevId: 656317943
2024-07-26 02:28:33 -07:00
jax authors
8ed94bcfb6 [shard_map docs]: Fix doc typos
PiperOrigin-RevId: 656265100
2024-07-25 23:29:55 -07:00
Tomás Longeri
0f834cdf24 [Mosaic TPU] Enable lane broadcast for packed types and offsets outside of first tile, and fix some broadcast infer logic
PiperOrigin-RevId: 656201666
2024-07-25 19:48:20 -07:00
Eugene Zhulenev
15d4389247 Use vmap for random_gamma implementation on CPU backend
XLA:CPU is preparing to switch from compiling whole XLA program into a single LLVM function to a mode where each fusion/kernel will have its own entry point, and a thin runtime that will dispatch compute functions concurrently. This execution mode does not work very well with while loops with tiny computations and large number of iterations. Similar to GPU backend use vmap to avoid excessive runtime overheads.

Context: https://github.com/openxla/community/pull/96
PiperOrigin-RevId: 656199716
2024-07-25 19:41:59 -07:00
Ayaka
6cc09173d5 Add lowering for lax.sign 2024-07-26 10:33:42 +08:00