165 Commits

Author SHA1 Message Date
jax authors
8b3f039252 Merge pull request #10039 from ajcr:add_scipy_linalg_rsf2csf
PiperOrigin-RevId: 439997145
2022-04-06 19:55:29 -07:00
Peter Hawkins
96ba290faf Jax 0.3.5 and jaxlib 0.3.5 release. 2022-04-06 23:56:41 +00:00
Alex Riley
869596fc2c Add jax.scipy.linalg.rsf2csf 2022-04-06 21:06:23 +01:00
Peter Hawkins
71a5eb263b [GPU] Force an input buffer copy for double precision complex-to-real IRFFTs.
Fixes https://github.com/google/jax/issues/9946

PiperOrigin-RevId: 439414091
2022-04-04 14:38:52 -07:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
Jake VanderPlas
b359b8ad96 Add CHANGELOG entry for #10069 2022-03-30 08:05:34 -07:00
Jake VanderPlas
093b7032a8 Implement jnp.from* array creation functions 2022-03-29 10:52:47 -07:00
Jake VanderPlas
f4d240c036 Remove lax_numpy from jax.numpy namespace
This is a private module that was inadvertently exported in the past.
2022-03-25 15:02:45 -07:00
dogeplusplus
7915c6ce27 Rename jax.nn.normalize to standardize. Add normalize alias with DeprecationWarning. 2022-03-23 20:55:22 +00:00
Jake VanderPlas
69969ef803 add random.loggamma and improve dirichlet & beta implementation 2022-03-21 08:33:11 -07:00
Matthew Johnson
4c5d8e969f update version and changelog for pypi 2022-03-18 14:16:00 -07:00
Matthew Johnson
d2b393bbf1 update version and changelog for pypi 2022-03-17 15:35:26 -07:00
Skye Wanderman-Milne
d7087abce6 Bump jax and jaxlib versions for 0.3.2 release
Also add CPU pjit to changelog
2022-03-16 14:31:00 -07:00
Skye Wanderman-Milne
f9775a2ced Update CHANGELOG and setup.py for jax + jaxlib 0.3.2 releases 2022-03-16 10:17:42 -07:00
jax authors
4d14899940 Add boolean flag to as_hlo_text to enable writing large constants.
PiperOrigin-RevId: 434556535
2022-03-14 13:46:22 -07:00
Peter Hawkins
08fbd77d90 [JAX] Deprecate .block_host_until_ready() in favor of .block_until_ready().
JAX kept an older name around (.block_host_until_ready()) in parallel with the new name (.block_until_ready()) to avoid breaking users. Deprecate it so we only have one name.

PiperOrigin-RevId: 433228545
2022-03-08 09:14:40 -08:00
Jake VanderPlas
8c57ae2a19 Call _check_arraylike on inputs to broadcast_to and broadcast_arrays 2022-03-04 11:22:27 -08:00
jax authors
fb44d7c12f [JAX] Add release note for the graduration of the experimental.ann module.
PiperOrigin-RevId: 431951602
2022-03-02 08:58:53 -08:00
Jake VanderPlas
51727033b8 Remove duplicate changelog entry 2022-02-24 08:18:30 -08:00
Peter Hawkins
f51a05a889 Remove jax.ops.index... functions.
These functions have been deprecated and have issued a DeprecationWarning since jax 0.2.22 in October 2021.
2022-02-24 09:36:28 -05:00
Yash Katariya
c161c62878 Finish jax release
PiperOrigin-RevId: 429670894
2022-02-18 16:23:39 -08:00
Jake VanderPlas
da3aaa1960 Add deprecation warning to JaxTestCase and JaxTestLoader 2022-02-17 14:58:58 -08:00
Peter Hawkins
3e5ecfe363 Add jax.distributed and jax.dlpack to the docs.
Reorder the doc modules into something closer to alphabetical order.

Add missing functions from jax.scipy.linalg and jax.scipy.signal to the docs.
2022-02-17 16:10:07 -05:00
Jake VanderPlas
97512e9e44 JaxTestCase: set jax_numpy_rank_promotion='raise' by default 2022-02-14 09:22:05 -08:00
Yash Katariya
2162868ed9 Update values after release
PiperOrigin-RevId: 427910510
2022-02-10 20:32:53 -08:00
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00
Yash Katariya
1ad3551ec9 Release jax and jaxlib 0.3.0 as per the new release process.
PiperOrigin-RevId: 427809845
2022-02-10 11:59:13 -08:00
Skye Wanderman-Milne
715066c624 Bump jax version to 0.2.29 and update CHANGELOG 2022-02-01 17:59:57 -08:00
Peter Hawkins
2388e353da Increase bazel version to 5.0.0 to match TensorFlow
(8871926b0a).
2022-01-28 21:11:02 +00:00
Peter Hawkins
be2f6a91ec Update XLA for jaxlib 0.1.76 release. 2022-01-27 14:26:16 +00:00
Peter Hawkins
74e4db47da Change the default IR dialect returned by .compiler_ir() to MHLO.
PiperOrigin-RevId: 423091674
2022-01-20 09:50:17 -08:00
Peter Hawkins
3fef74b2d0 [JAX] Change signature of .mhlo() method on compiler IR objects to return an ir.Module object instead of its string representation.
It isn't free to pretty-print IR, so it's best to avoid it unless necessary. In addition, by returning an IR object, the user is now free to, say, print it with different options.

For example, one can now write things like:

