8205 Commits

Author SHA1 Message Date
tlu7
095e6507b9 Support value computation of associated Legendre functions.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-06-14 14:51:37 -07:00
jax authors
1e4d28a2d9 Merge pull request #6965 from hawkinsp:mypy
PiperOrigin-RevId: 379345896
2021-06-14 14:01:09 -07:00
Peter Hawkins
07277f0785 Bump mypy version to 0.902. 2021-06-14 10:05:34 -04:00
jax authors
16c4c49909 Merge pull request #6961 from gnecula:tf_errors
PiperOrigin-RevId: 379253554
2021-06-14 04:34:03 -07:00
George Necula
4f607c5e48 [jax2tf] Improve dtype coverage for neg; update limitation documentation
PiperOrigin-RevId: 379237975
2021-06-14 02:24:10 -07:00
jax authors
eb83ff42cf Merge pull request #6934 from pschuh:propagate-name
PiperOrigin-RevId: 379218767
2021-06-13 23:18:40 -07:00
jax authors
411466955a Merge pull request #6956 from gnecula:tf_bool
PiperOrigin-RevId: 379216799
2021-06-13 22:54:48 -07:00
George Necula
07cc58122d [jax2tf] Change the InconclusiveDimensionOperation error to include link to documentation 2021-06-13 17:58:22 +03:00
George Necula
dd8ab85121 [jax2tf] Support inequality and min/max for booleans.
For inequalities we add casts to int8. For min/max we rewrite
to logical operations and/or.
2021-06-12 21:08:37 +03:00
jax authors
5e3be94d8c Merge pull request #6955 from gnecula:tf_poly_add_transpose
PiperOrigin-RevId: 379061224
2021-06-12 10:06:58 -07:00
George Necula
edd96884c7 [jax2tf] Extend shape polymorphism to handle add_transpose with broadcasting 2021-06-12 11:42:15 +03:00
jax authors
f6e0297c22 Merge pull request #6944 from jakevdp:bcoo-reduce
PiperOrigin-RevId: 378940175
2021-06-11 13:36:56 -07:00
Jake VanderPlas
2113d9c34d add bcoo_reduce_sum() function 2021-06-11 13:19:54 -07:00
jax authors
31e9c65f2a Merge pull request #6952 from jakevdp:hstack-reshape
PiperOrigin-RevId: 378916971
2021-06-11 11:43:38 -07:00
jax authors
8fcdb85f09 Merge pull request #6940 from jakevdp:fix-sinc
PiperOrigin-RevId: 378904804
2021-06-11 10:47:57 -07:00
Jake VanderPlas
17710c0711 add efficient path for array input to jnp.stack, jnp.[hvd]stack, jnp.column_stack 2021-06-11 10:42:06 -07:00
jax authors
3550732a74 Merge pull request #6946 from jakevdp:concat-reshape
PiperOrigin-RevId: 378900980
2021-06-11 10:32:58 -07:00
jax authors
750f586400 Merge pull request #6941 from hawkinsp:numpy
PiperOrigin-RevId: 378863899
2021-06-11 06:53:59 -07:00
Peter Hawkins
b130257ee1 Drop support for NumPy 1.16. 2021-06-11 09:03:09 -04:00
jax authors
87a533e4ea Merge pull request #6947 from gnecula:tf_call_tf
PiperOrigin-RevId: 378837807
2021-06-11 03:05:14 -07:00
George Necula
1994f6df4a [jax2tf] Fix the round-trip call_tf(convert)
Also cleaned the handling of global state in jax2tf.
2021-06-11 11:57:27 +03:00
jax authors
3d1a6a308e Merge pull request #6945 from skye:version
PiperOrigin-RevId: 378756818
jax-v0.2.14
2021-06-10 16:11:49 -07:00
jax authors
e8068c0802 Merge pull request #6943 from jakevdp:bcoo-todense-fix
PiperOrigin-RevId: 378722733
2021-06-10 13:31:09 -07:00
Jake VanderPlas
0470f4f368 jnp.concatenate: add fast path based on lax.reshape for array inputs 2021-06-10 13:25:33 -07:00
Skye Wanderman-Milne
063401f3ef Update jax version to 0.2.14 2021-06-10 13:15:53 -07:00
jax authors
42b540c2f4 Merge pull request #6942 from skye:tpu_driver_version
PiperOrigin-RevId: 378703719
2021-06-10 12:02:54 -07:00
Jake VanderPlas
72fe3babee bcoo_todense: fix corner case 2021-06-10 12:02:04 -07:00
Skye Wanderman-Milne
4abac4f170 Pin the tpu_driver version used for Cloud TPU Colabs, instead of using nightly.
There have been some recent breakages affecting the nightly driver,
causing JAX operations to fail on Cloud TPU Colabs. Pinning to a
specific version will alleviate these problems. This version may need
to be updated if there are breaking changes to the tpu_driver
client/server boundary, but that doesn't happen very often.
2021-06-10 10:51:01 -07:00
jax authors
d622d5c824 Merge pull request #6939 from 8bitmp3:patch-1
PiperOrigin-RevId: 378675118
2021-06-10 10:03:40 -07:00
Jake VanderPlas
80d8f2d56c jnp.sinc: fix NaNs at x=0 2021-06-10 09:14:07 -07:00
George Necula
888db31ede [jax2tf] Fix passing of indices_are_sorted to the TF XlaGather op
PiperOrigin-RevId: 378660840
2021-06-10 08:54:17 -07:00
8bitmp3
8568aee800
Add missing back ticks to fix jax2tf README Markdown rendering in Different 64-bit precision in JAX and TensorFlow 2021-06-10 16:09:10 +01:00
jax authors
f67acb9379 Merge pull request #6937 from mariosasko:specify-zip-safe
PiperOrigin-RevId: 378647488
2021-06-10 07:34:18 -07:00
mariosasko
55b421ff36 Specify zip_safe for mypy 2021-06-10 16:06:11 +02:00
jax authors
7540690157 Merge pull request #6897 from hawkinsp:indexunique
PiperOrigin-RevId: 378550369
2021-06-09 18:49:00 -07:00
jax authors
69f6d5e3d2 Merge pull request #6781 from lukepfister:resize_weight
PiperOrigin-RevId: 378550280
2021-06-09 18:45:49 -07:00
Peter Hawkins
1ff12f05b3 Add unique/sorted annotations to gather().
XLA itself does not consume these, but they can be propagated onto scatter() when computing gradients.

