[sparse] support custom JVP in sparsify

This commit is contained in:
Jake VanderPlas 2023-06-21 02:36:07 -07:00
parent fc0dcd15a2
commit b6d544549b
2 changed files with 43 additions and 0 deletions

View File

@ -53,8 +53,10 @@ from typing import (
import numpy as np
import jax
from jax import lax
from jax._src import core
from jax._src.custom_derivatives import lift_jvp
from jax._src import linear_util as lu
from jax._src import pjit
from jax._src import sharding_impls
@ -343,6 +345,11 @@ class SparseTrace(core.Trace):
setnewattr(self.main, 'spenv', SparsifyEnv(bufs_out))
return [SparseTracer(self, spvalue=spvalue) for spvalue in out_spvalues()]
def process_custom_jvp_call(self, primitive, fun, jvp, tracers, *, symbolic_zeros):
# TODO(jakevdp): handle the jvp here
del primitive, jvp, symbolic_zeros
return fun.call_wrapped(*tracers)
@lu.transformation_with_aux
def sparsify_subtrace(main, spvalues, *bufs):
setnewattr(main, 'spenv', SparsifyEnv(bufs))
@ -852,6 +859,24 @@ def _todense_sparse_rule(spenv, spvalue, *, tree):
sparse_rules_bcoo[sparse.todense_p] = _todense_sparse_rule
def _custom_jvp_sparse_rule(spenv, *spvalues, **params):
call_jaxpr = params.pop('call_jaxpr')
jvp_jaxpr_thunk = params.pop('jvp_jaxpr_thunk')
num_consts = params.pop('num_consts')
sp_call_jaxpr, out_tree = _sparsify_jaxpr(spenv, call_jaxpr, *spvalues)
@lu.wrap_init
def fun(*arrs):
sparrs = arrays_to_spvalues(spenv, arrs)
out = eval_sparse(call_jaxpr.jaxpr, call_jaxpr.consts, sparrs, spenv)
return spvalues_to_arrays(spenv, out)
jvp = lift_jvp(num_consts, jvp_jaxpr_thunk)
invals = spvalues_to_arrays(spenv, spvalues)
outvals = jax.custom_derivatives.custom_jvp_call_p.bind(fun, jvp, *invals, **params)
return arrays_to_spvalues(spenv, outvals)
sparse_rules_bcoo[jax.custom_derivatives.custom_jvp_call_p] = _custom_jvp_sparse_rule
sparse_rules_bcsr[jax.custom_derivatives.custom_jvp_call_p] = _custom_jvp_sparse_rule
# ------------------------------------------------------------------------------
# BCOO methods derived from sparsify

View File

@ -655,6 +655,24 @@ class SparsifyTest(jtu.JaxTestCase):
if fmt == "BCSR":
self.assertArraysAllClose(sparse_result.indptr, mat.indptr)
def testCustomJVP(self):
square = jax.custom_derivatives.custom_jvp(lambda x: x ** 2)
square.defjvp(lambda p, t: (p[0] ** 2, 2 * t[0] * p[0]))
x = BCOO.fromdense(jnp.arange(5.0))
# Test calling the function itself.
result = self.sparsify(square)(x)
expected = self.sparsify(lambda x: x ** 2)(x)
self.assertArraysEqual(result.indices, expected.indices)
self.assertArraysAllClose(result.data, expected.data)
# Test evaluating the custom gradient.
grad_square_sum = jax.grad(lambda x: square(x).sum())
result = self.sparsify(grad_square_sum)(x)
expected = self.sparsify(jax.grad(lambda x: jnp.sum(x ** 2)))(x)
self.assertArraysEqual(result.indices, expected.indices)
self.assertArraysAllClose(result.data, expected.data)
class SparsifyTracerTest(SparsifyTest):
@classmethod