mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
f8fa589511
commit
a821e67d60
@ -22,7 +22,7 @@ from jax import core
|
||||
from jax.util import unzip2
|
||||
from jax import ad_util
|
||||
from jax.tree_util import (register_pytree_node, tree_structure,
|
||||
treedef_is_leaf, tree_flatten, tree_unflatten)
|
||||
treedef_is_leaf, tree_flatten, tree_unflatten, tree_map)
|
||||
import jax.linear_util as lu
|
||||
from jax.interpreters import xla
|
||||
from jax.lax import lax
|
||||
@ -59,6 +59,8 @@ def jet_fun(primals, series):
|
||||
with core.new_master(JetTrace) as master:
|
||||
out_primals, out_terms = yield (master, primals, series), {}
|
||||
del master
|
||||
out_terms = [tree_map(lambda x: onp.zeros_like(x, dtype=onp.result_type(out_primals[0])), series[0])
|
||||
if s is zero_series else s for s in out_terms]
|
||||
yield out_primals, out_terms
|
||||
|
||||
@lu.transformation
|
||||
|
@ -61,10 +61,6 @@ class JetTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
|
||||
check_dtypes=check_dtypes)
|
||||
|
||||
# TODO(duvenaud): Lower zero_series to actual zeros automatically.
|
||||
if terms == zero_series:
|
||||
terms = tree_map(np.zeros_like, expected_terms)
|
||||
|
||||
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
|
||||
check_dtypes=check_dtypes)
|
||||
|
||||
@ -86,10 +82,6 @@ class JetTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(y, expected_y, atol=atol, rtol=rtol,
|
||||
check_dtypes=check_dtypes)
|
||||
|
||||
# TODO(duvenaud): Lower zero_series to actual zeros automatically.
|
||||
if terms == zero_series:
|
||||
terms = tree_map(np.zeros_like, expected_terms)
|
||||
|
||||
self.assertAllClose(terms, expected_terms, atol=atol, rtol=rtol,
|
||||
check_dtypes=check_dtypes)
|
||||
|
||||
@ -291,6 +283,21 @@ class JetTest(jtu.JaxTestCase):
|
||||
series_in = (terms_b, terms_x, terms_y)
|
||||
self.check_jet(np.where, primals, series_in)
|
||||
|
||||
def test_inst_zero(self):
|
||||
def f(x):
|
||||
return 2.
|
||||
def g(x):
|
||||
return 2. + 0 * x
|
||||
x = np.ones(1)
|
||||
order = 3
|
||||
f_out_primals, f_out_series = jet(f, (x, ), ([np.ones_like(x) for _ in range(order)], ))
|
||||
assert f_out_series is not zero_series
|
||||
|
||||
g_out_primals, g_out_series = jet(g, (x, ), ([np.ones_like(x) for _ in range(order)], ))
|
||||
|
||||
assert g_out_primals == f_out_primals
|
||||
assert g_out_series == f_out_series
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user