Compute unique/sorted information on indexed accesses and indexed updates. Non-advanced indexes are always sorted and unique.
2021-06-09 21:05:41 -04:00
Parker Schuh
10c0b8d94a Jaxpr transform was losing name. 2021-06-09 17:11:31 -07:00
jax authors
7d0bda604a Merge pull request #6930 from jakevdp:union1d-size
PiperOrigin-RevId: 378530731
2021-06-09 16:51:19 -07:00
Peter Hawkins
73c47dce6e [XLA] Revert to using the textbook algorithm to construct the 2x2 Jacobi rotations in Eigendecomposition.
The current version is causing wrong outputs when the diagonal elements are exactly zero.

https://github.com/tensorflow/tensorflow/issues/50017

PiperOrigin-RevId: 378506479
2021-06-09 14:56:45 -07:00
Jake VanderPlas
79d0852145 Add optional size argument to jnp.union1d for JIT compatibility 2021-06-09 11:36:34 -07:00
jax authors
2a1936e6f9 Merge pull request #6824 from jakevdp:sparse-bcoo
PiperOrigin-RevId: 378437750
2021-06-09 10:25:28 -07:00
jax authors
8362db6ef8 Merge pull request #6925 from jakevdp:nonzero-test
PiperOrigin-RevId: 378420194
2021-06-09 09:10:59 -07:00
Jake VanderPlas
69f29a631a Add experimental batched COO sparse format.
This is an implementation of a batch-friendly multidimensional COO format for JAX. It contains implementations of four primitives (bcoo_todense, bcoo_fromdense, bcoo_extract, bcoo_dot_general), as well as batching, JVP, and transpose rules for each.

For convenience, this also adds class BCOO, which is a pytree wrapper around these.
2021-06-09 09:10:53 -07:00
Marc van Zee
b749e78d2c Adds support for enable_xla=False for shape polymorphism tests and adds such tests for dynamic_slice.
It turned out that, in jax2tf._dynamic_slice, tf.constant doesn't work with polymorphic shapes, so I replaced it with a tf.cast.

PiperOrigin-RevId: 378392273
2021-06-09 06:35:07 -07:00
jax authors
86d2da44c0 Merge pull request #6919 from marcvanzee:patch-3
PiperOrigin-RevId: 378352041
2021-06-09 01:53:27 -07:00
jax authors
aac8d7434c Merge pull request #6860 from gnecula:tf_source
PiperOrigin-RevId: 378351073
2021-06-09 01:45:19 -07:00
George Necula
59ae45a83c [jax2tf] Add support for generating HLO OpMetadata in the TF graph
The goal is to ensure that the HLO that
jax2tf->TF/XLA generates has the same metadata as what JAX generates.
This includes `op_type`, `op_name`, and source information, which are
used for debugging and profiling.

In order to ensure that this metadata is carried from the JAX tracing
time to TF/XLA, we save the metadata in custom TF op attributes. These
attributes are automatically preserved through SavedModel. This relies
on a separate change in TF/XLA to look for these custom attributes
and override its default.

For the source information, we use pretty much the same code that
xla.py uses. HLO OpMetadata has room for only one source location.
JAX (xla.py) picks the top-most user frame, which is obtained by
filtering out the stack frames in the JAX source tree. When used
with jax2tf we also need to filter out stack frames in the
TensorFlow source tree.

The hardest part is to generate the `op_name`, which is a hierarchical
name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`.
We carry the current `name_stack` in thread-local state. Unfortunately, there
is no easy way to share the exact code that achieves this in xla.py. At the
same time it is not crucial that we have exactly identical name stacks as in
JAX.

I attempted to also carry this state in the JAX `MainTrace`, but could not
fully control the name stack. E.g., when calling a jitted-function we
have to reuse the current `MainTrace` although we want to push an element
on the name stack.

For now this option is not yet enabled until we make the necessary
changes in TensorFlow.
2021-06-09 08:08:42 +02:00
Jake VanderPlas
0f4f4102ce Add more complete test for jnp.nonzero size argument 2021-06-08 16:40:53 -07:00
jax authors
30b00095a9 Merge pull request #6915 from jakevdp:argwhere-size
PiperOrigin-RevId: 378275722
2021-06-08 16:39:57 -07:00