instantiate zeros (#2924)

fix dtype

remove TODO
This commit is contained in:
Jacob Kelly 2020-05-01 20:10:20 -04:00 committed by GitHub
parent f8fa589511
commit a821e67d60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 9 deletions

View File

@ -22,7 +22,7 @@ from jax import core
from jax.util import unzip2
from jax import ad_util
from jax.tree_util import (register_pytree_node, tree_structure,
treedef_is_leaf, tree_flatten, tree_unflatten)
treedef_is_leaf, tree_flatten, tree_unflatten, tree_map)
import jax.linear_util as lu
from jax.interpreters import xla
from jax.lax import lax
@ -59,6 +59,8 @@ def jet_fun(primals, series):
with core.new_master(JetTrace) as master:
out_primals, out_terms = yield (master, primals, series), {}
del master
out_terms = [tree_map(lambda x: onp.zeros_like(x, dtype=onp.result_type(out_primals[0])), series[0])
if s is zero_series else s for s in out_terms]
yield out_primals, out_terms
@lu.transformation

View File

@ -61,10 +61,6 @@ class JetTest(jtu.JaxTestCase):
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)
# TODO(duvenaud): Lower zero_series to actual zeros automatically.
if terms == zero_series:
terms = tree_map(np.zeros_like, expected_terms)
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)
@ -86,10 +82,6 @@ class JetTest(jtu.JaxTestCase):
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)
# TODO(duvenaud): Lower zero_series to actual zeros automatically.
if terms == zero_series:
terms = tree_map(np.zeros_like, expected_terms)
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
check_dtypes=check_dtypes)
@ -291,6 +283,21 @@ class JetTest(jtu.JaxTestCase):
series_in = (terms_b, terms_x, terms_y)
self.check_jet(np.where, primals, series_in)
def test_inst_zero(self):
def f(x):
return 2.
def g(x):
return 2. + 0 * x
x = np.ones(1)
order = 3
f_out_primals, f_out_series = jet(f, (x, ), ([np.ones_like(x) for _ in range(order)], ))
assert f_out_series is not zero_series
g_out_primals, g_out_series = jet(g, (x, ), ([np.ones_like(x) for _ in range(order)], ))
assert g_out_primals == f_out_primals
assert g_out_series == f_out_series
if __name__ == '__main__':
absltest.main()