mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
propagate symbolic zeros in sparse op JVPs
This commit is contained in:
parent
926de5a2bc
commit
25ff2f4a94
@ -306,22 +306,12 @@ def _coo_todense_abstract_eval(data, row, col, *, shape):
|
||||
def _coo_todense_gpu_translation_rule(c, data, row, col, *, shape):
|
||||
return cusparse.coo_todense(c, data, row, col, shape=shape)
|
||||
|
||||
def _coo_todense_jvp(primals, tangents, *, shape):
|
||||
data, row, col = primals
|
||||
data_dot, row_dot, col_dot = tangents
|
||||
|
||||
assert isinstance(row_dot, ad.Zero)
|
||||
assert isinstance(col_dot, ad.Zero)
|
||||
# TODO: propagate symbolic zeros if possible.
|
||||
data_dot = lax.zeros_like_array(data) if isinstance(data_dot, ad.Zero) else data_dot
|
||||
|
||||
# Note: we assume that transpose has the same sparsity pattern. Can we assert this?
|
||||
primals_out = coo_todense(data, row, col, shape=shape)
|
||||
tangents_out = coo_todense(data_dot, row, col, shape=shape)
|
||||
|
||||
return primals_out, tangents_out
|
||||
def _coo_todense_jvp(data_dot, data, row, col, *, shape):
|
||||
return coo_todense(data_dot, row, col, shape=shape)
|
||||
|
||||
def _coo_todense_transpose(ct, data, row, col, *, shape):
|
||||
# Note: we assume that transpose has the same sparsity pattern.
|
||||
# Can we check this?
|
||||
assert ad.is_undefined_primal(data)
|
||||
if ad.is_undefined_primal(row) or ad.is_undefined_primal(col):
|
||||
raise ValueError("Cannot transpose with respect to sparse indices")
|
||||
@ -330,7 +320,7 @@ def _coo_todense_transpose(ct, data, row, col, *, shape):
|
||||
assert ct.dtype == data.aval.dtype
|
||||
return _coo_extract(row, col, ct), row, col
|
||||
|
||||
ad.primitive_jvps[coo_todense_p] = _coo_todense_jvp
|
||||
ad.defjvp(coo_todense_p, _coo_todense_jvp, None, None)
|
||||
ad.primitive_transposes[coo_todense_p] = _coo_todense_transpose
|
||||
xla.translations[coo_todense_p] = xla.lower_fun(
|
||||
_coo_todense_impl, multiple_results=False)
|
||||
@ -389,12 +379,16 @@ def _coo_fromdense_jvp(primals, tangents, *, nnz, index_dtype):
|
||||
M, = primals
|
||||
Mdot, = tangents
|
||||
|
||||
# TODO: propagate symbolic zeros if possible.
|
||||
Mdot = lax.zeros_like_array(M) if isinstance(Mdot, ad.Zero) else Mdot
|
||||
|
||||
primals_out = coo_fromdense(M, nnz=nnz, index_dtype=index_dtype)
|
||||
_, row, col = primals_out
|
||||
tangents_out = _coo_extract(row, col, Mdot), ad.Zero(row.aval), ad.Zero(col.aval)
|
||||
data, row, col = primals_out
|
||||
|
||||
if type(Mdot) is ad.Zero:
|
||||
data_dot = ad.Zero.from_value(data)
|
||||
else:
|
||||
data_dot = _coo_extract(row, col, Mdot)
|
||||
|
||||
tangents_out = (data_dot, ad.Zero.from_value(row), ad.Zero.from_value(col))
|
||||
|
||||
return primals_out, tangents_out
|
||||
|
||||
def _coo_fromdense_transpose(ct, M, *, nnz, index_dtype):
|
||||
@ -461,24 +455,11 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, shape, transpose):
|
||||
def _coo_matvec_gpu_translation_rule(c, data, row, col, v, *, shape, transpose):
|
||||
return cusparse.coo_matvec(c, data, row, col, v, shape=shape, transpose=transpose)
|
||||
|
||||
def _coo_matvec_jvp(primals, tangents, *, shape, transpose):
|
||||
data, row, col, v = primals
|
||||
data_dot, row_dot, col_dot, v_dot = tangents
|
||||
def _coo_matvec_jvp_mat(data_dot, data, row, col, v, *, shape, transpose):
|
||||
return coo_matvec(data_dot, row, col, v, shape=shape, transpose=transpose)
|
||||
|
||||
assert isinstance(row_dot, ad.Zero)
|
||||
assert isinstance(col_dot, ad.Zero)
|
||||
|
||||
# TODO: propagate symbolic zeros if possible.
|
||||
_zero = lambda p, t: lax.zeros_like_array(p) if isinstance(t, ad.Zero) else t
|
||||
data_dot = _zero(data, data_dot)
|
||||
v_dot = _zero(v, v_dot)
|
||||
|
||||
primals_out = coo_matvec(data, row, col, v, shape=shape, transpose=transpose)
|
||||
tangents_out = (
|
||||
coo_matvec(data_dot, row, col, v, shape=shape, transpose=transpose) +
|
||||
coo_matvec(data, row, col, v_dot, shape=shape, transpose=transpose)
|
||||
)
|
||||
return primals_out, tangents_out
|
||||
def _coo_matvec_jvp_vec(v_dot, data, row, col, v, *, shape, transpose):
|
||||
return coo_matvec(data, row, col, v_dot, shape=shape, transpose=transpose)
|
||||
|
||||
def _coo_matvec_transpose(ct, data, row, col, v, *, shape, transpose):
|
||||
assert not ad.is_undefined_primal(row)
|
||||
@ -491,7 +472,7 @@ def _coo_matvec_transpose(ct, data, row, col, v, *, shape, transpose):
|
||||
# return _coo_extract(row, col, jnp.outer(ct, v)), row, col, v
|
||||
return ct[row] * v[col], row, col, v
|
||||
|
||||
ad.primitive_jvps[coo_matvec_p] = _coo_matvec_jvp
|
||||
ad.defjvp(coo_matvec_p, _coo_matvec_jvp_mat, None, None, _coo_matvec_jvp_vec)
|
||||
ad.primitive_transposes[coo_matvec_p] = _coo_matvec_transpose
|
||||
xla.translations[coo_matvec_p] = xla.lower_fun(
|
||||
_coo_matvec_impl, multiple_results=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user