112 Commits

Author SHA1 Message Date
Qiao Zhang
ad4cb94734 update version and changelog for pypi 2021-11-10 14:21:26 -08:00
Jake VanderPlas
734a91350b jax.random.permutation: add independent keyword 2021-11-02 11:39:41 -07:00
Tianjian Lu
c5f73b3d8e [JAX] Added jax.lax.linalg.qdwh.
PiperOrigin-RevId: 406453671
2021-10-29 14:45:06 -07:00
Qiao Zhang
0be30fbf96 Add jax.distributed.initialize for multi-host GPU. 2021-10-26 14:37:54 -07:00
Roy Frostig
623c201054 [JAX] move example libraries from jax.experimental into jax.example_libraries
The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.

PiperOrigin-RevId: 404405186
2021-10-19 17:30:45 -07:00
Yash Katariya
a7c9b6d11f Update jax version number for jax release.
PiperOrigin-RevId: 404262742
2021-10-19 08:05:31 -07:00
Yash Katariya
ee752b32f7 Use cuda11_cudnn82 instead of cuda=11,cudnn=82 because the latter one is a syntax error
PiperOrigin-RevId: 404240654
2021-10-19 06:24:53 -07:00
Yash Katariya
4d8bce1b85 Add a default cuda installation path and more explicit installation paths for CUDA jaxlib.
```
# Installs Cuda 11 with Cudnn 8.2
$ pip install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html

$ pip install jax[cuda=11,cudnn=82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

$ pip install jax[cuda=11,cudnn=805] -f https://storage.googleapis.com/jax-releases/jax_releases.html
```

PiperOrigin-RevId: 404134291
2021-10-18 19:56:22 -07:00
Jake VanderPlas
a353e3eafa jnp.take/jnp.take_along_axis: require array inputs 2021-10-15 09:37:05 -07:00
Julius Kunze
f66cbb9b3d Fix CHANGELOG.md 2021-10-13 17:11:50 -06:00
jax authors
10af170a85 Merge pull request #8161 from juliuskunze:multidim-permutation
PiperOrigin-RevId: 402852030
2021-10-13 09:31:19 -07:00
Julius Kunze
63898b6ca6 Allow random.choice and random.permutation on multidimensional arrays 2021-10-13 09:39:25 -06:00
Peter Hawkins
2388804abc Add a regression test for #7461.
Fixes #7461
2021-10-13 11:11:24 -04:00
Skye Wanderman-Milne
962c496b25 Update jax version and CHANGELOG for 0.2.22 release 2021-10-12 18:46:37 -07:00
Yash Katariya
66a4a9ff3f Remove 10.2 cuda support
PiperOrigin-RevId: 402707900
2021-10-12 18:44:07 -07:00
Skye Wanderman-Milne
0072c32546 Update CHANGELOG and verson numbers for jaxlib 0.1.72 release 2021-10-12 17:37:29 -07:00
Jake VanderPlas
0b93c46c71 jnp.unique: add fill_value for when size is not None 2021-10-06 16:28:36 -07:00
Peter Hawkins
b466187bbe Add note to changelog about deprecation of jax.ops.index_... 2021-10-06 17:11:35 -04:00
Jean-Baptiste Lespiau
803b83ee15 Enable C++ pmap.
On CPU:
```
name                                     old cpu/op  new cpu/op  delta
pmap_trivial_2_devices                    128µs ± 6%    14µs ± 3%  -89.06%  (p=0.008 n=5+5)
pmap_trivial_dispatch_8_devices           212µs ± 2%    35µs ± 1%  -83.54%  (p=0.008 n=5+5)
pmap_trivial_8_devices                    215µs ± 1%    40µs ± 4%  -81.31%  (p=0.008 n=5+5)
pmap_simple_2_devices                     123µs ± 5%    15µs ± 6%  -87.70%  (p=0.008 n=5+5)
pmap_simple_dispatch_8_devices            211µs ± 3%    35µs ± 2%  -83.24%  (p=0.008 n=5+5)
pmap_simple_8_devices                     217µs ± 5%    40µs ± 2%  -81.68%  (p=0.008 n=5+5)
pmap_simple_dispatch_8_devices_100_args  5.42ms ± 7%  0.52ms ± 2%  -90.44%  (p=0.008 n=5+5)
pmap_simple_8_devices_100_args           26.5ms ±21%  17.5ms ±37%  -34.18%  (p=0.008 n=5+5)
sda_index_1                              7.45µs ± 6%  7.53µs ± 6%     ~     (p=0.222 n=5+5)
sda_index_2                              14.1µs ± 1%  14.3µs ± 4%     ~     (p=0.690 n=5+5)
sda_index_8                              56.0µs ± 3%  56.9µs ± 4%     ~     (p=0.310 n=5+5)

name                                     old time/op             new time/op             delta
pmap_trivial_2_devices                    136µs ± 8%               19µs ± 3%  -86.08%          (p=0.008 n=5+5)
pmap_trivial_dispatch_8_devices           216µs ± 3%               39µs ± 2%  -81.94%          (p=0.008 n=5+5)
pmap_trivial_8_devices                    219µs ± 2%               49µs ±38%  -77.67%          (p=0.008 n=5+5)
pmap_simple_2_devices                     130µs ± 5%               20µs ± 5%  -84.38%          (p=0.008 n=5+5)
pmap_simple_dispatch_8_devices            216µs ± 3%               39µs ± 5%  -81.71%          (p=0.008 n=5+5)
pmap_simple_8_devices                     221µs ± 6%               43µs ± 1%  -80.41%          (p=0.016 n=5+4)
pmap_simple_dispatch_8_devices_100_args  5.52ms ± 7%             0.59ms ± 2%  -89.28%          (p=0.008 n=5+5)
pmap_simple_8_devices_100_args           26.6ms ±21%             17.6ms ±37%  -34.04%          (p=0.008 n=5+5)
sda_index_1                              7.48µs ± 8%             7.53µs ± 6%     ~             (p=0.310 n=5+5)
sda_index_2                              14.1µs ± 1%             14.3µs ± 4%     ~             (p=0.690 n=5+5)
sda_index_8                              56.0µs ± 3%             56.9µs ± 4%     ~             (p=0.310 n=5+5)
```

