mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #19760 from Blair-Johnson:fix-pytree-grads-sparse
PiperOrigin-RevId: 657258194
This commit is contained in:
commit
9beb4f1474
@ -67,8 +67,12 @@ def flatten_fun_for_sparse_ad(fun, argnums: int | tuple[int, ...], args: tuple[A
|
||||
return f_recons(grad_out)
|
||||
|
||||
def postprocess_gradients(grads_out):
|
||||
out = [reconstruct(*args) for args in safe_zip(argnums_flat1, grads_out)]
|
||||
return out[0] if isinstance(argnums, int) else out
|
||||
leaf_grads = [None] * tree1.num_leaves
|
||||
for i, grad in safe_zip(argnums_flat1, grads_out):
|
||||
leaf_grads[i] = reconstruct(i, grad)
|
||||
grad_tree = tree_util.tree_unflatten(tree1, leaf_grads)
|
||||
grad_tree = tuple(filter(lambda x: jax.tree.leaves(x), grad_tree))
|
||||
return grad_tree[0] if len(grad_tree) == 1 else grad_tree
|
||||
|
||||
return fun_flat, argnums_flat, args_flat, postprocess_gradients
|
||||
|
||||
|
@ -712,6 +712,97 @@ class SparseGradTest(sptu.SparseTestCase):
|
||||
self.assertAllClose(jac_dense(f, argnums=1, has_aux=has_aux)(X, y),
|
||||
jac_sparse(f, argnums=1, has_aux=has_aux)(Xsp, y), rtol=rtol)
|
||||
|
||||
@jtu.sample_product(has_aux=[True, False],
|
||||
deep=[True,False],
|
||||
arg0=[True,False],
|
||||
bias=[True,False])
|
||||
def test_sparse_pytree_grad(self, has_aux, deep, arg0, bias):
|
||||
rng_sparse = sptu.rand_sparse(self.rng())
|
||||
rng = jtu.rand_default(self.rng())
|
||||
|
||||
y = rng(5, "float32")
|
||||
X = rng_sparse((10, 5), "float32")
|
||||
b = rng(10, "float32")
|
||||
Xsp = sparse.BCOO.fromdense(X)
|
||||
Xtree_sp = {'deep':{'X':Xsp},
|
||||
'X':Xsp,
|
||||
'list':[None,(b,None)]}
|
||||
Xtree_de = {'deep':{'X':X},
|
||||
'X':X,
|
||||
'list':[None,(b,None)]}
|
||||
|
||||
def f(Xtree, y):
|
||||
if deep:
|
||||
out = Xtree['deep']['X'] @ y
|
||||
else:
|
||||
out = Xtree['X'] @ y
|
||||
# Other grad variables
|
||||
if bias:
|
||||
out += Xtree['list'][1][0]
|
||||
out = jnp.sum(out)
|
||||
if has_aux:
|
||||
return out, {'y': y.shape}
|
||||
else:
|
||||
return out
|
||||
|
||||
def g(y, Xtree):
|
||||
if deep:
|
||||
out = Xtree['deep']['X'] @ y
|
||||
else:
|
||||
out = Xtree['X'] @ y
|
||||
# Other grad variables
|
||||
if bias:
|
||||
out += Xtree['list'][1][0]
|
||||
out = jnp.sum(out)
|
||||
if has_aux:
|
||||
return out, {'y': y.shape}
|
||||
return out
|
||||
|
||||
with self.subTest("wrt sparse"):
|
||||
# Argument ordering
|
||||
if arg0:
|
||||
grad_de = jax.grad(f, argnums=0, has_aux=has_aux)(Xtree_de, y)
|
||||
grad_sp = sparse.grad(f, argnums=0, has_aux=has_aux)(Xtree_sp, y)
|
||||
else:
|
||||
grad_de = jax.grad(g, argnums=1, has_aux=has_aux)(y, Xtree_de)
|
||||
grad_sp = sparse.grad(g, argnums=1, has_aux=has_aux)(y, Xtree_sp)
|
||||
|
||||
if has_aux:
|
||||
grad_de, aux_de = grad_de
|
||||
grad_sp, aux_sp = grad_sp
|
||||
self.assertAllClose(aux_de, aux_sp)
|
||||
|
||||
# Pytree structure
|
||||
is_bcoo = lambda x: isinstance(x, sparse.bcoo.BCOO)
|
||||
grad_densified = jax.tree_util.tree_map(sparse.todense, grad_sp,
|
||||
is_leaf=is_bcoo)
|
||||
self.assertEqual(jax.tree_util.tree_structure(grad_de),
|
||||
jax.tree_util.tree_structure(grad_densified))
|
||||
|
||||
# Depth in tree
|
||||
if deep:
|
||||
grad_sp_arr = grad_sp['deep']['X']
|
||||
grad_de_arr = grad_de['deep']['X']
|
||||
else:
|
||||
grad_sp_arr = grad_sp['X']
|
||||
grad_de_arr = grad_de['X']
|
||||
self.assertIsInstance(grad_sp_arr, sparse.BCOO)
|
||||
self.assertAllClose(grad_sp_arr.data,
|
||||
sparse_bcoo._bcoo_extract(grad_sp_arr.indices,
|
||||
grad_de_arr))
|
||||
# Other grad variables
|
||||
if bias:
|
||||
self.assertAllClose(grad_sp['list'][1][0],
|
||||
grad_de['list'][1][0])
|
||||
|
||||
with self.subTest("wrt dense"):
|
||||
# Argument ordering
|
||||
if arg0:
|
||||
self.assertAllClose(jax.grad(f, argnums=1, has_aux=has_aux)(Xtree_de, y),
|
||||
sparse.grad(f, argnums=1, has_aux=has_aux)(Xtree_sp, y))
|
||||
else:
|
||||
self.assertAllClose(jax.grad(g, argnums=0, has_aux=has_aux)(y, Xtree_de),
|
||||
sparse.grad(g, argnums=0, has_aux=has_aux)(y, Xtree_sp))
|
||||
|
||||
class SparseObjectTest(sptu.SparseTestCase):
|
||||
@parameterized.named_parameters(
|
||||
|
Loading…
x
Reference in New Issue
Block a user