17 Commits

Author SHA1 Message Date
Jake VanderPlas
245581411e Add PEP484-compatible export for jax and its subpackages 2021-09-13 14:08:48 -07:00
sunilkpai
997ad31670 added bicgstab to new jax repo
fixed some bugs in the bicgstab method and adjusted tolerance for scipy comparison

fixed flake8

added some tests for gradients, fixed symmetry checks, modified lax.cond -> jnp.where

comment out gmres grad check, to be addressed on future PR

increasing tolerance for bicgstab grad test

change to order 1 checks for bicgstab (gmres still fails in order 1) for internal CI check

remove grad checks for now

changing tolerance to pass numpy comparison test
2021-02-18 18:01:28 -08:00
Stephan Hoyer
6cc5b28327 Cleanup/fixup jax.scipy.sparse.linalg.gmres and expose it publicly. 2020-12-03 09:23:00 -08:00
Peter Hawkins
94cd2046fa [JAX] Move implementation of jax.scipy.sparse.linalg into jax._src.
PiperOrigin-RevId: 343276958
2020-11-19 06:18:09 -08:00
Stephan Hoyer
7e62270e5a More unit-tests + mark gmres as internal for now 2020-11-10 21:02:51 -08:00
Stephan Hoyer
c5c71a0a37 Cleanup _safe_normalize 2020-11-08 17:04:00 -08:00
Adam GM Lewis
7ed9fe70ea Corrections to GMRES - now gives correct result.
Co-authored-by: gehring <clement.gehring@gmail.com>

Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-11-08 15:37:50 -08:00
gehring
342cc36051 Initial implementation of GMRES 2020-11-08 15:34:56 -08:00
Stephan Hoyer
36eb137dd3
Refine argument validation inside jax.scipy.sparse.linalg.cg (#3630)
Now we check tree structure and leaf shapes separately. This allow us to
support pytrees that either don't define equality or that define it
inconsistently (e.g., elementwise like NumPy) with builtin data structures like
list/dict.
2020-07-06 09:24:44 -07:00
Jake Vanderplas
bc51e9c7f6
deflake jax/scipy/* and add to setup.cfg (#3316) 2020-06-04 14:38:41 -07:00
Yusuke Oda
ccb8d45975
Uses jnp.square instead of power. (#3036)
* Uses multiplication instead of power.

* Uses jnp.square instead of mul and adds check if jnp.square is implemented by mul.
2020-05-12 11:04:53 -04:00
Stephan Hoyer
8fa707af98
Fixup complex values and tol in tests for jax.scipy.linalg.sparse.cg (#2717)
* Fixup complex values and tol in tests for jax.scipy.linalg.sparse.cg

The tests for CG were failing on TPUs:

- `test_cg_pytree` is fixed by requiring slightly less precision than the
  unit-test default.
- `test_cg_against_scipy` is fixed for complex values in two independent ways:
  1. We don't set both `tol=0` and `atol=0`, which made the termination
     behavior of CG (convergence or NaN) dependent on exactly how XLA handles
     arithmetic with denormals.
  2. We make use of *real valued* inner products inside `cg`, even for complex
     values. It turns that all these inner products are mathematically
     guaranteed to yield a real number anyways, so we can save some flops and
     avoid ill-defined comparisons of complex-values (see
     https://github.com/numpy/numpy/issues/15981) by ignoring the complex part
     of the result from `jnp.vdot`. (Real numbers also happen to have the
     desired rounding behavior for denormals on TPUs, so this on its own would
     also fix these failures.)

* comment fixup

* fix my comment
2020-04-14 22:35:48 -07:00
Stephan Hoyer
9cc5e9018c Renable custom_linear_solve and cg with complex values 2020-04-09 00:53:00 -07:00
Stephan Hoyer
e8f989e38f
Add import from scipy.sparse (#2621)
* Add import from scipy.sparse

* Fix formatting in cg docstring
2020-04-06 14:45:02 -07:00
Stephan Hoyer
1472eb3ade
DOC: note how derivatives are computed for CG (#2619) 2020-04-06 12:49:11 -07:00
Stephan Hoyer
1cf708ea77
Support pytrees in jax.scipy.linalg.cg (#2600)
* Support pytrees in jax.scipy.linalg.cg

Ideally there would be an easier way to write this, but for now this will do.

* Fixup test
2020-04-04 15:55:46 -07:00
Stephan Hoyer
1b93bb51a8
Implement scipy.sparse.linalg.cg (second try) (#2566)
* super minimal starter code

* Update optimizers.py

* implement flip with axis = None

* Create sparse.py

* fix some imports

* Update sparse.py

* add partial function & test

* Update lax_scipy_sparse_test.py

* Update lax_scipy_sparse_test.py

* add a test case for sparse pd matrix & add bigger dim

* address comments

* fix info return & create matrix with rng_factory

* Update lax_scipy_sparse_test.py

* Update lax_scipy_sparse_test.py

* Update sparse.py

* Update sparse.py

* Update sparse.py

* Update lax_scipy_sparse_test.py

* Update lax_scipy_sparse_test.py

* cast jax arrays into numpy array for scipy compatibility

* Update sparse.py

* Update sparse.py

* fix None issue, but algo is not working

* fix return of build_and_solve and output of while_loop

* fix condition func of while loop

* clearer variable names

* mismatch error

* Update lax_scipy_sparse_test.py

* Fixes to jax.experimental.sparse.cg

* Fix tests for gradients

* Add support for preconditioners to cg

* Move cg into scipy, update docs

* doc tweak

Co-authored-by: Tuan Nguyen <anhtuan277@gmail.com>
2020-04-03 13:37:11 -07:00