mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
[x64] make jax.experimental.loops consistent with default dtype
This commit is contained in:
parent
1b5630eed6
commit
b49c75c0d7
@ -115,7 +115,6 @@ from typing import Any, Dict, List, cast
|
||||
from jax import lax, core
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax import tree_util
|
||||
from jax import numpy as jnp
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.util import safe_map
|
||||
@ -500,7 +499,7 @@ class _BoundedLoopBuilder(_LoopBuilder):
|
||||
|
||||
def build_output_vals(self, scope, carried_state_names, carried_tree,
|
||||
init_vals, body_closed_jaxpr, body_const_vals):
|
||||
arange_val = jnp.arange(self.start, stop=self.stop, step=self.step)
|
||||
arange_val = np.arange(self.start, stop=self.stop, step=self.step)
|
||||
return lax_control_flow.scan_p.bind(*body_const_vals, *init_vals, arange_val,
|
||||
reverse=False, length=arange_val.shape[0],
|
||||
jaxpr=body_closed_jaxpr,
|
||||
|
@ -62,7 +62,7 @@ class LoopsTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(5., jax.grad(f_op)(2.))
|
||||
self.assertAllClose(5., jax.grad(f_op)(2.))
|
||||
inc_batch = np.arange(5.0)
|
||||
self.assertAllClose(jnp.array([f_expected(inc) for inc in inc_batch]),
|
||||
self.assertAllClose(np.array([f_expected(inc) for inc in inc_batch]),
|
||||
jax.vmap(f_op)(inc_batch))
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user