15481 Commits

Author SHA1 Message Date
George Necula
befb449f05 [shape_poly] Fixed bug with dimension variables in unused args
JAX will aggressively drop module input arguments if they are not
used. This can interfere with shape polymorphism, because it may
result in dropping arguments from which we need to derive the
values of shape variables.

We fix this for now by disabling dropping arguments if there
are dimension variables in the arguments shapes. A more precise
technique would be to force keeping only of arguments that we
need for deriving the dimension variables. However, that would be
a much more involved change, for an uncertain benefit.
2023-03-27 13:37:39 +02:00
George Necula
99facbab2a [jax2tf] Turn an error into a warning with native serialization
We want to allow using native_serialization_platforms even if the native_serialization is False. This is useful for code that is runnable with and without native serialization.

PiperOrigin-RevId: 519649827
2023-03-27 01:24:56 -07:00
Ravin Kumar
08a8a5e812
Fix hessian llnk 2023-03-26 19:21:52 -07:00
Ravin Kumar
8c2549519b
Update user_guides.rst
Fix minor typo
2023-03-26 17:21:35 -07:00
Peter Hawkins
b62f114524 Add support for using pip-installed CUDA wheels.
Add a currently undocumented jax[cuda11_pip] and jax[cuda12_pip] that depend on the pip CUDA wheels.
Add a currently undocumented jax[cuda11_local] and jax[cuda12_local] that avoid the CUDA wheel dependency.
2023-03-26 12:35:00 +00:00
Peter Hawkins
ec427f2c95 Split dtype argument from other arguments in special functions.
This helps pytype to determine that the arguments are of different kinds, preventing type errors.

PiperOrigin-RevId: 519401250
2023-03-25 11:41:14 -07:00
Peter Hawkins
05319b5b87 Suppress mypy warnings about missing imports. 2023-03-25 09:45:55 -04:00
Yash Katariya
a5d308542e Add src argument to device_put as an experimental arg
PiperOrigin-RevId: 519308082
2023-03-24 21:10:26 -07:00
Yash Katariya
5da4a5ae76 Add SDA deprecation warning to pytest.ini
PiperOrigin-RevId: 519281775
2023-03-24 18:04:07 -07:00
jax authors
c572155cc1 Merge pull request #15212 from google:pjrt_c_api_tests
PiperOrigin-RevId: 519276265
2023-03-24 17:27:40 -07:00
Skye Wanderman-Milne
ef5e4a4035 Remove 'pjrt_c_api_unimplemented' pytest mark.
Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
2023-03-24 23:14:54 +00:00
Anish Tondwalkar
6842e98ca1 Migrate regularized_incomplete_beta_p off xla_fallback
PiperOrigin-RevId: 519244597
2023-03-24 14:53:20 -07:00
Anish Tondwalkar
ac44d2c2e3 Migrate besseli0e off xla_fallback
PiperOrigin-RevId: 519241252
2023-03-24 14:39:40 -07:00
Yu-Hang 'Maxin' Tang
caaa0a2669 add build option to create editable jaxlib
Co-authored-by: Yonghao Zhuang <zhuangyh@sjtu.edu.cn>
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
2023-03-24 21:25:26 +00:00
Yash Katariya
257ac6a993 If each host has the full value of the Array, allow fetching it to host. Fixes #15162
Benchmarks:

```
name                  old cpu/op   new cpu/op   delta
np_asarray_8_devices  3.71ms ± 6%  3.32ms ± 7%  -10.48%  (p=0.008 n=5+5)

name                  old time/op  new time/op  delta
np_asarray_8_devices  3.86ms ± 6%  3.49ms ± 7%  -9.72%  (p=0.008 n=5+5)
```

PiperOrigin-RevId: 519222320
2023-03-24 13:21:57 -07:00
Peter Hawkins
6ed66ada0f Delete remote TPU support.
TPU VMs are the only supported way to use TPUs as of JAX 0.4.0.

PiperOrigin-RevId: 519211267
2023-03-24 12:33:33 -07:00
John QiangZhang
fad4e6f95a [1/n] store embedded tf.graph to stablehlo.custom_call
PiperOrigin-RevId: 519194911
2023-03-24 11:27:24 -07:00
Parker Schuh
21541e60b1 Guard ArrayImpl checks by xla_extension_version.
PiperOrigin-RevId: 519191714
2023-03-24 11:15:36 -07:00
Yash Katariya
bc231ee0bf After the SPMD bug fix, always take the _rewriting_take route for getitem instead of bouncing to host.
PiperOrigin-RevId: 519170785
2023-03-24 10:00:41 -07:00
jax authors
61a5686f51 Merge pull request #15175 from nouiz:nightly_ci_keep_alive
PiperOrigin-RevId: 519167122
2023-03-24 09:52:49 -07:00
jax authors
72bb4fe3de Merge pull request #15203 from nouiz:nightly_ci
PiperOrigin-RevId: 519167117
2023-03-24 09:45:21 -07:00
Anish Tondwalkar
8c75e27f67 Migrate random_gamma_grad off xla_fallback
PiperOrigin-RevId: 519154537
2023-03-24 08:49:40 -07:00
Anish Tondwalkar
8d1d522618 Migrate igamma_grad_a_p off xla_fallback
PiperOrigin-RevId: 519148548
2023-03-24 08:21:22 -07:00
Frederic Bastien
229a4cfdb4 remove another dependency not currently needed. 2023-03-24 08:04:27 -07:00
Anish Tondwalkar
4a9b09485e Migrate igammac_p off xla_fallback path
It is now decomposed into stablehlo ops.

