59 Commits

Author SHA1 Message Date
tilakrayal
3bbd141d3d
Fixing the naming conventions in linalg.py 2024-06-27 13:22:39 +05:30
Dan Foreman-Mackey
6d35b109fd Rename "Example" to "Examples" in docstrings.
This PR updates all docstrings that previously had a section heading
called "Example" and replaces that with "Examples" to be consistent.
2024-06-21 11:43:16 -04:00
rajasekharporeddy
7989c70572 Add example code snippets to jax.scipy.linalg.expm and jax.scipy.linalg.polar docs 2024-06-08 03:30:12 +05:30
rajasekharporeddy
aee310b54d Fix some doc typos 2024-06-06 15:19:53 +05:30
Jake VanderPlas
d51ccdf628 DOC: Improve docstrings for jax.scipy.linalg 2024-05-01 17:36:24 -07:00
rajasekharporeddy
4d6a53fb63 Add Hilbert matrix to jax.scipy.linalg 2024-03-20 22:55:03 +05:30
Jake VanderPlas
85f205bdc7 typing: fix incorrect tuple annotations 2024-02-26 10:53:19 -08:00
Jake VanderPlas
43a9faa06a Rename _wraps to implements 2024-01-24 14:14:19 -08:00
Jake VanderPlas
e356d76913 Remove a number of deprecated APIs
All of these were deprecated prior to the JAX 0.4.16 release, on Sept 18 2023.
As of Monday Dec 18, we have met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 591933493
2023-12-18 10:08:47 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Jake VanderPlas
340e655ac2 Remove deprecated sym_pos argument from jax.scipy.linalg.solve
PiperOrigin-RevId: 580940755
2023-11-09 09:53:37 -08:00
Peter Hawkins
3082109a59 Add a type stub for jax.numpy.
This type stub is intended to match what pytype currently infers for jax.numpy, which is not particularly accurate in many cases. Future changes will add more accurate types to this stub.

Fix a number of new type errors this reveals to mypy.

PiperOrigin-RevId: 559179804
2023-08-22 11:50:49 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Jake VanderPlas
222b951b19 Use new matrix_transpose in linalg code 2023-05-25 09:32:14 -07:00
Jake VanderPlas
ad0fc8979b jax.scipy.linalg.expm: support batched inputs 2023-03-27 16:39:48 -07:00
Jake VanderPlas
760deb310e Remove leading underscores in jax._src.numpy.util 2023-03-13 12:18:36 -07:00
Jake VanderPlas
c8c269f5f5 internal: avoid unused imports in lax_numpy 2023-03-08 10:29:04 -08:00
tttc3
96707f09b1 Removed deprecated polar_unitary as per comment. 2023-01-27 07:24:55 +00:00
Yotaro Kubo
1ade5f8592 Add jax.scipy.linalg.toeplitz. 2022-12-09 01:03:21 +09:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
Jake VanderPlas
26d9837b36 Switch to new-style f-strings 2022-12-01 09:14:16 -08:00
Peter Hawkins
1cead779a3 Add support for Hessenberg and tridiagonal matrix reductions on CPU.
* Implement jax.scipy.linalg.hessenberg and jax.lax.linalg.hessenberg.
* Export what was previously jax._src.lax.linalg.orgqr as jax.lax.linalg.householder_product, since it can be used with some minor tweaks to compute the unitary matrix of a Hessenberg reduction.
* Implement jax.lax.linalg.tridiagonal, which is the symmetric (Hermitian) equivalent of Hessenberg reduction.

None of these primitives are differentiable at the moment.

PiperOrigin-RevId: 487224934
2022-11-09 06:23:55 -08:00
Peter Hawkins
807269990e Enable more GPU and TPU tests that pass at head.
Increase precision of matmuls in LU decompositions, pseudo-inverse solves, and their gradients. It is unlikely users want to use low precision for these operations and high precision is probably the right default.

