* Add options to compute L/R eigenvectors in geev.
The new arguments are by default set to True to ensure backwards
compatibility between jaxlib and jax.
Reformulate eig-related operations based on the new geev API.
* Addressed hawkinsp's comments from google/jax#3882.
Co-authored-by: Skye Wanderman-Milne <skyewm@google.com>
The permutation is more efficiently computed during the decomposition on TPU, and the only use case that would not require us to compute it would be for evaluating determinants.
* Fix svd_p to stop returning garbage values of u/vt when compute_uv == False.
The custom call made by svd_p does not compute u and vt when
compute_uv is set to False. Returning them using the primitive
means that it is up to the caller to throw away these values.
* Added documentation to svd.
* Revert "linalg_test: define test matrices lazily (#3684)"
This reverts commit 2be1baa41a170192c209c94b060d0d034d1de2c2.
* Revert "Make LU gradient work for low-rank matrices (#3610)"
This reverts commit 23deefa71838ceeab41977ac0ab781164c914a8c.
* Correct LU JVP rule to handle low-rank matrices
This requires one additional triangular_solve, though the timing should still be dominated by the LU decomposition itself. Also, the problem of computing the JVP of L for a low-rank matrix is underdetermined and likely depends on the backend used. We do not enforce any form on L'.
Still need to fix:
*NaNs in reverse mode
*Non-square matrices
*Complex matrices
* Updated LU JVP rule for VJP and nonsquare, complex matrices
* Added singular test matrices to linalg_test header
* Add tests of LU for singular matrices
* Upgrade order for Det and LU grad tests
* Increased matmul precision for better TPU accuracy
* Added comment explaining fix to LU gradient
* Remove trailing whitespace
* Moved nonsquare singular matrices next to testLuGradOfNonSquareSingularMatrix
* Fixed linter issues
* Revert changes to tests
* Fix eigh JVP to ensure that both the primal and tangents of the eigenvalues are real.
Add test to jax.test_util.check_jvp that ensure the primals and both the primals and tangents produced by a JVP rule have identical types.
* Cast input to static indexing grad tests to a JAX array so new type check passes.
This is useful for remat transpose rule submitted in #3162 and e.g.
allowed me to catch a slight overuse of defjvp2 for `random_gamma_p` (it
was unnecessarily declared as having multiple outputs).
* Remove usage of xla_client.{Computation,ComputationBuilder}.
ComputationBuilder is a fairly pointless wrapper class that mimics an outdated version of the the C++ XLA API. It dates back from when we used to have SWIG bindings and needed to write a non-trivial Python shim to keep the interface pleasant to use. Now that we have pybind11-based bindings that are reasonably ergonomic by themselves, we don't need the wrapper class. Instead, we can simply call the pybind11-wrapped C++ API directly, removing the impedance mismatch between the C++ and Python APIs and allowing us to delete the Python ComputationBuilder class.
Similarly we can delete xla_client.Computation for the same reasons; it doesn't do anything useful on top of the C++ API.
* Add explicit derivative for jax.numpy.linalg.pinv.
* Fix type confusion problems in the JVP rule for SVD that meant it produced 64-bit tangents for 32-bit primals.
* change the xla representation of JAX's unit
Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.
This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:
```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```
The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.
The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.
In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).
* add comment about JAX<->XLA array types assumption
Long, long ago, when JAX was first born, we realized that we couldn't
transpose this jaxpr:
{ lambda ; a.
let b = reduce_sum[ axes=(0,) ] a
in b }
The problem was that the transpose of a reduce-sum is a broadcast, but
because jaxprs didn't have shape information available, we didn't know
what input shape to broadcast to!
Our hack was to have the primitives that required shape information for
transposition to acquire it into their parameters, so that we'd produce
jaxprs like this one:
{ lambda ; a.
let b = reduce_sum[ axes=(0,)
input_shape=(3,) ] a
in b }
That's not only aesthetically unpleasant, but also it meant we were
limiting an (unused) capability of the system: ideally we should be able
to trace a reduce-sum jaxpr without specializing on shape information
(e.g. at the Unshaped level) and only require shape specialization for
transposition. (Good thing no one actually traces at Unshaped...)
But at long last @chr1sj0nes in #2299 added avals to jaxprs, so that
shape information (or whatever information with which the jaxpr was
specialized out of Python) is in the jaxpr itself. So we could finally
remove these shapes-in-params warts!
That's exactly what this commit does!
Co-authored-by: Roy Frostig <frostig@google.com>
Co-authored-by: Roy Frostig <frostig@google.com>
* Consolidate LU solve logic from scipy/numpy in lax_linalg.lu_solve
This single implementation supports broadcasting like NumPy in both the NumPy
and SciPy interfaces to LU solve, even though only original NumPy supports
broadcasting.
This change is technical backwards incompatible in the SciPy wrapper, which
previously supported adding extra dimensions to the end of `b`, e.g.,
`b.shape == (8, 4, 2)` when `a.shape == (8, 8)`. There was a testcase for this,
but it isn't documented in either JAX or SciPy.
* fix recursive import
* Use np.vectorize instead of experimental.vectorize
* Better batching rule for triangular_solve
Now, if only the right hand side argument `b` is batched, we leverage
triangular solve's builtin batching for handling multiple right-hand-side
vectors.
This makes the performance of `vmap` over only the second argument of linear
solves equivalent to relying on builtin batching::
rs = onp.random.RandomState(0)
a = rs.randn(500, 500) + 0.1 * np.eye(500)
b_mat = jax.device_put(rs.randn(500, 10))
solve1 = jax.jit(np.linalg.solve)
solve2 = jax.jit(jax.vmap(np.linalg.solve, in_axes=(None, 1), out_axes=1))
Before::
In [6]: %timeit jax.device_get(solve1(a, b_mat))
3.88 ms ± 293 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# 8x slower :(
In [9]: %timeit jax.device_get(solve2(a, b_mat))
23.5 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Now::
In [2]: %timeit jax.device_get(solve1(a, b_mat))
3.76 ms ± 304 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# same speed :)
In [3]: %timeit jax.device_get(solve2(a, b_mat))
3.72 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
* Test failures
* Check b.ndim == 2 in triangular solve shape rule
* Added batching to cpu triangular_solver
* addressed comments about int overflows and returned triangular solve to use XLA over LAPACK
* add todo to benchmark LAPACK vs XLA
* Implement batched Cholesky decomposition on CPU and GPU using LAPACK and cuSolver.
Adds support for complex batched Cholesky decomposition on both platforms..
Fix concurrency bug in batched cuBlas kernels where a host to device memcpy could take place too early before the device buffer was ready.
* Change jax.numpy scalar types to return 0D JAX arrays rather than NumPy scalars when instantiated.
jax.numpy and numpy have slightly different promotion behaviors. For consistency with JAX arrays, we would like the result of, say, `jax.numpy.int32(7)` to have the same promotion behavior as `jax.numpy.array(7, dtype=jax.numpy.int32)`. The easiest way to do this is to have the jax.numpy scalars return 0D arrays when instantiated; the difference between NumPy scalars and arrays is not a fundamental one and we do not need to distinguish between them in JAX.
* Move internal type-related functions into a new (internal) jax.types module.
Avoid calling onp type functions in lieu of the wrappers in jax.types. Currently these do the same thing, but future changes will make the behavior of the jax type functions diverge from the classic NumPy versions in some cases.
Move xla_bridge.canonicalize_dtype into jax.types, since it fits there more naturally.
* Rename jax.types to jax.dtypes.
* s/types/dtypes/ in tests.
I'm pretty sure all of these are dead code that can no longer be triggered,
e.g., as evidenced by their use of no longer existing `core.pack` and
`ad.TangentTuple`.
* Speedup JVP for triangular solve
There is still room for improvement, e.g., by combining the currently separate
JVPs for a and b into a single expression (which will allow for saving an inner
triangular solve when both arguments are being differentiated), but this is
already significantly faster in the typical case of only solving a single
vector.
On my laptop's CPU, I measure 2.98 ms before vs 1.18 ms after on
a 500x500 matrix:
rs = onp.random.RandomState(0)
a = rs.randn(500, 500)
b = rs.randn(500)
@jax.jit
def loss(a, b):
return np.sum(jax.scipy.linalg.solve_triangular(a, b))
grad = jax.jit(jax.grad(loss))
%timeit jax.device_get(grad(a, b))
* comment
* Optimal order for left_side=False, too
* Test the JVP for lax_linalg.triangular_solve directly