Avoid index out of range error in carry structure check

This commit is contained in:
IvyZX 2024-12-09 10:44:28 -08:00 committed by Peter Hawkins
parent 259194a69f
commit 65b6088411
3 changed files with 27 additions and 1 deletions

View File

@ -10,7 +10,14 @@ Remember to align the itemized text with the first line of an item within a list
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
-->
## jax 0.4.36
## jax 0.4.37
* Bug fixes
* Fix a bug that will throw `index out of range` error in
{func}`jax.lax.while_loop` if the user register pytree node class with
different aux data for the flatten and flatten_with_path.
## jax 0.4.36 (Dec 5, 2024)
* Breaking Changes
* This release lands "stackless", an internal change to JAX's tracing

View File

@ -376,6 +376,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
f'of the carry output is a {thing2}, so {explanation}'
for path, thing1, thing2, explanation
in equality_errors(in_carry, out_carry)]
if len(diffs) == 0:
# The trees may have different aux data but structures are the same.
return
if len(diffs) == 1:
differences = f'{diffs[0]}.\n'.capitalize()
else:
@ -393,6 +396,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}'
for path, in_aval, out_aval in zip(paths, in_avals, out_avals)
if not core.typematch(in_aval, out_aval)]
if len(diffs) == 0:
# The trees may have different aux data but structures are the same.
return
if len(diffs) == 1:
differences = f'{diffs[0]}.\n'.capitalize()
else:

View File

@ -322,6 +322,19 @@ class LaxControlFlowTest(jtu.JaxTestCase):
lax.while_loop(lambda c: True, lambda c: (True, True),
(np.bool_(True), np.float32(0.)))
def testWhileLoopCustomPytreeDiffAuxData(self):
class Node:
def __init__(self, x, y):
self.x = x
self.y = y
tree_util.register_pytree_with_keys(
Node,
lambda o: ((("x", o.x), ("y", o.y)), 'with_keys'), # flatten_with_keys
lambda _, xy: Node(xy[0], xy[1]), # unflatten (no key involved)
lambda o: ((o.x, o.y), 'without_keys'), # flatten
)
lax.while_loop(lambda o: o.x > 0., lambda c: Node(0., 0.), Node(1., 1.))
def testNestedWhileWithDynamicUpdateSlice(self):
num = 5