5891 Commits

Author SHA1 Message Date
Tianjian Lu
c5f73b3d8e [JAX] Added jax.lax.linalg.qdwh.
PiperOrigin-RevId: 406453671
2021-10-29 14:45:06 -07:00
Jake VanderPlas
40d6f5ed90 Tighten up dtypes across the package 2021-10-29 13:50:30 -07:00
jax authors
853fca2245 Merge pull request #8385 from jakevdp:fix-reshape
PiperOrigin-RevId: 406441883
2021-10-29 13:48:56 -07:00
Matthew Johnson
2cb74e1f97 make djax run again 2021-10-29 10:56:39 -07:00
Peter Hawkins
d0065d8a76 Forbid collapsing of size-0 dimensions in gather() operations.
The shape rule for gather should not allow collapsing size-0 dimensions because it is nonsensical: "collapsing" a size 0 dimension might turn an empty array into a non-empty array. And it's quite unclear what that non-empty array should contain. Forbid such collapsing in the JAX shape rule.

This appears to have arisen in practice when the size of the array is known to be 0 in another dimension, e.g., batching with a size 0 batch dimension. Instead, avoid using a gather to create these arrays. This isn't an ideal solution because it isn't polymorphic in the shape, but I think to do better we would need to change the definition of `gather` more extensively.

PiperOrigin-RevId: 406346374
2021-10-29 06:34:34 -07:00
jax authors
345ab50963 Merge pull request #8389 from tamaranorman:patch-1
PiperOrigin-RevId: 406342759
2021-10-29 06:08:27 -07:00
jax authors
2ab00151ed Copybara import of the project:
--
b40245e38d7837a7777735ad60f3b5b1ac2d499d by Sharad Vikram <sharad.vikram@gmail.com>:

Use `SourceInfo` named tuple to keep track of source information

PiperOrigin-RevId: 406293469
2021-10-28 23:07:56 -07:00
jax authors
af3c1acdec Merge pull request #8392 from sharadmv:source-info
PiperOrigin-RevId: 406275194
2021-10-28 20:18:46 -07:00
Peter Hawkins
954cb9983b [JAX] Update JAX users in preparation for a change that makes iteration over a JAX array return JAX arrays, instead of NumPy arrays.
See https://github.com/google/jax/pull/8043 for context as to why we are making this change.

The upshot for most users is that the values returned by iteration over a JAX array are now themselves JAX arrays, with the semantics of JAX arrays, which sometimes differ from the semantics of NumPy scalars and arrays. In particular:

* Unlike NumPy scalars 0-dimensional JAX arrays are not hashable. This change updates users to call `.tolist()` or `np.asarray(...)` when the output of iterating over a JAX array is hashed, used as a dictionary key, or passed to `set(...)`. In some instances, we can just call `numpy` functions instead of `jax.numpy` functions to build the array in the first place.
* This change confuses Pandas and PIL when a JAX array is converted to a Pandas dataframe or a PIL image. For now, cast JAX arrays to a NumPy array first before passing them into those libraries.
* We now need to use `numpy.testing.assert_array_equal` instead of `numpy.testing.assert_equal` to compare JAX arrays.

PiperOrigin-RevId: 406247725
2021-10-28 16:49:37 -07:00
Sharad Vikram
b40245e38d Use SourceInfo named tuple to keep track of source information 2021-10-28 13:31:26 -07:00
jax authors
934bfc0f24 Merge pull request #8364 from zhangqiaorjc:dsys
PiperOrigin-RevId: 406205507
2021-10-28 13:22:04 -07:00
tamaranorman
d890ae9068
Use default backend if no backend supplied to xla_computation 2021-10-28 18:37:15 +01:00
Peter Hawkins
9ea55468ab [JAX] Update users of jax.ops.index... functions, which are deprecated.
* replace uses of `jax.ops.index[...]` with `jax.numpy.index_exp[...]`, which is a standard NumPy function that does the same thing.
* remove some redundant uses of `jax.ops.index[...]`, where the expression is passed directly to an indexed accessor function like `.at[...]`.
* update some remaining users of `jax.ops.index_update(x, jax.ops.index[idx], y)` to use the `x.at[idx].set(y)` APIs.

PiperOrigin-RevId: 406162068
2021-10-28 09:54:26 -07:00
Jake VanderPlas
aa5e3c3b65 typo: primitive names do not need _p 2021-10-28 06:40:22 -07:00
jax authors
4ccd72bc81 Merge pull request #8309 from google:variance_scaling_axes
PiperOrigin-RevId: 406088360
2021-10-28 02:15:57 -07:00
Jake VanderPlas
723361f8f4 lax_numpy: replace some reshapes with expand_dims 2021-10-27 20:36:50 -07:00
jax authors
fd750ebcac Merge pull request #8370 from jakevdp:sparse-random
PiperOrigin-RevId: 405972112
2021-10-27 13:49:23 -07:00
jax authors
f3b1a3010a Merge pull request #8115 from sharadmv:all-gather-grad
PiperOrigin-RevId: 405924058
2021-10-27 10:40:07 -07:00
Sharad Vikram
ae9e69814a Broadcast unmapped values in all_to_all batching rule
Fixes #7965.

