247 Commits

Author SHA1 Message Date
Sharad Vikram
b0fdf10a63 Apply suggestions from code review
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-18 10:50:50 -07:00
Sharad Vikram
393bca122d Expose pure callback and enable rank polymorphic callbacks 2022-08-17 10:56:42 -07:00
Matthew Johnson
d19e34fa4a delete old remat implementation
moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py
2022-08-16 23:16:37 -07:00
Sharad Vikram
841a662894 Update version + changelog 2022-08-11 17:11:52 -07:00
Matthew Johnson
be6f6bfe9f set new jax.remat / jax.checkpoint to be on-by-default 2022-08-10 10:29:38 -07:00
Jake VanderPlas
79406757d0 Remove deprecated jax.experimental.optimizers
The new location is jax.example_libraries.optimizers
2022-08-09 08:50:59 -07:00
Peter Hawkins
c735c6bf0e Increase minimum NumPy version to 1.20.
Per NEP 29, support for 1.19 ended on Jun 21, 2022.
2022-08-06 14:51:14 +00:00
Sharad Vikram
e1da5c7a36
Update CHANGELOG.md 2022-08-02 22:28:52 -07:00
Jake VanderPlas
91dbcbf525 Remove deprecated jax.experimental.stax
The new location is jax.example_libraries.stax
2022-08-02 16:50:06 -07:00
Lena Martens
8ca5ecc7f3 Re-land #11498 after internal fixes.
maintain an alias to `jax.tree_util.tree_map` in the top level `jax` module

PiperOrigin-RevId: 463885774
2022-07-28 11:33:34 -07:00
Jake VanderPlas
108376d792 Remove deprecated function jax.tree_util.tree_multimap 2022-07-26 09:37:27 -07:00
George Necula
afa8f5acb4 Remove jax.experimental.loops. See CHANGELOG
PiperOrigin-RevId: 463297399
2022-07-26 03:39:47 -07:00
Jake VanderPlas
bc90743603 Update changelog for jax/jaxlib v0.3.15 release 2022-07-25 09:47:44 -07:00
George Necula
66dc95e2de removes the jax.mask and jax.shapecheck APIs.
PiperOrigin-RevId: 463026577
2022-07-25 01:23:38 -07:00
George Necula
07fcf79324 jax.mask and jax.shapecheck are being deprecated
Issue: #11557
PiperOrigin-RevId: 462315754
2022-07-21 00:09:31 -07:00
Jake VanderPlas
9090dd179d jax.scipy.linalg.solve: deprecate the sym_pos argument following scipy 1.9.0 2022-07-19 13:57:49 -07:00
jax authors
023e6f5955 Copybara import of the project:
--
e1f1e93e0c8b53e62a064b06b56c84a2bfedb911 by Roy Frostig <frostig@google.com>:

maintain an alias to `jax.tree_util.tree_map` in the top level `jax` module

