314 Commits

Author SHA1 Message Date
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Adam Paszke
64510bd5b6 Add axis and tiled options to lax.all_gather.
This is especially convenient when using JAX as an HLO generator, because the
HLO AllGather defaults to the tiling behavior.

PiperOrigin-RevId: 384897270
2021-07-15 04:22:36 -07:00
jax authors
4d026e06b1 Merge pull request #7255 from jakevdp:remove-broadcast-p
PiperOrigin-RevId: 384888218
2021-07-15 03:12:07 -07:00
Jake VanderPlas
c45acd70a8 Cleanup: use pep 448 unpacking to simplify some code 2021-07-12 16:30:53 -07:00
Jake VanderPlas
12e435f71e remove lax.broadcast_p
Why? It has been subsumed by lax.broadcast_in_dim_p
2021-07-12 15:33:26 -07:00
George Necula
0beef34d25 [jax2tf] Fix conversion for argmin/argmax; add conversion for reduce
The previous conversion for argmin/argmax simply used tf.argmin and tf.argmax.
Those ops behave differently than JAX when the inputs contain NaN and Inf. Added
a few test cases in primitive_harness to expose the failures.

In order to implement an accurate conversion of argmin/argmax, we need to use the
XLA Reduce op.

Also tightened the shape checks for lax.argmin and lax.argmax, to ensure they are
not used with an empty reduced dimension. E.g., if the axis=-1, previously we got
an internal error:
```
RuntimeError: Invalid argument: Reducing out-of-bounds dimension -1 in shape f32[2,0,3].:
This is a bug in JAX's shape-checking rules; please report it!
```
PiperOrigin-RevId: 384182794
2021-07-12 01:11:42 -07:00
George Necula
1f946ad51e Fix grad of conv 0D.
This bug was introduced in #6345, and was not caught by existing tests.
Add a reproducing test.
2021-07-11 10:46:42 +03:00
James Bradbury
8e86952ee4 AWN-enabled reduction over named axes in reverse-mode AD
Previously, reverse-mode AD operators inside JAX maps always meant "compute
a gradient (or VJP, etc.) for each axis index in the map". For instance,
`vmap(grad(f))` is the standard JAX spelling of the per-example gradient of `f`.

In batching tracer terms, this "elementwise" behavior means that, if any inputs
to a function being transposed are mapped, the cotangents of all inputs, even
unmapped ones, would also be mapped. But a user might want them to be unmapped
(if, for instance, they're interested in a total gradient rather than a
per-example gradient). They could always reduce (`psum`) the cotangents
afterwards, but computing mapped cotangents in the first place would likely be
an unacceptable waste of memory and can't necessarily be optimized away.

If we want to fuse these reductions into reverse-mode autodiff itself, we need
the backward_pass logic and/or transpose rules to know about whether primal
values are mapped or unmapped. This is made possible by avals-with-names,
which encodes that information in the avals of the primal jaxpr.

Putting things together, **this change adds an option to reverse-mode AD APIs
that indicates which named axes should be reduced over in the backward pass in
situations where they were broadcasted over in the forward pass**. All other
named axes will be treated in the current elementwise way. This has the effect
of making APIs like `grad` behave akin to collectives like `psum`: they act
collectively over axes that are named explicitly, and elementwise otherwise.

Since avals-with-names is currently enabled only in `xmap`, this behavior is
only available in that context for now. It's also missing some optimizations:
  - reductions aren't fused into any first-order primitives (e.g. a `pdot`
    should have a named contracting axis added rather than being followed by a
    `psum`; this can be implemented by putting these primitives into
    `reducing_transposes`)
  - reductions are performed eagerly, even over axes that are mapped to
    hardware resources (the optimal thing to do would be to reduce eagerly
    over any vectorized axis component while delaying the reduction over any
    hardware-mapped component until the end of the overall backward pass; this
    would require a way to represent these partially-reduced values)

PiperOrigin-RevId: 383685336
2021-07-08 12:06:29 -07:00
David Majnemer
5f11bf571a Use XLA atan2 for complex atan
PiperOrigin-RevId: 382831891
2021-07-02 16:19:00 -07:00
David Majnemer
781f85b09c
Fix broken markdown
A backtick was missing.
2021-06-28 23:35:26 -07:00
Peter Hawkins
d658108d36 Fix type errors with current mypy and NumPy.
Enable type stubs for jaxlib.

