20801 Commits

Author SHA1 Message Date
Peter Hawkins
483f924ea1 Bump shard count for experimental_rnn_test, which is timing out in CI when built under ASAN.
PiperOrigin-RevId: 635850400
2024-05-21 10:25:24 -07:00
jax authors
5350bc960d Merge pull request #21258 from olupton:skip-cusolver-test-with-cuda-12.4
PiperOrigin-RevId: 635845926
2024-05-21 10:11:54 -07:00
Kyle Lucke
418b68828a Automated Code Change
PiperOrigin-RevId: 635818645
2024-05-21 08:40:34 -07:00
Olli Lupton
9ba77f8ecd Skip a test when run with cuSolver >= 11.6
This version is shipped with CUDA 12.4. The test assumes that a
workspace size baked in with an older version of cuSolver can be used
with a newer version of cuSolver. This is not safe, and leads to an
error when upgrading from 11.5 to 11.6.
2024-05-21 14:46:43 +00:00
Dan Suh
4394bdc2ad Change the log message in pxla.py to be less confusing.
PiperOrigin-RevId: 635789016
2024-05-21 06:42:07 -07:00
jax authors
3f1b059503 Update XLA dependency to use revision
c82597f555.

PiperOrigin-RevId: 635634941
2024-05-20 18:47:32 -07:00
Jake VanderPlas
d33a5689de Refactor & test internal deprecation APIs
The names and APIs were previously too similar and therefore somewhat confusing; this will be more clear I think.

PiperOrigin-RevId: 635615163
2024-05-20 17:16:31 -07:00
jax authors
2eff241f62 Merge pull request #21319 from gnecula:exp_fix_mesh
PiperOrigin-RevId: 635611557
2024-05-20 17:02:44 -07:00
George Necula
6deeee27db [export] Fix device assignment error for grad of exported.
Currently, the export code uses a manufactured device assignment
for exporting the VJP function. We should use instead the same
device assigment that was used when exporting the primal function.

This PR fixes that for the case when the export is done through
the direct use of `jax.experimental.export`, and leaves as future
work the case when the use is from `jax2tf`. We add a disabled
tests for the latter case.

Bug: #21314
2024-05-20 16:11:01 -07:00
Tomás Longeri
b197ae527e [Mosaic] Also check bitwidth in apply-vector-layout's layoutIsValidForValue.
PiperOrigin-RevId: 635595321
2024-05-20 15:57:08 -07:00
jax authors
118ca21b5b Merge pull request #21318 from jakevdp:fix-mypy
PiperOrigin-RevId: 635553370
2024-05-20 13:31:33 -07:00
Jake VanderPlas
329ab036ee CI: fix mypy error 2024-05-20 13:23:15 -07:00
Shanbin Ke
06d2e489eb Copybara import of the project:
--
f625317cc80639178882316df6f8775294adc6b7 by cjkkkk <ske@nvidia.com>:

init

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21228 from Cjkkkk:sdpa_new_cudnn_frontend f625317cc80639178882316df6f8775294adc6b7
PiperOrigin-RevId: 635518631
2024-05-20 11:31:15 -07:00
George Karpenkov
e0a6453a39 Simplify JAX lowering rules for cumulative sum
Upstream fix has landed => removing CPU workaround.

PiperOrigin-RevId: 635505632
2024-05-20 10:52:29 -07:00
jax authors
cc3a380f76 Add unit test to check if the backend serialization/deserialization result equal to the original executable.
PiperOrigin-RevId: 635485374
2024-05-20 09:52:38 -07:00
jax authors
61ff828715 Add support for TPU delay in Mosaic
PiperOrigin-RevId: 635473532
2024-05-20 09:07:56 -07:00
jax authors
2f45830de5 [Mosaic GPU] Prepare matmul example so it can be exposed to other projects.
PiperOrigin-RevId: 635442413
2024-05-20 06:54:44 -07:00
Sergei Lebedev
f600caa2c4 Use @register_lowering to register Pallas GPU lowering rules
This leads to slightly more compact code, but should otherwise be identical.

PiperOrigin-RevId: 635442002
2024-05-20 06:51:33 -07:00
jax authors
c3d0b0d12c Merge pull request #21305 from jakevdp:scalar-bool
PiperOrigin-RevId: 635436437
2024-05-20 06:20:32 -07:00
jax authors
974c72b9a1 Merge pull request #21292 from ROCm:rv_stable_051624
PiperOrigin-RevId: 635430659
2024-05-20 05:52:36 -07:00
Jake VanderPlas
4bac10e750 Finalize deprecation of the config module.
To configure JAX, use `import jax` and reference the config object via `jax.config`.

PiperOrigin-RevId: 635430169
2024-05-20 05:49:31 -07:00
jax authors
bb616eff8a Merge pull request #21231 from nouiz:doc_experimental_serialize_executable
PiperOrigin-RevId: 635428472
2024-05-20 05:39:21 -07:00
Jake VanderPlas
5b28170b94 Support scalar boolean indices in arr.at[idx].set(vals) 2024-05-20 05:33:36 -07:00
Adam Paszke
53ec2cd26f Add notap tag to Mosaic tests
PiperOrigin-RevId: 635379982
2024-05-20 01:35:56 -07:00
jax authors
ffdb9bb0b0 Update XLA dependency to use revision
4c566c945a.