PiperOrigin-RevId: 519122775
2023-03-24 05:58:38 -07:00
Anish Tondwalkar
8081031c90 [jaxlib] fix build w/ depenency on stablehlo_serialization
PiperOrigin-RevId: 519120624
2023-03-24 05:42:38 -07:00
jax authors
32e712864c Merge pull request #15192 from mattjj:issue15190
PiperOrigin-RevId: 519037959
2023-03-23 20:48:33 -07:00
jax authors
1982a113d6 Merge pull request #15187 from mattjj:djax-revival
PiperOrigin-RevId: 519036576
2023-03-23 20:38:01 -07:00
Matthew Johnson
7743fcd758 [dynamic-shapes] make dynamic shape staging-to-jaxpr work with pjit 2023-03-23 20:20:01 -07:00
Matthew Johnson
793387e496 fix jax.Array.round()
fixes #15190
2023-03-23 20:16:23 -07:00
Skye Wanderman-Milne
4cb3b011a0 Remove PJRT C API bypass.
Now that all functionality needed by frameworks is implemented, let's
remove the possibility of not noticing missing functionality due to
the bypass.

PiperOrigin-RevId: 519018438
2023-03-23 18:39:14 -07:00
jax authors
e36bd57f9f Merge pull request #14309 from hawkinsp:numpy
PiperOrigin-RevId: 519015703
2023-03-23 18:24:01 -07:00
Peter Hawkins
b7375b316b Increase minimum NumPy version to 1.21.
Also increase minimum SciPy version to 1.7, which was released just before NumPy 1.21.
2023-03-23 21:15:10 -04:00
jax authors
e9bc7ee866 Merge pull request #15184 from jakevdp:move-median
PiperOrigin-RevId: 519003606
2023-03-23 17:14:50 -07:00
Jake VanderPlas
6f8885a0c2 lax_numpy: move quantile-based functions to reductions.py 2023-03-23 16:39:20 -07:00
Anish Tondwalkar
f981243af5 Migrate igamma_p off xla_fallback
We decompose it into a series or a call to igammac.

PiperOrigin-RevId: 518993077
2023-03-23 16:26:59 -07:00
George Necula
d777cf229e [jax2tf] A simple failing test on TPU with native serialization
PiperOrigin-RevId: 518987577
2023-03-23 16:04:53 -07:00
John QiangZhang
171b22dbbc Add padding option "SAME_LOWER" for ticket https://github.com/google/jax/pull/14990
PiperOrigin-RevId: 518984018
2023-03-23 15:50:16 -07:00
George Necula
7df52459d0 [jax2tf] Create a jax_export library with JAX-only pieces for native serialization
This is a pure refactor, no functionality should change.

PiperOrigin-RevId: 518982222
2023-03-23 15:42:45 -07:00
jax authors
6407fa63ae Merge pull request #15178 from mattjj:improve-scan-errors
PiperOrigin-RevId: 518977054
2023-03-23 15:28:16 -07:00
Anish Tondwalkar
adbdaa47a3 Refactor special functions into their own module.
We're going to want to decompose these using series and
continued fraction representations, and for that we'll need
control flow

PiperOrigin-RevId: 518977008
2023-03-23 15:21:15 -07:00
jax authors
1f7c305cf6 Merge pull request #15172 from jakevdp:jax-array-refactor
PiperOrigin-RevId: 518971494
2023-03-23 15:00:15 -07:00
Matthew Johnson
ba2ff519ca improve scan error messages 2023-03-23 14:53:05 -07:00
George Necula
1136d0f6c7 [jax2tf] Minor addition to the documentation
PiperOrigin-RevId: 518969936
2023-03-23 14:52:01 -07:00
jax authors
d600c5aad4 Merge pull request #15180 from cgarciae:add-trailing-whitespaces-hook
PiperOrigin-RevId: 518967503
2023-03-23 14:42:55 -07:00
Cristian Garcia
dfc24f2981 add trailing-whitespace pre-commit hook 2023-03-23 21:17:54 +00:00
jax authors
d857187503 Merge pull request #15179 from jakevdp:fix-jet-mypy
PiperOrigin-RevId: 518958675
2023-03-23 14:07:27 -07:00
Jake VanderPlas
1286446b52 Fix mypy issue in jax/experimental/jet.py 2023-03-23 13:56:11 -07:00
Frederic Bastien
f3be75cb53 WAR ssh timeout like:
client_loop: send disconnect: Broken pipe
https://github.com/google/jax/actions/runs/4500333187/jobs/7919324156#step:8:42
2023-03-23 11:58:57 -07:00
Yash Katariya
a9e48af260 Deprecated xla_call_p since it has been replaced with pjit.pjit_p
PiperOrigin-RevId: 518921538
2023-03-23 11:44:42 -07:00