133 Commits

Author SHA1 Message Date
Peter Hawkins
3fef74b2d0 [JAX] Change signature of .mhlo() method on compiler IR objects to return an ir.Module object instead of its string representation.
It isn't free to pretty-print IR, so it's best to avoid it unless necessary. In addition, by returning an IR object, the user is now free to, say, print it with different options.

For example, one can now write things like:

```
In [1]: import numpy as np, jax, jax.numpy as jnp
In [2]: m = jax.jit(lambda x: x + jnp.array(np.arange(1000))).lower(7.).compiler_ir(dialect='mhlo')
In [3]: m.operation.print(large_elements_limit=10)
module @jit__lambda_.4 {
  func public @main(%arg0: tensor<f32>) -> tensor<1000xf32> {
    %0 = mhlo.constant opaque<"_", "0xDEADBEEF"> : tensor<1000xi32>
    %1 = "mhlo.convert"(%0) : (tensor<1000xi32>) -> tensor<1000xf32>
    %2 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1000xf32>
    %3 = mhlo.add %2, %1 : tensor<1000xf32>
    return %3 : tensor<1000xf32>
  }
}
```

Fixes https://github.com/google/jax/issues/9226

PiperOrigin-RevId: 422855649
2022-01-19 11:04:48 -08:00
Matthew Johnson
0066533dae update version and changelog for pypi 2022-01-18 11:38:32 -08:00
jax authors
6411f8a033 Merge pull request #9184 from jakevdp:unique-nan
PiperOrigin-RevId: 422287302
2022-01-16 23:57:40 -08:00
jax authors
c9169fa0d5 Merge pull request #9189 from gnecula:tf_reduce_window
PiperOrigin-RevId: 421875035
2022-01-14 11:35:16 -08:00
Jake VanderPlas
bd157cf056 jnp.unique: properly handle NaN values 2022-01-13 15:54:07 -08:00
Jake VanderPlas
d8bdd9a19d lax.sort: regularize handling of -0.0 and -NaN 2022-01-13 13:03:41 -08:00
George Necula
5bfe1852a4 [jax2tf] Add jax2tf_associative_scan_reductions flag
This flag allows users to match the JAX performance for
associative reductions in CPU.
See README.md for details.
2022-01-13 15:52:18 +02:00
Matthew Johnson
1cf7d4ab5d Copybara import of the project:
--
4fcdadbfb3f4c484fd4432203cf13b88782b9311 by Matthew Johnson <mattjj@google.com>:

add jax.ensure_compile_time_eval to public api

aka jax.core.eval_context

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/7987 from google:issue7535 4fcdadbfb3f4c484fd4432203cf13b88782b9311
PiperOrigin-RevId: 420928687
2022-01-10 20:58:26 -08:00
Peter Hawkins
04369a3588 Drop support for NumPy 1.18.
Per NEP-29, we can drop NumPy 1.18 support on Dec 22, 2021.

The next NumPy deprecation will be 1.19 on Jun 21, 2022.

PiperOrigin-RevId: 419651428
2022-01-04 12:11:38 -08:00
George Necula
3021d3e2e2 [hcb] Add support for remat2 to host_callback
A callback under ad_checkpoint.checkpoint will be invoked
twice when taking the gradient: once during the forward pass
and once again during the backward pass when the residuals
for the forward pass are rematerialized.
2021-12-15 10:32:15 +02:00
Matthew Johnson
9d2a3ed7a8
Merge branch 'main' into add-block-until-ready-to-docs 2021-12-14 20:57:14 -08:00
Peter Hawkins
bca17ad59c Add debugging flag for dumping the JAX-generated MHLO/HLO IR to a file.
While HLO dumping is redundant with XLA's XLA_FLAGS=--xla_dump_to=... feature, MHLO dumping is useful since XLA only ever sees and dumps the IR after it has been canonicalized and converted to HLO. Some debugging tasks require easy access to the MHLO as well.

PiperOrigin-RevId: 416435598
2021-12-14 17:44:16 -08:00
Matthew Johnson
0c68605bf1 add jax.block_until_ready to docs and changelog
also unrelatedly fix a couple of the uses of rst in changelog.md (though
many others remain)
2021-12-14 13:39:47 -08:00
Peter Hawkins
66823d1392 Include compute capability 8.0 SASS in jaxlib wheels.
Drop compute capability 6.1 to avoid growing the wheel size.

Also fix an unrelated build error due to a gcc warning in boringssl.
2021-12-14 14:27:19 -05:00
George Necula
f08156ab7c [hcb] Simplifications to the host_calback API
* dropping support for special AD handling for hcb.id_tap and id_print.
  From now on, only the primals are tapped. The old behavior can be
  obtained (for a limited time) by setting the JAX_HOST_CALLBACK_AD_TRANSFORMS
  environment variale, or the --flax_host_callback_ad_transforms flag.
  Additionally, added documentation for how to implement the old behavior
  using JAX custom AD APIs.

