58 Commits

Author SHA1 Message Date
jax authors
735637e313 Previously, using sparse.todense on a BCSR matrix with sparse.sparsify would raise NotImplementedError: sparse rule for todense is not implemented. By adding the sparse rule, it will resolve this issue.
PiperOrigin-RevId: 551291543
2023-07-26 13:01:02 -07:00
Jake VanderPlas
7986ba75c6 [sparse] support preferred_element_type in dot_general 2023-07-14 18:23:34 -07:00
Jake VanderPlas
b6d544549b [sparse] support custom JVP in sparsify 2023-06-23 00:27:19 -07:00
Jake VanderPlas
5521423d92 Change np.prod->math.prod
Why? This is generally used for static operations on shapes, but np.prod
has an unfortunate corner-case behavior that np.prod([]) returns a float.
math.prod is available as of Python 3.8, and is a better solution here.
2023-04-13 11:48:11 -07:00
Jake VanderPlas
05f32a7947 [sparse] allow sparse-dense add when the output is the same size as dense input 2023-04-05 10:39:43 -07:00
Jake VanderPlas
74242f06d9 [sparse] add BCOO lowering for div
We had avoiding this previously because dividing by zero is
a densifying operation, but we already support mul which has
similar issues if the operand contains infinities.
2023-03-14 11:58:43 -07:00
Jake VanderPlas
f32e72da2a [sparse] add support for integer_pow 2023-03-08 09:24:52 -08:00
Jake VanderPlas
f3e5024787 [sparse] implement bcsr_concatenate 2023-02-15 14:10:56 -08:00
Jake VanderPlas
597c20173f [sparse] support BCSR in sparsify transform 2023-02-06 11:01:57 -08:00
Jake VanderPlas
428713e88e [sparse] support all unbatched 1D convolutions 2023-02-03 15:59:42 -08:00
Jake VanderPlas
038798ed25 [sparse] add support for simple 1D convolutions 2023-02-01 18:53:49 -08:00
Jake VanderPlas
4fa80b44cd [sparse] implement sparse rule for lax.rev 2023-02-01 15:43:47 -08:00
Jake VanderPlas
e673f1fd44 [sparse] avoid re-indexing for linear unary ops 2022-11-17 16:31:46 -08:00
Jake VanderPlas
66262901f0 [sparse] improve testing framework 2022-11-16 09:58:06 -08:00
Jake VanderPlas
c85230c2c6 [sparse] support dense dimensions in bcoo_reshape 2022-11-15 13:19:44 -08:00
Jake VanderPlas
7d3b1d6439 [sparse] fix bcoo_reshape under jit 2022-11-08 17:00:25 -08:00
Jake VanderPlas
af956636b8 [sparse] fix bcoo_reshape when n_sparse=0 2022-11-08 12:00:24 -08:00
Peter Hawkins
72f4f389be Migrate remaining tests from jtu.cases_from_list to jtu.sample_product.
Delete jtu.cases_from_list.
2022-10-12 15:20:53 +00:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
47b9f216bc [sparse] add sparse support for dynamic_slice 2022-09-01 13:42:02 -07:00
Jake VanderPlas
2b4f72b6f4 [sparse] fix unary operations in presence of duplicate indices 2022-07-07 13:49:50 -07:00
Jake VanderPlas
7917766828 [x64] make sparse tests compatible with strict dtype promotion 2022-06-14 12:45:26 -07:00
Tianjian Lu
cc4f2ade63 [sparse] Track unique_indices in sparse transform.
PiperOrigin-RevId: 452848200
2022-06-03 14:51:26 -07:00
Tianjian Lu
b6d4c59e64 [sparse] Trace BCOO indices_sorted in sparsifying zero_preserving_unary_ops.
PiperOrigin-RevId: 451081409
2022-05-25 21:07:29 -07:00
Tianjian Lu
1656b33ca9 [sparse] Track indices_sorted in sparsify transform.
PiperOrigin-RevId: 449833669
2022-05-19 14:29:14 -07:00
Matthew Johnson
9cd55a2bbd [remove-units] remove units 2022-05-04 10:58:56 -07:00
Jake VanderPlas
1a9a796a0c [sparse] implement sparse rule for lax.reshape_p 2022-05-02 09:11:55 -07:00
Jake VanderPlas
2d9af38a2c [sparse] implement sparse rule for lax.concatenate_p 2022-04-28 10:59:11 -07:00
Tianjian Lu
bcfa290c26 [sparse] Add BCOO attribute _indices_sorted.
PiperOrigin-RevId: 444659603
2022-04-26 14:01:17 -07:00
Jake VanderPlas
c37c1e683e [sparse] improve error messages for unimplemented primitives 2022-04-25 16:22:19 -07:00
Jake VanderPlas
8c6e001e45 [sparse] refactor internal implementation of sparsify transform 2022-03-07 12:48:03 -08:00
Jake VanderPlas
97512e9e44 JaxTestCase: set jax_numpy_rank_promotion='raise' by default 2022-02-14 09:22:05 -08:00
jax authors
5691010d2f Copybara import of the project:
--
d42fffd849a4bac0c0c11a3346c93f07f8c64c44 by Jake VanderPlas <jakevdp@google.com>:

