Peter Hawkins
483f924ea1
Bump shard count for experimental_rnn_test, which is timing out in CI when built under ASAN.
...
PiperOrigin-RevId: 635850400
2024-05-21 10:25:24 -07:00
jax authors
5350bc960d
Merge pull request #21258 from olupton:skip-cusolver-test-with-cuda-12.4
...
PiperOrigin-RevId: 635845926
2024-05-21 10:11:54 -07:00
Kyle Lucke
418b68828a
Automated Code Change
...
PiperOrigin-RevId: 635818645
2024-05-21 08:40:34 -07:00
Olli Lupton
9ba77f8ecd
Skip a test when run with cuSolver >= 11.6
...
This version is shipped with CUDA 12.4. The test assumes that a
workspace size baked in with an older version of cuSolver can be used
with a newer version of cuSolver. This is not safe, and leads to an
error when upgrading from 11.5 to 11.6.
2024-05-21 14:46:43 +00:00
Dan Suh
4394bdc2ad
Change the log message in pxla.py
to be less confusing.
...
PiperOrigin-RevId: 635789016
2024-05-21 06:42:07 -07:00
jax authors
3f1b059503
Update XLA dependency to use revision
...
c82597f555
.
PiperOrigin-RevId: 635634941
2024-05-20 18:47:32 -07:00
Jake VanderPlas
d33a5689de
Refactor & test internal deprecation APIs
...
The names and APIs were previously too similar and therefore somewhat confusing; this will be more clear I think.
PiperOrigin-RevId: 635615163
2024-05-20 17:16:31 -07:00
jax authors
2eff241f62
Merge pull request #21319 from gnecula:exp_fix_mesh
...
PiperOrigin-RevId: 635611557
2024-05-20 17:02:44 -07:00
George Necula
6deeee27db
[export] Fix device assignment error for grad of exported.
...
Currently, the export code uses a manufactured device assignment
for exporting the VJP function. We should use instead the same
device assigment that was used when exporting the primal function.
This PR fixes that for the case when the export is done through
the direct use of `jax.experimental.export`, and leaves as future
work the case when the use is from `jax2tf`. We add a disabled
tests for the latter case.
Bug: #21314
2024-05-20 16:11:01 -07:00
Tomás Longeri
b197ae527e
[Mosaic] Also check bitwidth in apply-vector-layout's layoutIsValidForValue
.
...
PiperOrigin-RevId: 635595321
2024-05-20 15:57:08 -07:00
jax authors
118ca21b5b
Merge pull request #21318 from jakevdp:fix-mypy
...
PiperOrigin-RevId: 635553370
2024-05-20 13:31:33 -07:00
Jake VanderPlas
329ab036ee
CI: fix mypy error
2024-05-20 13:23:15 -07:00
Shanbin Ke
06d2e489eb
Copybara import of the project:
...
--
f625317cc80639178882316df6f8775294adc6b7 by cjkkkk <ske@nvidia.com>:
init
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/21228 from Cjkkkk:sdpa_new_cudnn_frontend f625317cc80639178882316df6f8775294adc6b7
PiperOrigin-RevId: 635518631
2024-05-20 11:31:15 -07:00
George Karpenkov
e0a6453a39
Simplify JAX lowering rules for cumulative sum
...
Upstream fix has landed => removing CPU workaround.
PiperOrigin-RevId: 635505632
2024-05-20 10:52:29 -07:00
jax authors
cc3a380f76
Add unit test to check if the backend serialization/deserialization result equal to the original executable.
...
PiperOrigin-RevId: 635485374
2024-05-20 09:52:38 -07:00
jax authors
61ff828715
Add support for TPU delay in Mosaic
...
PiperOrigin-RevId: 635473532
2024-05-20 09:07:56 -07:00
jax authors
2f45830de5
[Mosaic GPU] Prepare matmul example so it can be exposed to other projects.
...
PiperOrigin-RevId: 635442413
2024-05-20 06:54:44 -07:00
Sergei Lebedev
f600caa2c4
Use @register_lowering to register Pallas GPU lowering rules
...
This leads to slightly more compact code, but should otherwise be identical.
PiperOrigin-RevId: 635442002
2024-05-20 06:51:33 -07:00
jax authors
c3d0b0d12c
Merge pull request #21305 from jakevdp:scalar-bool
...
PiperOrigin-RevId: 635436437
2024-05-20 06:20:32 -07:00
jax authors
974c72b9a1
Merge pull request #21292 from ROCm:rv_stable_051624
...
PiperOrigin-RevId: 635430659
2024-05-20 05:52:36 -07:00
Jake VanderPlas
4bac10e750
Finalize deprecation of the config module.
...
To configure JAX, use `import jax` and reference the config object via `jax.config`.
PiperOrigin-RevId: 635430169
2024-05-20 05:49:31 -07:00
jax authors
bb616eff8a
Merge pull request #21231 from nouiz:doc_experimental_serialize_executable
...
PiperOrigin-RevId: 635428472
2024-05-20 05:39:21 -07:00
Jake VanderPlas
5b28170b94
Support scalar boolean indices in arr.at[idx].set(vals)
2024-05-20 05:33:36 -07:00
Adam Paszke
53ec2cd26f
Add notap tag to Mosaic tests
...
PiperOrigin-RevId: 635379982
2024-05-20 01:35:56 -07:00
jax authors
ffdb9bb0b0
Update XLA dependency to use revision
...
4c566c945a
.
PiperOrigin-RevId: 635307171
2024-05-19 18:27:13 -07:00
Vadym Matsishevskyi
45a7c22e93
fix: Update hermetic python dependencies to numpy=2.0.0rc2 and scipy=1.13.0 for all python version
...
Also install built jaxlib in hermetic python to support //jax:build_jaxlib=false tests.
PiperOrigin-RevId: 635169327
2024-05-18 23:39:09 -07:00
jax authors
8caeaa29f4
Update XLA dependency to use revision
...
ea6c3d5916
.
PiperOrigin-RevId: 635132837
2024-05-18 19:18:57 -07:00
Yash Katariya
6577f47b83
Make eqn.ctx
context manager thread safe by creating eqn.ctx.manager
.
...
PiperOrigin-RevId: 635057475
2024-05-18 08:46:18 -07:00
jax authors
e3a7a87f92
Update XLA dependency to use revision
...
e631101673
.
PiperOrigin-RevId: 634954649
2024-05-17 20:09:52 -07:00
Yash Katariya
25aa13c46b
Support remat + compute_on. If the rematted computation is annotated to run on host, the backward pass will also execute on host. Also enable no-op nested compute tests.
...
PiperOrigin-RevId: 634943450
2024-05-17 18:59:49 -07:00
Ruturaj4
79fccf6c82
add cholesky changes in bazel
2024-05-18 00:37:09 +00:00
jax authors
641d5c8be3
jax/pallas support ellipsis indexing
...
PiperOrigin-RevId: 634922391
2024-05-17 16:57:53 -07:00
Yash Katariya
02c19e9600
Make jax.grad
and compute_on
work correctly. If the forward pass has annotation to execute on CPU, then it's backward pass also executes on CPU.
...
PiperOrigin-RevId: 634917402
2024-05-17 16:38:35 -07:00
Ashish Shenoy
1043e24f6a
Add quantization support for PagedAttention TPU Pallas kernel.
...
PiperOrigin-RevId: 634914369
2024-05-17 16:17:33 -07:00
Yash Katariya
2d6d408b19
Initial commit for jax.experimental.compute_on
API.
...
The current supported values for compute type is `device_host`, `device`. `device_sparse` will be allowed in follow up CL. Using `device_host` means that the device's PJRT client will be orchestrating the execution of the computation on the host.
`cpu` as a compute_type is reserved for pure CPU only computations without a device's pjrt client orchestrating the computation.
PiperOrigin-RevId: 634909918
2024-05-17 15:59:21 -07:00
jax authors
7a3fc7113b
Merge pull request #21289 from jakevdp:stacklevel
...
PiperOrigin-RevId: 634895282
2024-05-17 15:01:37 -07:00
Jake VanderPlas
9ad6729b5e
jax.experimental.export: fix stacklevel for warning
2024-05-17 14:36:34 -07:00
jax authors
56301d471c
Merge pull request #21283 from dfm:gh21279
...
PiperOrigin-RevId: 634819832
2024-05-17 10:32:37 -07:00
jax authors
815256687f
[pallas:Mosaic GPU] Configurable smem scratch and a small bug fix in Mosaic GPU
...
PiperOrigin-RevId: 634813241
2024-05-17 10:10:20 -07:00
Dan Foreman-Mackey
be0695474a
Add RegularGridInterpolator to generated API docs
...
In responding to gh21279, I noticed that `RegularGridInterpolator` isn't
currently listed in the API docs. I know that `scipy.interpolate` is out
of scope (https://jax.readthedocs.io/en/latest/jep/18137-numpy-scipy-scope.html#scipy-interpolate ),
but since we do currently provide this wrapper, it seems like it makes
sense to include it in the docs!
2024-05-17 11:24:35 -04:00
jax authors
e93f36aa7c
Merge pull request #21281 from jaro-sevcik:enable-host-offloading-scan-tests-gpu
...
PiperOrigin-RevId: 634770894
2024-05-17 07:39:24 -07:00
Jaroslav Sevcik
7be7c1eace
Enable remat-scan offloading test
2024-05-17 07:29:31 -07:00
jax authors
f87be35b0f
[Mosaic GPU] reduce_sum does an intra-warp reduction before communicating with the other warps
...
PiperOrigin-RevId: 634765339
2024-05-17 07:24:35 -07:00
Sergei Lebedev
210f8bbfca
Use absl.flags.FlagHolder to defined --mosaic_gpu_debug
...
PiperOrigin-RevId: 634713545
2024-05-17 04:33:17 -07:00
jax authors
c4559115ec
Internal BUILD file change
...
PiperOrigin-RevId: 634713068
2024-05-17 04:30:21 -07:00
Sergei Lebedev
527aef3a01
Added a slow (but working!) implementation of layer norm in Pallas via Mosaic GPU
...
PiperOrigin-RevId: 634710243
2024-05-17 04:15:40 -07:00
jax authors
5e2710c2c2
Merge pull request #21261 from superbobry:mypy-ruff
...
PiperOrigin-RevId: 634654578
2024-05-17 00:10:27 -07:00
jax authors
1829a66739
Merge pull request #21268 from jakevdp:register-dataclass
...
PiperOrigin-RevId: 634624518
2024-05-16 21:27:30 -07:00
jax authors
0e9243391b
[Mosaic GPU] Add a WGSplatLayout that trivially supports reshape and broadcast.
...
PiperOrigin-RevId: 634610004
2024-05-16 20:04:05 -07:00
jax authors
efa420b299
Update XLA dependency to use revision
...
0821ce5408
.
PiperOrigin-RevId: 634608759
2024-05-16 20:01:07 -07:00