189 Commits

Author SHA1 Message Date
George Necula
b62ceba91c [jax2tf] Expand shape polymorphism support to use dimension polynomials as values.
The goal of this change is to support shape polymorphism for operations
such as average (which needs to divide by the size of a dimension) or
indexing (which needs to normalize indices by comparing them with 0 and
adding dimension size for negative indices). In both of these cases
the size of a dimenion needs to be used as a value in the array
computation. In general, the size of a dimension is used only to
customize primitives.

This change introduces `core.dim_as_value` which must be used on
a dimension size before using it as a value in the array computation.
E.g.,

```
def average(x):
   return jnp.sum(x, axis=0) / core.dim_as_value(x.shape[0])
```

This function is the identity function if the dimension size is
constant, otherwise it uses a new primitive `shape_poly.dim_as_value_p`.

Note that this does not change fundamentally the flavor of shape
polymorphism supported in jax2tf: intermediate shapes and their values
may depend on the input shapes, but never does a shape depend on the
input values. In fact, one could have expressed the `dim_as_value`
already:

```
def dim_as_value(d):
   jnp.sum(jnp.broadcast_to(jnp.array(1), shape=(d,)))
```

We were able to suppot `jnp.mean`, `jnp.average`, `jnp.take`,
`lax.dynamic_slice`, `lax.dynamic_update_slice` by using
`core.dim_as_value` internally, but to fully roll-up the solution
we need to make `core.dim_as_value` a public API and teach the
users how to use it when they want to use shape polymorphism.
Alternatively, perhaps there is a way to automatically convert
dimension polynomials to values when passed to the lax primitives.
2021-07-27 09:02:15 +03:00
George Necula
4a3e8e99e6 Fix for numpy 1.17.5 2021-07-26 21:37:13 +03:00
George Necula
2749b63524 Ensure zeros from AD are generated on device.
Fixes: #7093
Also fixes type checking in jax2tf, because now we have to be careful
about computations that have result float0 (the broadcast_in_dim used
to compute the zeros).
2021-07-26 20:40:13 +03:00
Peter Hawkins
278ff13b66 Improve implementation of cbrt() in JAX.
Lower to XLA cbrt() operator in sufficiently new jaxlibs.
On TPU, use a Newton-Raphson step to improve the cube root.

Remove support for complex cbrt() in jax.numpy; the existing lowering was wrong and it is not entirely clear to me that we actually want to support complex `jnp.cbrt()`. NumPy itself does not support complex numbers in this case.

Add testing for `sqrt`/`rsqrt` for more types.

[XLA:Python] Add cbrt to XLA:Python bindings.

PiperOrigin-RevId: 386316949
2021-07-22 14:01:28 -07:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
jax authors
4d026e06b1 Merge pull request #7255 from jakevdp:remove-broadcast-p
PiperOrigin-RevId: 384888218
2021-07-15 03:12:07 -07:00
Jake VanderPlas
12e435f71e remove lax.broadcast_p
Why? It has been subsumed by lax.broadcast_in_dim_p
2021-07-12 15:33:26 -07:00
George Necula
0beef34d25 [jax2tf] Fix conversion for argmin/argmax; add conversion for reduce
The previous conversion for argmin/argmax simply used tf.argmin and tf.argmax.
Those ops behave differently than JAX when the inputs contain NaN and Inf. Added
a few test cases in primitive_harness to expose the failures.

In order to implement an accurate conversion of argmin/argmax, we need to use the
XLA Reduce op.

Also tightened the shape checks for lax.argmin and lax.argmax, to ensure they are
not used with an empty reduced dimension. E.g., if the axis=-1, previously we got
an internal error:
```
RuntimeError: Invalid argument: Reducing out-of-bounds dimension -1 in shape f32[2,0,3].:
This is a bug in JAX's shape-checking rules; please report it!
```
PiperOrigin-RevId: 384182794
2021-07-12 01:11:42 -07:00
George Necula
1f946ad51e Fix grad of conv 0D.
This bug was introduced in #6345, and was not caught by existing tests.
Add a reproducing test.
2021-07-11 10:46:42 +03:00
David Majnemer
5f11bf571a Use XLA atan2 for complex atan
PiperOrigin-RevId: 382831891
2021-07-02 16:19:00 -07:00
Peter Hawkins
d658108d36 Fix type errors with current mypy and NumPy.
Enable type stubs for jaxlib.

