5069 Commits

Author SHA1 Message Date
jax authors
3d1a6a308e Merge pull request #6945 from skye:version
PiperOrigin-RevId: 378756818
2021-06-10 16:11:49 -07:00
jax authors
e8068c0802 Merge pull request #6943 from jakevdp:bcoo-todense-fix
PiperOrigin-RevId: 378722733
2021-06-10 13:31:09 -07:00
Skye Wanderman-Milne
063401f3ef Update jax version to 0.2.14 2021-06-10 13:15:53 -07:00
Jake VanderPlas
72fe3babee bcoo_todense: fix corner case 2021-06-10 12:02:04 -07:00
Skye Wanderman-Milne
4abac4f170 Pin the tpu_driver version used for Cloud TPU Colabs, instead of using nightly.
There have been some recent breakages affecting the nightly driver,
causing JAX operations to fail on Cloud TPU Colabs. Pinning to a
specific version will alleviate these problems. This version may need
to be updated if there are breaking changes to the tpu_driver
client/server boundary, but that doesn't happen very often.
2021-06-10 10:51:01 -07:00
jax authors
d622d5c824 Merge pull request #6939 from 8bitmp3:patch-1
PiperOrigin-RevId: 378675118
2021-06-10 10:03:40 -07:00
George Necula
888db31ede [jax2tf] Fix passing of indices_are_sorted to the TF XlaGather op
PiperOrigin-RevId: 378660840
2021-06-10 08:54:17 -07:00
8bitmp3
8568aee800
Add missing back ticks to fix jax2tf README Markdown rendering in Different 64-bit precision in JAX and TensorFlow 2021-06-10 16:09:10 +01:00
jax authors
7540690157 Merge pull request #6897 from hawkinsp:indexunique
PiperOrigin-RevId: 378550369
2021-06-09 18:49:00 -07:00
jax authors
69f6d5e3d2 Merge pull request #6781 from lukepfister:resize_weight
PiperOrigin-RevId: 378550280
2021-06-09 18:45:49 -07:00
Peter Hawkins
1ff12f05b3 Add unique/sorted annotations to gather().
XLA itself does not consume these, but they can be propagated onto scatter() when computing gradients.

Compute unique/sorted information on indexed accesses and indexed updates. Non-advanced indexes are always sorted and unique.
2021-06-09 21:05:41 -04:00
Jake VanderPlas
79d0852145 Add optional size argument to jnp.union1d for JIT compatibility 2021-06-09 11:36:34 -07:00
Jake VanderPlas
69f29a631a Add experimental batched COO sparse format.
This is an implementation of a batch-friendly multidimensional COO format for JAX. It contains implementations of four primitives (bcoo_todense, bcoo_fromdense, bcoo_extract, bcoo_dot_general), as well as batching, JVP, and transpose rules for each.

For convenience, this also adds class BCOO, which is a pytree wrapper around these.
2021-06-09 09:10:53 -07:00
Marc van Zee
b749e78d2c Adds support for enable_xla=False for shape polymorphism tests and adds such tests for dynamic_slice.
It turned out that, in jax2tf._dynamic_slice, tf.constant doesn't work with polymorphic shapes, so I replaced it with a tf.cast.

PiperOrigin-RevId: 378392273
2021-06-09 06:35:07 -07:00
jax authors
86d2da44c0 Merge pull request #6919 from marcvanzee:patch-3
PiperOrigin-RevId: 378352041
2021-06-09 01:53:27 -07:00
George Necula
59ae45a83c [jax2tf] Add support for generating HLO OpMetadata in the TF graph
The goal is to ensure that the HLO that
jax2tf->TF/XLA generates has the same metadata as what JAX generates.
This includes `op_type`, `op_name`, and source information, which are
used for debugging and profiling.

In order to ensure that this metadata is carried from the JAX tracing
time to TF/XLA, we save the metadata in custom TF op attributes. These
attributes are automatically preserved through SavedModel. This relies
on a separate change in TF/XLA to look for these custom attributes
and override its default.