PiperOrigin-RevId: 635307171
2024-05-19 18:27:13 -07:00
Vadym Matsishevskyi
45a7c22e93 fix: Update hermetic python dependencies to numpy=2.0.0rc2 and scipy=1.13.0 for all python version
Also install built jaxlib in hermetic python to support //jax:build_jaxlib=false tests.

PiperOrigin-RevId: 635169327
2024-05-18 23:39:09 -07:00
jax authors
8caeaa29f4 Update XLA dependency to use revision
ea6c3d5916.

PiperOrigin-RevId: 635132837
2024-05-18 19:18:57 -07:00
Yash Katariya
6577f47b83 Make eqn.ctx context manager thread safe by creating eqn.ctx.manager.
PiperOrigin-RevId: 635057475
2024-05-18 08:46:18 -07:00
jax authors
e3a7a87f92 Update XLA dependency to use revision
e631101673.

PiperOrigin-RevId: 634954649
2024-05-17 20:09:52 -07:00
Yash Katariya
25aa13c46b Support remat + compute_on. If the rematted computation is annotated to run on host, the backward pass will also execute on host. Also enable no-op nested compute tests.
PiperOrigin-RevId: 634943450
2024-05-17 18:59:49 -07:00
Ruturaj4
79fccf6c82 add cholesky changes in bazel 2024-05-18 00:37:09 +00:00
jax authors
641d5c8be3 jax/pallas support ellipsis indexing
PiperOrigin-RevId: 634922391
2024-05-17 16:57:53 -07:00
Yash Katariya
02c19e9600 Make jax.grad and compute_on work correctly. If the forward pass has annotation to execute on CPU, then it's backward pass also executes on CPU.
PiperOrigin-RevId: 634917402
2024-05-17 16:38:35 -07:00
Ashish Shenoy
1043e24f6a Add quantization support for PagedAttention TPU Pallas kernel.
PiperOrigin-RevId: 634914369
2024-05-17 16:17:33 -07:00
Yash Katariya
2d6d408b19 Initial commit for jax.experimental.compute_on API.
The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.

`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.

PiperOrigin-RevId: 634909918
2024-05-17 15:59:21 -07:00
jax authors
7a3fc7113b Merge pull request #21289 from jakevdp:stacklevel
PiperOrigin-RevId: 634895282
2024-05-17 15:01:37 -07:00
Jake VanderPlas
9ad6729b5e jax.experimental.export: fix stacklevel for warning 2024-05-17 14:36:34 -07:00
jax authors
56301d471c Merge pull request #21283 from dfm:gh21279
PiperOrigin-RevId: 634819832
2024-05-17 10:32:37 -07:00
jax authors
815256687f [pallas:Mosaic GPU] Configurable smem scratch and a small bug fix in Mosaic GPU
PiperOrigin-RevId: 634813241
2024-05-17 10:10:20 -07:00
Dan Foreman-Mackey
be0695474a Add RegularGridInterpolator to generated API docs
In responding to gh21279, I noticed that `RegularGridInterpolator` isn't
currently listed in the API docs. I know that `scipy.interpolate` is out
of scope (https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html#scipy-interpolate),
but since we do currently provide this wrapper, it seems like it makes
sense to include it in the docs!
2024-05-17 11:24:35 -04:00
jax authors
e93f36aa7c Merge pull request #21281 from jaro-sevcik:enable-host-offloading-scan-tests-gpu
PiperOrigin-RevId: 634770894
2024-05-17 07:39:24 -07:00
Jaroslav Sevcik
7be7c1eace Enable remat-scan offloading test 2024-05-17 07:29:31 -07:00
jax authors
f87be35b0f [Mosaic GPU] reduce_sum does an intra-warp reduction before communicating with the other warps
PiperOrigin-RevId: 634765339
2024-05-17 07:24:35 -07:00
Sergei Lebedev
210f8bbfca Use absl.flags.FlagHolder to defined --mosaic_gpu_debug
PiperOrigin-RevId: 634713545
2024-05-17 04:33:17 -07:00
jax authors
c4559115ec Internal BUILD file change
PiperOrigin-RevId: 634713068
2024-05-17 04:30:21 -07:00
Sergei Lebedev
527aef3a01 Added a slow (but working!) implementation of layer norm in Pallas via Mosaic GPU
PiperOrigin-RevId: 634710243
2024-05-17 04:15:40 -07:00
jax authors
5e2710c2c2 Merge pull request #21261 from superbobry:mypy-ruff
PiperOrigin-RevId: 634654578
2024-05-17 00:10:27 -07:00
jax authors
1829a66739 Merge pull request #21268 from jakevdp:register-dataclass
PiperOrigin-RevId: 634624518
2024-05-16 21:27:30 -07:00
jax authors
0e9243391b [Mosaic GPU] Add a WGSplatLayout that trivially supports reshape and broadcast.
PiperOrigin-RevId: 634610004
2024-05-16 20:04:05 -07:00
jax authors
efa420b299 Update XLA dependency to use revision
0821ce5408.

PiperOrigin-RevId: 634608759
2024-05-16 20:01:07 -07:00