Fix a nondeterminism problem in jax2tf tests.
2021-06-24 10:51:06 -04:00
Qiao Zhang
e8e6138c75 Add vmap rule for lax.clamp. 2021-06-21 14:59:34 -07:00
Jake VanderPlas
7d2057b5c7 BUG: fix validation of permutation in transpose 2021-06-18 21:54:28 -07:00
George Necula
dd8ab85121 [jax2tf] Support inequality and min/max for booleans.
For inequalities we add casts to int8. For min/max we rewrite
to logical operations and/or.
2021-06-12 21:08:37 +03:00
George Necula
edd96884c7 [jax2tf] Extend shape polymorphism to handle add_transpose with broadcasting 2021-06-12 11:42:15 +03:00
Peter Hawkins
1ff12f05b3 Add unique/sorted annotations to gather().
XLA itself does not consume these, but they can be propagated onto scatter() when computing gradients.

Compute unique/sorted information on indexed accesses and indexed updates. Non-advanced indexes are always sorted and unique.
2021-06-09 21:05:41 -04:00
Peter Hawkins
e9611eb090 Move jax.ad_util to jax._src.ad_util.
Expose ad_util.stop_gradient_p as jax.lax.stop_gradient_p. stop_gradient() is already under the external lax namespace.

PiperOrigin-RevId: 378011152
2021-06-07 14:51:34 -07:00
George Necula
d243258b86 [jax2tf] Implement inequalities and friends for complex numbers.
This requires re-using JAX's lowering rule for comparisons of
complex numbers to use lexicographic comparison.
2021-06-04 17:56:44 +03:00
George Necula
293ca655e0 [jax2tf] Update limitations to account for tf.math improvements for trigonometric functions.
PiperOrigin-RevId: 377436077
2021-06-03 21:17:56 -07:00
George Necula
2ccda70d83 [jax2tf] Improved coverage of shape polymorphism by allowing dimension polynomials.
Previously we allowed a dimension variable in lieu of a dimension. Now we
allow multi-variate dimension polynomials. These polynomials overload addition, subtraction,
multiplication. They also partially support equality and inequality checking.

Equality and inequality are supported only when the operation result is the
same for all valuations of variables greater than 0. For example, `a == a`,
`a * b + 1 == 1 + b * a`, `a >= 1`, `2 * a + b >= 3`, `a >= a`. However, for
the following a `core.InconclusiveDimensionOperation` is raised: `a = b`, `a
>= 2`.

Division  is supported only in the cases when either there is no remainder,
or the divisor is a constant.

This change allows us to support more general cases of `jnp.reshape(-1)`,
such as those used in the internal implementation of `random_gamma`:

```
   y = x.reshape((2, -1))
   z = ... y ...
   return z.reshape(x.shape)
```
2021-06-03 10:58:06 +03:00
Peter Hawkins
46cc654537 Move jax.abstract_arrays to jax._src.abstract_arrays.
PiperOrigin-RevId: 377044255
2021-06-02 06:25:22 -07:00
Lukas Geiger
971eb86fa4 Reduce redundant calculations of tan, erfc and rsqrt jvp 2021-05-21 18:36:42 +01:00
jax authors
3319cd0ed0 Merge pull request #6740 from hawkinsp:scatter
PiperOrigin-RevId: 374718181
2021-05-19 13:35:01 -07:00
jax authors
683289c4ad Merge pull request #6764 from hawkinsp:argmax
PiperOrigin-RevId: 374437439
2021-05-18 09:31:53 -07:00
Peter Hawkins
6cc440b79d Fix handling of NaNs in GPU argmax translation rule. 2021-05-18 11:35:54 -04:00
George Necula
afe5ec3df4 Improve accuracy of the jax2tf convolution conversion.
Part of the discrepancies were due to JAX using a workaround for
missing complex convolutions on CPU/GPU, while jax2tf was not using
it. We apply the same lowering as JAX, on all platforms.

This allows us to remove custom numeric tolerances and enables complex
convolutions on GPU.

PiperOrigin-RevId: 374199441
2021-05-17 08:18:51 -07:00
Peter Hawkins
44c98ad4e8 Improve JVP rule for scatters with non-overlapping indices.
If the scattered values don't overlap, we don't need complicated masking logic to work out which of the two overlapping values "win".
2021-05-12 14:16:35 -04:00
George Necula
ba5e11f86f [jax2tf] Improve the conversion of integer_pow for better numerical accuracy.
Previously we simply converted integer_pow to tf.math.pow. JAX instead uses
a series of multiplications. We now use the same lowering strategy as JAX, so
that we have the same numerical result.

