propagate symbolic zeros in sparse op JVPs

This commit is contained in:
Roy Frostig 2021-05-12 13:02:16 -07:00 committed by Jake VanderPlas
parent 926de5a2bc
commit 25ff2f4a94

View File

@ -306,22 +306,12 @@ def _coo_todense_abstract_eval(data, row, col, *, shape):
def _coo_todense_gpu_translation_rule(c, data, row, col, *, shape):
return cusparse.coo_todense(c, data, row, col, shape=shape)
def _coo_todense_jvp(primals, tangents, *, shape):
data, row, col = primals
data_dot, row_dot, col_dot = tangents
assert isinstance(row_dot, ad.Zero)
assert isinstance(col_dot, ad.Zero)
# TODO: propagate symbolic zeros if possible.
data_dot = lax.zeros_like_array(data) if isinstance(data_dot, ad.Zero) else data_dot
# Note: we assume that transpose has the same sparsity pattern. Can we assert this?
primals_out = coo_todense(data, row, col, shape=shape)
tangents_out = coo_todense(data_dot, row, col, shape=shape)
return primals_out, tangents_out
def _coo_todense_jvp(data_dot, data, row, col, *, shape):
return coo_todense(data_dot, row, col, shape=shape)
def _coo_todense_transpose(ct, data, row, col, *, shape):
# Note: we assume that transpose has the same sparsity pattern.
# Can we check this?
assert ad.is_undefined_primal(data)
if ad.is_undefined_primal(row) or ad.is_undefined_primal(col):
raise ValueError("Cannot transpose with respect to sparse indices")
@ -330,7 +320,7 @@ def _coo_todense_transpose(ct, data, row, col, *, shape):
assert ct.dtype == data.aval.dtype
return _coo_extract(row, col, ct), row, col
ad.primitive_jvps[coo_todense_p] = _coo_todense_jvp
ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
xla.translations[coo_todense_p] = xla.lower_fun(
_coo_todense_impl, multiple_results=False)
@ -389,12 +379,16 @@ def _coo_fromdense_jvp(primals, tangents, *, nnz, index_dtype):
M, = primals
Mdot, = tangents
# TODO: propagate symbolic zeros if possible.
Mdot = lax.zeros_like_array(M) if isinstance(Mdot, ad.Zero) else Mdot
primals_out = coo_fromdense(M, nnz=nnz, index_dtype=index_dtype)
_, row, col = primals_out
tangents_out = _coo_extract(row, col, Mdot), ad.Zero(row.aval), ad.Zero(col.aval)
data, row, col = primals_out
if type(Mdot) is ad.Zero:
data_dot = ad.Zero.from_value(data)
else:
data_dot = _coo_extract(row, col, Mdot)
tangents_out = (data_dot, ad.Zero.from_value(row), ad.Zero.from_value(col))
return primals_out, tangents_out
def _coo_fromdense_transpose(ct, M, *, nnz, index_dtype):
@ -461,24 +455,11 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose):
def _coo_matvec_gpu_translation_rule(c, data, row, col, v, *, shape, transpose):
return cusparse.coo_matvec(c, data, row, col, v, shape=shape, transpose=transpose)
def _coo_matvec_jvp(primals, tangents, *, shape, transpose):
data, row, col, v = primals
data_dot, row_dot, col_dot, v_dot = tangents
def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, shape, transpose):
return coo_matvec(data_dot, row, col, v, shape=shape, transpose=transpose)
assert isinstance(row_dot, ad.Zero)
assert isinstance(col_dot, ad.Zero)
# TODO: propagate symbolic zeros if possible.
_zero = lambda p, t: lax.zeros_like_array(p) if isinstance(t, ad.Zero) else t
data_dot = _zero(data, data_dot)
v_dot = _zero(v, v_dot)
primals_out = coo_matvec(data, row, col, v, shape=shape, transpose=transpose)
tangents_out = (
coo_matvec(data_dot, row, col, v, shape=shape, transpose=transpose) +
coo_matvec(data, row, col, v_dot, shape=shape, transpose=transpose)
)
return primals_out, tangents_out
def _coo_matvec_jvp_vec(v_dot, data, row, col, v, *, shape, transpose):
return coo_matvec(data, row, col, v_dot, shape=shape, transpose=transpose)
def _coo_matvec_transpose(ct, data, row, col, v, *, shape, transpose):
assert not ad.is_undefined_primal(row)
@ -491,7 +472,7 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, shape, transpose):
# return _coo_extract(row, col, jnp.outer(ct, v)), row, col, v
return ct[row] * v[col], row, col, v
ad.primitive_jvps[coo_matvec_p] = _coo_matvec_jvp
ad.defjvp(coo_matvec_p, _coo_matvec_jvp_mat, None, None, _coo_matvec_jvp_vec)
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
xla.translations[coo_matvec_p] = xla.lower_fun(
_coo_matvec_impl, multiple_results=False)