Merge pull request #19760 from Blair-Johnson:fix-pytree-grads-sparse

PiperOrigin-RevId: 657258194
This commit is contained in:
jax authors 2024-07-29 10:59:57 -07:00
commit 9beb4f1474
2 changed files with 97 additions and 2 deletions

View File

@ -67,8 +67,12 @@ def flatten_fun_for_sparse_ad(fun, argnums: int | tuple[int, ...], args: tuple[A
return f_recons(grad_out)
def postprocess_gradients(grads_out):
out = [reconstruct(*args) for args in safe_zip(argnums_flat1, grads_out)]
return out[0] if isinstance(argnums, int) else out
leaf_grads = [None] * tree1.num_leaves
for i, grad in safe_zip(argnums_flat1, grads_out):
leaf_grads[i] = reconstruct(i, grad)
grad_tree = tree_util.tree_unflatten(tree1, leaf_grads)
grad_tree = tuple(filter(lambda x: jax.tree.leaves(x), grad_tree))
return grad_tree[0] if len(grad_tree) == 1 else grad_tree
return fun_flat, argnums_flat, args_flat, postprocess_gradients

View File

@ -712,6 +712,97 @@ class SparseGradTest(sptu.SparseTestCase):
self.assertAllClose(jac_dense(f, argnums=1, has_aux=has_aux)(X, y),
jac_sparse(f, argnums=1, has_aux=has_aux)(Xsp, y), rtol=rtol)
@jtu.sample_product(has_aux=[True, False],
deep=[True,False],
arg0=[True,False],
bias=[True,False])
def test_sparse_pytree_grad(self, has_aux, deep, arg0, bias):
rng_sparse = sptu.rand_sparse(self.rng())
rng = jtu.rand_default(self.rng())
y = rng(5, "float32")
X = rng_sparse((10, 5), "float32")
b = rng(10, "float32")
Xsp = sparse.BCOO.fromdense(X)
Xtree_sp = {'deep':{'X':Xsp},
'X':Xsp,
'list':[None,(b,None)]}
Xtree_de = {'deep':{'X':X},
'X':X,
'list':[None,(b,None)]}
def f(Xtree, y):
if deep:
out = Xtree['deep']['X'] @ y
else:
out = Xtree['X'] @ y
# Other grad variables
if bias:
out += Xtree['list'][1][0]
out = jnp.sum(out)
if has_aux:
return out, {'y': y.shape}
else:
return out
def g(y, Xtree):
if deep:
out = Xtree['deep']['X'] @ y
else:
out = Xtree['X'] @ y
# Other grad variables
if bias:
out += Xtree['list'][1][0]
out = jnp.sum(out)
if has_aux:
return out, {'y': y.shape}
return out
with self.subTest("wrt sparse"):
# Argument ordering
if arg0:
grad_de = jax.grad(f, argnums=0, has_aux=has_aux)(Xtree_de, y)
grad_sp = sparse.grad(f, argnums=0, has_aux=has_aux)(Xtree_sp, y)
else:
grad_de = jax.grad(g, argnums=1, has_aux=has_aux)(y, Xtree_de)
grad_sp = sparse.grad(g, argnums=1, has_aux=has_aux)(y, Xtree_sp)
if has_aux:
grad_de, aux_de = grad_de
grad_sp, aux_sp = grad_sp
self.assertAllClose(aux_de, aux_sp)
# Pytree structure
is_bcoo = lambda x: isinstance(x, sparse.bcoo.BCOO)
grad_densified = jax.tree_util.tree_map(sparse.todense, grad_sp,
is_leaf=is_bcoo)
self.assertEqual(jax.tree_util.tree_structure(grad_de),
jax.tree_util.tree_structure(grad_densified))
# Depth in tree
if deep:
grad_sp_arr = grad_sp['deep']['X']
grad_de_arr = grad_de['deep']['X']
else:
grad_sp_arr = grad_sp['X']
grad_de_arr = grad_de['X']
self.assertIsInstance(grad_sp_arr, sparse.BCOO)
self.assertAllClose(grad_sp_arr.data,
sparse_bcoo._bcoo_extract(grad_sp_arr.indices,
grad_de_arr))
# Other grad variables
if bias:
self.assertAllClose(grad_sp['list'][1][0],
grad_de['list'][1][0])
with self.subTest("wrt dense"):
# Argument ordering
if arg0:
self.assertAllClose(jax.grad(f, argnums=1, has_aux=has_aux)(Xtree_de, y),
sparse.grad(f, argnums=1, has_aux=has_aux)(Xtree_sp, y))
else:
self.assertAllClose(jax.grad(g, argnums=0, has_aux=has_aux)(y, Xtree_de),
sparse.grad(g, argnums=0, has_aux=has_aux)(y, Xtree_sp))
class SparseObjectTest(sptu.SparseTestCase):
@parameterized.named_parameters(