For the source information, we use pretty much the same code that
xla.py uses. HLO OpMetadata has room for only one source location.
JAX (xla.py) picks the top-most user frame, which is obtained by
filtering out the stack frames in the JAX source tree. When used
with jax2tf we also need to filter out stack frames in the
TensorFlow source tree.

The hardest part is to generate the `op_name`, which is a hierarchical
name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`.
We carry the current `name_stack` in thread-local state. Unfortunately, there
is no easy way to share the exact code that achieves this in xla.py. At the
same time it is not crucial that we have exactly identical name stacks as in
JAX.

I attempted to also carry this state in the JAX `MainTrace`, but could not
fully control the name stack. E.g., when calling a jitted-function we
have to reuse the current `MainTrace` although we want to push an element
on the name stack.

For now this option is not yet enabled until we make the necessary
changes in TensorFlow.
2021-06-09 08:08:42 +02:00
jax authors
30b00095a9 Merge pull request #6915 from jakevdp:argwhere-size
PiperOrigin-RevId: 378275722
2021-06-08 16:39:57 -07:00
jax authors
c6d389387e Merge pull request #6926 from jakevdp:fix-random-validation
PiperOrigin-RevId: 378271745
2021-06-08 16:21:18 -07:00
Jake VanderPlas
f97e2f945f jnp.argwhere: add optional size parameter for JIT compatibility 2021-06-08 16:17:37 -07:00
Jake VanderPlas
022464b91b jnp.where: add optional size argument 2021-06-08 15:53:12 -07:00
Allen Lavoie
e7fe44e9fd Fix jax2tf._is_tfval after tf.constant placement changes
complex128 isn't supported on TPUs in TF, tf.constant now places on TPU by default, _is_tfval saw the exception and assumed it wasn't convertable to a TF type.

PiperOrigin-RevId: 378240447
2021-06-08 14:06:22 -07:00
Jake VanderPlas
119c9bc0dd jax.random: improve input validation (fixes #6922) 2021-06-08 13:37:21 -07:00
Jake VanderPlas
1296dc3f1e jnp.flatnonzero: add optional size argument for JIT compatibility 2021-06-08 13:16:51 -07:00
jax authors
72cd6d0072 Merge pull request #6912 from jakevdp:jittable-unique
PiperOrigin-RevId: 378217815
2021-06-08 12:34:24 -07:00
jax authors
d38def4660 Merge pull request #6923 from jakevdp:gamma-doc
PiperOrigin-RevId: 378217732
2021-06-08 12:31:06 -07:00
jax authors
648b5d3265 Merge pull request #6066 from apaszke:xmap-no-mesh-slicing
PiperOrigin-RevId: 378209333
2021-06-08 11:54:05 -07:00
Jake VanderPlas
d198ad0ac1 jnp.unique: add optional size argument for JIT compatibility 2021-06-08 11:31:42 -07:00
Jake VanderPlas
22dbe80255 DOC: state that digamma only accepts float 2021-06-08 10:47:27 -07:00
Adam Paszke
54ba051631 Always run xmap over all mesh axes
Automatic mesh slicing might be surprising, and can always be
performed manually.
2021-06-08 13:36:13 +00:00
Adam Paszke
490f9778c8 Raise a friendlier error message when using loop axes in collectives 2021-06-08 11:55:03 +00:00
Marc van Zee
80e69d456e
Update README.md 2021-06-08 07:21:40 +02:00
jax authors
7a3a160b26 Merge pull request #6869 from colemanliyah:file_system_cache
PiperOrigin-RevId: 378012890
2021-06-07 14:59:49 -07:00
Peter Hawkins
e9611eb090 Move jax.ad_util to jax._src.ad_util.
Expose ad_util.stop_gradient_p as jax.lax.stop_gradient_p. stop_gradient() is already under the external lax namespace.

PiperOrigin-RevId: 378011152
2021-06-07 14:51:34 -07:00
Peter Hawkins
d5ba87ad7f Add a device_put handler for tokens.
Fixes bug with tokens passed to trivial computations.
2021-06-07 16:19:14 -04:00
Liyah Coleman
75cc734c8e completed FilesystemCache class with corresponding unit tests 2021-06-07 17:26:23 +00:00
jax authors
fe3de5ab72 Merge pull request #6906 from apaszke:xmap-loops
PiperOrigin-RevId: 377939118
2021-06-07 09:49:24 -07:00
Marc van Zee
9fed620119 Adds support for lax.dynamic_slice_p when XLA is disabled.
PiperOrigin-RevId: 377909682
2021-06-07 07:22:48 -07:00
Adam Paszke
4fc4a3e471 Add support for sequential loop resources in xmap
This is especially useful because it makes it easy to implement
"memory-limited vmaps". It might also come in handy for pipelining,
as that could represent the microbatching loop.

Note that at the moment the xmap has to be a pure map along all axes
that are assigned to loop resources. No collectives are supported.
2021-06-07 12:58:08 +00:00
George Necula
d243258b86 [jax2tf] Implement inequalities and friends for complex numbers.
This requires re-using JAX's lowering rule for comparisons of
complex numbers to use lexicographic comparison.
2021-06-04 17:56:44 +03:00
George Necula
ede457f1a5 [jax2tf] Fix bug with max_int for uint64 2021-06-04 15:29:54 +03:00
George Necula
293ca655e0 [jax2tf] Update limitations to account for tf.math improvements for trigonometric functions.
PiperOrigin-RevId: 377436077
2021-06-03 21:17:56 -07:00
Peter Hawkins
5dc9df386c [JAX] Attach a priority to JAX backends. Use the backend with the highest priority when choosing a default backend.
PiperOrigin-RevId: 377351657
2021-06-03 12:48:24 -07:00
Jake VanderPlas
21dbe30fbb BUG: return JAX arrays rather than NumPy arrays in jnp.unravel_index 2021-06-03 09:15:01 -07:00
Peter Hawkins
b2c7ae728d [JAX] Catch all exceptions from backend initialization.
PiperOrigin-RevId: 377278098
2021-06-03 06:49:56 -07:00
Adam Paszke
c7a98b3b62 Fix a typo in shape checks for associative_scan
Fixes #6884.

PiperOrigin-RevId: 377276183
2021-06-03 06:37:31 -07:00
jax authors
39526a0d08 Merge pull request #6873 from ROCmSoftwarePlatform:fix_rocm_linalg
PiperOrigin-RevId: 377273325
2021-06-03 06:16:39 -07:00
Adam Paszke
bca3d61b3b Insert xmap SPMD axes into pjit sharding annotations
This should let us emit good XLA annotations for `xmap(pjit)`. Previously
we might have been overestimating the set of replicated mesh dimensions.

PiperOrigin-RevId: 377259226
2021-06-03 04:13:29 -07:00
jax authors
ecab743e5c Merge pull request #6877 from hawkinsp:tracebacks
PiperOrigin-RevId: 377247694
2021-06-03 02:47:21 -07:00
George Necula
2ccda70d83 [jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.

Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.

Division  is supported only in the cases when either there is no remainder,
or the divisor is a constant.

This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:

```
   y = x.reshape((2, -1))
   z = ... y ...
   return z.reshape(x.shape)
```
2021-06-03 10:58:06 +03:00
George Necula
d03d849a19 [jax2tf] Fix the 32/64-bit behavior to follow JAX rules
JAX and TensorFlow have different behavior w.r.t. 32-64 bit
computations. This PR cleans up the handling of types in jax2tf
to ensure that we follow the same behavior in jax2tf and in JAX.

This means that f_jax(args) always does the computation with the
same precision as jax2tf.convert(f_jax)(args). This may mean that
the result of the conversion depends on the value of JAX_ENABLE_x64.

See README.md for more details.
2021-06-03 10:12:58 +03:00