JaxTestCase: set numpy_rank_promotion='raise' by default
PiperOrigin-RevId: 427896974
2022-02-10 19:08:29 -08:00
Jake VanderPlas
6324577a63 JaxTestCase: set numpy_rank_promotion='raise' by default 2022-02-10 16:54:31 -08:00
Jake VanderPlas
fa24395040 [sparse] avoid implicit rank promotion 2022-01-25 09:17:44 -08:00
Jake VanderPlas
b58ac44228 [sparse] add sparse rule for lax.sub_p 2021-12-13 16:46:47 -08:00
Yash Katariya
6621d4cb23 Copybara import of the project:
--
8abdd9eceb7bba66de4d3a2554e50b9bbf0b8aec by Tianjian Lu <tianjianlu@google.com>:

[sparse] Update bcoo_dot_general GPU translation rule.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
PiperOrigin-RevId: 415878536
2021-12-12 09:29:29 -08:00
Tianjian Lu
8abdd9eceb [sparse] Update bcoo_dot_general GPU translation rule.
Co-authored-by: Jake VanderPlas <jakevdp@google.com>
2021-12-10 14:08:21 -08:00
Jake VanderPlas
b00061e038 [sparse]: add tracer-based implementation of sparsify
Co-authored by: Matthew Johnson <mattjj@google.com>
2021-11-04 13:02:46 -07:00
Jake VanderPlas
2259e2b0a8 [sparse] add todense() primitive for use in sparsify transform 2021-10-26 13:52:48 -07:00
Jake VanderPlas
f2bbd51cc2 [sparse] respect weak types in sparsify transform 2021-10-18 13:35:20 -07:00
Jake VanderPlas
3a440d665f [sparse] add sparsify support for sparse-sparse matmul 2021-10-05 16:45:48 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
jax authors
b9cc31e35d Merge pull request #7852 from google:sparse-jaxpr-consts
PiperOrigin-RevId: 395421332
2021-09-08 01:27:01 -07:00
Roy Frostig
8bb8bf1081 avoid constvar conversion when closing a sparse jaxpr 2021-09-07 22:02:21 -07:00
Roy Frostig
bf44398790 handle dropped output values in the sparse interpreter 2021-09-07 18:50:13 -07:00
Jake VanderPlas
4f9310088d [sparse] handle pytree inputs in sparsify transform 2021-08-10 10:31:16 -07:00
Jake VanderPlas
25b3737e81 [sparse] correctly handle units in sparsify argspecs 2021-08-09 09:15:08 -07:00
Jake VanderPlas
f76108ba0e [sparse] add sparsify rule for lax.cond 2021-08-06 13:32:23 -07:00