Fix a nondeterminism problem in jax2tf tests.
2021-06-24 10:51:06 -04:00
Qiao Zhang
7cc277d634 Test for gtsv2 attr on cusparse. 2021-06-23 14:22:08 -07:00
jax authors
28977761d5 Merge pull request #6849 from tomhennigan:changelist/376000598
PiperOrigin-RevId: 381010658
2021-06-23 05:46:01 -07:00
Qiao Zhang
e8e6138c75 Add vmap rule for lax.clamp. 2021-06-21 14:59:34 -07:00
jax authors
50f48e8a48 Merge pull request #7027 from jakevdp:transpose-validation
PiperOrigin-RevId: 380364559
2021-06-19 09:38:21 -07:00
Jake VanderPlas
7d2057b5c7 BUG: fix validation of permutation in transpose 2021-06-18 21:54:28 -07:00
Nicholas Junge
ccc8bb7f19 Add auxiliary data support to lax.custom_root 2021-06-17 19:25:46 +02:00
George Necula
dd8ab85121 [jax2tf] Support inequality and min/max for booleans.
For inequalities we add casts to int8. For min/max we rewrite
to logical operations and/or.
2021-06-12 21:08:37 +03:00
George Necula
edd96884c7 [jax2tf] Extend shape polymorphism to handle add_transpose with broadcasting 2021-06-12 11:42:15 +03:00
Peter Hawkins
1ff12f05b3 Add unique/sorted annotations to gather().
XLA itself does not consume these, but they can be propagated onto scatter() when computing gradients.

Compute unique/sorted information on indexed accesses and indexed updates. Non-advanced indexes are always sorted and unique.
2021-06-09 21:05:41 -04:00
Adam Paszke
490f9778c8 Raise a friendlier error message when using loop axes in collectives 2021-06-08 11:55:03 +00:00
Peter Hawkins
e9611eb090 Move jax.ad_util to jax._src.ad_util.
Expose ad_util.stop_gradient_p as jax.lax.stop_gradient_p. stop_gradient() is already under the external lax namespace.

PiperOrigin-RevId: 378011152
2021-06-07 14:51:34 -07:00
jax authors
fe3de5ab72 Merge pull request #6906 from apaszke:xmap-loops
PiperOrigin-RevId: 377939118
2021-06-07 09:49:24 -07:00
Adam Paszke
4fc4a3e471 Add support for sequential loop resources in xmap
This is especially useful because it makes it easy to implement
"memory-limited vmaps". It might also come in handy for pipelining,
as that could represent the microbatching loop.

Note that at the moment the xmap has to be a pure map along all axes
that are assigned to loop resources. No collectives are supported.
2021-06-07 12:58:08 +00:00
George Necula
d243258b86 [jax2tf] Implement inequalities and friends for complex numbers.
This requires re-using JAX's lowering rule for comparisons of
complex numbers to use lexicographic comparison.
2021-06-04 17:56:44 +03:00
George Necula
293ca655e0 [jax2tf] Update limitations to account for tf.math improvements for trigonometric functions.
PiperOrigin-RevId: 377436077
2021-06-03 21:17:56 -07:00
Adam Paszke
c7a98b3b62 Fix a typo in shape checks for associative_scan
Fixes #6884.

