73 Commits

Author SHA1 Message Date
Peter Hawkins
48bbdbc890 Change jax.core.DropVar to be a non-singleton.
Previously jax.core.DropVar was a singleton value (jax.core.dropvar) whose type was always jax.core.AbstractUnit. However, this type is misleading: a DropVar is an equation output, and typically we would expect it to have an array type. In particular, the unit type confuses new-style translation rules that expect to use the output aval on an equation as part of the lowering logic.

Instead, change DropVar to be a non-singleton subclass of Var instead with a flexible choice of aval.

PiperOrigin-RevId: 404071001
2021-10-18 15:02:54 -07:00
Jake VanderPlas
f2bbd51cc2 [sparse] respect weak types in sparsify transform 2021-10-18 13:35:20 -07:00
Jake VanderPlas
c2dd90e3a0 [sparse] Factor BCOO-related routines into a separate submodule 2021-10-06 08:06:18 -07:00
Jake VanderPlas
3a440d665f [sparse] add sparsify support for sparse-sparse matmul 2021-10-05 16:45:48 -07:00
Peter Hawkins
e869e5e0f8 Move contents of jax.api_util to jax._src.api_util and add a forwarding shim.
One of many changes to codify the set of exported symbols in the jax.* namespace.

PiperOrigin-RevId: 395484706
2021-09-08 09:00:56 -07:00
jax authors
b9cc31e35d Merge pull request #7852 from google:sparse-jaxpr-consts
PiperOrigin-RevId: 395421332
2021-09-08 01:27:01 -07:00
Roy Frostig
8bb8bf1081 avoid constvar conversion when closing a sparse jaxpr 2021-09-07 22:02:21 -07:00
Roy Frostig
bf44398790 handle dropped output values in the sparse interpreter 2021-09-07 18:50:13 -07:00
Jake VanderPlas
82a7b7ee4d DOC: add documentation of jax.experimental.sparse 2021-09-02 17:08:10 -07:00
Jake VanderPlas
c5fed9c3b5 [sparse] Change BCOO index order 2021-09-01 13:48:55 -07:00
Jake VanderPlas
4f9310088d [sparse] handle pytree inputs in sparsify transform 2021-08-10 10:31:16 -07:00
Jake VanderPlas
0b7c0daee2 [sparse] bug: thread through params in sparsify 2021-08-09 12:14:37 -07:00
Jake VanderPlas
25b3737e81 [sparse] correctly handle units in sparsify argspecs 2021-08-09 09:15:08 -07:00
Jake VanderPlas
f76108ba0e [sparse] add sparsify rule for lax.cond 2021-08-06 13:32:23 -07:00
Jake VanderPlas
1d359f8c61 [sparse]: add sparse rule for scan/fori_loop 2021-08-05 15:19:43 -07:00
Jake VanderPlas
1eb3b5f8d6 [sparse] support sparse arguments in xla_call 2021-07-13 15:23:14 -07:00
Jake VanderPlas
9af8676341 [sparse] support dense xla_call within sparsify jaxpr interpreter 2021-07-13 13:31:21 -07:00
Jake VanderPlas
5db97e0bf9 [sparse] add sparse transform rule for lax.while_p 2021-07-12 16:53:24 -07:00
Jake VanderPlas
087da553cd [sparse]: add support for rdot_general in sparsify transform 2021-07-09 06:00:05 -07:00
Jake VanderPlas
76f9e6f016 [sparse] globally change nnz->nse 2021-06-30 17:46:02 -07:00
Jake VanderPlas
5ed9471b9a flake: fix unused import 2021-06-28 11:40:23 -07:00
jax authors
cb8582c63d Merge pull request #6929 from jakevdp:sparsify
PiperOrigin-RevId: 381897788
2021-06-28 10:43:07 -07:00
Jake VanderPlas
0401d2be57 Add experimental sparsify transform
Co-authored-by: Roy Frostig <frostig@google.com>
2021-06-25 10:45:16 -07:00