Add register_static decorator to tree_util to facilitate creating leafless classes.

PiperOrigin-RevId: 558937697
This commit is contained in:
jax authors 2023-08-21 16:53:48 -07:00
parent 9f5999d545
commit f2d7798a0c
3 changed files with 129 additions and 3 deletions

View File

@ -21,7 +21,7 @@ import functools
from functools import partial
import operator as op
import textwrap
from typing import Any, Callable, NamedTuple, TypeVar, Union, overload
from typing import Any, Callable, NamedTuple, Type, TypeVar, Union, overload
import warnings
from jax._src import traceback_util
@ -34,6 +34,7 @@ traceback_util.register_exclusion(__file__)
T = TypeVar("T")
U = TypeVar("U", bound=type[Any])
H = TypeVar("H", bound=Hashable)
Leaf = Any
PyTreeDef = pytree.PyTreeDef
@ -115,7 +116,6 @@ def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef:
return pytree.tuple(list(treedefs)) # type: ignore
def treedef_children(treedef: PyTreeDef) -> list[PyTreeDef]:
return treedef.children()
@ -757,6 +757,30 @@ def register_pytree_with_keys_class(cls: U) -> U:
return cls
def register_static(cls: Type[H]) -> Type[H]:
"""Registers `cls` as a pytree with no leaves.
Instances are treated as static by `jax.jit`, `jax.pmap`, etc. This can be an
alternative to labeling inputs as static using `jax.jit`'s `static_argnums`
and `static_argnames` kwargs, `jax.pmap`'s `static_broadcasted_argnums`, etc.
`cls` must be hashable, as defined in
https://docs.python.org/3/glossary.html#term-hashable.
`register_static` can be applied to subclasses of builtin hashable classes
such as `str`, like this:
```
@tree_util.register_static
class StaticStr(str):
pass
```
"""
flatten = lambda obj: ((), obj)
unflatten = lambda obj, empty_iter_children: obj
register_pytree_with_keys(cls, flatten, unflatten)
return cls
def tree_flatten_with_path(
tree: Any, is_leaf: Callable[[Any], bool] | None = None
) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]:
@ -866,7 +890,6 @@ def _child_keys(pytree: Any) -> KeyPath:
return tuple(FlattenedIndexKey(i) for i in range(num_children))
def _prefix_error(
key_path: KeyPath,
prefix_tree: Any,

View File

@ -66,6 +66,7 @@ from jax._src.tree_util import (
DictKey as DictKey,
GetAttrKey as GetAttrKey,
FlattenedIndexKey as FlattenedIndexKey,
register_static as register_static,
# TODO(ivyzheng): Remove these old APIs after June 10 2023.
register_keypaths,
AttributeKeyPathEntry,

View File

@ -13,6 +13,7 @@
# limitations under the License.
import collections
import dataclasses
import functools
import pickle
import re
@ -128,6 +129,36 @@ class FlatCache:
data, meta = tree_util.tree_flatten(tree_util.tree_unflatten(meta, data))
return FlatCache(None, leaves=data, treedef=meta)
@tree_util.register_static
class StaticInt(int):
pass
@tree_util.register_static
class StaticTuple(tuple):
pass
@tree_util.register_static
class StaticDict(dict):
pass
@tree_util.register_static
@dataclasses.dataclass
class BlackBox:
"""Stores a value but pretends to be equal to every other black box."""
value: int
def __hash__(self):
return 0
def __eq__(self, other):
return isinstance(other, BlackBox)
TREES = (
(None,),
((None,),),
@ -141,6 +172,9 @@ TREES = (
([AnObject2(3, None, [4, "foo"])],),
(Special(2, 3.),),
({"a": 1, "b": 2},),
(StaticInt(1),),
(StaticTuple((2, 3)),),
(StaticDict(foo=4, bar=5),),
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
(collections.defaultdict(dict,
[("foo", 34), ("baz", 101), ("something", -42)]),),
@ -148,6 +182,7 @@ TREES = (
(FlatCache(None),),
(FlatCache(1),),
(FlatCache({"a": [1, 2]}),),
(BlackBox(value=2),),
)
@ -165,6 +200,9 @@ TREE_STRINGS = (
"PyTreeDef([CustomNode(AnObject2[[4, 'foo']], [*, None])])",
"PyTreeDef(CustomNode(Special[None], [*, *]))",
"PyTreeDef({'a': *, 'b': *})",
"PyTreeDef(CustomNode(StaticInt[1], []))",
"PyTreeDef(CustomNode(StaticTuple[(2, 3)], []))",
"PyTreeDef(CustomNode(StaticDict[{'foo': 4, 'bar': 5}], []))",
)
# pytest expects "tree_util_test.ATuple"
@ -199,6 +237,10 @@ TREES_WITH_KEYPATH = (
(collections.defaultdict(dict,
[("foo", 34), ("baz", 101), ("something", -42)]),),
(ANamedTupleSubclass(foo="hello", bar=3.5),),
(StaticInt(1),),
(StaticTuple((2, 3)),),
(StaticDict(foo=4, bar=5),),
(BlackBox(value=2),),
)
@ -592,6 +634,66 @@ class TreeTest(jtu.JaxTestCase):
self.assertLen(leaves, 1)
class StaticTest(parameterized.TestCase):
@parameterized.parameters(
(StaticInt(2),),
(StaticTuple((2, None)),),
(StaticDict(foo=2),),
)
def test_trace_just_once_with_same_static(self, y):
num_called = 0
@jax.jit
def fn(x: int, static_y: StaticInt):
nonlocal num_called
num_called += 1
unstatic_y = type(static_y).__base__(static_y)
[y] = tree_util.tree_leaves(unstatic_y)
return x + y
fn(1, y)
fn(3, y)
self.assertEqual(num_called, 1)
def test_name(self):
self.assertEqual(StaticInt.__name__, "StaticInt")
self.assertEqual(BlackBox.__name__, "BlackBox")
@parameterized.parameters(
(StaticInt(2), StaticInt(4)),
(StaticTuple((2, None)), StaticTuple((4, None))),
(StaticDict(foo=2), StaticDict(foo=4)),
)
def test_trace_twice_with_different_static(self, y1, y2):
num_called = 0
@jax.jit
def fn(x: int, static_y: StaticInt):
nonlocal num_called
num_called += 1
unstatic_y = type(static_y).__base__(static_y)
[y] = tree_util.tree_leaves(unstatic_y)
return x + y
fn(1, y1)
fn(3, y2)
self.assertEqual(num_called, 2)
def test_trace_just_once_if_static_looks_constant(self):
num_called = 0
@jax.jit
def fn(x: int, static_y: BlackBox):
nonlocal num_called
num_called += 1
return x + static_y.value
self.assertEqual(fn(1, BlackBox(2)), 3)
self.assertEqual(fn(3, BlackBox(1)), 5)
self.assertEqual(num_called, 1)
class RavelUtilTest(jtu.JaxTestCase):
def testFloats(self):