Benjamin Chetioui
4e748e1fda
Bump up tolerance for tests using _testQdwh
.
...
PiperOrigin-RevId: 643547107
2024-06-14 23:18:20 -07:00
jax authors
c839b268d2
Get rid of the is_hermitian argument for lax.qdwh. If it was known that H was also positive semi-definite, the polar decomposition would be I*H. But for indefinite H, the QDWH algorithm does not differ from the general case for Hermitian inputs.
...
PiperOrigin-RevId: 643141687
2024-06-13 15:33:49 -07:00
jax authors
dd3b0a6981
Add test for QDWH with dynamic shapes.
...
PiperOrigin-RevId: 643087130
2024-06-13 12:33:20 -07:00
jax authors
95e2c17b61
Update test of QDWH to use stricter tolerances and test more shapes and types.
...
Get rid of comparison with scipy.linalg.polar, since its outputs are significantly less accurate than QDWH. Since the polar decomposition is unique, comparing to a less accurate implementation does not add value.
PiperOrigin-RevId: 642423757
2024-06-11 16:04:38 -07:00
James Lottes
9fd5f7c6a2
Refactor QDWH to be more efficient when run batched under vmap.
...
In particular, avoid using lax.cond to switch to CholeskyQR for later iterations, as under vmap this can result in both branches being executed.
PiperOrigin-RevId: 628144162
2024-04-25 11:48:21 -07:00
Sergei Lebedev
cbcaac2756
MAINT Migrate remaining internal/test modules to use state objects
...
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
This is a follow up to #18008 .
2023-10-12 17:32:15 +01:00
Peter Hawkins
ef3f2abfd2
Fix test failures in JAX under NumPy 1.25.0rc1.
...
`jnp.finfo(...)` of an Array type yields:
```
TypeError: unhashable type: 'ArrayImpl'
```
However, `np.finfo(...)` no longer accepts NumPy arrays as input either, so it would be consistent to require the user to pass a dtype where they are currently passing an array.
PiperOrigin-RevId: 539174254
2023-06-09 14:10:35 -07:00
Jake VanderPlas
fbe4f10403
Change to simpler import for jax.config
2023-04-21 11:51:22 -07:00
jax authors
f25b701b26
[XLA] Change criterion for annihilating off-diagonal elements in the 2x2 symmetric Schur decomposition used by eigh. This significantly improves the accuracy, and makes eigh exact for the identity matrix.
...
Modify the QDWH test so it doesn't have a dependence on eigh.
PiperOrigin-RevId: 523171958
2023-04-10 11:43:56 -07:00
Peter Hawkins
cd0533cab0
Replace uses of jnp.ndarray with jax.Array inside JAX.
...
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Jake VanderPlas
f09fd8a4e9
[x64] minor test-only updates for better type safety
2022-11-30 15:18:40 -08:00
Peter Hawkins
2ba0396ddb
Add changes accidentally omitted from
...
https://github.com/google/jax/pull/12717
2022-10-10 19:11:58 +00:00
Peter Hawkins
c657449528
Copybara import of the project:
...
--
d39bdefb33a19e407c352df27fb04127f4fe8a1d by Peter Hawkins <phawkins@google.com>:
Migrate more tests from jtu.cases_from_list to jtu.sample_product.
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/12717 from hawkinsp:sampletest d39bdefb33a19e407c352df27fb04127f4fe8a1d
PiperOrigin-RevId: 480136538
2022-10-10 11:35:32 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Matthew Johnson
68e3f58041
un-skip polar/qdwh decomp tests skipped on gpu in ad6ce74
...
On an A100 machine, these tests seem to run fine now. See https://github.com/google/jax/issues/8628#issuecomment-1215651697 .
2022-08-15 12:31:43 -07:00
Jeppe Klitgaard
17de89b16a
feat: refactor code using pyupgrade
...
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```
a
2022-05-17 22:14:05 +01:00
Peter Hawkins
db73670ec3
Add support for padded arrays in QDWH algorithm.
...
This change is in preparation for adding a jit-table QDWH-eig implementation.
PiperOrigin-RevId: 448571523
2022-05-13 13:57:36 -07:00
Peter Hawkins
7ba36fc178
Change implementation of jax.scipy.linalg.polar() and jax._src.scipy.eigh to use the QDWH decomposition from jax._src.lax.qdwh.
...
Remove jax._src.lax.polar.
PiperOrigin-RevId: 448241206
2022-05-12 07:20:52 -07:00
Rohit Santhanam
8d9f17df19
Disabled one and enabled several unit tests for ROCm.
2022-05-10 19:47:26 +00:00
Tianjian Lu
d57e36416f
[linalg] Update qdwh to prevent underflow in norm estimation.
...
PiperOrigin-RevId: 446887070
2022-05-05 20:12:32 -07:00
Reza Rahimi
a0d9d81f92
Update JAX to use new math libraries in ROCm-5.0.
2022-03-01 20:02:15 +00:00
Jake VanderPlas
97512e9e44
JaxTestCase: set jax_numpy_rank_promotion='raise' by default
2022-02-14 09:22:05 -08:00
jax authors
5691010d2f
Copybara import of the project:
...
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:
JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63
JaxTestCase: set numpy_rank_promotion='raise' by default
2022-02-10 16:54:31 -08:00
Jake VanderPlas
e376df29be
disable implicit rank promotion in a number of remaining tests
2022-01-28 08:16:30 -08:00
Tianjian Lu
19554e21d3
Enable QDWH TPU tests.
2021-11-30 15:47:50 -08:00
Peter Hawkins
f4351e8419
Disable QDWH tests that fail on GPU and TPU.
...
PiperOrigin-RevId: 411591003
2021-11-22 10:21:41 -08:00
Tianjian Lu
c5f73b3d8e
[JAX] Added jax.lax.linalg.qdwh
.
...
PiperOrigin-RevId: 406453671
2021-10-29 14:45:06 -07:00