Gleb Pobudzey
35ba6f78bb
Add /dev/nvidiactl to the list of NVIDIA GPU devices. This is to cover the use case where a subset of GPUs are exposed to Docker/Kubernetes; the container might not necessarily see /dev/nvidia0.
...
Testing:
Tested on a A100 16 GPU VM with only 8 GPU exposed to the container.
PiperOrigin-RevId: 657801141
2024-07-30 18:25:10 -07:00
jax authors
bd14d6ab64
Merge pull request #22768 from jakevdp:array-api-reshape
...
PiperOrigin-RevId: 657782900
2024-07-30 17:15:02 -07:00
jax authors
7583cbb438
Merge pull request #22766 from jakevdp:array-api-cleanup
...
PiperOrigin-RevId: 657775853
2024-07-30 16:50:15 -07:00
jax authors
9c669a74d6
Update XLA dependency to use revision
...
78418c6a4a
.
PiperOrigin-RevId: 657742817
2024-07-30 15:04:30 -07:00
Jake VanderPlas
5198db9fdb
jnp.repeat: add copy argument for Array API
2024-07-30 14:07:08 -07:00
Jake VanderPlas
1259322f86
[array api] remove redundant definitions for clip() & hypot()
2024-07-30 13:51:04 -07:00
Jake VanderPlas
8bcd288621
Raise ValueError for complex inputs to jnp.clip and jnp.hypot.
...
Such inputs were deprecated in JAX v0.4.27, and have been raising a DeprecationWarning for the last several releases.
PiperOrigin-RevId: 657717875
2024-07-30 13:49:37 -07:00
jax authors
b996612865
Merge pull request #22085 from vfdev-5:add-device-kwarg-fftfreq
...
PiperOrigin-RevId: 657707966
2024-07-30 13:20:26 -07:00
jax authors
b1066ee413
Merge pull request #22764 from jakevdp:array-api-methods
...
PiperOrigin-RevId: 657697622
2024-07-30 12:50:48 -07:00
Jake VanderPlas
c2f2b0ed28
[array API] move api metadata into jax.numpy namespace
2024-07-30 12:15:24 -07:00
Kanglan Tang
d7c2b49c5a
Skip test_concrete_layout_in_shardings on GPU backend.
...
PiperOrigin-RevId: 657661214
2024-07-30 11:10:29 -07:00
Jake VanderPlas
ff8e8ad2fe
revert #22734
...
Reverts 5ce66dc1aae67a88a8ed72584bdc3f5a7f712507
PiperOrigin-RevId: 657638187
2024-07-30 10:17:34 -07:00
jax authors
256956ad58
Merge pull request #22704 from gnecula:pallas_better_errors
...
PiperOrigin-RevId: 657571604
2024-07-30 06:39:44 -07:00
Peter Hawkins
c1cd7f9e2d
Drop support for mhlo in JAX's public API.
...
PiperOrigin-RevId: 657551590
2024-07-30 05:29:52 -07:00
George Necula
6d53aaf7d0
[pallas] Improve the error localization
...
* Add the source location information for the index map function to
`BlockMapping`.
* Removed the `compute_index` wrapper around the index_map, so that
we can get the location information for the index_map, not the wrapper.
* Added source location to the errors related to index map functions.
* Added an error if the index map returns something other than integer
scalars.
* Construct BlockSpec origins for arguments using JAX helper functions
to get argument names
* Removed redundant API error tests from tpu_pallas_test.py
2024-07-30 14:11:57 +02:00
vfdev-5
bb1fb3ba45
Follow-up to #22736
...
On adding device kwarg to jnp.fft.fftfreq and jnp.fft.rfftfreq
2024-07-30 05:39:19 +02:00
jax authors
cc212457d2
Merge pull request #22481 from zhenying-liu:offloading
...
PiperOrigin-RevId: 657413977
2024-07-29 19:43:35 -07:00
Yash Katariya
30037547d7
Bump minimum jaxlib version to 0.4.31. The corresponding xla_extension_version is 279 and mlir_api_version is 57
...
PiperOrigin-RevId: 657400413
2024-07-29 18:44:31 -07:00
Yash Katariya
2106a25977
Finish jax and jaxlib v0.4.31 release
...
PiperOrigin-RevId: 657388782
2024-07-29 17:57:37 -07:00
Bixia Zheng
c81f5cd2fc
[xla] Replace debug option xla_use_shardy with execution option
...
use_shardy_partitioner.
Replace the use of xla_use_shardy with use_shardy_partitioner and remove
xla_use_shardy.
PiperOrigin-RevId: 657359119
2024-07-29 16:11:36 -07:00
jax authors
291438403d
Update XLA dependency to use revision
...
ffd724e235
.
PiperOrigin-RevId: 657341038
2024-07-29 15:14:04 -07:00
jax authors
63303973a2
Merge pull request #22716 from superbobry:pallas
...
PiperOrigin-RevId: 657333519
2024-07-29 14:50:01 -07:00
Sergei Lebedev
a44265aa73
Added a trivial discharge rule for debug_callback_p
...
This allows using jax.debug.print with Refs in interpreted Pallas kernels.
2024-07-29 22:26:01 +01:00
jax authors
091eba1955
Merge pull request #22736 from jakevdp:fft-device
...
PiperOrigin-RevId: 657313812
2024-07-29 13:48:55 -07:00
Jake VanderPlas
6516a079f8
[array API] add device argument to fftfreq & rfftfreq
2024-07-29 13:23:54 -07:00
jax authors
5ce66dc1aa
Merge pull request #22734 from jakevdp:array-api-methods
...
PiperOrigin-RevId: 657299896
2024-07-29 13:05:22 -07:00
Yash Katariya
7fd9302785
Start JAX and jaxlib 0.4.31 release
...
PiperOrigin-RevId: 657295431
2024-07-29 12:49:58 -07:00
Jake VanderPlas
00ba7a6d25
[array API] move api metadata into jax.numpy namespace
2024-07-29 12:43:11 -07:00
jax authors
f070c0658f
Merge pull request #22703 from Rifur13:plugin-fix
...
PiperOrigin-RevId: 657283607
2024-07-29 12:10:42 -07:00
jax authors
9beb4f1474
Merge pull request #19760 from Blair-Johnson:fix-pytree-grads-sparse
...
PiperOrigin-RevId: 657258194
2024-07-29 10:59:57 -07:00
Blair-Johnson
802a14cd61
Re-pack gradients of jax.experimental.sparse.grad() to match original pytrees & test cases
2024-07-29 13:04:05 -04:00
jax authors
85e83b508b
Merge pull request #22690 from jakevdp:inplace-doc
...
PiperOrigin-RevId: 657237218
2024-07-29 10:03:43 -07:00
Peter Hawkins
d1c0d993fc
Bump the minimum CUDNN version to v9.1.
...
This actually was already the minimum version since we build with that version, but we needed to tighten the constraints.
Also in passing, drop mentions of CUDA builds from the Windows build instructions. jaxlib hasn't built with CUDA enabled on Windows for a very long time, so it's probably best we just don't mention it.
PiperOrigin-RevId: 657225917
2024-07-29 09:28:47 -07:00
Vladimir Belitskiy
6127baa117
Ignore the Deprecation warning produced about native_serialization=False
.
...
PiperOrigin-RevId: 657221363
2024-07-29 09:11:54 -07:00
Peter Hawkins
fd23b8733d
Bump minimum SciPy version to 1.10.
...
SciPy 1.9.0 was released July 29, 2022, which is 24 months ago
PiperOrigin-RevId: 657215038
2024-07-29 08:50:18 -07:00
Jake VanderPlas
cfa1e78549
Improve documentation for jnp.put, jnp.place, jnp.fill_diagonal
...
These are all the APIs that have an inplace parameter
2024-07-29 08:31:54 -07:00
Gleb Pobudzey
0224235edf
Skip cuda backend initialization if no nvidia GPUs are visible.
2024-07-29 15:12:58 +00:00
jax authors
e78e643b5f
Merge pull request #22593 from gnecula:pallas_more_simplification
...
PiperOrigin-RevId: 657198330
2024-07-29 07:48:52 -07:00
Dan Foreman-Mackey
ff4e0b1214
Rearrange the LAPACK handler definitions in jaxlib to avoid duplicate handler errors.
...
When linking the jaxlib `cpu_kernels` target and importing JAX, we currently silently fail to instantiate the CPU backend. This refactor means that we only ever define one version of the handlers.
PiperOrigin-RevId: 657186057
2024-07-29 06:59:44 -07:00
Vladimir Belitskiy
fef91fb201
Skip tests/mock_gpu_test.py on pytest.
...
PiperOrigin-RevId: 657185249
2024-07-29 06:55:43 -07:00
George Necula
70a11acbb1
[pallas] More simplification of grid mapping and calling convention
...
In previous PR #22552 I have expanded `GridMapping` to encode more
parts of the calling convention. Here we use that new functionality
and clean up some code.
I have removed the internal methods from `BlockSpec` and `GridSpec` because
these classes are part of the API.
I added entries to pallas/CHANGELOG.
2024-07-29 15:53:47 +02:00
George Necula
68972de021
[pallas] Add lowering errors for block shapes that are not supported.
...
Previously these errors came from Mosaic with less useful stack traces, and in the case of GPU we were getting a crash instead of an exception.
PiperOrigin-RevId: 657184114
2024-07-29 06:49:27 -07:00
Sergei Lebedev
ccc4c42ec9
Reduced the input size in PallasCallInputOutputAliasingTest
...
This ensures the test doesn't OOM when running on A100 on the CI.
PiperOrigin-RevId: 657165032
2024-07-29 05:29:45 -07:00
Adam Paszke
a00b659b03
[Mosaic GPU] Fix two subtle issues with kernel lowering
...
1. The MLIR context is created by the user and its lifetime is not
in our control. To avoid depending on it, we serialize the module.
2. The operand and result layout requirements were missing from the custom call.
PiperOrigin-RevId: 657164985
2024-07-29 05:25:50 -07:00
jax authors
6a7822a73b
Update XLA dependency to use revision
...
95e3eea8d2
.
PiperOrigin-RevId: 657003194
2024-07-28 15:32:56 -07:00
jax authors
74649be7ed
Update XLA dependency to use revision
...
89089aa569
.
PiperOrigin-RevId: 656797625
2024-07-27 15:23:41 -07:00
Jake VanderPlas
a17c8d945b
Finalize deprecation of jax.random.shuffle
...
This has been raising a DeprecationWarning for longer than anyone can remember.
PiperOrigin-RevId: 656765001
2024-07-27 11:21:49 -07:00
jax authors
dab15d6fdd
Merge pull request #22684 from froystig:rngdoc
...
PiperOrigin-RevId: 656600958
2024-07-26 19:12:36 -07:00
jax authors
40d569b22e
Update XLA dependency to use revision
...
cf139009c9
.
PiperOrigin-RevId: 656531286
2024-07-26 14:34:09 -07:00
Roy Frostig
f30ebd8586
document vmap peculiarity of experimental RNG implementations
2024-07-26 13:40:16 -07:00