239 Commits

Author SHA1 Message Date
Jake VanderPlas
91dbcbf525 Remove deprecated jax.experimental.stax
The new location is jax.example_libraries.stax
2022-08-02 16:50:06 -07:00
Lena Martens
8ca5ecc7f3 Re-land #11498 after internal fixes.
maintain an alias to `jax.tree_util.tree_map` in the top level `jax` module

PiperOrigin-RevId: 463885774
2022-07-28 11:33:34 -07:00
Jake VanderPlas
108376d792 Remove deprecated function jax.tree_util.tree_multimap 2022-07-26 09:37:27 -07:00
George Necula
afa8f5acb4 Remove jax.experimental.loops. See CHANGELOG
PiperOrigin-RevId: 463297399
2022-07-26 03:39:47 -07:00
Jake VanderPlas
bc90743603 Update changelog for jax/jaxlib v0.3.15 release 2022-07-25 09:47:44 -07:00
George Necula
66dc95e2de removes the jax.mask and jax.shapecheck APIs.
PiperOrigin-RevId: 463026577
2022-07-25 01:23:38 -07:00
George Necula
07fcf79324 jax.mask and jax.shapecheck are being deprecated
Issue: #11557
PiperOrigin-RevId: 462315754
2022-07-21 00:09:31 -07:00
Jake VanderPlas
9090dd179d jax.scipy.linalg.solve: deprecate the sym_pos argument following scipy 1.9.0 2022-07-19 13:57:49 -07:00
jax authors
023e6f5955 Copybara import of the project:
--
e1f1e93e0c8b53e62a064b06b56c84a2bfedb911 by Roy Frostig <frostig@google.com>:

maintain an alias to `jax.tree_util.tree_map` in the top level `jax` module

PiperOrigin-RevId: 461146464
2022-07-15 01:23:51 -07:00
Roy Frostig
e1f1e93e0c maintain an alias to jax.tree_util.tree_map in the top level jax module 2022-07-14 11:00:54 -07:00
Jake VanderPlas
ce08a9fc5c Deprecate top-level aliases of jax.tree_util functions 2022-07-07 11:41:46 -07:00
Jake VanderPlas
39b0ff7eb6 jnp.ndarray: raise TypeError for binary operations with builtin collections 2022-06-29 08:22:05 -07:00
Dan F-M
0788d5708a Implementation of jax.scipy.stats.gaussian_kde 2022-06-28 15:17:12 -04:00
Jake VanderPlas
887abbc3b9 jax.test_util: remove deprecated test classes.
JaxTestCase and JaxTestLoader were deprecated in jax v0.3.1, released Feb 2022.
2022-06-27 11:04:50 -07:00
Peter Hawkins
1e29b7b762 Update CHANGELOG.md and setup.py for 0.3.14 release. 2022-06-27 09:38:41 -04:00
jax authors
406a61cf52 Merge pull request #11146 from sshahrokhi:AbortIfNotInitialized
PiperOrigin-RevId: 457115405
2022-06-24 16:24:57 -07:00
Shiva Shahrokhi
df8c6263de Change JAX_PLATFORMS to raise an exception when platform initialization fails 2022-06-24 21:54:53 +00:00
Ian McKenzie
0cc2ada432 Fix broken links for moved design_notes folder 2022-06-24 12:18:11 -07:00
Jake VanderPlas
f6476f7a03 jnp.roots: better support for computation under JIT 2022-06-23 14:48:53 -07:00
Sharad Vikram
9bd1bd67e0 Update versions for jax/jaxlib release 2022-06-21 12:57:28 -07:00
carlosgmartin
57b89ba7cb Added scipy.stats.gennorm. 2022-06-14 13:38:24 -04:00
jax authors
b174b7751b Merge pull request #10771 from sshahrokhi:gfilecache
PiperOrigin-RevId: 454692872
2022-06-13 13:58:15 -07:00
Shiva Shahrokhi
498ee6007d Using etils(gfile) to support gcs buckets and file system for persistent compilation caching 2022-06-10 00:17:13 +00:00
Jake VanderPlas
d2f80ef117 [x64] deprecate unsafe type casting in scatter-update operations 2022-06-09 15:21:49 -07:00
Sharad Vikram
c0b47fdf2c Update changelog for named_scope and adds it to the docs 2022-06-09 11:22:44 -07:00
Skye Wanderman-Milne
f86282579e Add jax.default_device to CHANGELOG 2022-06-08 14:00:54 -07:00
Sharad Vikram
143ed40a78 Add collect_profile script 2022-06-03 17:56:17 -07:00
carlosgmartin
ca83a80f95 Added random.generalized_normal and random.ball. 2022-06-03 15:11:29 -04:00
jax authors
c73da15d85 Merge pull request #10906 from sharadmv:profiler
PiperOrigin-RevId: 452429099
2022-06-01 18:15:34 -07:00
Sharad Vikram
449da304b3 Store profiler server as a global variable and add a stop_server function 2022-06-01 17:50:06 -07:00
jax authors
2d87a06888 Merge pull request #10944 from hawkinsp:macminver
PiperOrigin-RevId: 452424815
2022-06-01 17:47:41 -07:00
Peter Hawkins
69bda69fb6 Bump minimum Mac OS version to 10.14 (Mojave).
It turns out that the support for C++17 is partial in 10.12, and in particular absl::optional and std::optional are not the same thing under 10.12. Increment to 10.14 which is the lowest version that builds successfully with absl::optional == std::optional.

