216 Commits

Author SHA1 Message Date
Jake VanderPlas
d2f80ef117 [x64] deprecate unsafe type casting in scatter-update operations 2022-06-09 15:21:49 -07:00
Sharad Vikram
c0b47fdf2c Update changelog for named_scope and adds it to the docs 2022-06-09 11:22:44 -07:00
Skye Wanderman-Milne
f86282579e Add jax.default_device to CHANGELOG 2022-06-08 14:00:54 -07:00
Sharad Vikram
143ed40a78 Add collect_profile script 2022-06-03 17:56:17 -07:00
carlosgmartin
ca83a80f95 Added random.generalized_normal and random.ball. 2022-06-03 15:11:29 -04:00
jax authors
c73da15d85 Merge pull request #10906 from sharadmv:profiler
PiperOrigin-RevId: 452429099
2022-06-01 18:15:34 -07:00
Sharad Vikram
449da304b3 Store profiler server as a global variable and add a stop_server function 2022-06-01 17:50:06 -07:00
jax authors
2d87a06888 Merge pull request #10944 from hawkinsp:macminver
PiperOrigin-RevId: 452424815
2022-06-01 17:47:41 -07:00
Peter Hawkins
69bda69fb6 Bump minimum Mac OS version to 10.14 (Mojave).
It turns out that the support for C++17 is partial in 10.12, and in particular absl::optional and std::optional are not the same thing under 10.12. Increment to 10.14 which is the lowest version that builds successfully with absl::optional == std::optional.

See: 89cdaed655/absl/base/config.h (L528)
Strictly speaking, we could allow 10.13, but not without updating ABSL in the TF repository to incorporate c86347d4ce which fixes the version detection test to permit 10.13 as well.
2022-06-01 20:32:22 -04:00
Sharad Vikram
76669835ba Add an option to create a perfetto link in the JAX profiler 2022-06-01 15:48:29 -07:00
Jake VanderPlas
358f929681 [x64] jnp.ldexp: avoid implicit 64-bit promotion 2022-06-01 09:14:47 -07:00
Peter Hawkins
b6cdda763b Update changelog to incorporate some recent changes. 2022-05-31 14:03:27 -04:00
Jake VanderPlas
991ad72e24 DeviceArray: Improve support for copy, deepcopy, and pickle 2022-05-19 12:00:58 -07:00
Peter Hawkins
1bcb5e073c Add an implementation of jnp.linalg.slogdet based on QR decomposition.
Adds a non-standard `method` argument to `jnp.linalg.slogdet` to select between the current LU decomposition based implementation (like NumPy) and the QR decomposition implementation.

QR decomposition is more amenable to a high performance batched implementation particularly on TPU hardware because it does not need row pivoting. The same may be true on other hardware also, and having the option is nice either way!

PiperOrigin-RevId: 449271317
2022-05-17 11:24:11 -07:00
Skye Wanderman-Milne
6b926d5551 Update version + CHANGELOG for jax 0.3.13 release 2022-05-16 12:17:07 -07:00
Yash Katariya
6a6605263d Update values after jax release
PiperOrigin-RevId: 448854487
2022-05-15 18:35:46 -07:00
Yash Katariya
1381afc37f Update version after jax release
PiperOrigin-RevId: 448822949
2022-05-15 12:14:26 -07:00
Peter Hawkins
7ba36fc178 Change implementation of jax.scipy.linalg.polar() and jax._src.scipy.eigh to use the QDWH decomposition from jax._src.lax.qdwh.
Remove jax._src.lax.polar.

PiperOrigin-RevId: 448241206
2022-05-12 07:20:52 -07:00
Peter Hawkins
705e241409 Change non-array arguments to jax.lax.linalg functions to be keyword-only arguments.
PiperOrigin-RevId: 448066207
2022-05-11 13:06:54 -07:00
Peter Hawkins
590b9161fe Add a sort_eigenvalues option to lax.linalg.eigh().
An upcoming change to add a more scalable QDWH-based TPU symmetric eigendecomposition requires that we can obtain the TPU eigenvalues unsorted. The option already exists in XLA, so we simply need to plumb it through to the lax primitive.

PiperOrigin-RevId: 448047584
2022-05-11 11:46:03 -07:00
Yash Katariya
ff1a3c40ba jax and jaxlib release
PiperOrigin-RevId: 446295827
2022-05-03 14:52:40 -07:00
Yash Katariya
888e5c6958 Update the version numbers after JAX release.
PiperOrigin-RevId: 446092433
2022-05-02 19:51:11 -07:00
Matthew Johnson
838f22553b update version and changelog for pypi 2022-04-29 20:15:52 -07:00
Tianjian Lu
020849076c [linalg] Add tpu svd lowering rule.
PiperOrigin-RevId: 445533767
2022-04-29 16:43:53 -07:00
jax authors
227e525de2 Merge pull request #10458 from carlosgmartin:random_orthogonal_unitary
PiperOrigin-RevId: 445522278
2022-04-29 15:40:16 -07:00
Carlos Martin
b276c31b75 Added random.orthogonal. 2022-04-29 14:20:50 -04:00
Peter Hawkins
0b470361da Change the default jnp.take mode to "fill".
Previously, `jnp.take` defaulted to clamping out-of-bounds indices into range. Now, `jnp.take` returns invalid values (e.g., NaN) for out-of-bounds indices. This change attempts to prevent latent bugs caused by inadvertent out-of-bounds indices.

The previous behavior can be approximated using the "clip" or "wrap" fill modes.

PiperOrigin-RevId: 445130143
2022-04-28 06:01:56 -07:00
Peter Hawkins
7c6a550333 Change the default scatter mode to FILL_OR_DROP.
This is a reasonably safe change, because it has no effect on the forward pass of a computation: the default behavior (PROMISE_IN_BOUNDS) also drops out-of-bounds scatters.