PiperOrigin-RevId: 377276183
2021-06-03 06:37:31 -07:00
jax authors
39526a0d08 Merge pull request #6873 from ROCmSoftwarePlatform:fix_rocm_linalg
PiperOrigin-RevId: 377273325
2021-06-03 06:16:39 -07:00
George Necula
2ccda70d83 [jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.

Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.

Division  is supported only in the cases when either there is no remainder,
or the divisor is a constant.

This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:

```
   y = x.reshape((2, -1))
   z = ... y ...
   return z.reshape(x.shape)
```
2021-06-03 10:58:06 +03:00
Adam Paszke
ed96e5305f Fix incorrect handling of axis_index_groups in parallel primitive fallbacks
PiperOrigin-RevId: 377139424
2021-06-02 14:03:47 -07:00
Tom Hennigan
ffac40a2c0 Add lax.linalg.tridiagonal_solve(..), lowering to cusparse_gtsv2<T>() on GPU.
Fixes #6830.
2021-06-02 13:49:02 +00:00
Peter Hawkins
46cc654537 Move jax.abstract_arrays to jax._src.abstract_arrays.
PiperOrigin-RevId: 377044255
2021-06-02 06:25:22 -07:00
Reza Rahimi
012da545f7 add gpu to the rocsolver backend 2021-06-02 04:03:06 +00:00
jax authors
edd203e305 Merge pull request #6726 from njunge94:auxiliary_solver_data
PiperOrigin-RevId: 376899659
2021-06-01 12:58:39 -07:00
Rebecca Chen
5065e1bb93 Add missing typing.Optional type annotations to function parameters.
PiperOrigin-RevId: 376300297
2021-05-27 20:10:23 -07:00
Nicholas Junge
0308527f55 Add auxiliary data support in custom_linear_solve 2021-05-25 18:00:46 +02:00
Lukas Geiger
971eb86fa4 Reduce redundant calculations of tan, erfc and rsqrt jvp 2021-05-21 18:36:42 +01:00
George Necula
e7766838db Minor cleanup of the translation rules for cumred primitives 2021-05-21 17:45:33 +03:00
jax authors
3319cd0ed0 Merge pull request #6740 from hawkinsp:scatter
PiperOrigin-RevId: 374718181
2021-05-19 13:35:01 -07:00
Peter Hawkins
99de57f5d9 Enable Hermitian Eigendecompositions on TPU. 2021-05-19 14:39:46 -04:00
jax authors
683289c4ad Merge pull request #6764 from hawkinsp:argmax
PiperOrigin-RevId: 374437439
2021-05-18 09:31:53 -07:00
Peter Hawkins
6cc440b79d Fix handling of NaNs in GPU argmax translation rule. 2021-05-18 11:35:54 -04:00
George Necula
afe5ec3df4 Improve accuracy of the jax2tf convolution conversion.
Part of the discrepancies were due to JAX using a workaround for
missing complex convolutions on CPU/GPU, while jax2tf was not using
it. We apply the same lowering as JAX, on all platforms.

This allows us to remove custom numeric tolerances and enables complex
convolutions on GPU.

PiperOrigin-RevId: 374199441
2021-05-17 08:18:51 -07:00
Lena Martens
73e9302fc3 Fix jsp.linalg.lu translation rule to pass backend arg to lower_fun.
If it doesn't, trying to run `lu` with a custom CPU backend when a GPU is
present results in a `Unable to resolve runtime symbol:
`cuda_lu_pivots_to_permutation'` fatal error.
2021-05-14 17:37:09 +01:00
Peter Hawkins
44c98ad4e8 Improve JVP rule for scatters with non-overlapping indices.
If the scattered values don't overlap, we don't need complicated masking logic to work out which of the two overlapping values "win".
2021-05-12 14:16:35 -04:00
Peter Hawkins
ecaeb94655 Make associative_scan work for boolean arguments. 2021-05-12 10:28:55 -04:00
George Necula
ba5e11f86f [jax2tf] Improve the conversion of integer_pow for better numerical accuracy.
Previously we simply converted integer_pow to tf.math.pow. JAX instead uses
a series of multiplications. We now use the same lowering strategy as JAX, so
that we have the same numerical result.

Also improved the error messages for assertion failures.

PiperOrigin-RevId: 373351147
2021-05-12 05:45:39 -07:00
George Necula
235eb8c2b4 Copybara import of the project:
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:

[jax2tf] Change the conversion of dot_general to use XLA op.

Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.

Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
2021-05-12 02:30:15 -07:00
Matthew Johnson
b9d72a480f improve concreteness error from arguments
also tweak some error message wording
2021-05-03 17:37:34 -07:00
Jake VanderPlas
71a25cdac1 DOC: add examples to lax function docstrings 2021-04-29 09:48:52 -07:00