Yash Katariya
ae4f1fcb66
Update the commit in workspace too
...
PiperOrigin-RevId: 519797427
jax-v0.4.7
jaxlib-v0.4.7
jax-v0.4.7-rc1
2023-03-27 12:30:18 -07:00
Yash Katariya
e9cac5eb47
Prepare for jax and jaxlib 0.4.7 release
...
PiperOrigin-RevId: 519785176
jax-v0.4.7-rc
2023-03-27 11:45:22 -07:00
Yash Katariya
e21aee18a8
Add deprecation warning for FROM_GDA usage since that argument is not required anymore.
...
PiperOrigin-RevId: 519781715
2023-03-27 11:33:11 -07:00
Sharad Vikram
3c3fa042e3
Copy seq_lengths before creating descriptor
...
PiperOrigin-RevId: 519771897
2023-03-27 10:59:44 -07:00
Peter Hawkins
88c2898e36
Use pytype_strict_library() in Bazel build rules.
...
PiperOrigin-RevId: 519757928
2023-03-27 10:16:08 -07:00
Peter Hawkins
40fb646e35
Fix duplicate definition of 'cuda' extra in setup.py.
...
PiperOrigin-RevId: 519750659
2023-03-27 09:52:37 -07:00
jax authors
af4d4943a7
Merge pull request #8633 from shawwn:2021-11-19/autodidax-fix-jaxpr-subcomp-return-type
...
PiperOrigin-RevId: 519745476
2023-03-27 09:52:20 -07:00
jax authors
10d51c78f6
Merge pull request #15218 from hawkinsp:mypy
...
PiperOrigin-RevId: 519745465
2023-03-27 09:37:54 -07:00
Yash Katariya
41695cc78c
Temporarily fix the compilation cache test which is failing on latest jaxlib release
...
PiperOrigin-RevId: 519745099
2023-03-27 09:37:37 -07:00
jax authors
2c4be6f662
Merge pull request #15226 from canyon289:patch-1
...
PiperOrigin-RevId: 519743393
2023-03-27 09:30:24 -07:00
jax authors
d19e60ea07
Merge pull request #15228 from canyon289:patch-2
...
PiperOrigin-RevId: 519742908
2023-03-27 09:30:07 -07:00
Yash Katariya
cf8c2b8450
Delete benchmark and pmap_benchmark files as they are legacy and replaced with api_benchmark.py
...
PiperOrigin-RevId: 519742866
2023-03-27 09:22:57 -07:00
jax authors
d473e86912
Merge pull request #13008 from hawkinsp:pipcuda
...
PiperOrigin-RevId: 519740461
2023-03-27 09:14:24 -07:00
jax authors
6715736583
Merge pull request #15205 from yhtang:editable-jaxlib-build
...
PiperOrigin-RevId: 519704474
2023-03-27 06:33:31 -07:00
jax authors
f3613a11b9
Merge pull request #15234 from gnecula:get_dim_size
...
PiperOrigin-RevId: 519691037
2023-03-27 05:21:54 -07:00
George Necula
befb449f05
[shape_poly] Fixed bug with dimension variables in unused args
...
JAX will aggressively drop module input arguments if they are not
used. This can interfere with shape polymorphism, because it may
result in dropping arguments from which we need to derive the
values of shape variables.
We fix this for now by disabling dropping arguments if there
are dimension variables in the arguments shapes. A more precise
technique would be to force keeping only of arguments that we
need for deriving the dimension variables. However, that would be
a much more involved change, for an uncertain benefit.
2023-03-27 13:37:39 +02:00
George Necula
99facbab2a
[jax2tf] Turn an error into a warning with native serialization
...
We want to allow using native_serialization_platforms even if the native_serialization is False. This is useful for code that is runnable with and without native serialization.
PiperOrigin-RevId: 519649827
2023-03-27 01:24:56 -07:00
Ravin Kumar
08a8a5e812
Fix hessian llnk
2023-03-26 19:21:52 -07:00
Ravin Kumar
8c2549519b
Update user_guides.rst
...
Fix minor typo
2023-03-26 17:21:35 -07:00
Peter Hawkins
b62f114524
Add support for using pip-installed CUDA wheels.
...
Add a currently undocumented jax[cuda11_pip] and jax[cuda12_pip] that depend on the pip CUDA wheels.
Add a currently undocumented jax[cuda11_local] and jax[cuda12_local] that avoid the CUDA wheel dependency.
2023-03-26 12:35:00 +00:00
Peter Hawkins
ec427f2c95
Split dtype argument from other arguments in special functions.
...
This helps pytype to determine that the arguments are of different kinds, preventing type errors.
PiperOrigin-RevId: 519401250
2023-03-25 11:41:14 -07:00
Peter Hawkins
05319b5b87
Suppress mypy warnings about missing imports.
2023-03-25 09:45:55 -04:00
Yash Katariya
a5d308542e
Add src
argument to device_put as an experimental arg
...
PiperOrigin-RevId: 519308082
2023-03-24 21:10:26 -07:00
Yash Katariya
5da4a5ae76
Add SDA deprecation warning to pytest.ini
...
PiperOrigin-RevId: 519281775
2023-03-24 18:04:07 -07:00
jax authors
c572155cc1
Merge pull request #15212 from google:pjrt_c_api_tests
...
PiperOrigin-RevId: 519276265
2023-03-24 17:27:40 -07:00
Skye Wanderman-Milne
ef5e4a4035
Remove 'pjrt_c_api_unimplemented' pytest mark.
...
Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
2023-03-24 23:14:54 +00:00
Anish Tondwalkar
6842e98ca1
Migrate regularized_incomplete_beta_p off xla_fallback
...
PiperOrigin-RevId: 519244597
2023-03-24 14:53:20 -07:00
Anish Tondwalkar
ac44d2c2e3
Migrate besseli0e off xla_fallback
...
PiperOrigin-RevId: 519241252
2023-03-24 14:39:40 -07:00
Yu-Hang 'Maxin' Tang
caaa0a2669
add build option to create editable jaxlib
...
Co-authored-by: Yonghao Zhuang <zhuangyh@sjtu.edu.cn>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
2023-03-24 21:25:26 +00:00
Yash Katariya
257ac6a993
If each host has the full value of the Array, allow fetching it to host. Fixes #15162
...
Benchmarks:
```
name old cpu/op new cpu/op delta
np_asarray_8_devices 3.71ms ± 6% 3.32ms ± 7% -10.48% (p=0.008 n=5+5)
name old time/op new time/op delta
np_asarray_8_devices 3.86ms ± 6% 3.49ms ± 7% -9.72% (p=0.008 n=5+5)
```
PiperOrigin-RevId: 519222320
2023-03-24 13:21:57 -07:00
Peter Hawkins
6ed66ada0f
Delete remote TPU support.
...
TPU VMs are the only supported way to use TPUs as of JAX 0.4.0.
PiperOrigin-RevId: 519211267
2023-03-24 12:33:33 -07:00
John QiangZhang
fad4e6f95a
[1/n] store embedded tf.graph to stablehlo.custom_call
...
PiperOrigin-RevId: 519194911
2023-03-24 11:27:24 -07:00
Parker Schuh
21541e60b1
Guard ArrayImpl checks by xla_extension_version.
...
PiperOrigin-RevId: 519191714
2023-03-24 11:15:36 -07:00
Yash Katariya
bc231ee0bf
After the SPMD bug fix, always take the _rewriting_take route for getitem instead of bouncing to host.
...
PiperOrigin-RevId: 519170785
2023-03-24 10:00:41 -07:00
jax authors
61a5686f51
Merge pull request #15175 from nouiz:nightly_ci_keep_alive
...
PiperOrigin-RevId: 519167122
2023-03-24 09:52:49 -07:00
jax authors
72bb4fe3de
Merge pull request #15203 from nouiz:nightly_ci
...
PiperOrigin-RevId: 519167117
2023-03-24 09:45:21 -07:00
Anish Tondwalkar
8c75e27f67
Migrate random_gamma_grad off xla_fallback
...
PiperOrigin-RevId: 519154537
2023-03-24 08:49:40 -07:00
Anish Tondwalkar
8d1d522618
Migrate igamma_grad_a_p off xla_fallback
...
PiperOrigin-RevId: 519148548
2023-03-24 08:21:22 -07:00
Frederic Bastien
229a4cfdb4
remove another dependency not currently needed.
2023-03-24 08:04:27 -07:00
Anish Tondwalkar
4a9b09485e
Migrate igammac_p off xla_fallback path
...
It is now decomposed into stablehlo ops.
PiperOrigin-RevId: 519122775
2023-03-24 05:58:38 -07:00
Anish Tondwalkar
8081031c90
[jaxlib] fix build w/ depenency on stablehlo_serialization
...
PiperOrigin-RevId: 519120624
2023-03-24 05:42:38 -07:00
jax authors
32e712864c
Merge pull request #15192 from mattjj:issue15190
...
PiperOrigin-RevId: 519037959
2023-03-23 20:48:33 -07:00
jax authors
1982a113d6
Merge pull request #15187 from mattjj:djax-revival
...
PiperOrigin-RevId: 519036576
2023-03-23 20:38:01 -07:00
Matthew Johnson
7743fcd758
[dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit
2023-03-23 20:20:01 -07:00
Matthew Johnson
793387e496
fix jax.Array.round()
...
fixes #15190
2023-03-23 20:16:23 -07:00
Skye Wanderman-Milne
4cb3b011a0
Remove PJRT C API bypass.
...
Now that all functionality needed by frameworks is implemented, let's
remove the possibility of not noticing missing functionality due to
the bypass.
PiperOrigin-RevId: 519018438
2023-03-23 18:39:14 -07:00
jax authors
e36bd57f9f
Merge pull request #14309 from hawkinsp:numpy
...
PiperOrigin-RevId: 519015703
2023-03-23 18:24:01 -07:00
Peter Hawkins
b7375b316b
Increase minimum NumPy version to 1.21.
...
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
jax authors
e9bc7ee866
Merge pull request #15184 from jakevdp:move-median
...
PiperOrigin-RevId: 519003606
2023-03-23 17:14:50 -07:00
Jake VanderPlas
6f8885a0c2
lax_numpy: move quantile-based functions to reductions.py
2023-03-23 16:39:20 -07:00