mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
document and test CustomVJPPrimal
type as API symbol
This commit is contained in:
parent
0860c24767
commit
598e311191
@ -530,13 +530,16 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
derivative rules to detect when certain inputs, and when certain
|
||||
output cotangents, are not involved in differentiation. If ``True``:
|
||||
|
||||
* ``fwd`` must accept, in place of each leaf value ``x`` in the pytree
|
||||
comprising an argument to the original function, an object with two
|
||||
attributes instead: ``value`` and ``perturbed``. The ``value`` field
|
||||
is the original primal argument, and ``perturbed`` is a boolean.
|
||||
The ``perturbed`` bit indicates whether the argument is involved in
|
||||
differentiation (i.e., if it is ``False``, then the corresponding
|
||||
Jacobian "column" is zero).
|
||||
* ``fwd`` must accept, in place of each leaf value ``x`` in
|
||||
the pytree comprising an argument to the original function,
|
||||
an object (of type
|
||||
``jax.custom_derivatives.CustomVJPPrimal``) with two
|
||||
attributes instead: ``value`` and ``perturbed``. The
|
||||
``value`` field is the original primal argument, and
|
||||
``perturbed`` is a boolean. The ``perturbed`` bit indicates
|
||||
whether the argument is involved in differentiation (i.e.,
|
||||
if it is ``False``, then the corresponding Jacobian "column"
|
||||
is zero).
|
||||
|
||||
* ``bwd`` will be passed objects representing static symbolic zeros in
|
||||
its cotangent argument in correspondence with unperturbed values;
|
||||
@ -621,6 +624,7 @@ class custom_vjp(Generic[ReturnValue]):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CustomVJPPrimal:
|
||||
"""Primal to a ``custom_vjp``'s forward rule when ``symbolic_zeros`` is set"""
|
||||
value: Any
|
||||
perturbed: bool
|
||||
|
||||
|
@ -8597,12 +8597,18 @@ class CustomVJPTest(jtu.JaxTestCase):
|
||||
return x, x
|
||||
|
||||
def fwd(x, y, z):
|
||||
self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal)
|
||||
self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal)
|
||||
self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal)
|
||||
self.assertTrue(x.perturbed)
|
||||
self.assertFalse(y.perturbed)
|
||||
self.assertFalse(z.perturbed)
|
||||
return (x.value, x.value), None
|
||||
|
||||
def fwd_all(x, y, z):
|
||||
self.assertIsInstance(x, jax.custom_derivatives.CustomVJPPrimal)
|
||||
self.assertIsInstance(y, jax.custom_derivatives.CustomVJPPrimal)
|
||||
self.assertIsInstance(z, jax.custom_derivatives.CustomVJPPrimal)
|
||||
self.assertTrue(x.perturbed)
|
||||
self.assertTrue(y.perturbed)
|
||||
self.assertTrue(z.perturbed)
|
||||
|
Loading…
x
Reference in New Issue
Block a user