126 Commits

Author SHA1 Message Date
Peter Hawkins
7e1b826ef5 Enable fast TPU LU decomposition for complex types. 2020-09-22 22:03:02 -04:00
Benjamin Chetioui
d478e346ac
Fix conditional in eig and expand eig test suite. (#4320)
* Fix conditional in eig and expand eig test suite.
2020-09-18 10:30:19 +03:00
Peter Hawkins
cefa93f2ed
Lower LU decomposition to a custom TPU implementation for float32 types. (#4291) 2020-09-15 09:04:54 -04:00
Benjamin Chetioui
58a117fe0d
Modifies eig_p and related operations to take advantage of the new jaxlib geev API (#4266)
* Add options to compute L/R eigenvectors in geev.

The new arguments are by default set to True to ensure backwards
compatibility between jaxlib and jax.

Reformulate eig-related operations based on the new geev API.

* Addressed hawkinsp's comments from google/jax#3882.

Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
2020-09-15 11:45:15 +03:00
Peter Hawkins
cf65f6b24e
Change lax_linalg.lu to return a permutation representation of the partial pivoting information. (#4241)
The permutation is more efficiently computed during the decomposition on TPU, and the only use case that would not require us to compute it would be for evaluating determinants.
2020-09-10 11:16:35 -04:00
Benjamin Chetioui
7065d07166
Fix svd_p to stop returning garbage values of u/vt when compute_uv ==… (#3895)
* Fix svd_p to stop returning garbage values of u/vt when compute_uv == False.

The custom call made by svd_p does not compute u and vt when
compute_uv is set to False. Returning them using the primitive
means that it is up to the caller to throw away these values.

* Added documentation to svd.
2020-08-05 13:35:46 +03:00
Jake Vanderplas
82dbaca0f1
Revert #3610 & #3684 (#3688)
* Revert "linalg_test: define test matrices lazily (#3684)"

This reverts commit 2be1baa41a170192c209c94b060d0d034d1de2c2.

* Revert "Make LU gradient work for low-rank matrices (#3610)"

This reverts commit 23deefa71838ceeab41977ac0ab781164c914a8c.
2020-07-07 16:19:43 -07:00
David Pfau
23deefa718
Make LU gradient work for low-rank matrices (#3610)
* Correct LU JVP rule to handle low-rank matrices

This requires one additional triangular_solve, though the timing should still be dominated by the LU decomposition itself. Also, the problem of computing the JVP of L for a low-rank matrix is underdetermined and likely depends on the backend used. We do not enforce any form on L'.
Still need to fix:
*NaNs in reverse mode
*Non-square matrices
*Complex matrices

* Updated LU JVP rule for VJP and nonsquare, complex matrices

* Added singular test matrices to linalg_test header

* Add tests of LU for singular matrices

* Upgrade order for Det and LU grad tests

* Increased matmul precision for better TPU accuracy

* Added comment explaining fix to LU gradient

* Remove trailing whitespace

* Moved nonsquare singular matrices next to testLuGradOfNonSquareSingularMatrix

* Fixed linter issues

* Revert changes to tests
2020-07-06 13:13:56 -07:00
Peter Hawkins
32e419d189
Fix eigh JVP to ensure that both the primal and tangents of the eigen… (#3550)
* Fix eigh JVP to ensure that both the primal and tangents of the eigenvalues are real.

Add test to jax.test_util.check_jvp that ensure the primals and both the primals and tangents produced by a JVP rule have identical types.

* Cast input to static indexing grad tests to a JAX array so new type check passes.
2020-06-25 08:14:54 -04:00
Matthew Johnson
2f7108f78b
remove the lower_fun default multiple_results=True (#3524) 2020-06-22 17:50:33 -07:00
Adam Paszke
e36c72b983
Make ad_util.zero a class that carries avals (similar to UndefinedPrimal) (#3222) 2020-06-08 17:50:14 +02:00
Jake Vanderplas
2a10dbbf37
deflake remainder of jax (#3343) 2020-06-06 10:51:34 -07:00
Peter Hawkins
841f21fcad
Enable SVD on TPU. (#3334) 2020-06-05 12:21:30 -04:00
Adam Paszke
adb442eb8a Make ad_util.zero a class that carries avals (similar to UndefinedPrimal)
This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).
2020-06-05 15:51:30 +00:00
Peter Hawkins
d55ea510e2
Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46. (#3046)
* Update JAX to avoid XLA:Python API names deprecated in jaxlib 0.1.46.

* Bump minimum jaxlib version to 0.1.47.
2020-05-11 17:43:55 -04:00
Peter Hawkins
2f09e89e72
Update internal aliases to lax_numpy to jnp instead of np. (#2975) 2020-05-05 20:41:57 -04:00
Jamie Townsend
f6e9060379
Qr complex jvp fix (#2872)
* Fix qr jvp for complex input

* Fix qr jvp for complex64 inputs when jax_enable_x64=True

* Reenable complex jvp test for qr
2020-04-28 12:58:49 -04:00
Peter Hawkins
e287f98c3a
Fix definition of qr primitive to return only the upper triangular part of r. (#2870)
Issue #2863.
2020-04-28 12:01:54 -04:00
Peter Hawkins
5290c03a17
Remove usage of xla_client.{Computation,ComputationBuilder}. (#2808)
* Remove usage of xla_client.{Computation,ComputationBuilder}.

ComputationBuilder is a fairly pointless wrapper class that mimics an outdated version of the the C++ XLA API. It dates back from when we used to have SWIG bindings and needed to write a non-trivial Python shim to keep the interface pleasant to use. Now that we have pybind11-based bindings that are reasonably ergonomic by themselves, we don't need the wrapper class. Instead, we can simply call the pybind11-wrapped C++ API directly, removing the impedance mismatch between the C++ and Python APIs and allowing us to delete the Python ComputationBuilder class.

Similarly we can delete xla_client.Computation for the same reasons; it doesn't do anything useful on top of the C++ API.
2020-04-23 18:30:47 -04:00
Peter Hawkins
8fe3c59ced
Add explicit derivative for jax.numpy.linalg.pinv. (#2794)
* Add explicit derivative for jax.numpy.linalg.pinv.

* Fix type confusion problems in the JVP rule for SVD that meant it produced 64-bit tangents for 32-bit primals.
2020-04-22 20:15:04 -04:00
Peter Hawkins
f371bfc0bf
Improve speed of LU decomposition on TPU. (#2526)
Increase the block size, which helps with compilation time.
Merge the two row permutations in the outer loop, which means we do row-at-a-time gathers.
2020-03-27 21:24:26 -04:00
Matthew Johnson
93d3e34721 make lax_linalg.solve_triangular allow vector rhs
also add tests for jax.scipy.linalg.cho_solve
2020-03-21 10:46:07 -07:00
Peter Hawkins
68b32bf704
Add mypy type checking (#2430)
* Add type annotations to make mypy pass.

* Add mypy to .travis.yml.
2020-03-18 17:06:05 -04:00
Matthew Johnson
47df7b95c4
change the xla representation of JAX's unit (#2416)
* change the xla representation of JAX's unit

Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.

This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:

```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
  return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```

The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.

The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.

In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).

* add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
Matthew Johnson
7f0463e2c9
remove input shapes from params of some primitives (#2410)
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:

  { lambda  ; a.
    let b = reduce_sum[ axes=(0,) ] a
    in b }

The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!

Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:

  { lambda  ; a.
    let b = reduce_sum[ axes=(0,)
                        input_shape=(3,) ] a
    in b }

That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)

But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!

That's exactly what this commit does!

Co-authored-by: Roy Frostig <frostig@google.com>

Co-authored-by: Roy Frostig <frostig@google.com>
2020-03-13 07:13:29 -07:00
George Necula
c52f32b59d
Removed unused imports (#2385)
Also disabled a couple more linalg tests that crash on my Mac
2020-03-09 20:42:08 +01:00
Stephan Hoyer
218a1711d2
Add a jit around lax_linalg.lu_pivots_to_permutation (#2277)
I think this is almost always called inside a jit already, but adding this
results in more interprettable JAXprs.
2020-02-20 16:10:09 -08:00
Stephan Hoyer
aca7bccefd
Consolidate LU solve logic from scipy/numpy in lax_linalg.lu_solve (#2144)
* Consolidate LU solve logic from scipy/numpy in lax_linalg.lu_solve

This single implementation supports broadcasting like NumPy in both the NumPy
and SciPy interfaces to LU solve, even though only original NumPy supports
broadcasting.

This change is technical backwards incompatible in the SciPy wrapper, which
previously supported adding extra dimensions to the end of `b`, e.g.,
`b.shape == (8, 4, 2)` when `a.shape == (8, 8)`. There was a testcase for this,
but it isn't documented in either JAX or SciPy.

* fix recursive import

* Use np.vectorize instead of experimental.vectorize
2020-02-12 17:05:18 -08:00
Stephan Hoyer
0644f5c561
Better batching rule for triangular_solve (#2138)
* Better batching rule for triangular_solve

Now, if only the right hand side argument `b` is batched, we leverage
triangular solve's builtin batching for handling multiple right-hand-side
vectors.

This makes the performance of `vmap` over only the second argument of linear
solves equivalent to relying on builtin batching::

    rs = onp.random.RandomState(0)
    a = rs.randn(500, 500) + 0.1 * np.eye(500)
    b_mat = jax.device_put(rs.randn(500, 10))
    solve1 = jax.jit(np.linalg.solve)
    solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1))

Before::

    In [6]: %timeit jax.device_get(solve1(a, b_mat))
    3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

    # 8x slower :(
    In [9]: %timeit jax.device_get(solve2(a, b_mat))
    23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Now::

    In [2]: %timeit jax.device_get(solve1(a, b_mat))
    3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

    # same speed :)
    In [3]: %timeit jax.device_get(solve2(a, b_mat))
    3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

* Test failures

* Check b.ndim == 2 in triangular solve shape rule
2020-02-03 09:27:03 -08:00
Peter Hawkins
991324f8df
Increase minimum jaxlib version to 0.1.38. (#2120) 2020-01-29 14:16:58 -05:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Aidan Dang
e0ed5adc75 Allow JVP for SVD when not computing singular vectors (#2076)
* Allow SVD JVP when not computing singular vectors

* Test SVD JVP when not computing full singular vecs
2020-01-27 11:57:43 -08:00
Srinivas Vasudevan
80b35dd4e5 Add betainc to JAX (#1998)
Adds betaln, a wrapper for the Beta function (scipy.special.betaln).
2020-01-15 16:13:11 -05:00
AmKhan
dcda87d0e7 added batching to LAPACK triangular_solve (#1985)
* Added batching to cpu triangular_solver

* addressed comments about int overflows and returned triangular solve to use XLA over LAPACK

* add todo to benchmark LAPACK vs XLA
2020-01-14 11:18:47 -05:00
Peter Hawkins
c5a9eba3a8
Implement batched cholesky decomposition using LAPACK/Cusolver (#1956)
* Implement batched Cholesky decomposition on CPU and GPU using LAPACK and cuSolver.

Adds support for complex batched Cholesky decomposition on both platforms..
Fix concurrency bug in batched cuBlas kernels where a host to device memcpy could take place too early before the device buffer was ready.
2020-01-07 10:56:15 -05:00
Peter Hawkins
a52dc452d2
Change jax.numpy scalar types to return 0D JAX arrays when instantiated. (#1836)
* Change jax.numpy scalar types to return 0D JAX arrays rather than NumPy scalars when instantiated.

jax.numpy and numpy have slightly different promotion behaviors. For consistency with JAX arrays, we would like the result of, say, `jax.numpy.int32(7)` to have the same promotion behavior as `jax.numpy.array(7, dtype=jax.numpy.int32)`. The easiest way to do this is to have the jax.numpy scalars return 0D arrays when instantiated; the difference between NumPy scalars and arrays is not a fundamental one and we do not need to distinguish between them in JAX.
2019-12-18 11:57:22 -05:00
Peter Hawkins
8f00e3f511
Disable SVD GPU test because it is failing due to an LLVM integrate. (#1890)
Remove some stale jaxlib version tests.
2019-12-18 11:07:39 -05:00
Stephan Hoyer
6ac1c569e8
Use HIGHEST precision for dot_general in linalg JVP rules (#1835) 2019-12-10 00:38:18 -08:00
George Necula
120270cb47 Refined the test disabling for only TPU 2019-12-04 15:38:17 +01:00
Peter Hawkins
f4aa5150e8
Move internal type-related functions into a new (internal) jax.types … (#1695)
* Move internal type-related functions into a new (internal) jax.types module.

Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.

Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.

* Rename jax.types to jax.dtypes.

* s/types/dtypes/ in tests.
2019-11-15 10:02:51 -05:00
Peter Hawkins
f7a44523be
Add some type helpers to lax_numpy. (#1593)
Prefer to use jax.numpy type helpers rather than numpy type helpers in various places.
Cleanup in preparation for adding bfloat16 support to jax.
2019-10-29 20:53:20 -04:00
Stephan Hoyer
5b724e0019
Fix grad(jit(custom_linear_solve)) encounters ad_util.Zero in backwards pass (#1537)
Fixes https://github.com/google/jax/issues/1536

I wrote a regression test, but I could not figure how to trigger this directly
with triangular_solve alone.
2019-10-21 18:03:36 -07:00
Stephan Hoyer
16219358e2
Remove unnecessary zero checks from JVP rules in lax_linalg (#1490)
I'm pretty sure all of these are dead code that can no longer be triggered,
e.g., as evidenced by their use of no longer existing `core.pack` and
`ad.TangentTuple`.
2019-10-14 09:49:10 -07:00
Stephan Hoyer
c82477862c
Speedup JVP for triangular solve (#1466)
* Speedup JVP for triangular solve

There is still room for improvement, e.g., by combining the currently separate
JVPs for a and b into a single expression (which will allow for saving an inner
triangular solve when both arguments are being differentiated), but this is
already significantly faster in the typical case of only solving a single
vector.

On my laptop's CPU, I measure 2.98 ms before vs 1.18 ms after on
a 500x500 matrix:

    rs = onp.random.RandomState(0)
    a = rs.randn(500, 500)
    b = rs.randn(500)

    @jax.jit
    def loss(a, b):
      return np.sum(jax.scipy.linalg.solve_triangular(a, b))
    grad = jax.jit(jax.grad(loss))

    %timeit jax.device_get(grad(a, b))

* comment

* Optimal order for left_side=False, too

* Test the JVP for lax_linalg.triangular_solve directly
2019-10-10 08:27:21 -07:00
Peter Hawkins
d83200ca53 Make eigh JVP work for batched inputs. 2019-10-09 14:35:30 -04:00
Peter Hawkins
62b459d1ac Use the XLA HLO implementation of symmetric Eigendecomposition as a fallback for backends that don't have a custom kernel. 2019-10-08 16:09:50 -04:00
Matthew Johnson
762b602f33
Merge pull request #1394 from j-towns/fix-scatter-caching
Ensure all ops get cache hits on second op-by-op mode call
2019-09-26 06:48:42 -07:00
Jamie Townsend
57d9fd6ab4 Googly indentation 2019-09-26 13:39:35 +02:00
Jamie Townsend
b24d6cacd3 Ensure LU decomposition cache hits in op-by-op 2019-09-26 11:25:51 +02:00
Stephan Hoyer
298d838be1 Update error message for eigh. 2019-09-24 18:25:17 -07:00