mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[XLA:Python] Improve error checking for the return value of the to_iterable function of custom pytree nodes.
PiperOrigin-RevId: 617066587
This commit is contained in:
parent
0e1b3e5ba6
commit
f759452219
@ -17,6 +17,7 @@ import dataclasses
|
||||
import functools
|
||||
import pickle
|
||||
import re
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -25,6 +26,7 @@ import jax
|
||||
from jax import tree_util
|
||||
from jax import flatten_util
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.tree_util import prefix_errors, flatten_one_level
|
||||
import jax.numpy as jnp
|
||||
|
||||
@ -42,6 +44,19 @@ ATuple2 = collections.namedtuple("ATuple2", ("foo", "bar"))
|
||||
tree_util.register_pytree_node(ATuple2, lambda o: ((o.foo,), o.bar),
|
||||
lambda bar, foo: ATuple2(foo[0], bar))
|
||||
|
||||
BadFlattenNonTuple = collections.namedtuple("ATuple2", ("foo", "bar"))
|
||||
tree_util.register_pytree_node(BadFlattenNonTuple, lambda o: "hello",
|
||||
lambda bar, foo: ATuple2(foo[0], bar))
|
||||
|
||||
BadFlattenBadArityTuple = collections.namedtuple("ATuple2", ("foo", "bar"))
|
||||
tree_util.register_pytree_node(BadFlattenBadArityTuple, lambda o: (2, 3, 4),
|
||||
lambda bar, foo: ATuple2(foo[0], bar))
|
||||
|
||||
BadFlattenNonIterableLeaves = collections.namedtuple("ATuple2", ("foo", "bar"))
|
||||
tree_util.register_pytree_node(BadFlattenNonIterableLeaves, lambda o: (7, 7),
|
||||
lambda bar, foo: ATuple2(foo[0], bar))
|
||||
|
||||
|
||||
class AnObject:
|
||||
|
||||
def __init__(self, x, y, z):
|
||||
@ -762,6 +777,37 @@ class TreeTest(jtu.JaxTestCase):
|
||||
leaves, _ = tree_util.tree_flatten_with_path(ATuple2(1, 'hi'))
|
||||
self.assertLen(leaves, 1)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 247, "Requires jaxlib>=0.4.26")
|
||||
def testBadFlattenNonTuple(self):
|
||||
t = BadFlattenNonTuple(3, 4)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The to_iterable function for a custom PyTree node should return a"
|
||||
r" \(children, aux_data\) tuple, got 'hello'",
|
||||
):
|
||||
tree_util.tree_flatten(t)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 247, "Requires jaxlib>=0.4.26")
|
||||
def testBadFlattenBadArityTuple(self):
|
||||
t = BadFlattenBadArityTuple(3, 4)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The to_iterable function for a custom PyTree node should return a"
|
||||
r" \(children, aux_data\) tuple, got \(2, 3, 4\)",
|
||||
):
|
||||
tree_util.tree_flatten(t)
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 247, "Requires jaxlib>=0.4.26")
|
||||
def testBadFlattenNonIterableLeaves(self):
|
||||
t = BadFlattenNonIterableLeaves(3, 4)
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"The to_iterable function for a custom PyTree node should return a"
|
||||
r" \(children, aux_data\) tuple where 'children' is iterable, got "
|
||||
r"\(7, 7\)",
|
||||
):
|
||||
tree_util.tree_flatten(t)
|
||||
|
||||
|
||||
class StaticTest(parameterized.TestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user