Jake VanderPlas
ae6c4676d4
[sparse] add low-level primitives wrapping cuda SpMV & SpMM
...
This is in preparation for cleaning up our bcoo_dot_general GPU lowering rules: by creating private primitives that closely follow the API of the cusparse kernels, we will be able to better express lowered translation rules that preprocess that data appropriately.
PiperOrigin-RevId: 513212715
2023-03-01 05:56:31 -08:00
Jake VanderPlas
97f819b1ed
[sparse] fix dot_general precision in test
...
PiperOrigin-RevId: 513205756
2023-03-01 05:10:42 -08:00
Jake VanderPlas
06441883b9
[sparse] temporarily disable bcoo_dot_general_sampled fast cases test on GPU
...
This is failing with precision issues on some GPU architectures; it's not clear why.
PiperOrigin-RevId: 513021864
2023-02-28 13:23:54 -08:00
Peter Hawkins
f66f6ec98a
[JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
...
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
Jake VanderPlas
aad6a70ee9
[sparse] bcoo_dot_general_sampled: another special case
2023-02-24 10:50:54 -08:00
Jake VanderPlas
bf1f5d21a2
[sparse] remove handling of padded indices from COO/CSR
2023-02-23 12:39:12 -08:00
jax authors
2d93b28b18
Merge pull request #14630 from jakevdp:bcoo-dot-general-sampled
...
PiperOrigin-RevId: 511856372
2023-02-23 12:32:59 -08:00
Jake VanderPlas
54bd631c1a
[sparse] bcoo_dot_general_sampled: faster special case
2023-02-22 13:17:16 -08:00
Adam Paszke
1638313a99
Slightly increase the tolerance in sparse tests to avoid flakiness
...
PiperOrigin-RevId: 511548667
2023-02-22 11:22:02 -08:00
Jake VanderPlas
df358242ff
[sparse] test coo/csr extra nse
2023-02-16 16:27:31 -08:00
Tianjian Lu
4fa69e60a0
[sparse] Correct BCOO out-of-bound indices before calling cusparse SpMM.
...
PiperOrigin-RevId: 510248091
2023-02-16 14:40:18 -08:00
Jake VanderPlas
d1334c80d2
[sparse] bring sparse.csr API in line with sparse.coo
2023-02-16 12:47:38 -08:00
Jake VanderPlas
29f91c5038
[sparse] add bcsr_matmul batching tests
2023-02-15 15:46:37 -08:00
jax authors
7fa24703ec
Merge pull request #14496 from jakevdp:bcsr-concatenate
...
PiperOrigin-RevId: 509949683
2023-02-15 15:32:19 -08:00
Peter Hawkins
cd0533cab0
Replace uses of jnp.ndarray with jax.Array inside JAX.
...
PiperOrigin-RevId: 509939691
2023-02-15 14:53:00 -08:00
Jake VanderPlas
f3e5024787
[sparse] implement bcsr_concatenate
2023-02-15 14:10:56 -08:00
Jake VanderPlas
d688b6d6f3
[sparse] implement bcsr_broadcast_in_dim
2023-02-15 10:26:10 -08:00
Peter Hawkins
33bed1e520
Opt into higher matmul precision for A100 and TPU tests.
...
PiperOrigin-RevId: 509598465
2023-02-14 12:03:12 -08:00
Jake VanderPlas
15196bc1aa
[sparse] enable bcsr_dot_general cusparse lowering
...
PiperOrigin-RevId: 509537223
2023-02-14 08:32:04 -08:00
jax authors
1bdcd5e138
Merge pull request #14415 from jakevdp:bcsr-matmul
...
PiperOrigin-RevId: 508785095
2023-02-10 16:55:05 -08:00
Jake VanderPlas
de8a77a3eb
[sparse] implement BCSR.__matmul__
2023-02-10 16:11:57 -08:00
Jake VanderPlas
552fc2c5a3
[sparse] add CPU lowering rule for sparse.linalg.spsolve
2023-02-10 15:35:42 -08:00
Jake VanderPlas
ac647b9459
[sparse] implement autodiff rules for bcsr_dot_general
2023-02-10 12:00:30 -08:00
Jake VanderPlas
7651866b1d
[sparse] implement autodiff rules for bcsr primitives
2023-02-09 14:19:22 -08:00
Rahul Batra
01a10a1d06
[ROCm] Re-enable some linalg and sparse tests
2023-02-07 22:05:14 +00:00
Jake VanderPlas
428713e88e
[sparse] support all unbatched 1D convolutions
2023-02-03 15:59:42 -08:00
Jake VanderPlas
4fa80b44cd
[sparse] implement sparse rule for lax.rev
2023-02-01 15:43:47 -08:00
Jake VanderPlas
27c068e7b7
[sparse] implement __len__ on sparse objects
2023-02-01 11:46:02 -08:00
Jake VanderPlas
5b0329daa8
[sparse] add BCSR.to_bcoo and from_bcoo methods
2023-01-30 10:42:05 -08:00
Tianjian Lu
5aea7d95e0
[sparse] Add function that fixes out-of-bound indices.
...
PiperOrigin-RevId: 504335149
2023-01-24 11:46:46 -08:00
Jake VanderPlas
b00890b036
[sparse] refactor tests to improve runtime
2023-01-20 11:15:37 -08:00
Jake VanderPlas
7a8781db1c
[sparse] add higher-level version of bcoo_extract & improve tests
2023-01-13 07:13:13 -08:00
Jake VanderPlas
e37e3a9b0f
[sparse] bcoo_extract: add assume_unique keyword
2023-01-12 15:21:11 -08:00
Jake VanderPlas
f314f2d504
[sparse] generalize batch rule for bcoo_sum_duplicates & improve tests
2023-01-12 14:28:12 -08:00
Jake VanderPlas
a6a3b59748
[sparse] generalize batch rule for bcoo_dot_general
2023-01-12 12:03:21 -08:00
Jake VanderPlas
eaf6179594
[sparse] generalize batch rule for bcoo_spdot_general
2023-01-12 10:59:28 -08:00
Jake VanderPlas
841d7a9cb3
[sparse] mark several slow tests
2023-01-04 12:03:02 -08:00
Jake VanderPlas
275ec3362d
[sparse] automate gradient tests for bcoo_todense/bcoo_fromdense
2023-01-03 12:28:13 -08:00
Jake VanderPlas
ea51d074d7
[sparse] handle general batch dimensions in bcoo transpose
2022-12-28 10:15:05 -08:00
Jake VanderPlas
2b82e7a0a1
[sparse] improve implementation of sparse.grad & sparse.value_and_grad
2022-12-28 09:55:11 -08:00
Jake VanderPlas
cca83ae2b5
[sparse] add _CheckBatchingSparse utility
2022-12-27 13:55:50 -08:00
Jake VanderPlas
cee27797da
[sparse] propagate metadata through vmappable
2022-12-22 14:46:40 -08:00
Jake VanderPlas
e661c3f0cf
[sparse] Handle general batch dimensions in bcoo todense/fromdense
2022-12-22 09:08:30 -08:00
Tianjian Lu
e89b60e383
[sparse] Propagate SparseInfo to BCSR todense() and tree_(un)flatten().
...
PiperOrigin-RevId: 496945167
2022-12-21 09:55:21 -08:00
Jake VanderPlas
71f861ae73
[sparse] improve tests for bcoo_dot_general_sampled
2022-12-21 08:14:42 -08:00
jax authors
80998dfcd4
Merge pull request #13675 from jakevdp:bcoo-dot-general-vjp
...
PiperOrigin-RevId: 496902013
2022-12-21 06:01:03 -08:00
Jake VanderPlas
54871ccf02
[sparse] add gradient test for bcoo_concatenate
2022-12-19 15:26:15 -08:00
Tianjian Lu
bc34af9f30
[sparse] Add bcsr dot_general
...
PiperOrigin-RevId: 496489368
2022-12-19 14:17:44 -08:00
Jake VanderPlas
435f338838
[sparse] fix typo in test_coo_fromdense
2022-12-19 10:34:39 -08:00
Jake VanderPlas
85f30c13c5
[sparse] implement more cases for vjp of bcoo_dot_general
2022-12-19 09:16:06 -08:00