[x64] make jax.experimental.loops consistent with default dtype

This commit is contained in:
Jake VanderPlas 2021-12-08 12:08:49 -08:00
parent 1b5630eed6
commit b49c75c0d7
2 changed files with 2 additions and 3 deletions

View File

@ -115,7 +115,6 @@ from typing import Any, Dict, List, cast
from jax import lax, core
from jax._src.lax import control_flow as lax_control_flow
from jax import tree_util
from jax import numpy as jnp
from jax.errors import UnexpectedTracerError
from jax.interpreters import partial_eval as pe
from jax._src.util import safe_map
@ -500,7 +499,7 @@ class _BoundedLoopBuilder(_LoopBuilder):
def build_output_vals(self, scope, carried_state_names, carried_tree,
init_vals, body_closed_jaxpr, body_const_vals):
arange_val = jnp.arange(self.start, stop=self.stop, step=self.step)
arange_val = np.arange(self.start, stop=self.stop, step=self.step)
return lax_control_flow.scan_p.bind(*body_const_vals, *init_vals, arange_val,
reverse=False, length=arange_val.shape[0],
jaxpr=body_closed_jaxpr,

View File

@ -62,7 +62,7 @@ class LoopsTest(jtu.JaxTestCase):
self.assertAllClose(5., jax.grad(f_op)(2.))
self.assertAllClose(5., jax.grad(f_op)(2.))
inc_batch = np.arange(5.0)
self.assertAllClose(jnp.array([f_expected(inc) for inc in inc_batch]),
self.assertAllClose(np.array([f_expected(inc) for inc in inc_batch]),
jax.vmap(f_op)(inc_batch))