PiperOrigin-RevId: 401274089
2021-10-06 10:08:28 -07:00
Peter Hawkins
a11d957e61 Disallow non-hashable static arguments in pmap().
* Don't wrap static arguments in hashable wrappers in pmap.
* Delete wrap_hashably().
* In argnums_partial, either enforce hashability or wrap values with an explicitly unhashable wrapper. The intent here is that either we should check for hashability early or we should make sure it's clear that it's not something we intended..
* Delete argnames_partial, which appears unused.
2021-09-30 15:50:07 -04:00
Jake VanderPlas
48157a7c1e Update v0.2.21 changelog for #7927 2021-09-27 11:38:36 -07:00
Yash Katariya
dbeb97d394 Create 0.2.21 jax release
PiperOrigin-RevId: 398528427
2021-09-23 11:00:31 -07:00
jax authors
fc7775e1d1 Merge pull request #7968 from hawkinsp:partial
PiperOrigin-RevId: 398025545
2021-09-21 10:21:13 -07:00
Peter Hawkins
1163e218e8 Attempt to land https://github.com/google/jax/pull/6400 again.
This PR changes `jax.numpy.array()` to avoid creating any on-device arrays during tracing. As a consequence, calls to `jnp.array()` in a traced context, such as `jax.jit` will always be staged into the trace.

This change may break code that depends on the current (undocumented and unintentional) behavior of `jnp.array()` to perform shape or index calculations that must be known statically (at trace time). The workaround for such cases is to use classic NumPy to perform shape/index calculations.

PiperOrigin-RevId: 398008511
2021-09-21 09:06:40 -07:00
Peter Hawkins
58c7ee46bc Remove jax.util.partial. 2021-09-20 20:32:49 -04:00
Peter Hawkins
f35ab3693d Remove jax.partial from the JAX API.
Use functools.partial instead.
2021-09-20 09:19:53 -04:00
jax authors
f47926a23d Merge pull request #7940 from hawkinsp:api
PiperOrigin-RevId: 397319298
2021-09-17 07:58:17 -07:00
Jake VanderPlas
9a2697437e Update changelog for several recent PRs 2021-09-16 14:10:08 -07:00
Peter Hawkins
6a1b626564 Remove jax.api.
Functions exported as jax.api were aliases for names in jax.*. Use the jax.* names instead.
2021-09-16 16:29:06 -04:00
Jake VanderPlas
abeeb48ba1 jnp.array: raise TypeError on boolean scalar indices 2021-09-15 12:50:44 -07:00
Jake VanderPlas
404e22ec67 Add Changelog for jax v0.2.21 development 2021-09-15 12:10:30 -07:00
Peter Hawkins
b56c2ccadd Remove export of jax.lax.partial. 2021-09-14 16:17:50 -04:00
yashkatariya
765746b60e update version and changelog for pypi 2021-09-02 15:38:47 -07:00
yashkatariya
14a02c6880 Remove new features 2021-09-01 11:26:41 -07:00
yashkatariya
84edde2f9b Add new features section 2021-09-01 10:56:54 -07:00
yashkatariya
be824a792e Update files after new jaxlib release 0.1.71 2021-09-01 10:43:20 -07:00
Jake VanderPlas
fb30fa852d update CHANGELOG for #7662 & #7732 2021-08-27 16:43:58 -07:00
Matthew Johnson
6f7be1fad9 update version and changelog for pypi 2021-08-12 21:17:53 -07:00
Peter Hawkins
beddf598bd Add @jit decorators to jax.numpy operators.
By wrapping common operators in `jit`, we get a number of benefits:
* `jit` has a faster, more optimized dispatch path compared to the primitive dispatch path in JAX. It's faster to dispatch a `jit` computation than a single primitive.
* `jit` allows us to cache and reuse logic such as broadcasting and type promotion.

