223 Commits

Author SHA1 Message Date
Jake VanderPlas
ae6c4676d4 [sparse] add low-level primitives wrapping cuda SpMV & SpMM
This is in preparation for cleaning up our bcoo_dot_general GPU lowering rules: by creating private primitives that closely follow the API of the cusparse kernels, we will be able to better express lowered translation rules that preprocess that data appropriately.

PiperOrigin-RevId: 513212715
2023-03-01 05:56:31 -08:00
Jake VanderPlas
97f819b1ed [sparse] fix dot_general precision in test
PiperOrigin-RevId: 513205756
2023-03-01 05:10:42 -08:00
Jake VanderPlas
06441883b9 [sparse] temporarily disable bcoo_dot_general_sampled fast cases test on GPU
This is failing with precision issues on some GPU architectures; it's not clear why.

PiperOrigin-RevId: 513021864
2023-02-28 13:23:54 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Jake VanderPlas
aad6a70ee9 [sparse] bcoo_dot_general_sampled: another special case 2023-02-24 10:50:54 -08:00
Jake VanderPlas
bf1f5d21a2 [sparse] remove handling of padded indices from COO/CSR 2023-02-23 12:39:12 -08:00
jax authors
2d93b28b18 Merge pull request #14630 from jakevdp:bcoo-dot-general-sampled
PiperOrigin-RevId: 511856372
2023-02-23 12:32:59 -08:00
Jake VanderPlas
54bd631c1a [sparse] bcoo_dot_general_sampled: faster special case 2023-02-22 13:17:16 -08:00
Adam Paszke
1638313a99 Slightly increase the tolerance in sparse tests to avoid flakiness
PiperOrigin-RevId: 511548667
2023-02-22 11:22:02 -08:00
Jake VanderPlas
df358242ff [sparse] test coo/csr extra nse 2023-02-16 16:27:31 -08:00
Tianjian Lu
4fa69e60a0 [sparse] Correct BCOO out-of-bound indices before calling cusparse SpMM.
PiperOrigin-RevId: 510248091
2023-02-16 14:40:18 -08:00
Jake VanderPlas
d1334c80d2 [sparse] bring sparse.csr API in line with sparse.coo 2023-02-16 12:47:38 -08:00
Jake VanderPlas
29f91c5038 [sparse] add bcsr_matmul batching tests 2023-02-15 15:46:37 -08:00
jax authors
7fa24703ec Merge pull request #14496 from jakevdp:bcsr-concatenate
PiperOrigin-RevId: 509949683
2023-02-15 15:32:19 -08:00
Peter Hawkins
cd0533cab0 Replace uses of jnp.ndarray with jax.Array inside JAX.
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Jake VanderPlas
f3e5024787 [sparse] implement bcsr_concatenate 2023-02-15 14:10:56 -08:00
Jake VanderPlas
d688b6d6f3 [sparse] implement bcsr_broadcast_in_dim 2023-02-15 10:26:10 -08:00
Peter Hawkins
33bed1e520 Opt into higher matmul precision for A100 and TPU tests.
PiperOrigin-RevId: 509598465
2023-02-14 12:03:12 -08:00
Jake VanderPlas
15196bc1aa [sparse] enable bcsr_dot_general cusparse lowering
PiperOrigin-RevId: 509537223
2023-02-14 08:32:04 -08:00
jax authors
1bdcd5e138 Merge pull request #14415 from jakevdp:bcsr-matmul
PiperOrigin-RevId: 508785095
2023-02-10 16:55:05 -08:00
Jake VanderPlas
de8a77a3eb [sparse] implement BCSR.__matmul__ 2023-02-10 16:11:57 -08:00
Jake VanderPlas
552fc2c5a3 [sparse] add CPU lowering rule for sparse.linalg.spsolve 2023-02-10 15:35:42 -08:00
Jake VanderPlas
ac647b9459 [sparse] implement autodiff rules for bcsr_dot_general 2023-02-10 12:00:30 -08:00
Jake VanderPlas
7651866b1d [sparse] implement autodiff rules for bcsr primitives 2023-02-09 14:19:22 -08:00
Rahul Batra
01a10a1d06 [ROCm] Re-enable some linalg and sparse tests 2023-02-07 22:05:14 +00:00
Jake VanderPlas
428713e88e [sparse] support all unbatched 1D convolutions 2023-02-03 15:59:42 -08:00
Jake VanderPlas
4fa80b44cd [sparse] implement sparse rule for lax.rev 2023-02-01 15:43:47 -08:00
Jake VanderPlas
27c068e7b7 [sparse] implement __len__ on sparse objects 2023-02-01 11:46:02 -08:00
Jake VanderPlas
5b0329daa8 [sparse] add BCSR.to_bcoo and from_bcoo methods 2023-01-30 10:42:05 -08:00
Tianjian Lu
5aea7d95e0 [sparse] Add function that fixes out-of-bound indices.
PiperOrigin-RevId: 504335149
2023-01-24 11:46:46 -08:00
Jake VanderPlas
b00890b036 [sparse] refactor tests to improve runtime 2023-01-20 11:15:37 -08:00
Jake VanderPlas
7a8781db1c [sparse] add higher-level version of bcoo_extract & improve tests 2023-01-13 07:13:13 -08:00
Jake VanderPlas
e37e3a9b0f [sparse] bcoo_extract: add assume_unique keyword 2023-01-12 15:21:11 -08:00
Jake VanderPlas
f314f2d504 [sparse] generalize batch rule for bcoo_sum_duplicates & improve tests 2023-01-12 14:28:12 -08:00
Jake VanderPlas
a6a3b59748 [sparse] generalize batch rule for bcoo_dot_general 2023-01-12 12:03:21 -08:00
Jake VanderPlas
eaf6179594 [sparse] generalize batch rule for bcoo_spdot_general 2023-01-12 10:59:28 -08:00
Jake VanderPlas
841d7a9cb3 [sparse] mark several slow tests 2023-01-04 12:03:02 -08:00
Jake VanderPlas
275ec3362d [sparse] automate gradient tests for bcoo_todense/bcoo_fromdense 2023-01-03 12:28:13 -08:00
Jake VanderPlas
ea51d074d7 [sparse] handle general batch dimensions in bcoo transpose 2022-12-28 10:15:05 -08:00
Jake VanderPlas
2b82e7a0a1 [sparse] improve implementation of sparse.grad & sparse.value_and_grad 2022-12-28 09:55:11 -08:00
Jake VanderPlas
cca83ae2b5 [sparse] add _CheckBatchingSparse utility 2022-12-27 13:55:50 -08:00
Jake VanderPlas
cee27797da [sparse] propagate metadata through vmappable 2022-12-22 14:46:40 -08:00
Jake VanderPlas
e661c3f0cf [sparse] Handle general batch dimensions in bcoo todense/fromdense 2022-12-22 09:08:30 -08:00
Tianjian Lu
e89b60e383 [sparse] Propagate SparseInfo to BCSR todense() and tree_(un)flatten().
PiperOrigin-RevId: 496945167
2022-12-21 09:55:21 -08:00
Jake VanderPlas
71f861ae73 [sparse] improve tests for bcoo_dot_general_sampled 2022-12-21 08:14:42 -08:00
jax authors
80998dfcd4 Merge pull request #13675 from jakevdp:bcoo-dot-general-vjp
PiperOrigin-RevId: 496902013
2022-12-21 06:01:03 -08:00
Jake VanderPlas
54871ccf02 [sparse] add gradient test for bcoo_concatenate 2022-12-19 15:26:15 -08:00
Tianjian Lu
bc34af9f30 [sparse] Add bcsr dot_general
PiperOrigin-RevId: 496489368
2022-12-19 14:17:44 -08:00
Jake VanderPlas
435f338838 [sparse] fix typo in test_coo_fromdense 2022-12-19 10:34:39 -08:00
Jake VanderPlas
85f30c13c5 [sparse] implement more cases for vjp of bcoo_dot_general 2022-12-19 09:16:06 -08:00