32 Commits

Author SHA1 Message Date
Jake VanderPlas
b5ba210097 [x64] make linalg functions & tests compatible with strict dtype promotion 2022-06-16 10:32:20 -07:00
Sergei Lebedev
c5d3ece6f5 MAINT Fixed new mypy errors
mypy seems to handle lambdas and named functions differently. So, I had to
promote a few helpers to named functions to get them to type check.
2022-05-23 20:21:00 +01:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
1bcb5e073c Add an implementation of jnp.linalg.slogdet based on QR decomposition.
Adds a non-standard `method` argument to `jnp.linalg.slogdet` to select between the current LU decomposition based implementation (like NumPy) and the QR decomposition implementation.

QR decomposition is more amenable to a high performance batched implementation particularly on TPU hardware because it does not need row pivoting. The same may be true on other hardware also, and having the option is nice either way!

PiperOrigin-RevId: 449271317
2022-05-17 11:24:11 -07:00
Peter Hawkins
909c0328b0 Decompose lax.linalg.qr into two subprimitives geqrf and orgqr.
In essence, this lifts the implementation of QR decomposition out of the lowering rules and into the JAX level instead.

This is useful because it allows direct access to the raw form of the decomposition returned by geqrf; sometimes we actually want access to the Householder reflectors instead of their product. Currently neither geqrf nor orgqr are differentiable in isolation.

Change in preparation for adding an implementation of jnp.linalg.slogdet that uses QR decomposition instead of LU decomposition.

Fixes https://github.com/google/jax/issues/2322

PiperOrigin-RevId: 449033350
2022-05-16 12:59:57 -07:00
Peter Hawkins
705e241409 Change non-array arguments to jax.lax.linalg functions to be keyword-only arguments.
PiperOrigin-RevId: 448066207
2022-05-11 13:06:54 -07:00
YouJiacheng
bb2682db6d remove numpy.linalg._promote_arg_dtypes
in favor of numpy.util._promote_dtypes_inexact
2022-04-21 00:23:56 +08:00
Lukas Geiger
50e8bc4514 Replace reshape with expand_dims if possible 2022-03-31 01:34:26 +01:00
jax authors
cf1161ff8b Merge pull request #9826 from froystig:lax-cleanup2
PiperOrigin-RevId: 433827272
2022-03-10 12:48:34 -08:00
Peter Hawkins
051f4dd0cf Suggest eigh() in the eig() not implemented error. 2022-03-10 08:51:13 -05:00
Roy Frostig
8f93629e87 remove _convert_element_type from public jax.lax module 2022-03-09 18:46:38 -08:00
Roy Frostig
f7731bf959 remove _const from public jax.lax module
Modify all internal call sites to use `jax._src.lax.lax._const`.
2022-03-07 12:26:25 -08:00
Jake VanderPlas
27f285782b linalg_test: disable implicit rank promotion 2022-01-26 09:29:06 -08:00
Jake VanderPlas
f8e18e9a00 [x64] minor weak_type changes to linalg.py 2021-12-07 16:27:29 -08:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Peter Hawkins
05e6f84919 Implement hermitian=... option on jax.numpy.linalg.svd. 2021-11-01 09:55:30 -04:00
Peter Hawkins
2eb20357db Add @jit decorators to jax.numpy.linalg and jax.scipy.linalg. 2021-09-24 15:52:11 -04:00
Peter Hawkins
a84426cb8f Switch internal users of jax.ops.index_... to use x.at[x].set() APIs. 2021-09-13 19:48:29 -04:00
ho-oto
160c3e9357
rename v to vh 2021-05-31 22:18:24 +09:00
cdfreeman-google
d69bb535b1 Added special cases for 2x2 and 3x3 determinant, and added test coverage for these cases. 2021-04-22 13:48:05 -07:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Jake VanderPlas
796f1bde7b DOC: add note about return type in jnp.linalg.eig 2021-03-22 10:43:04 -07:00
Roy Frostig
e0b3ef0f65 fix broadcasted 1x1 cofactor solve, called by linalg.det jvp 2021-03-19 19:13:45 -07:00
Jake VanderPlas
91872497f9 Support empty inputs in jnp.linalg.svd() and jnp.linalg.pinv() 2021-01-13 10:39:00 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Jake VanderPlas
65954cebeb Fix linalg.solve() and linalg.inv() for empty matrices 2021-01-05 11:51:32 -08:00
Jake VanderPlas
9d2f6148ed Call asarray() rather than array() to avoid host round-trips. 2020-11-24 16:05:48 -08:00
David Pfau
8d1daba901 Add complex types to gradient of slogdet 2020-11-24 16:57:35 +00:00
Jamie Townsend
5fccc89a42 Add derivatives for eigenvalues (not eigenvectors)
We aren't supporting eigenvectors for now because eigenvectors are not
uniquely determined by the input matrix, they're only determined up to
'gauge' (that is multiplication by a complex scalar with absolute value
1). Note, this means that second derivatives aren't supported, because
they involve differentiating the eigvals jvp, which itself depends on
eigenvectors.
2020-11-20 16:41:40 +00:00
Peter Hawkins
c57bbb3cea [JAX] Move jax/lax_linalg.py to jax/_src/lax/linalg.py.
Because we now have a facade around the lax library, we can expose the lax_linalg primitives directly in lax without creating circular dependency problems.

Leave a few forwarding stubs to be removed later.

PiperOrigin-RevId: 340658800
2020-11-04 08:59:36 -08:00
Peter Hawkins
3ddd3905a4 Move jax.third_party to jax._src.third_party.
PiperOrigin-RevId: 337675377
2020-10-17 11:43:07 -07:00
Peter Hawkins
aa107cf1f4 Move jax.numpy internals into jax._src.numpy. 2020-10-16 20:35:19 -04:00