```
In [1]: import numpy as np, jax, jax.numpy as jnp
In [2]: m = jax.jit(lambda x: x + jnp.array(np.arange(1000))).lower(7.).compiler_ir(dialect='mhlo')
In [3]: m.operation.print(large_elements_limit=10)
module @jit__lambda_.4 {
  func public @main(%arg0: tensor<f32>) -> tensor<1000xf32> {
    %0 = mhlo.constant opaque<"_", "0xDEADBEEF"> : tensor<1000xi32>
    %1 = "mhlo.convert"(%0) : (tensor<1000xi32>) -> tensor<1000xf32>
    %2 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1000xf32>
    %3 = mhlo.add %2, %1 : tensor<1000xf32>
    return %3 : tensor<1000xf32>
  }
}
```

Fixes https://github.com/google/jax/issues/9226

PiperOrigin-RevId: 422855649
2022-01-19 11:04:48 -08:00
Matthew Johnson
0066533dae update version and changelog for pypi 2022-01-18 11:38:32 -08:00
jax authors
6411f8a033 Merge pull request #9184 from jakevdp:unique-nan
PiperOrigin-RevId: 422287302
2022-01-16 23:57:40 -08:00
jax authors
c9169fa0d5 Merge pull request #9189 from gnecula:tf_reduce_window
PiperOrigin-RevId: 421875035
2022-01-14 11:35:16 -08:00
Jake VanderPlas
bd157cf056 jnp.unique: properly handle NaN values 2022-01-13 15:54:07 -08:00
Jake VanderPlas
d8bdd9a19d lax.sort: regularize handling of -0.0 and -NaN 2022-01-13 13:03:41 -08:00
George Necula
5bfe1852a4 [jax2tf] Add jax2tf_associative_scan_reductions flag
This flag allows users to match the JAX performance for
associative reductions in CPU.
See README.md for details.
2022-01-13 15:52:18 +02:00
Matthew Johnson
1cf7d4ab5d Copybara import of the project:
--
4fcdadbfb3f4c484fd4432203cf13b88782b9311 by Matthew Johnson <mattjj@google.com>:

add jax.ensure_compile_time_eval to public api

aka jax.core.eval_context

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7987 from google:issue7535 4fcdadbfb3f4c484fd4432203cf13b88782b9311
PiperOrigin-RevId: 420928687
2022-01-10 20:58:26 -08:00
Peter Hawkins
04369a3588 Drop support for NumPy 1.18.
Per NEP-29, we can drop NumPy 1.18 support on Dec 22, 2021.

The next NumPy deprecation will be 1.19 on Jun 21, 2022.

PiperOrigin-RevId: 419651428
2022-01-04 12:11:38 -08:00
George Necula
3021d3e2e2 [hcb] Add support for remat2 to host_callback
A callback under ad_checkpoint.checkpoint will be invoked
twice when taking the gradient: once during the forward pass
and once again during the backward pass when the residuals
for the forward pass are rematerialized.
2021-12-15 10:32:15 +02:00
Matthew Johnson
9d2a3ed7a8
Merge branch 'main' into add-block-until-ready-to-docs 2021-12-14 20:57:14 -08:00
Peter Hawkins
bca17ad59c Add debugging flag for dumping the JAX-generated MHLO/HLO IR to a file.
While HLO dumping is redundant with XLA's XLA_FLAGS=--xla_dump_to=... feature, MHLO dumping is useful since XLA only ever sees and dumps the IR after it has been canonicalized and converted to HLO. Some debugging tasks require easy access to the MHLO as well.

PiperOrigin-RevId: 416435598
2021-12-14 17:44:16 -08:00
Matthew Johnson
0c68605bf1 add jax.block_until_ready to docs and changelog
also unrelatedly fix a couple of the uses of rst in changelog.md (though
many others remain)
2021-12-14 13:39:47 -08:00
Peter Hawkins
66823d1392 Include compute capability 8.0 SASS in jaxlib wheels.
Drop compute capability 6.1 to avoid growing the wheel size.

Also fix an unrelated build error due to a gcc warning in boringssl.
2021-12-14 14:27:19 -05:00
George Necula
f08156ab7c [hcb] Simplifications to the host_calback API
* dropping support for special AD handling for hcb.id_tap and id_print.
  From now on, only the primals are tapped. The old behavior can be
  obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS
  environment variale, or the --flax_host_callback_ad_transforms flag.
  Additionally, added documentation for how to implement the old behavior
  using JAX custom AD APIs.

This allows us to make some significant cleanup in the internals.
2021-12-11 08:24:56 +01:00
Yash Katariya
8be304c936 Bump jax version after jax release
PiperOrigin-RevId: 415064518
2021-12-08 12:08:14 -08:00
Yash Katariya
1b5630eed6 Update jaxlib version number to 0.1.76
PiperOrigin-RevId: 415050863
2021-12-08 11:14:12 -08:00
George Necula
43433078bc [jax2tf] Force TF compilation for code under jax.jit.
Previously, jax.jit was ignored by jax2tf. This can result in the
converted code being much slower than the JAX core, unless the
user adds an explicit `tf.function(jit_compile=True)`. With this
change that wrapper is added automatically for all code fragments
under jax.jit. Note that most jax.numpy functions are annotated
with jax.jit, so with this change they will all be compiled.

When doing this I ran into problems with tf.custom_gradient and
tf.function. As documented in the
[tf.custom_gradient](https://www.tensorflow.org/api_docs/python/tf/custom_gradient)
documentation, you get a LookupError when trying to build the gradient
of a tf.function, even if it has a tf.custom_gradient defined. The
recommended solution is to add a tf.stop_gradient. This is safe, since
jax2tf will always wrap the converted functions with a tf.custom_gradient.
2021-11-23 10:24:46 +02:00