This allows us to make some significant cleanup in the internals.
2021-12-11 08:24:56 +01:00
Yash Katariya
8be304c936 Bump jax version after jax release
PiperOrigin-RevId: 415064518
2021-12-08 12:08:14 -08:00
Yash Katariya
1b5630eed6 Update jaxlib version number to 0.1.76
PiperOrigin-RevId: 415050863
2021-12-08 11:14:12 -08:00
George Necula
43433078bc [jax2tf] Force TF compilation for code under jax.jit.
Previously, jax.jit was ignored by jax2tf. This can result in the
converted code being much slower than the JAX core, unless the
user adds an explicit `tf.function(jit_compile=True)`. With this
change that wrapper is added automatically for all code fragments
under jax.jit. Note that most jax.numpy functions are annotated
with jax.jit, so with this change they will all be compiled.

When doing this I ran into problems with tf.custom_gradient and
tf.function. As documented in the
[tf.custom_gradient](https://www.tensorflow.org/api_docs/python/tf/custom_gradient)
documentation, you get a LookupError when trying to build the gradient
of a tf.function, even if it has a tf.custom_gradient defined. The
recommended solution is to add a tf.stop_gradient. This is safe, since
jax2tf will always wrap the converted functions with a tf.custom_gradient.
2021-11-23 10:24:46 +02:00
Peter Hawkins
4679f455f9 Change the default out-of-bounds behavior for jax.ops.segment_... to FILL_OR_DROP.
This matches the documented behavior.

Fixes https://github.com/google/jax/issues/8634

PiperOrigin-RevId: 411635687
2021-11-22 13:32:58 -08:00
Peter Hawkins
7902ddaca2 Update jaxlib versions. 2021-11-17 11:46:41 -05:00
Peter Hawkins
1bcedd58cb Fix test failures and update changelog.
Use dtypes.issubdtype to test for subtyping otherwise we mishandle bfloat16 dtypes.
Don't pass an empty list to concatenate() when converting a shape to a value.
Forbid empty lists as arguments to lax.concatenate().
2021-11-16 17:36:55 -05:00
Qiao Zhang
ad4cb94734 update version and changelog for pypi 2021-11-10 14:21:26 -08:00
Jake VanderPlas
734a91350b jax.random.permutation: add independent keyword 2021-11-02 11:39:41 -07:00
Tianjian Lu
c5f73b3d8e [JAX] Added jax.lax.linalg.qdwh.
PiperOrigin-RevId: 406453671
2021-10-29 14:45:06 -07:00
Qiao Zhang
0be30fbf96 Add jax.distributed.initialize for multi-host GPU. 2021-10-26 14:37:54 -07:00
Roy Frostig
623c201054 [JAX] move example libraries from jax.experimental into jax.example_libraries
The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.

PiperOrigin-RevId: 404405186
2021-10-19 17:30:45 -07:00
Yash Katariya
a7c9b6d11f Update jax version number for jax release.
PiperOrigin-RevId: 404262742
2021-10-19 08:05:31 -07:00
Yash Katariya
ee752b32f7 Use cuda11_cudnn82 instead of cuda=11,cudnn=82 because the latter one is a syntax error
PiperOrigin-RevId: 404240654
2021-10-19 06:24:53 -07:00
Yash Katariya
4d8bce1b85 Add a default cuda installation path and more explicit installation paths for CUDA jaxlib.
```
# Installs Cuda 11 with Cudnn 8.2
$ pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html

$ pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

$ pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
```

PiperOrigin-RevId: 404134291
2021-10-18 19:56:22 -07:00
Jake VanderPlas
a353e3eafa jnp.take/jnp.take_along_axis: require array inputs 2021-10-15 09:37:05 -07:00
Julius Kunze
f66cbb9b3d Fix CHANGELOG.md 2021-10-13 17:11:50 -06:00
jax authors
10af170a85 Merge pull request #8161 from juliuskunze:multidim-permutation
PiperOrigin-RevId: 402852030
2021-10-13 09:31:19 -07:00
Julius Kunze
63898b6ca6 Allow random.choice and random.permutation on multidimensional arrays 2021-10-13 09:39:25 -06:00
Peter Hawkins
2388804abc Add a regression test for #7461.
Fixes #7461
2021-10-13 11:11:24 -04:00
Skye Wanderman-Milne
962c496b25 Update jax version and CHANGELOG for 0.2.22 release 2021-10-12 18:46:37 -07:00
Yash Katariya
66a4a9ff3f Remove 10.2 cuda support
PiperOrigin-RevId: 402707900
2021-10-12 18:44:07 -07:00
Skye Wanderman-Milne
0072c32546 Update CHANGELOG and verson numbers for jaxlib 0.1.72 release 2021-10-12 17:37:29 -07:00
Jake VanderPlas
0b93c46c71 jnp.unique: add fill_value for when size is not None 2021-10-06 16:28:36 -07:00
Peter Hawkins
b466187bbe Add note to changelog about deprecation of jax.ops.index_... 2021-10-06 17:11:35 -04:00
Jean-Baptiste Lespiau
803b83ee15 Enable C++ pmap.
On CPU:
```
name                                     old cpu/op  new cpu/op  delta
pmap_trivial_2_devices                    128µs ± 6%    14µs ± 3%  -89.06%  (p=0.008 n=5+5)
pmap_trivial_dispatch_8_devices           212µs ± 2%    35µs ± 1%  -83.54%  (p=0.008 n=5+5)
pmap_trivial_8_devices                    215µs ± 1%    40µs ± 4%  -81.31%  (p=0.008 n=5+5)
pmap_simple_2_devices                     123µs ± 5%    15µs ± 6%  -87.70%  (p=0.008 n=5+5)
pmap_simple_dispatch_8_devices            211µs ± 3%    35µs ± 2%  -83.24%  (p=0.008 n=5+5)
pmap_simple_8_devices                     217µs ± 5%    40µs ± 2%  -81.68%  (p=0.008 n=5+5)
pmap_simple_dispatch_8_devices_100_args  5.42ms ± 7%  0.52ms ± 2%  -90.44%  (p=0.008 n=5+5)
pmap_simple_8_devices_100_args           26.5ms ±21%  17.5ms ±37%  -34.18%  (p=0.008 n=5+5)
sda_index_1                              7.45µs ± 6%  7.53µs ± 6%     ~     (p=0.222 n=5+5)
sda_index_2                              14.1µs ± 1%  14.3µs ± 4%     ~     (p=0.690 n=5+5)
sda_index_8                              56.0µs ± 3%  56.9µs ± 4%     ~     (p=0.310 n=5+5)

name                                     old time/op             new time/op             delta
pmap_trivial_2_devices                    136µs ± 8%               19µs ± 3%  -86.08%          (p=0.008 n=5+5)
pmap_trivial_dispatch_8_devices           216µs ± 3%               39µs ± 2%  -81.94%          (p=0.008 n=5+5)
pmap_trivial_8_devices                    219µs ± 2%               49µs ±38%  -77.67%          (p=0.008 n=5+5)
pmap_simple_2_devices                     130µs ± 5%               20µs ± 5%  -84.38%          (p=0.008 n=5+5)
pmap_simple_dispatch_8_devices            216µs ± 3%               39µs ± 5%  -81.71%          (p=0.008 n=5+5)
pmap_simple_8_devices                     221µs ± 6%               43µs ± 1%  -80.41%          (p=0.016 n=5+4)
pmap_simple_dispatch_8_devices_100_args  5.52ms ± 7%             0.59ms ± 2%  -89.28%          (p=0.008 n=5+5)
pmap_simple_8_devices_100_args           26.6ms ±21%             17.6ms ±37%  -34.04%          (p=0.008 n=5+5)
sda_index_1                              7.48µs ± 8%             7.53µs ± 6%     ~             (p=0.310 n=5+5)
sda_index_2                              14.1µs ± 1%             14.3µs ± 4%     ~             (p=0.690 n=5+5)
sda_index_8                              56.0µs ± 3%             56.9µs ± 4%     ~             (p=0.310 n=5+5)
```

PiperOrigin-RevId: 401274089
2021-10-06 10:08:28 -07:00
Peter Hawkins
a11d957e61 Disallow non-hashable static arguments in pmap().
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
2021-09-30 15:50:07 -04:00
Jake VanderPlas
48157a7c1e Update v0.2.21 changelog for #7927 2021-09-27 11:38:36 -07:00
Yash Katariya
dbeb97d394 Create 0.2.21 jax release
PiperOrigin-RevId: 398528427
2021-09-23 11:00:31 -07:00
jax authors
fc7775e1d1 Merge pull request #7968 from hawkinsp:partial
PiperOrigin-RevId: 398025545
2021-09-21 10:21:13 -07:00
Peter Hawkins
1163e218e8 Attempt to land https://github.com/google/jax/pull/6400 again.
This PR changes `jax.numpy.array()` to avoid creating any on-device arrays during tracing. As a consequence, calls to `jnp.array()` in a traced context, such as `jax.jit` will always be staged into the trace.

This change may break code that depends on the current (undocumented and unintentional) behavior of `jnp.array()` to perform shape or index calculations that must be known statically (at trace time). The workaround for such cases is to use classic NumPy to perform shape/index calculations.

PiperOrigin-RevId: 398008511
2021-09-21 09:06:40 -07:00
Peter Hawkins
58c7ee46bc Remove jax.util.partial. 2021-09-20 20:32:49 -04:00
Peter Hawkins
f35ab3693d Remove jax.partial from the JAX API.
Use functools.partial instead.
2021-09-20 09:19:53 -04:00
jax authors
f47926a23d Merge pull request #7940 from hawkinsp:api
PiperOrigin-RevId: 397319298
2021-09-17 07:58:17 -07:00
Jake VanderPlas
9a2697437e Update changelog for several recent PRs 2021-09-16 14:10:08 -07:00
Peter Hawkins
6a1b626564 Remove jax.api.
Functions exported as jax.api were aliases for names in jax.*. Use the jax.* names instead.
2021-09-16 16:29:06 -04:00