Also improved the error messages for assertion failures.

PiperOrigin-RevId: 373351147
2021-05-12 05:45:39 -07:00
George Necula
235eb8c2b4 Copybara import of the project:
--
1ecf4f02891cad70cc8f094b49cf2458105ca366 by George Necula <gcnecula@gmail.com>:

[jax2tf] Change the conversion of dot_general to use XLA op.

Instead of converting the dot_general to a sea of TF ops, when
we enable_xla we just use the XLA op. This has the advantage
that it also supports the preferred_element_type.

Fixed bug with passing the precision parameter to TF.
Also improved tests to print the HLO in case of numerical errors.

COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/6717 from gnecula:tf_dot 1ecf4f02891cad70cc8f094b49cf2458105ca366
PiperOrigin-RevId: 373326655
2021-05-12 02:30:15 -07:00
Jake VanderPlas
71a25cdac1 DOC: add examples to lax function docstrings 2021-04-29 09:48:52 -07:00
Jake VanderPlas
ca684df0e9 DOC: add example for lax.dynamic_update_slice 2021-04-23 09:10:43 -07:00
Lukas Geiger
f7f42694d9 Add support for preferred_element_type arg in convolutions 2021-04-22 10:29:31 +02:00
Skye Wanderman-Milne
feb79e5698 Fix some Cloud TPU test failures.
The new select_and_gather_add logic was inspired by
3a35f7072a.
2021-04-21 00:37:02 +00:00
Lena Martens
fa5e19b630 Fix Zero handling in select_jvp. 2021-04-19 17:03:07 +01:00
Peter Hawkins
14d991dd90 Move jax.config to jax._src.config.
PiperOrigin-RevId: 369230109
2021-04-19 08:53:12 -07:00
Matthew Johnson
9d6263a743 support implicit broadcasting in transpose rules 2021-04-16 12:51:11 -07:00
Peter Hawkins
26e9ebcdae Move jax.api to jax._src.api.
PiperOrigin-RevId: 368233837
2021-04-13 09:43:24 -07:00
Peter Hawkins
0f1520b6d2 Enable variadic select_and_gather on TPU. 2021-04-13 09:09:10 -04:00
jax authors
8f2502324a Merge pull request #6408 from LenaMartens:changelist/367979622
PiperOrigin-RevId: 367991796
2021-04-12 06:42:24 -07:00
jax authors
ce67e563a1 Merge pull request #6375 from gnecula:mask_clean
PiperOrigin-RevId: 367985125
2021-04-12 05:50:19 -07:00
Lena Martens
b4f66d2676 Fix handling of ad.Zero in _select_and_scatter_add_transpose.
Fixes #6403.
2021-04-12 13:07:40 +01:00
Peter Hawkins
3a35f7072a Implement select_and_gather_add using variadic reducewindow on CPU. 2021-04-09 14:40:43 -04:00
jax authors
438b56c483 Fix typo in rng_bit_generator comment.
PiperOrigin-RevId: 367460802
2021-04-08 10:42:45 -07:00
George Necula
0e280bbac0 [masking] Remove references to masking.Poly from the lax.py and lax_numpy.py
Previously, in order to increase the coverage of masking we added special
cases in lax.py and lax_numpy.py to avoid exceptions in presence of
masking.Poly.

For example:
```
if not isinstance(d, masking.Poly):
   if some_check(d):
      raise ValueError
```

All such conditionals make the code behave potentially different when
tracing with masking.Poly than when tracing with concrete shapes, which
makes it hard to ensure soundness.

Perhaps the most eggregious was:
```
if type(i) is Poly:
  # dummy index if i is polynomial, doesn't matter for shape inference
  i = 0
```
2021-04-08 17:45:14 +03:00
jax authors
3a9ce3990e Merge pull request #6345 from gnecula:shape_poly
PiperOrigin-RevId: 367416742
2021-04-08 06:21:12 -07:00
George Necula
2e9e824289 Cleanup and fix triangular_solve 2021-04-08 10:42:38 +03:00
George Necula
99d5f09b29 Fix select and eigh 2021-04-08 10:42:38 +03:00
George Necula
5750ec074a Fix scatter 2021-04-08 10:42:38 +03:00
George Necula
551a89cfe9 Fixes for slice 2021-04-08 10:42:38 +03:00
George Necula
cbe5f54cca Added support for lax.pad, and more error checking 2021-04-08 10:42:38 +03:00