[XLA:Python] Improve error checking for the return value of the to_iterable function of custom pytree nodes.

PiperOrigin-RevId: 617066587
This commit is contained in:
Peter Hawkins 2024-03-18 23:23:17 -07:00 committed by jax authors
parent 0e1b3e5ba6
commit f759452219

View File

@ -17,6 +17,7 @@ import dataclasses
import functools
import pickle
import re
import unittest
from absl.testing import absltest
from absl.testing import parameterized
@ -25,6 +26,7 @@ import jax
from jax import tree_util
from jax import flatten_util
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
from jax._src.tree_util import prefix_errors, flatten_one_level
import jax.numpy as jnp
@ -42,6 +44,19 @@ ATuple2 = collections.namedtuple("ATuple2", ("foo", "bar"))
tree_util.register_pytree_node(ATuple2, lambda o: ((o.foo,), o.bar),
lambda bar, foo: ATuple2(foo[0], bar))
BadFlattenNonTuple = collections.namedtuple("ATuple2", ("foo", "bar"))
tree_util.register_pytree_node(BadFlattenNonTuple, lambda o: "hello",
lambda bar, foo: ATuple2(foo[0], bar))
BadFlattenBadArityTuple = collections.namedtuple("ATuple2", ("foo", "bar"))
tree_util.register_pytree_node(BadFlattenBadArityTuple, lambda o: (2, 3, 4),
lambda bar, foo: ATuple2(foo[0], bar))
BadFlattenNonIterableLeaves = collections.namedtuple("ATuple2", ("foo", "bar"))
tree_util.register_pytree_node(BadFlattenNonIterableLeaves, lambda o: (7, 7),
lambda bar, foo: ATuple2(foo[0], bar))
class AnObject:
def __init__(self, x, y, z):
@ -762,6 +777,37 @@ class TreeTest(jtu.JaxTestCase):
leaves, _ = tree_util.tree_flatten_with_path(ATuple2(1, 'hi'))
self.assertLen(leaves, 1)
@unittest.skipIf(xla_extension_version < 247, "Requires jaxlib>=0.4.26")
def testBadFlattenNonTuple(self):
t = BadFlattenNonTuple(3, 4)
with self.assertRaisesRegex(
ValueError,
"The to_iterable function for a custom PyTree node should return a"
r" \(children, aux_data\) tuple, got 'hello'",
):
tree_util.tree_flatten(t)
@unittest.skipIf(xla_extension_version < 247, "Requires jaxlib>=0.4.26")
def testBadFlattenBadArityTuple(self):
t = BadFlattenBadArityTuple(3, 4)
with self.assertRaisesRegex(
ValueError,
"The to_iterable function for a custom PyTree node should return a"
r" \(children, aux_data\) tuple, got \(2, 3, 4\)",
):
tree_util.tree_flatten(t)
@unittest.skipIf(xla_extension_version < 247, "Requires jaxlib>=0.4.26")
def testBadFlattenNonIterableLeaves(self):
t = BadFlattenNonIterableLeaves(3, 4)
with self.assertRaisesRegex(
ValueError,
"The to_iterable function for a custom PyTree node should return a"
r" \(children, aux_data\) tuple where 'children' is iterable, got "
r"\(7, 7\)",
):
tree_util.tree_flatten(t)
class StaticTest(parameterized.TestCase):