One downside is that we now report an error when large Python integer scalars (e.g. `2**32 - 1`) are passed as arguments to JAX array operators. The workaround to this is to use explicitly typed constants instead of Python scalars.

On my laptop, this benchmark improves from 95us to 4us:

```
In [1]: import jax.numpy as jnp, jax

In [2]: x = jax.device_put(7)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

In [3]: %timeit jnp.add(x, x).block_until_ready()
4.18 µs ± 159 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
```

PiperOrigin-RevId: 389871450
2021-08-10 06:49:28 -07:00
Yash Katariya
bf967d88d8 Upgrade versions after jaxlib release
PiperOrigin-RevId: 389753047
2021-08-09 16:37:44 -07:00
elliotwaite
7392a57b75 DOC: many small fixes 2021-08-04 16:55:13 -07:00
Peter Hawkins
6e9169d100 Drop support for NumPy 1.17. 2021-07-29 09:18:01 -04:00
George Necula
b62ceba91c [jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.

This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,

```
def average(x):
   return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```

This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.

Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:

```
def dim_as_value(d):
   jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```

We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-27 09:02:15 +03:00
Skye Wanderman-Milne
a7916f1428 Bump jax version and CHANGELOG to 0.2.18 2021-07-21 11:56:24 -07:00
Peter Hawkins
0dfd76af97 Remove additional info return value from jax.scipy.linalg.polar(). 2021-07-20 13:13:31 -04:00
George Necula
a21683605d [host_callback] Increase number of threads for callback processing.
Previously there was one thread per device for receiving the outfeed from
devices, but there was a single global thread that was calling into the Python
callbacks. This meant that if one of the callbacks was slow, it was blocking
processing of all other callbacks.

One situation when this created difficulties was if one wanted to break a host_callback into two operations: a quick one to enqueue work on a threadpool,
and a subsequent slow one to wait for and retreive the result. The first slow callback would block all other callbacks, including possibly some quick ones, thus missing the opportunity to start the slow work.

With this change there is a separate queue of outfeeds for each device and a
separate thread per device to call into Python. This allows for concurrency
between callbacks from different devices, although the callbacks from one
device are still sequential. If the programmer wants more concurrency, they can use a threadpool. Having more concurrency by default is tricky, because it may mean that the Python callbacks for one device may be seen out of order.

PiperOrigin-RevId: 385493070
2021-07-19 00:18:06 -07:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Peter Hawkins
94446ff757 Drop Python 3.6 support.
Per the deprecation policy (https://jax.readthedocs.io/en/latest/deprecation.html),
Python 3.6 support has been due for removal since June 23, 2020.
2021-07-15 14:20:29 -04:00
Qiao Zhang
82e74959fe Update changelog for jaxlib-0.1.69. 2021-07-12 12:06:41 -07:00
George Necula
0beef34d25 [jax2tf] Fix conversion for argmin/argmax; add conversion for reduce
The previous conversion for argmin/argmax simply used tf.argmin and tf.argmax.
Those ops behave differently than JAX when the inputs contain NaN and Inf. Added
a few test cases in primitive_harness to expose the failures.

In order to implement an accurate conversion of argmin/argmax, we need to use the
XLA Reduce op.

Also tightened the shape checks for lax.argmin and lax.argmax, to ensure they are
not used with an empty reduced dimension. E.g., if the axis=-1, previously we got
an internal error:
```
RuntimeError: Invalid argument: Reducing out-of-bounds dimension -1 in shape f32[2,0,3].:
This is a bug in JAX's shape-checking rules; please report it!
```
PiperOrigin-RevId: 384182794
2021-07-12 01:11:42 -07:00