See: 89cdaed655/absl/base/config.h (L528)
Strictly speaking, we could allow 10.13, but not without updating ABSL in the TF repository to incorporate c86347d4ce which fixes the version detection test to permit 10.13 as well.
2022-06-01 20:32:22 -04:00
Sharad Vikram
76669835ba Add an option to create a perfetto link in the JAX profiler 2022-06-01 15:48:29 -07:00
Jake VanderPlas
358f929681 [x64] jnp.ldexp: avoid implicit 64-bit promotion 2022-06-01 09:14:47 -07:00
Peter Hawkins
b6cdda763b Update changelog to incorporate some recent changes. 2022-05-31 14:03:27 -04:00
Jake VanderPlas
991ad72e24 DeviceArray: Improve support for copy, deepcopy, and pickle 2022-05-19 12:00:58 -07:00
Peter Hawkins
1bcb5e073c Add an implementation of jnp.linalg.slogdet based on QR decomposition.
Adds a non-standard `method` argument to `jnp.linalg.slogdet` to select between the current LU decomposition based implementation (like NumPy) and the QR decomposition implementation.

QR decomposition is more amenable to a high performance batched implementation particularly on TPU hardware because it does not need row pivoting. The same may be true on other hardware also, and having the option is nice either way!

PiperOrigin-RevId: 449271317
2022-05-17 11:24:11 -07:00
Skye Wanderman-Milne
6b926d5551 Update version + CHANGELOG for jax 0.3.13 release 2022-05-16 12:17:07 -07:00
Yash Katariya
6a6605263d Update values after jax release
PiperOrigin-RevId: 448854487
2022-05-15 18:35:46 -07:00
Yash Katariya
1381afc37f Update version after jax release
PiperOrigin-RevId: 448822949
2022-05-15 12:14:26 -07:00
Peter Hawkins
7ba36fc178 Change implementation of jax.scipy.linalg.polar() and jax._src.scipy.eigh to use the QDWH decomposition from jax._src.lax.qdwh.
Remove jax._src.lax.polar.

PiperOrigin-RevId: 448241206
2022-05-12 07:20:52 -07:00
Peter Hawkins
705e241409 Change non-array arguments to jax.lax.linalg functions to be keyword-only arguments.
PiperOrigin-RevId: 448066207
2022-05-11 13:06:54 -07:00
Peter Hawkins
590b9161fe Add a sort_eigenvalues option to lax.linalg.eigh().
An upcoming change to add a more scalable QDWH-based TPU symmetric eigendecomposition requires that we can obtain the TPU eigenvalues unsorted. The option already exists in XLA, so we simply need to plumb it through to the lax primitive.

PiperOrigin-RevId: 448047584
2022-05-11 11:46:03 -07:00
Yash Katariya
ff1a3c40ba jax and jaxlib release
PiperOrigin-RevId: 446295827
2022-05-03 14:52:40 -07:00
Yash Katariya
888e5c6958 Update the version numbers after JAX release.
PiperOrigin-RevId: 446092433
2022-05-02 19:51:11 -07:00
Matthew Johnson
838f22553b update version and changelog for pypi 2022-04-29 20:15:52 -07:00
Tianjian Lu
020849076c [linalg] Add tpu svd lowering rule.
PiperOrigin-RevId: 445533767
2022-04-29 16:43:53 -07:00
jax authors
227e525de2 Merge pull request #10458 from carlosgmartin:random_orthogonal_unitary
PiperOrigin-RevId: 445522278
2022-04-29 15:40:16 -07:00
Carlos Martin
b276c31b75 Added random.orthogonal. 2022-04-29 14:20:50 -04:00
Peter Hawkins
0b470361da Change the default jnp.take mode to "fill".
Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices.

The previous behavior can be approximated using the "clip" or "wrap" fill modes.

PiperOrigin-RevId: 445130143
2022-04-28 06:01:56 -07:00