mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
stop_gradient_p -> ad_util.py, re-enable some mypy (#2806)
This commit is contained in:
parent
903010b7b9
commit
13a17286df
@ -13,7 +13,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from .core import lattice_join, Primitive, Unit, unit, AbstractUnit
|
||||
from .core import (lattice_join, Primitive, Unit, unit, AbstractUnit,
|
||||
valid_jaxtype)
|
||||
from .tree_util import register_pytree_node
|
||||
from typing import Any, Dict
|
||||
from .util import safe_map
|
||||
@ -64,3 +65,14 @@ class Zero(object):
|
||||
zero = Zero()
|
||||
|
||||
register_pytree_node(Zero, lambda z: ((), None), lambda _, xs: zero)
|
||||
|
||||
|
||||
def _stop_gradient_impl(x):
|
||||
if not valid_jaxtype(x):
|
||||
raise TypeError("stop_gradient only works on valid JAX arrays, but "
|
||||
f"input argument is: {x}")
|
||||
return x
|
||||
|
||||
stop_gradient_p = Primitive('stop_gradient')
|
||||
stop_gradient_p.def_impl(_stop_gradient_impl)
|
||||
stop_gradient_p.def_abstract_eval(lambda x: x)
|
||||
|
@ -18,14 +18,13 @@ import inspect
|
||||
import itertools as it
|
||||
import operator as op
|
||||
|
||||
import jax
|
||||
from . import core
|
||||
from . import linear_util as lu
|
||||
from .tree_util import tree_flatten, tree_unflatten, tree_map, tree_multimap
|
||||
from .util import safe_zip, safe_map, unzip2, split_list, curry
|
||||
from .api_util import flatten_fun_nokwargs, argnums_partial, wrap_hashably
|
||||
from .abstract_arrays import raise_to_shaped
|
||||
from .ad_util import zero
|
||||
from .ad_util import zero, stop_gradient_p
|
||||
from .interpreters import partial_eval as pe
|
||||
from .interpreters import ad
|
||||
from .interpreters import batching
|
||||
@ -88,7 +87,7 @@ def stop_gradient(x):
|
||||
|
||||
def _stop_gradient(x):
|
||||
if isinstance(x, core.Tracer) or core.valid_jaxtype(x):
|
||||
return jax.lax.stop_gradient(x)
|
||||
return stop_gradient_p.bind(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
|
@ -20,6 +20,7 @@ import numpy as onp
|
||||
import jax
|
||||
from jax import core
|
||||
from jax.util import unzip2
|
||||
from jax import ad_util
|
||||
from jax.tree_util import (register_pytree_node, tree_structure,
|
||||
treedef_is_leaf, tree_flatten, tree_unflatten)
|
||||
import jax.linear_util as lu
|
||||
@ -177,7 +178,7 @@ defzero(lax.floor_p)
|
||||
defzero(lax.ceil_p)
|
||||
defzero(lax.round_p)
|
||||
defzero(lax.sign_p)
|
||||
defzero(lax.stop_gradient_p)
|
||||
defzero(ad_util.stop_gradient_p)
|
||||
|
||||
|
||||
def deflinear(prim):
|
||||
|
@ -1292,7 +1292,7 @@ def stop_gradient(x):
|
||||
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
|
||||
array(0., dtype=float32)
|
||||
"""
|
||||
return tree_map(stop_gradient_p.bind, x)
|
||||
return tree_map(ad_util.stop_gradient_p.bind, x)
|
||||
|
||||
|
||||
### convenience wrappers around traceables
|
||||
@ -4600,14 +4600,6 @@ masking.shape_rules[tie_in_p] = lambda x, y: y.shape
|
||||
masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1]
|
||||
|
||||
|
||||
### stop-gradient
|
||||
|
||||
def _stop_gradient_impl(x):
|
||||
if not core.valid_jaxtype(x):
|
||||
raise TypeError("stop_gradient only works on valid JAX arrays, but "
|
||||
f"input argument is: {x}")
|
||||
return x
|
||||
|
||||
def _stop_gradient_jvp_rule(primals, tangents):
|
||||
# if we don't call stop_gradient here, we'd only peel off one autodiff tracer
|
||||
x, = primals
|
||||
@ -4618,12 +4610,10 @@ def _stop_gradient_batch_rule(batched_args, batch_dims):
|
||||
dim, = batch_dims
|
||||
return stop_gradient(x), dim
|
||||
|
||||
stop_gradient_p = Primitive('stop_gradient')
|
||||
stop_gradient_p.def_impl(_stop_gradient_impl)
|
||||
stop_gradient_p.def_abstract_eval(_identity)
|
||||
xla.translations[stop_gradient_p] = lambda c, x: x
|
||||
ad.primitive_jvps[stop_gradient_p] = _stop_gradient_jvp_rule
|
||||
batching.primitive_batchers[stop_gradient_p] = _stop_gradient_batch_rule
|
||||
xla.translations[ad_util.stop_gradient_p] = lambda c, x: x
|
||||
ad.primitive_jvps[ad_util.stop_gradient_p] = _stop_gradient_jvp_rule
|
||||
batching.primitive_batchers[ad_util.stop_gradient_p] = _stop_gradient_batch_rule
|
||||
|
||||
|
||||
def create_token(x):
|
||||
"""Creates an XLA token value with no preconditions for sequencing effects.
|
||||
|
@ -1388,7 +1388,7 @@ def _memcpy(axis, num, src, dst, offset):
|
||||
return lax.dynamic_update_index_in_dim(dst, update, i + offset, axis)
|
||||
return fori_loop(0, num, body, dst)
|
||||
|
||||
masking.masking_rules[lax.concatenate_p] = _concat_masking_rule # type: ignore
|
||||
masking.masking_rules[lax.concatenate_p] = _concat_masking_rule
|
||||
|
||||
|
||||
def _check_tree(func_name, expected_name, actual_tree, expected_tree):
|
||||
|
Loading…
x
Reference in New Issue
Block a user