22 Commits

Author SHA1 Message Date
Sergei Lebedev
92b1f71314 Removed various ununsed functions
To rerun the analysis do

    python -m vulture jax/_src --ignore-names "[A-Za-z]*" --ignore-decorators "*"
2024-08-01 11:18:19 +01:00
Peter Hawkins
84c1e825c0 Make jax.numpy.where()'s condition, x, y arguments positional-only to match numpy.where.
PiperOrigin-RevId: 584377134
2023-11-21 11:10:12 -08:00
Peter Hawkins
cd0533cab0 Replace uses of jnp.ndarray with jax.Array inside JAX.
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Roy Frostig
e199b35f4e Revert "Merge pull request #14113 from botev:main"
This reverts commit 69d18cc7b58ae4ed82246605d66ed07a49fad676, reversing
changes made to 13e875f8b8d8dd9152045c7e3b5045a9bb0d7db0.

Reverting until we address https://github.com/google/jax/issues/14249
2023-02-01 19:50:27 -08:00
botev
73ed511d39 Adding info to CG and BICGSTAB 2023-01-22 21:47:34 +00:00
Jake VanderPlas
924894fdd6 [x64] make tests more type-safe 2022-12-02 13:21:35 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
9090dd179d jax.scipy.linalg.solve: deprecate the sym_pos argument following scipy 1.9.0 2022-07-19 13:57:49 -07:00
Jake VanderPlas
4c4a83b108 [x64] make jax.scipy.sparse.linalg compatible with strict dtype promotion 2022-06-17 14:04:05 -07:00
Jake VanderPlas
ebd53a48ab [x64] gmres: avoid problematic type promotions 2022-06-09 14:05:56 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Jake VanderPlas
30fd817130 jax.scipy.sparse.linalg: support sparse matrices as operators 2022-05-05 10:33:08 -07:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
Jake VanderPlas
998d60dd07 DOC: clarify parameter types in cg/bicgstab 2022-03-23 08:35:25 -07:00
Roy Frostig
8f93629e87 remove _convert_element_type from public jax.lax module 2022-03-09 18:46:38 -08:00
Jake VanderPlas
f8e18e9a00 [x64] minor weak_type changes to linalg.py 2021-12-07 16:27:29 -08:00
Jake VanderPlas
022f8ac2ee [x64] preserve weak types in jax.scipy.sparse solvers 2021-11-30 10:36:28 -08:00
sunilkpai
997ad31670 added bicgstab to new jax repo
fixed some bugs in the bicgstab method and adjusted tolerance for scipy comparison

fixed flake8

added some tests for gradients, fixed symmetry checks, modified lax.cond -> jnp.where

comment out gmres grad check, to be addressed on future PR

increasing tolerance for bicgstab grad test

change to order 1 checks for bicgstab (gmres still fails in order 1) for internal CI check

remove grad checks for now

changing tolerance to pass numpy comparison test
2021-02-18 18:01:28 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Stephan Hoyer
cd9f6cccbf Support ndarrays as arguments to cg and gmres
This is consistent with SciPy, and makes things a little bit less
surprising for users.
2020-12-04 12:53:45 -08:00
Stephan Hoyer
6cc5b28327 Cleanup/fixup jax.scipy.sparse.linalg.gmres and expose it publicly. 2020-12-03 09:23:00 -08:00
Peter Hawkins
94cd2046fa [JAX] Move implementation of jax.scipy.sparse.linalg into jax._src.
PiperOrigin-RevId: 343276958
2020-11-19 06:18:09 -08:00