This change does however affect the transpose (gradient) of a scatter with out-of-bounds indices: the gradient of a PROMISE_IN_BOUNDS scatter is a PROMISE_IN_BOUNDS gather, and a PROMISE_IN_BOUNDS gather clips out-of-bounds indices into range. This is not mathematically correct: a dropped scatter index does not contribute to the primal output, and so its transpose should yield a zero cotangent.

After this change, the gradient of a default scatter is a gather with a fill value of 0: i.e., the indices that were dropped do not make gradient contributions, which is mathematically correct.

Separately, I am working towards switching out-of-bounds gather() operations to also have FILL_OR_DROP semantics, although that change is more disruptive because a number of users have out-of-bounds indices in their gather()s.

Issues: https://github.com/google/jax/issues/278 https://github.com/google/jax/issues/9839
PiperOrigin-RevId: 444935241
2022-04-27 12:26:55 -07:00
Jake VanderPlas
822b6aad3b jax.scipy.qr: fix return type for mode='r' 2022-04-26 11:26:56 -07:00
jax authors
cbda72c988 Merge pull request #10397 from YouJiacheng:replace-int-with-operator.index-part2
PiperOrigin-RevId: 444341705
2022-04-25 12:33:33 -07:00
jax authors
6e27156199 Merge pull request #10383 from YouJiacheng:implement-scipy.cluster.vq.vq
PiperOrigin-RevId: 443733421
2022-04-22 12:40:01 -07:00
YouJiacheng
b485b8e5ce implement scipy.cluster.vq.vq
also add no check_finite and overwrite_* docstring for some scipy.linalg functions
2022-04-23 03:14:32 +08:00
YouJiacheng
667d63aa2d replace int with operator.index part2
This change align the behavior of `ravel_multi_index`, `split` and `indices` to their `numpy` counterparts.
Also ensure size argument of `nonzero` should be integer.
The changes with `*space` are only simplification
2022-04-23 01:45:28 +08:00
Jake VanderPlas
a05c97be3f CHANGELOG: update test_util deprecation discussion 2022-04-21 13:37:56 -07:00
Jake VanderPlas
d9508304e4 Deprecate remaining functionality in jax.test_util 2022-04-21 12:12:40 -07:00
Peter Hawkins
74346f464b [JAX] Change jnp.take_along_axis to return invalid (e.g. NaN) values for out-of-bounds indices.
Previously, out-of-bounds indices were clipped into range, but that behavior is error prone. We would rather fail in a more visible way when out-of-bounds indices are used. Future changes will migrate other JAX indexing operations to have the same semantics.

PiperOrigin-RevId: 443390170
2022-04-21 08:52:14 -07:00
Peter Hawkins
4fd824c36f Change jnp.take_along_axis to require that its indices are of integer type.
Previously jnp.take_along_axis silently casted its indices to integers if they were not already integers.

PiperOrigin-RevId: 443124521
2022-04-20 10:05:16 -07:00
Peter Hawkins
a52f07a21b Add an optional mode= argument to jnp.take_along_axis.
This allows users of jnp.take_along_axis to override the out-of-bounds indexing behavior.
Default to "clip", which for the forward computation is identical to the current behavior. In a future change, we will change this to "fill".
2022-04-19 16:07:00 -04:00
Peter Hawkins
e1b606934f Temporarily revert: Change default jnp.take_along_axis gather mode to "fill".
Some tests were broken by the change; reverting this PR for the moment while debugging the problem.

PiperOrigin-RevId: 442868210
2022-04-19 11:39:12 -07:00
Peter Hawkins
7c73bfbc46 Change default jnp.take_along_axis gather mode to "fill".
PiperOrigin-RevId: 442817397
2022-04-19 08:24:24 -07:00
Yash Katariya
b4d05dcae8
Update the changelog to say sharded_jit is deprecated. 2022-04-18 13:48:19 -07:00
Yash Katariya
8a61414a88 Delete the mesh context manager. The replacement for it is Mesh.
PiperOrigin-RevId: 442619711
2022-04-18 13:37:51 -07:00
Peter Hawkins
32abfa84ff Fix typo in changelog. 2022-04-18 08:17:14 -04:00
Peter Hawkins
38ea5a6bc0 Copybara import of the project:
--
391dea76bc8fe264cf26ec93d42147f87847894d by Peter Hawkins <phawkins@google.com>:

Update version numbers after jax/jaxlib 0.3.7 release.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/10324 from hawkinsp:jaxlib 391dea76bc8fe264cf26ec93d42147f87847894d
PiperOrigin-RevId: 442311051
2022-04-16 22:37:09 -07:00
Peter Hawkins
6b33c55450 2nd attempt at jax/jaxlib 0.3.7 release. 2022-04-15 15:20:02 -04:00
Jake VanderPlas
be5c84d409 Deprecate DeviceArray.tile method 2022-04-15 10:11:03 -07:00
Peter Hawkins
52a97f2e06 Jax 0.3.7 and jaxlib 0.3.7 release. 2022-04-15 12:02:05 -04:00
Peter Hawkins
3f1032cf33 Fix incorrect cross-reference breaking readthedocs build. 2022-04-14 16:12:48 -04:00
George Necula
d050327592 Deprecate jax.experimental.loops, step 2.
Add deprecation warning and remove the tests.

PiperOrigin-RevId: 441828243
2022-04-14 12:38:55 -07:00
Yash Katariya
08f28a6119 Finish jax release
PiperOrigin-RevId: 441342919
2022-04-12 18:13:16 -07:00