stop_gradient_p -> ad_util.py, re-enable some mypy (#2806)

This commit is contained in:
Matthew Johnson 2020-04-23 13:12:24 -07:00 committed by GitHub
parent 903010b7b9
commit 13a17286df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 23 additions and 23 deletions

View File

@ -13,7 +13,8 @@
# limitations under the License.
from .core import lattice_join, Primitive, Unit, unit, AbstractUnit
from .core import (lattice_join, Primitive, Unit, unit, AbstractUnit,
valid_jaxtype)
from .tree_util import register_pytree_node
from typing import Any, Dict
from .util import safe_map
@ -64,3 +65,14 @@ class Zero(object):
zero = Zero()
register_pytree_node(Zero, lambda z: ((), None), lambda _, xs: zero)
def _stop_gradient_impl(x):
if not valid_jaxtype(x):
raise TypeError("stop_gradient only works on valid JAX arrays, but "
f"input argument is: {x}")
return x
stop_gradient_p = Primitive('stop_gradient')
stop_gradient_p.def_impl(_stop_gradient_impl)
stop_gradient_p.def_abstract_eval(lambda x: x)

View File

@ -18,14 +18,13 @@ import inspect
import itertools as it
import operator as op
import jax
from . import core
from . import linear_util as lu
from .tree_util import tree_flatten, tree_unflatten, tree_map, tree_multimap
from .util import safe_zip, safe_map, unzip2, split_list, curry
from .api_util import flatten_fun_nokwargs, argnums_partial, wrap_hashably
from .abstract_arrays import raise_to_shaped
from .ad_util import zero
from .ad_util import zero, stop_gradient_p
from .interpreters import partial_eval as pe
from .interpreters import ad
from .interpreters import batching
@ -88,7 +87,7 @@ def stop_gradient(x):
def _stop_gradient(x):
if isinstance(x, core.Tracer) or core.valid_jaxtype(x):
return jax.lax.stop_gradient(x)
return stop_gradient_p.bind(x)
else:
return x

View File

@ -20,6 +20,7 @@ import numpy as onp
import jax
from jax import core
from jax.util import unzip2
from jax import ad_util
from jax.tree_util import (register_pytree_node, tree_structure,
treedef_is_leaf, tree_flatten, tree_unflatten)
import jax.linear_util as lu
@ -177,7 +178,7 @@ defzero(lax.floor_p)
defzero(lax.ceil_p)
defzero(lax.round_p)
defzero(lax.sign_p)
defzero(lax.stop_gradient_p)
defzero(ad_util.stop_gradient_p)
def deflinear(prim):

View File

@ -1292,7 +1292,7 @@ def stop_gradient(x):
>>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.)
array(0., dtype=float32)
"""
return tree_map(stop_gradient_p.bind, x)
return tree_map(ad_util.stop_gradient_p.bind, x)
### convenience wrappers around traceables
@ -4600,14 +4600,6 @@ masking.shape_rules[tie_in_p] = lambda x, y: y.shape
masking.masking_rules[tie_in_p] = lambda vals, logical_shapes: vals[1]
### stop-gradient
def _stop_gradient_impl(x):
if not core.valid_jaxtype(x):
raise TypeError("stop_gradient only works on valid JAX arrays, but "
f"input argument is: {x}")
return x
def _stop_gradient_jvp_rule(primals, tangents):
# if we don't call stop_gradient here, we'd only peel off one autodiff tracer
x, = primals
@ -4618,12 +4610,10 @@ def _stop_gradient_batch_rule(batched_args, batch_dims):
dim, = batch_dims
return stop_gradient(x), dim
stop_gradient_p = Primitive('stop_gradient')
stop_gradient_p.def_impl(_stop_gradient_impl)
stop_gradient_p.def_abstract_eval(_identity)
xla.translations[stop_gradient_p] = lambda c, x: x
ad.primitive_jvps[stop_gradient_p] = _stop_gradient_jvp_rule
batching.primitive_batchers[stop_gradient_p] = _stop_gradient_batch_rule
xla.translations[ad_util.stop_gradient_p] = lambda c, x: x
ad.primitive_jvps[ad_util.stop_gradient_p] = _stop_gradient_jvp_rule
batching.primitive_batchers[ad_util.stop_gradient_p] = _stop_gradient_batch_rule
def create_token(x):
"""Creates an XLA token value with no preconditions for sequencing effects.

View File

@ -1388,7 +1388,7 @@ def _memcpy(axis, num, src, dst, offset):
return lax.dynamic_update_index_in_dim(dst, update, i + offset, axis)
return fori_loop(0, num, body, dst)
masking.masking_rules[lax.concatenate_p] = _concat_masking_rule # type: ignore
masking.masking_rules[lax.concatenate_p] = _concat_masking_rule
def _check_tree(func_name, expected_name, actual_tree, expected_tree):

View File

@ -12,5 +12,3 @@ ignore_missing_imports = True
ignore_missing_imports = True
[mypy-jax.interpreters.autospmd]
ignore_errors = True
[mypy-jax.lax.lax_parallel]
ignore_errors = True