diff --git a/CHANGELOG.md b/CHANGELOG.md index b6d0f97f4..92dcfe6cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,14 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## jax 0.4.36 +## jax 0.4.37 + +* Bug fixes + * Fix a bug that will throw `index out of range` error in + {func}`jax.lax.while_loop` if the user register pytree node class with + different aux data for the flatten and flatten_with_path. + +## jax 0.4.36 (Dec 5, 2024) * Breaking Changes * This release lands "stackless", an internal change to JAX's tracing diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index f62ce2434..9b2d688c3 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -376,6 +376,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): f'of the carry output is a {thing2}, so {explanation}' for path, thing1, thing2, explanation in equality_errors(in_carry, out_carry)] + if len(diffs) == 0: + # The trees may have different aux data but structures are the same. + return if len(diffs) == 1: differences = f'{diffs[0]}.\n'.capitalize() else: @@ -393,6 +396,9 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals): f'{out_aval.str_short()}{_aval_mismatch_extra(in_aval, out_aval)}' for path, in_aval, out_aval in zip(paths, in_avals, out_avals) if not core.typematch(in_aval, out_aval)] + if len(diffs) == 0: + # The trees may have different aux data but structures are the same. + return if len(diffs) == 1: differences = f'{diffs[0]}.\n'.capitalize() else: diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index d383e4c6a..4b0420fda 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -322,6 +322,19 @@ class LaxControlFlowTest(jtu.JaxTestCase): lax.while_loop(lambda c: True, lambda c: (True, True), (np.bool_(True), np.float32(0.))) + def testWhileLoopCustomPytreeDiffAuxData(self): + class Node: + def __init__(self, x, y): + self.x = x + self.y = y + tree_util.register_pytree_with_keys( + Node, + lambda o: ((("x", o.x), ("y", o.y)), 'with_keys'), # flatten_with_keys + lambda _, xy: Node(xy[0], xy[1]), # unflatten (no key involved) + lambda o: ((o.x, o.y), 'without_keys'), # flatten + ) + lax.while_loop(lambda o: o.x > 0., lambda c: Node(0., 0.), Node(1., 1.)) + def testNestedWhileWithDynamicUpdateSlice(self): num = 5