document and test CustomVJPPrimal type as API symbol

This commit is contained in:
Roy Frostig 2023-07-12 09:18:05 -07:00
parent 0860c24767
commit 598e311191
2 changed files with 17 additions and 7 deletions

View File

@ -530,13 +530,16 @@ class custom_vjp(Generic[ReturnValue]):
derivative rules to detect when certain inputs, and when certain
output cotangents, are not involved in differentiation. If ``True``:
* ``fwd`` must accept, in place of each leaf value ``x`` in the pytree
comprising an argument to the original function, an object with two
attributes instead: ``value`` and ``perturbed``. The ``value`` field
is the original primal argument, and ``perturbed`` is a boolean.
The ``perturbed`` bit indicates whether the argument is involved in
differentiation (i.e., if it is ``False``, then the corresponding
Jacobian "column" is zero).
* ``fwd`` must accept, in place of each leaf value ``x`` in
the pytree comprising an argument to the original function,
an object (of type
``jax.custom_derivatives.CustomVJPPrimal``) with two
attributes instead: ``value`` and ``perturbed``. The
``value`` field is the original primal argument, and
``perturbed`` is a boolean. The ``perturbed`` bit indicates
whether the argument is involved in differentiation (i.e.,
if it is ``False``, then the corresponding Jacobian "column"
is zero).
* ``bwd`` will be passed objects representing static symbolic zeros in
its cotangent argument in correspondence with unperturbed values;
@ -621,6 +624,7 @@ class custom_vjp(Generic[ReturnValue]):
@dataclasses.dataclass
class CustomVJPPrimal:
"""Primal to a ``custom_vjp``'s forward rule when ``symbolic_zeros`` is set"""
value: Any
perturbed: bool

View File

@ -8597,12 +8597,18 @@ class CustomVJPTest(jtu.JaxTestCase):
return x, x
def fwd(x, y, z):
self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal)
self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal)
self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal)
self.assertTrue(x.perturbed)
self.assertFalse(y.perturbed)
self.assertFalse(z.perturbed)
return (x.value, x.value), None
def fwd_all(x, y, z):
self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal)
self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal)
self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal)
self.assertTrue(x.perturbed)
self.assertTrue(y.perturbed)
self.assertTrue(z.perturbed)