258 Commits

Author SHA1 Message Date
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Blair-Johnson
802a14cd61 Re-pack gradients of jax.experimental.sparse.grad() to match original pytrees & test cases 2024-07-29 13:04:05 -04:00
Ruturaj4
637370baa0 [ROCM] Fix version checks after rocm pjrt integration 2024-07-12 08:07:52 -05:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Meekail Zain
a7737ca618 Clean up sparse test run conditions 2024-04-09 23:16:12 +00:00
Jake VanderPlas
cddee4654c tests: access tree utilities via jax.tree.* 2024-02-26 14:17:18 -08:00
Sergei Lebedev
cbcaac2756 MAINT Migrate remaining internal/test modules to use state objects
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.

This is a follow up to #18008.
2023-10-12 17:32:15 +01:00
Sergei Lebedev
5d9c39f4b0 MAINT Use a generator expression with all() and any()
There is no reason to allocate a list only for the purpose of iteration.
2023-10-10 22:33:03 +01:00
jax authors
1c37f5091c sparse_test: Split into two so that each target is small enough to fit within a medium timeout.
PiperOrigin-RevId: 570882867
2023-10-04 19:59:03 -07:00
Peter Hawkins
6be860bda8 Clean up some device opt-in/opt-outs in test suite.
Use allowlists rather than denylists in a few places.

PiperOrigin-RevId: 568968749
2023-09-27 14:56:00 -07:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
Peter Hawkins
9b447aa3ec Relax test tolerance to fix BCSR sparse matmul test failure on P100 GPU.
PiperOrigin-RevId: 563441383
2023-09-07 08:37:31 -07:00
Peter Hawkins
429422dfea Reverts 5fcd9265b1e20c41d684659af3d52c41f25ae2f3
PiperOrigin-RevId: 563426073
2023-09-07 07:35:44 -07:00
Peter Hawkins
4f805c2d8f [JAX] Change jax.test_util utilities to have identical tolerances on all platforms.
In cases where this causes TPU tests to fail, relax test tolerances in the test cases themselves.

TPUs are less precise only for specific operations, notably matrix multiplication (for which usually enabling higher-precision matrix multiplication is the right choice if precision is needed), and certain special functions (e.g., log/exp/pow).

The net effect of this change is mostly to tighten up many test tolerances on TPU.

PiperOrigin-RevId: 562953488
2023-09-05 18:48:55 -07:00
Jake VanderPlas
7d29ed6bdd Lower jax.numpy matmul functions to mixed-precision dot_general 2023-09-05 08:37:51 -07:00
Jake VanderPlas
2f878a7168 Tests: set jax_legacy_prng_key='error' 2023-08-28 10:56:09 -07:00
jax authors
5fcd9265b1 Merge pull request #16975 from hawkinsp:win
PiperOrigin-RevId: 554479893
2023-08-07 08:10:21 -07:00
Jake VanderPlas
eb1ab2b101 Add autodiff rules for spsolve_p
Fixes #16935
2023-08-04 12:08:57 -07:00
Peter Hawkins
622e7da635 Disable some tests that appear to crash on Windows. 2023-08-04 11:24:58 -04:00
Adam Paszke
d7940ee9a1 Always skip the BCOO matmul test on CUDA 12
We seem to be consistently hitting cuSPARSE bugs in this test, so disabling it
for now.

PiperOrigin-RevId: 553801002
2023-08-04 07:29:49 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
jax authors
68ea651ae4 Merge pull request #16740 from jakevdp:spdot-general-args
PiperOrigin-RevId: 548773744
2023-07-17 12:52:33 -07:00
Artem Belevich
3a7857130f Disable tests triggering a known bug in cuda-12.
PiperOrigin-RevId: 548727901
2023-07-17 10:26:12 -07:00
Jake VanderPlas
7986ba75c6 [sparse] support preferred_element_type in dot_general 2023-07-14 18:23:34 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Peter Hawkins
0adfafe293 Relax test tolerances.
This makes the tests pass on CPU with a slightly different seed (+ 1).

