188 Commits

Author SHA1 Message Date
Jake VanderPlas
822b6aad3b jax.scipy.qr: fix return type for mode='r' 2022-04-26 11:26:56 -07:00
jax authors
cbda72c988 Merge pull request #10397 from YouJiacheng:replace-int-with-operator.index-part2
PiperOrigin-RevId: 444341705
2022-04-25 12:33:33 -07:00
jax authors
6e27156199 Merge pull request #10383 from YouJiacheng:implement-scipy.cluster.vq.vq
PiperOrigin-RevId: 443733421
2022-04-22 12:40:01 -07:00
YouJiacheng
b485b8e5ce implement scipy.cluster.vq.vq
also add no check_finite and overwrite_* docstring for some scipy.linalg functions
2022-04-23 03:14:32 +08:00
YouJiacheng
667d63aa2d replace int with operator.index part2
This change align the behavior of `ravel_multi_index`, `split` and `indices` to their `numpy` counterparts.
Also ensure size argument of `nonzero` should be integer.
The changes with `*space` are only simplification
2022-04-23 01:45:28 +08:00
Jake VanderPlas
a05c97be3f CHANGELOG: update test_util deprecation discussion 2022-04-21 13:37:56 -07:00
Jake VanderPlas
d9508304e4 Deprecate remaining functionality in jax.test_util 2022-04-21 12:12:40 -07:00
Peter Hawkins
74346f464b [JAX] Change jnp.take_along_axis to return invalid (e.g. NaN) values for out-of-bounds indices.
Previously, out-of-bounds indices were clipped into range, but that behavior is error prone. We would rather fail in a more visible way when out-of-bounds indices are used. Future changes will migrate other JAX indexing operations to have the same semantics.

PiperOrigin-RevId: 443390170
2022-04-21 08:52:14 -07:00
Peter Hawkins
4fd824c36f Change jnp.take_along_axis to require that its indices are of integer type.
Previously jnp.take_along_axis silently casted its indices to integers if they were not already integers.

PiperOrigin-RevId: 443124521
2022-04-20 10:05:16 -07:00
Peter Hawkins
a52f07a21b Add an optional mode= argument to jnp.take_along_axis.
This allows users of jnp.take_along_axis to override the out-of-bounds indexing behavior.
Default to "clip", which for the forward computation is identical to the current behavior. In a future change, we will change this to "fill".
2022-04-19 16:07:00 -04:00
Peter Hawkins
e1b606934f Temporarily revert: Change default jnp.take_along_axis gather mode to "fill".
Some tests were broken by the change; reverting this PR for the moment while debugging the problem.

PiperOrigin-RevId: 442868210
2022-04-19 11:39:12 -07:00
Peter Hawkins
7c73bfbc46 Change default jnp.take_along_axis gather mode to "fill".
PiperOrigin-RevId: 442817397
2022-04-19 08:24:24 -07:00
Yash Katariya
b4d05dcae8
Update the changelog to say sharded_jit is deprecated. 2022-04-18 13:48:19 -07:00
Yash Katariya
8a61414a88 Delete the mesh context manager. The replacement for it is Mesh.
PiperOrigin-RevId: 442619711
2022-04-18 13:37:51 -07:00
Peter Hawkins
32abfa84ff Fix typo in changelog. 2022-04-18 08:17:14 -04:00
Peter Hawkins
38ea5a6bc0 Copybara import of the project:
--
391dea76bc8fe264cf26ec93d42147f87847894d by Peter Hawkins <phawkins@google.com>:

Update version numbers after jax/jaxlib 0.3.7 release.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/10324 from hawkinsp:jaxlib 391dea76bc8fe264cf26ec93d42147f87847894d
PiperOrigin-RevId: 442311051
2022-04-16 22:37:09 -07:00
Peter Hawkins
6b33c55450 2nd attempt at jax/jaxlib 0.3.7 release. 2022-04-15 15:20:02 -04:00
Jake VanderPlas
be5c84d409 Deprecate DeviceArray.tile method 2022-04-15 10:11:03 -07:00
Peter Hawkins
52a97f2e06 Jax 0.3.7 and jaxlib 0.3.7 release. 2022-04-15 12:02:05 -04:00
Peter Hawkins
3f1032cf33 Fix incorrect cross-reference breaking readthedocs build. 2022-04-14 16:12:48 -04:00
George Necula
d050327592 Deprecate jax.experimental.loops, step 2.
Add deprecation warning and remove the tests.