PiperOrigin-RevId: 461146464
2022-07-15 01:23:51 -07:00
Roy Frostig
e1f1e93e0c maintain an alias to jax.tree_util.tree_map in the top level jax module 2022-07-14 11:00:54 -07:00
Jake VanderPlas
ce08a9fc5c Deprecate top-level aliases of jax.tree_util functions 2022-07-07 11:41:46 -07:00
Jake VanderPlas
39b0ff7eb6 jnp.ndarray: raise TypeError for binary operations with builtin collections 2022-06-29 08:22:05 -07:00
Dan F-M
0788d5708a Implementation of jax.scipy.stats.gaussian_kde 2022-06-28 15:17:12 -04:00
Jake VanderPlas
887abbc3b9 jax.test_util: remove deprecated test classes.
JaxTestCase and JaxTestLoader were deprecated in jax v0.3.1, released Feb 2022.
2022-06-27 11:04:50 -07:00
Peter Hawkins
1e29b7b762 Update CHANGELOG.md and setup.py for 0.3.14 release. 2022-06-27 09:38:41 -04:00
jax authors
406a61cf52 Merge pull request #11146 from sshahrokhi:AbortIfNotInitialized
PiperOrigin-RevId: 457115405
2022-06-24 16:24:57 -07:00
Shiva Shahrokhi
df8c6263de Change JAX_PLATFORMS to raise an exception when platform initialization fails 2022-06-24 21:54:53 +00:00
Ian McKenzie
0cc2ada432 Fix broken links for moved design_notes folder 2022-06-24 12:18:11 -07:00
Jake VanderPlas
f6476f7a03 jnp.roots: better support for computation under JIT 2022-06-23 14:48:53 -07:00
Sharad Vikram
9bd1bd67e0 Update versions for jax/jaxlib release 2022-06-21 12:57:28 -07:00
carlosgmartin
57b89ba7cb Added scipy.stats.gennorm. 2022-06-14 13:38:24 -04:00
jax authors
b174b7751b Merge pull request #10771 from sshahrokhi:gfilecache
PiperOrigin-RevId: 454692872
2022-06-13 13:58:15 -07:00
Shiva Shahrokhi
498ee6007d Using etils(gfile) to support gcs buckets and file system for persistent compilation caching 2022-06-10 00:17:13 +00:00
Jake VanderPlas
d2f80ef117 [x64] deprecate unsafe type casting in scatter-update operations 2022-06-09 15:21:49 -07:00
Sharad Vikram
c0b47fdf2c Update changelog for named_scope and adds it to the docs 2022-06-09 11:22:44 -07:00
Skye Wanderman-Milne
f86282579e Add jax.default_device to CHANGELOG 2022-06-08 14:00:54 -07:00
Sharad Vikram
143ed40a78 Add collect_profile script 2022-06-03 17:56:17 -07:00
carlosgmartin
ca83a80f95 Added random.generalized_normal and random.ball. 2022-06-03 15:11:29 -04:00
jax authors
c73da15d85 Merge pull request #10906 from sharadmv:profiler
PiperOrigin-RevId: 452429099
2022-06-01 18:15:34 -07:00
Sharad Vikram
449da304b3 Store profiler server as a global variable and add a stop_server function 2022-06-01 17:50:06 -07:00
jax authors
2d87a06888 Merge pull request #10944 from hawkinsp:macminver
PiperOrigin-RevId: 452424815
2022-06-01 17:47:41 -07:00
Peter Hawkins
69bda69fb6 Bump minimum Mac OS version to 10.14 (Mojave).
It turns out that the support for C++17 is partial in 10.12, and in particular absl::optional and std::optional are not the same thing under 10.12. Increment to 10.14 which is the lowest version that builds successfully with absl::optional == std::optional.

See: 89cdaed655/absl/base/config.h (L528)
Strictly speaking, we could allow 10.13, but not without updating ABSL in the TF repository to incorporate c86347d4ce which fixes the version detection test to permit 10.13 as well.
2022-06-01 20:32:22 -04:00
Sharad Vikram
76669835ba Add an option to create a perfetto link in the JAX profiler 2022-06-01 15:48:29 -07:00
Jake VanderPlas
358f929681 [x64] jnp.ldexp: avoid implicit 64-bit promotion 2022-06-01 09:14:47 -07:00
Peter Hawkins
b6cdda763b Update changelog to incorporate some recent changes. 2022-05-31 14:03:27 -04:00
Jake VanderPlas
991ad72e24 DeviceArray: Improve support for copy, deepcopy, and pickle 2022-05-19 12:00:58 -07:00
Peter Hawkins
1bcb5e073c Add an implementation of jnp.linalg.slogdet based on QR decomposition.
Adds a non-standard `method` argument to `jnp.linalg.slogdet` to select between the current LU decomposition based implementation (like NumPy) and the QR decomposition implementation.

QR decomposition is more amenable to a high performance batched implementation particularly on TPU hardware because it does not need row pivoting. The same may be true on other hardware also, and having the option is nice either way!

PiperOrigin-RevId: 449271317
2022-05-17 11:24:11 -07:00
Skye Wanderman-Milne
6b926d5551 Update version + CHANGELOG for jax 0.3.13 release 2022-05-16 12:17:07 -07:00
Yash Katariya
6a6605263d Update values after jax release
PiperOrigin-RevId: 448854487
2022-05-15 18:35:46 -07:00
Yash Katariya
1381afc37f Update version after jax release
PiperOrigin-RevId: 448822949
2022-05-15 12:14:26 -07:00
Peter Hawkins
7ba36fc178 Change implementation of jax.scipy.linalg.polar() and jax._src.scipy.eigh to use the QDWH decomposition from jax._src.lax.qdwh.
Remove jax._src.lax.polar.

PiperOrigin-RevId: 448241206
2022-05-12 07:20:52 -07:00
Peter Hawkins
705e241409 Change non-array arguments to jax.lax.linalg functions to be keyword-only arguments.
PiperOrigin-RevId: 448066207
2022-05-11 13:06:54 -07:00