52 Commits

Author SHA1 Message Date
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Sergei Lebedev
2f70ae700a Migrate another subset of internal modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.

PiperOrigin-RevId: 572587137
2023-10-11 08:46:06 -07:00
Peter Hawkins
0adfafe293 Relax test tolerances.
This makes the tests pass on CPU with a slightly different seed (+ 1).

PiperOrigin-RevId: 542877795
2023-06-23 09:22:11 -07:00
Peter Hawkins
803c729b57 Fix some test failures under H100.
It seems that under H100 matmul precisions are a little lower by default than they historically were on A100. Opt out of tensorcore matmuls for tests that fail due to precision issues if they are enabled.

Happily, this also allows us to remove a number of TPU special cases for the same reason.

PiperOrigin-RevId: 539101155
2023-06-09 09:23:36 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Roy Frostig
26b75ff4ae add "linear solve batching via jacrev" test from github.com/google/jax/issues/14249 2023-02-01 20:01:53 -08:00
Peter Hawkins
72f4f389be Migrate remaining tests from jtu.cases_from_list to jtu.sample_product.
Delete jtu.cases_from_list.
2022-10-12 15:20:53 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Peter Hawkins
335b2cfb26 [JAX] Prepare not to export jax._src by default.
Currently
```
import jax
```
populates `jax._src` in the names exported from JAX. This change prepares for not exporting `jax._src` by default.

In particular, explicitly import modules from jax._src and refer to those imports rather than assuming jax._src contents will be around later. This is a common pattern in tests.

This change does not yet remove any exported names.

Issue https://github.com/google/jax/issues/11951