PiperOrigin-RevId: 441828243
2022-04-14 12:38:55 -07:00
Yash Katariya
08f28a6119 Finish jax release
PiperOrigin-RevId: 441342919
2022-04-12 18:13:16 -07:00
Peter Hawkins
4dc69034b0 Update version numbers after jax/jaxlib release. 2022-04-07 16:40:19 -04:00
jax authors
8b3f039252 Merge pull request #10039 from ajcr:add_scipy_linalg_rsf2csf
PiperOrigin-RevId: 439997145
2022-04-06 19:55:29 -07:00
Peter Hawkins
96ba290faf Jax 0.3.5 and jaxlib 0.3.5 release. 2022-04-06 23:56:41 +00:00
Alex Riley
869596fc2c Add jax.scipy.linalg.rsf2csf 2022-04-06 21:06:23 +01:00
Peter Hawkins
71a5eb263b [GPU] Force an input buffer copy for double precision complex-to-real IRFFTs.
Fixes https://github.com/google/jax/issues/9946

PiperOrigin-RevId: 439414091
2022-04-04 14:38:52 -07:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
Jake VanderPlas
b359b8ad96 Add CHANGELOG entry for #10069 2022-03-30 08:05:34 -07:00
Jake VanderPlas
093b7032a8 Implement jnp.from* array creation functions 2022-03-29 10:52:47 -07:00
Jake VanderPlas
f4d240c036 Remove lax_numpy from jax.numpy namespace
This is a private module that was inadvertently exported in the past.
2022-03-25 15:02:45 -07:00
dogeplusplus
7915c6ce27 Rename jax.nn.normalize to standardize. Add normalize alias with DeprecationWarning. 2022-03-23 20:55:22 +00:00
Jake VanderPlas
69969ef803 add random.loggamma and improve dirichlet & beta implementation 2022-03-21 08:33:11 -07:00
Matthew Johnson
4c5d8e969f update version and changelog for pypi 2022-03-18 14:16:00 -07:00
Matthew Johnson
d2b393bbf1 update version and changelog for pypi 2022-03-17 15:35:26 -07:00
Skye Wanderman-Milne
d7087abce6 Bump jax and jaxlib versions for 0.3.2 release
Also add CPU pjit to changelog
2022-03-16 14:31:00 -07:00
Skye Wanderman-Milne
f9775a2ced Update CHANGELOG and setup.py for jax + jaxlib 0.3.2 releases 2022-03-16 10:17:42 -07:00
jax authors
4d14899940 Add boolean flag to as_hlo_text to enable writing large constants.
PiperOrigin-RevId: 434556535
2022-03-14 13:46:22 -07:00
Peter Hawkins
08fbd77d90 [JAX] Deprecate .block_host_until_ready() in favor of .block_until_ready().
JAX kept an older name around (.block_host_until_ready()) in parallel with the new name (.block_until_ready()) to avoid breaking users. Deprecate it so we only have one name.

PiperOrigin-RevId: 433228545
2022-03-08 09:14:40 -08:00
Jake VanderPlas
8c57ae2a19 Call _check_arraylike on inputs to broadcast_to and broadcast_arrays 2022-03-04 11:22:27 -08:00
jax authors
fb44d7c12f [JAX] Add release note for the graduration of the experimental.ann module.
PiperOrigin-RevId: 431951602
2022-03-02 08:58:53 -08:00
Jake VanderPlas
51727033b8 Remove duplicate changelog entry 2022-02-24 08:18:30 -08:00
Peter Hawkins
f51a05a889 Remove jax.ops.index... functions.
These functions have been deprecated and have issued a DeprecationWarning since jax 0.2.22 in October 2021.
2022-02-24 09:36:28 -05:00
Yash Katariya
c161c62878 Finish jax release
PiperOrigin-RevId: 429670894
2022-02-18 16:23:39 -08:00
Jake VanderPlas
da3aaa1960 Add deprecation warning to JaxTestCase and JaxTestLoader 2022-02-17 14:58:58 -08:00
Peter Hawkins
3e5ecfe363 Add jax.distributed and jax.dlpack to the docs.
Reorder the doc modules into something closer to alphabetical order.

Add missing functions from jax.scipy.linalg and jax.scipy.signal to the docs.
2022-02-17 16:10:07 -05:00
Jake VanderPlas
97512e9e44 JaxTestCase: set jax_numpy_rank_promotion='raise' by default 2022-02-14 09:22:05 -08:00
Yash Katariya
2162868ed9 Update values after release
PiperOrigin-RevId: 427910510
2022-02-10 20:32:53 -08:00
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00