PiperOrigin-RevId: 542877795
2023-06-23 09:22:11 -07:00
Anton Geraschenko
27aa5fb774 Make dimensions argument of bcoo_reshape optional. 2023-05-09 10:38:18 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Jake VanderPlas
4180f8bf7b [sparse] improve worst-case nse in spdot_general 2023-03-07 20:46:27 -08:00
Jake VanderPlas
b527bcaa3c [sparse] fix GPU warnings in cusparse test 2023-03-06 17:39:28 -08:00
Tianjian Lu
7bcd490b69 [sparse] add low-level primitives wrapping cuda csr spmv and spmm.
PiperOrigin-RevId: 514473374
2023-03-06 11:34:30 -08:00
Jake VanderPlas
33c0a103c6 [sparse] more robust correction for out-of-bound indices in BCOO
Previously we were setting out-of-bound indices to zero, which works in most (but not all) cases. The problem is that if (0, 0) is a defined matrix element, these subsequent zeros effectively overwrite this element in some cusparse routines.

The fix here is to add another row or column to the matrix as necessary, and to push these undefined values into that row/col, where they can be sliced off at the end of the cusparse operation so that they will not affect the computation of interest.

PiperOrigin-RevId: 513639921
2023-03-02 14:26:25 -08:00
Jake VanderPlas
6cc0545006 [sparse] simplify bcoo_dot_general GPU lowering rule
Also, remove the cusparse lowering for batched matmul, because in testing I found that it returns incorrect results for CUSPARSE_SPMM_COO_ALG4. Our tests haven't revealed that because we currently only test for a single batch. To re-land this, we can add it to the private primitives in _lowerings and add another elif clause in the GPU impl.

PiperOrigin-RevId: 513604587
2023-03-02 12:11:39 -08:00
Jake VanderPlas
abc6c9bf49 [sparse] adjust tolerance on bcoo_dot_general_sampled
PiperOrigin-RevId: 513342523
2023-03-01 14:29:39 -08:00
Jake VanderPlas
a6409e85fb [sparse] fix expected warning in batched_matmat test case
PiperOrigin-RevId: 513274151
2023-03-01 10:25:20 -08:00
Jake VanderPlas
ae6c4676d4 [sparse] add low-level primitives wrapping cuda SpMV & SpMM
This is in preparation for cleaning up our bcoo_dot_general GPU lowering rules: by creating private primitives that closely follow the API of the cusparse kernels, we will be able to better express lowered translation rules that preprocess that data appropriately.

PiperOrigin-RevId: 513212715
2023-03-01 05:56:31 -08:00
Jake VanderPlas
97f819b1ed [sparse] fix dot_general precision in test
PiperOrigin-RevId: 513205756
2023-03-01 05:10:42 -08:00
Jake VanderPlas
06441883b9 [sparse] temporarily disable bcoo_dot_general_sampled fast cases test on GPU
This is failing with precision issues on some GPU architectures; it's not clear why.

PiperOrigin-RevId: 513021864
2023-02-28 13:23:54 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Jake VanderPlas
aad6a70ee9 [sparse] bcoo_dot_general_sampled: another special case 2023-02-24 10:50:54 -08:00
Jake VanderPlas
bf1f5d21a2 [sparse] remove handling of padded indices from COO/CSR 2023-02-23 12:39:12 -08:00
jax authors
2d93b28b18 Merge pull request #14630 from jakevdp:bcoo-dot-general-sampled
PiperOrigin-RevId: 511856372
2023-02-23 12:32:59 -08:00
Jake VanderPlas
54bd631c1a [sparse] bcoo_dot_general_sampled: faster special case 2023-02-22 13:17:16 -08:00
Adam Paszke
1638313a99 Slightly increase the tolerance in sparse tests to avoid flakiness
PiperOrigin-RevId: 511548667
2023-02-22 11:22:02 -08:00
Jake VanderPlas
df358242ff [sparse] test coo/csr extra nse 2023-02-16 16:27:31 -08:00
Tianjian Lu
4fa69e60a0 [sparse] Correct BCOO out-of-bound indices before calling cusparse SpMM.
PiperOrigin-RevId: 510248091
2023-02-16 14:40:18 -08:00
Jake VanderPlas
d1334c80d2 [sparse] bring sparse.csr API in line with sparse.coo 2023-02-16 12:47:38 -08:00
Jake VanderPlas
29f91c5038 [sparse] add bcsr_matmul batching tests 2023-02-15 15:46:37 -08:00
jax authors
7fa24703ec Merge pull request #14496 from jakevdp:bcsr-concatenate
PiperOrigin-RevId: 509949683
2023-02-15 15:32:19 -08:00
Peter Hawkins
cd0533cab0 Replace uses of jnp.ndarray with jax.Array inside JAX.
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00