PiperOrigin-RevId: 469480816
2022-08-23 09:36:47 -07:00
Jake VanderPlas
4c4a83b108 [x64] make jax.scipy.sparse.linalg compatible with strict dtype promotion 2022-06-17 14:04:05 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Jake VanderPlas
30fd817130 jax.scipy.sparse.linalg: support sparse matrices as operators 2022-05-05 10:33:08 -07:00
Jake VanderPlas
97512e9e44 JaxTestCase: set jax_numpy_rank_promotion='raise' by default 2022-02-14 09:22:05 -08:00
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00
Jake VanderPlas
2e713a7b91 Add missing gmres test (part of #8849) 2021-12-08 16:32:33 -08:00
Jake VanderPlas
022f8ac2ee [x64] preserve weak types in jax.scipy.sparse solvers 2021-11-30 10:36:28 -08:00
Peter Hawkins
267a4ca4cb Reenable a test that was disabled due to an (apparently fixed) LLVM bug.
PiperOrigin-RevId: 403623977
2021-10-16 10:34:36 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Jake VanderPlas
6114e6a0d3 test_util: add decorator to set config values in test cases 2021-08-05 14:06:37 -07:00
Jake VanderPlas
30ea76cb6b disable rank promotion for jax scipy tests 2021-08-04 10:44:23 -07:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Qiao Zhang
730ec1b7f4 Disable test_gmres_against_scipy due to LLVM changes. 2021-06-02 13:56:08 -07:00
Peter Hawkins
73df92bbad Reenable a GMRES test on CPU.
It appears to pass once again with jaxlib 0.1.67. Unfortunately it just missed the 0.1.66 release.
2021-05-11 22:00:06 -04:00
Peter Hawkins
fb74246151 Disable gmres test on CPU.
This test has started failing at LLVM head; disabling it while we debug.
2021-04-23 15:13:40 -04:00
Skye Wanderman-Milne
346df9c557 Disable lax_scipy_sparse_test.py cases that are hanging on GPU.
See #6471.
2021-04-15 15:54:15 -07:00
Matthew Johnson
fd7b286ec9 unify configuration state handling 2021-03-23 18:56:01 -07:00
sunilkpai
d35ae4c9bf removing x64 enable for testing
removing commented-out x64 line
2021-02-19 11:17:27 -08:00
sunilkpai
997ad31670 added bicgstab to new jax repo
fixed some bugs in the bicgstab method and adjusted tolerance for scipy comparison

fixed flake8

added some tests for gradients, fixed symmetry checks, modified lax.cond -> jnp.where

comment out gmres grad check, to be addressed on future PR

increasing tolerance for bicgstab grad test

change to order 1 checks for bicgstab (gmres still fails in order 1) for internal CI check

remove grad checks for now

changing tolerance to pass numpy comparison test
2021-02-18 18:01:28 -08:00
Jake VanderPlas
2fd682ef2a Make jax_enable_x64 a thread-local value. 2021-02-04 09:48:22 -08:00
Jake VanderPlas
f74235cdae X32 tests: fail on dtype warnings 2020-12-08 13:03:30 -08:00
Peter Hawkins
03f423bb4c Relax some test tolerances that appear to be sensitive to the random seed. 2020-12-06 15:44:44 -05:00
Stephan Hoyer
cd9f6cccbf Support ndarrays as arguments to cg and gmres
This is consistent with SciPy, and makes things a little bit less
surprising for users.
2020-12-04 12:53:45 -08:00
Stephan Hoyer
6cc5b28327 Cleanup/fixup jax.scipy.sparse.linalg.gmres and expose it publicly. 2020-12-03 09:23:00 -08:00
Peter Hawkins
94cd2046fa [JAX] Move implementation of jax.scipy.sparse.linalg into jax._src.
PiperOrigin-RevId: 343276958
2020-11-19 06:18:09 -08:00
Stephan Hoyer
fca5666382 Relax GMRES test tolerances 2020-11-12 11:02:12 -08:00
Stephan Hoyer
7e62270e5a More unit-tests + mark gmres as internal for now 2020-11-10 21:02:51 -08:00
Adam GM Lewis
7ed9fe70ea Corrections to GMRES - now gives correct result.
Co-authored-by: gehring <clement.gehring@gmail.com>

Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-11-08 15:37:50 -08:00
Stephan Hoyer
36eb137dd3
Refine argument validation inside jax.scipy.sparse.linalg.cg (#3630)
Now we check tree structure and leaf shapes separately. This allow us to
support pytrees that either don't define equality or that define it
inconsistently (e.g., elementwise like NumPy) with builtin data structures like
list/dict.
2020-07-06 09:24:44 -07:00
Jake VanderPlas
afce718eb1 Add ability to specify individual test targets 2020-06-29 11:08:57 -07:00
Peter Hawkins
fffdb2daa8
Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_… (#3280)
* Make check_dtypes, atol, and rtol keyword-only arguments in jax.test_util APIs.
Default to check_dtypes=True.

Remove explicit usages of check_dtypes=True from tests. This mostly just removes visual noise from tests. Testing for exact type equality is the sensible default, although there are cases where opting out makes sense.

No functional changes intended.

* Fix a number of lax reference implementations to preserve types.
2020-06-01 17:19:23 -04:00
Stephan Hoyer
2be471cc7e
Remove redundant flag configuration from lax_scipy_sparse_test (#3128) 2020-05-17 22:12:21 -07:00
Peter Hawkins
7116cc5b41
Improve JAX test PRNG APIs to fix correlations between test cases. (#2957)
* Improve JAX test PRNG APIs to fix correlations between test cases.

In #2863, we observed that we were missing gradient problems because the random test cases being generated were too similar because they were formed with identically seeded PRNGs. This change updates the test_util.rand_...() functions to take an explicit numpy.random.RandomState, and adds a rng() method to JaxTestCase to form a RandomState seeded on the test case name.

This gives the following properties:
* different test cases receive different seeds
* PRNG seeding is deterministic and independent of execution order and sharding.
* PRNG seeding is deterministic across runs.

* Fix some failing tests.

* Fix more test failures.

Simplify ediff1d implementation and make it more permissive when casting.

* Relax test tolerance of laplace CDF test.
2020-05-04 23:00:20 -04:00
Peter Hawkins
d61d6f44dc
Fix a number of flaky tests. (#2953)
* relax some test tolerances.
* disable 'random' preconditioner in CG test (#2951).
* ensure that scatter and top-k tests don't create ties.
2020-05-04 14:34:08 -04:00
Matthew Johnson
1bcaef142f
apply is_stable=True to sort translation rules (#2789)
fixes #2779
2020-04-21 17:47:28 -07:00
Stephan Hoyer
8fa707af98
Fixup complex values and tol in tests for jax.scipy.linalg.sparse.cg (#2717)
* Fixup complex values and tol in tests for jax.scipy.linalg.sparse.cg

The tests for CG were failing on TPUs:

- `test_cg_pytree` is fixed by requiring slightly less precision than the
  unit-test default.
- `test_cg_against_scipy` is fixed for complex values in two independent ways:
  1. We don't set both `tol=0` and `atol=0`, which made the termination
     behavior of CG (convergence or NaN) dependent on exactly how XLA handles
     arithmetic with denormals.
  2. We make use of *real valued* inner products inside `cg`, even for complex
     values. It turns that all these inner products are mathematically
     guaranteed to yield a real number anyways, so we can save some flops and
     avoid ill-defined comparisons of complex-values (see
     https://github.com/numpy/numpy/issues/15981) by ignoring the complex part
     of the result from `jnp.vdot`. (Real numbers also happen to have the
     desired rounding behavior for denormals on TPUs, so this on its own would
     also fix these failures.)

* comment fixup

* fix my comment
2020-04-14 22:35:48 -07:00
Peter Hawkins
1298e9e8c4
Fix some test failures. (#2713) 2020-04-14 18:23:19 -04:00
Peter Hawkins
2dc81fb40c
Make pytest run over JAX tests warning clean, and error on warnings. (#2674)
* Make pytest run over JAX tests warning clean, and error on warnings.

Remove global warning suppression in travis.yml. Instead add a pytest.ini that converts warnings to errors, with the exception of a whitelist.
Either fix or locally suppress warnings in tests.

Also fix crashes on Mac related to a preexisting linear algebra bug.

* Fix some type errors in the FFT transpose rules revealed by the convert_element_type transpose rule change.
2020-04-12 15:35:35 -04:00
Stephan Hoyer
9cc5e9018c Renable custom_linear_solve and cg with complex values 2020-04-09 00:53:00 -07:00