Co-authored-by: Sharad Vikram<sharad.vikram@gmail.com>
Co-authored-by: Adam Paszke <apaszke@google.com>
2021-10-27 10:10:41 -07:00
Jake VanderPlas
94a1feea25 [sparse] add sparse.random_bcoo() utility 2021-10-26 15:21:00 -07:00
Jake VanderPlas
4d4dae184f [sparse] implement broadcasted sparse-dense multiplication 2021-10-26 14:45:21 -07:00
Qiao Zhang
0be30fbf96 Add jax.distributed.initialize for multi-host GPU. 2021-10-26 14:37:54 -07:00
Jake VanderPlas
2259e2b0a8 [sparse] add todense() primitive for use in sparsify transform 2021-10-26 13:52:48 -07:00
Jake VanderPlas
c62452f2d2 benchmarks: add JIT versions of sparse.BCOO benchmarks
PiperOrigin-RevId: 405696495
2021-10-26 11:39:01 -07:00
Jake VanderPlas
2f0df85dcc [sparse] avoid default jit-compilation of sparse array functions 2021-10-26 10:58:37 -07:00
jax authors
08cf9fd862 Merge pull request #8317 from jakevdp:dynamic-index-dtype
PiperOrigin-RevId: 405683915
2021-10-26 10:48:30 -07:00
jax authors
7ae47cbc50 [JAX] Polish doc formatting for approx_top_k.
PiperOrigin-RevId: 405637140
2021-10-26 06:56:42 -07:00
Marc van Zee
0f477121e9 Implements a lowering function for tf.expm1 and adds tests.
PiperOrigin-RevId: 404958939
2021-10-21 23:59:04 -07:00
jax authors
66009d69b4 Merge pull request #8331 from jakevdp:odeint-args
PiperOrigin-RevId: 404909637
2021-10-21 16:48:40 -07:00
jax authors
beb473dae0 Merge pull request #8320 from jakevdp:hist-density
PiperOrigin-RevId: 404909058
2021-10-21 16:48:24 -07:00
Jake VanderPlas
3338947fee odeint: args validation allows pytrees 2021-10-21 16:34:59 -07:00
Jake VanderPlas
20062e514c [sparse] add .block_until_ready() to sparse objects 2021-10-21 10:26:14 -07:00
Marc van Zee
7a5e84311c Internal change
PiperOrigin-RevId: 404812711
2021-10-21 09:30:16 -07:00
jax authors
bc1c6b1090 [JAX] Fix sphinx formatting issue
PiperOrigin-RevId: 404797143
2021-10-21 08:10:17 -07:00
George Karpenkov
17e165929d Reland: Use variadic reduce on GPU for argmax/argmin
Fixed underlying bug

PiperOrigin-RevId: 404713314
2021-10-20 22:14:09 -07:00
Jake VanderPlas
eedf6e823d jnp.histogramdd: more succinct density computation 2021-10-20 16:54:06 -07:00
jax authors
09c4cf7bc3 Merge pull request #8288 from jakevdp:fix-multivariate-normal
PiperOrigin-RevId: 404661655
2021-10-20 16:15:34 -07:00
jax authors
720151c5f4 Merge pull request #8308 from iolloj:fft_enhanced
PiperOrigin-RevId: 404661598
2021-10-20 16:14:56 -07:00
Jake VanderPlas
fd4d987aa9 dynamic_slice: ensure start_indices dtypes match 2021-10-20 15:52:43 -07:00
Jake VanderPlas
eacec81915 odeint: validate *args 2021-10-20 14:22:43 -07:00
iollo jacopo
67dc16fc24 add fft normalisation 2021-10-20 22:15:35 +01:00
jax authors
0453e0de0f Merge pull request #8305 from jakevdp:bcoo-error
PiperOrigin-RevId: 404626396
2021-10-20 13:28:22 -07:00
jax authors
06b595321f [XLA:TPU] Support jvp/vjp in approx_top_k
Copies the jvp implementation lax.sort uses.
Left some comments for future optimizations

PiperOrigin-RevId: 404608289
2021-10-20 12:08:04 -07:00
Jake VanderPlas
b8d3035d20 [sparse] improve error for BCOO.fromdense if nse is not specified 2021-10-20 10:43:59 -07:00
James Bradbury
f5f0581281
update docstring 2021-10-20 09:16:55 -07:00
James Bradbury
eaf9eca617
Support multiple in/out axes in scaled inits 2021-10-20 09:12:37 -07:00
Marc van Zee
1b80feea6a Fixes a dtype bug in the conversion of dynamic_slice when enable_xla=False.
I tried adding a test, but in this specific case the TFLite converter uses for parameter `operand` the dtype `float32`, and for `start_indices` a tuple consisting of `tf.consts` of dtype `uint32`. I didn't know how to set up this test, but the examples eval shows that the bug is fixed for the TFLite examples.

PiperOrigin-RevId: 404527169
2021-10-20 07:24:05 -07:00
Alexander Kolesnikov
142be1348b Update jax2tf documentation with leading underscore when setting tf.Module() variables, e.g. m._variables = .... Also added a test for this.
PiperOrigin-RevId: 404483990
2021-10-20 03:22:23 -07:00
Marc van Zee
8dbd51d3f6 Implements padding support for lax.reduce_window when enable_xla=False.
Also does a few cleanups and adds some tests.

PiperOrigin-RevId: 404468307
2021-10-20 01:44:35 -07:00
Roy Frostig
623c201054 [JAX] move example libraries from jax.experimental into jax.example_libraries
The `jax.experimental.stax` and `jax.experimental.optimizers` modules are standalone examples libraries. By contrast, the remaining modules in `jax.experimental` are experimental features of the JAX core system. This change moves the two example libraries, and the README that describes them, to `jax.example_libraries` to reflect this distinction.

PiperOrigin-RevId: 404405186
2021-10-19 17:30:45 -07:00