PiperOrigin-RevId: 482071629
2022-10-18 18:09:44 -07:00
Jake VanderPlas
afe74b4710 [typing] add type annotations to jax.scipy.linalg 2022-10-10 16:54:29 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
29d03160e3 Remove _ prefix from functions in jax._src.dtypes.
to_inexact_dtype and to_complex_dtype are used across the JAX code base,
so they shouldn't have _ prefixes.
2022-08-12 12:51:09 +00:00
Jake VanderPlas
114b03670c Add missing f-string marker 2022-07-20 10:48:07 -07:00
Jake VanderPlas
9090dd179d jax.scipy.linalg.solve: deprecate the sym_pos argument following scipy 1.9.0 2022-07-19 13:57:49 -07:00
Jake VanderPlas
b5ba210097 [x64] make linalg functions & tests compatible with strict dtype promotion 2022-06-16 10:32:20 -07:00
Jake VanderPlas
e888e7c10c [x64] make lax_scipy_test.py compatible with strict dtype promotion 2022-06-14 10:02:45 -07:00
Jeppe Klitgaard
838a05329d feat: validate jit args 2022-05-18 21:54:47 +01:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
7ba36fc178 Change implementation of jax.scipy.linalg.polar() and jax._src.scipy.eigh to use the QDWH decomposition from jax._src.lax.qdwh.
Remove jax._src.lax.polar.

PiperOrigin-RevId: 448241206
2022-05-12 07:20:52 -07:00
Peter Hawkins
705e241409 Change non-array arguments to jax.lax.linalg functions to be keyword-only arguments.
PiperOrigin-RevId: 448066207
2022-05-11 13:06:54 -07:00
jax authors
39ac9faf61 Merge pull request #10419 from YouJiacheng:jsp.linalg-unused-params
PiperOrigin-RevId: 444840248
2022-04-27 06:27:09 -07:00
Jake VanderPlas
822b6aad3b jax.scipy.qr: fix return type for mode='r' 2022-04-26 11:26:56 -07:00
YouJiacheng
ecab8e00e8 try to improve docs for scipy.linalg with unused parameters 2022-04-26 14:56:26 +08:00
YouJiacheng
af7b94b110 Fix typo of #10381
and add a basic regression test
2022-04-21 02:17:09 +08:00
YouJiacheng
bb2682db6d remove numpy.linalg._promote_arg_dtypes
in favor of numpy.util._promote_dtypes_inexact
2022-04-21 00:23:56 +08:00
Alex Riley
869596fc2c Add jax.scipy.linalg.rsf2csf 2022-04-06 21:06:23 +01:00
Leello Tadesse Dadi
cb732323f3 adds jax.scipy.linalg.sqrtm 2022-02-16 22:33:47 +01:00
Leello Tadesse Dadi
514d8883ce adds jax.scipy.schur 2022-02-16 22:33:37 +01:00
Jake VanderPlas
27f285782b linalg_test: disable implicit rank promotion 2022-01-26 09:29:06 -08:00
Attila Szabó
70bf281250 Fix max_squarings in expm 2022-01-25 09:54:23 +00:00
DanPuzzuoli
2d2ac12aa0 switching to use jnp.digitize for index identification 2022-01-19 10:02:06 -08:00
DanPuzzuoli
cbb2f7baab changed to use argwhere 2022-01-18 16:28:36 -08:00
DanPuzzuoli
9d43ccafa3 converting to use switch 2022-01-18 14:17:17 -08:00
Peter Hawkins
2eb20357db Add @jit decorators to jax.numpy.linalg and jax.scipy.linalg. 2021-09-24 15:52:11 -04:00
Peter Hawkins
b232d09440 Enable flake8 checks for spaces around operators. 2021-07-30 08:45:38 -04:00
Peter Hawkins
0dfd76af97 Remove additional info return value from jax.scipy.linalg.polar(). 2021-07-20 13:13:31 -04:00