mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Avoid index out of range error in carry structure check
This commit is contained in:
parent
259194a69f
commit
65b6088411
@ -10,7 +10,14 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
|
||||
-->
|
||||
|
||||
## jax 0.4.36
|
||||
## jax 0.4.37
|
||||
|
||||
* Bug fixes
|
||||
* Fix a bug that will throw `index out of range` error in
|
||||
{func}`jax.lax.while_loop` if the user register pytree node class with
|
||||
different aux data for the flatten and flatten_with_path.
|
||||
|
||||
## jax 0.4.36 (Dec 5, 2024)
|
||||
|
||||
* Breaking Changes
|
||||
* This release lands "stackless", an internal change to JAX's tracing
|
||||
|
@ -376,6 +376,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
|
||||
f'of the carry output is a {thing2}, so {explanation}'
|
||||
for path, thing1, thing2, explanation
|
||||
in equality_errors(in_carry, out_carry)]
|
||||
if len(diffs) == 0:
|
||||
# The trees may have different aux data but structures are the same.
|
||||
return
|
||||
if len(diffs) == 1:
|
||||
differences = f'{diffs[0]}.\n'.capitalize()
|
||||
else:
|
||||
@ -393,6 +396,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
|
||||
f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}'
|
||||
for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
|
||||
if not core.typematch(in_aval, out_aval)]
|
||||
if len(diffs) == 0:
|
||||
# The trees may have different aux data but structures are the same.
|
||||
return
|
||||
if len(diffs) == 1:
|
||||
differences = f'{diffs[0]}.\n'.capitalize()
|
||||
else:
|
||||
|
@ -322,6 +322,19 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
lax.while_loop(lambda c: True, lambda c: (True, True),
|
||||
(np.bool_(True), np.float32(0.)))
|
||||
|
||||
def testWhileLoopCustomPytreeDiffAuxData(self):
|
||||
class Node:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
tree_util.register_pytree_with_keys(
|
||||
Node,
|
||||
lambda o: ((("x", o.x), ("y", o.y)), 'with_keys'), # flatten_with_keys
|
||||
lambda _, xy: Node(xy[0], xy[1]), # unflatten (no key involved)
|
||||
lambda o: ((o.x, o.y), 'without_keys'), # flatten
|
||||
)
|
||||
lax.while_loop(lambda o: o.x > 0., lambda c: Node(0., 0.), Node(1., 1.))
|
||||
|
||||
def testNestedWhileWithDynamicUpdateSlice(self):
|
||||
num = 5
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user