mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add register_static
decorator to tree_util to facilitate creating leafless classes.
PiperOrigin-RevId: 558937697
This commit is contained in:
parent
9f5999d545
commit
f2d7798a0c
@ -21,7 +21,7 @@ import functools
|
||||
from functools import partial
|
||||
import operator as op
|
||||
import textwrap
|
||||
from typing import Any, Callable, NamedTuple, TypeVar, Union, overload
|
||||
from typing import Any, Callable, NamedTuple, Type, TypeVar, Union, overload
|
||||
import warnings
|
||||
|
||||
from jax._src import traceback_util
|
||||
@ -34,6 +34,7 @@ traceback_util.register_exclusion(__file__)
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U", bound=type[Any])
|
||||
H = TypeVar("H", bound=Hashable)
|
||||
|
||||
Leaf = Any
|
||||
PyTreeDef = pytree.PyTreeDef
|
||||
@ -115,7 +116,6 @@ def treedef_tuple(treedefs: Iterable[PyTreeDef]) -> PyTreeDef:
|
||||
return pytree.tuple(list(treedefs)) # type: ignore
|
||||
|
||||
|
||||
|
||||
def treedef_children(treedef: PyTreeDef) -> list[PyTreeDef]:
|
||||
return treedef.children()
|
||||
|
||||
@ -757,6 +757,30 @@ def register_pytree_with_keys_class(cls: U) -> U:
|
||||
return cls
|
||||
|
||||
|
||||
def register_static(cls: Type[H]) -> Type[H]:
|
||||
"""Registers `cls` as a pytree with no leaves.
|
||||
|
||||
Instances are treated as static by `jax.jit`, `jax.pmap`, etc. This can be an
|
||||
alternative to labeling inputs as static using `jax.jit`'s `static_argnums`
|
||||
and `static_argnames` kwargs, `jax.pmap`'s `static_broadcasted_argnums`, etc.
|
||||
|
||||
`cls` must be hashable, as defined in
|
||||
https://docs.python.org/3/glossary.html#term-hashable.
|
||||
|
||||
`register_static` can be applied to subclasses of builtin hashable classes
|
||||
such as `str`, like this:
|
||||
```
|
||||
@tree_util.register_static
|
||||
class StaticStr(str):
|
||||
pass
|
||||
```
|
||||
"""
|
||||
flatten = lambda obj: ((), obj)
|
||||
unflatten = lambda obj, empty_iter_children: obj
|
||||
register_pytree_with_keys(cls, flatten, unflatten)
|
||||
return cls
|
||||
|
||||
|
||||
def tree_flatten_with_path(
|
||||
tree: Any, is_leaf: Callable[[Any], bool] | None = None
|
||||
) -> tuple[list[tuple[KeyPath, Any]], PyTreeDef]:
|
||||
@ -866,7 +890,6 @@ def _child_keys(pytree: Any) -> KeyPath:
|
||||
return tuple(FlattenedIndexKey(i) for i in range(num_children))
|
||||
|
||||
|
||||
|
||||
def _prefix_error(
|
||||
key_path: KeyPath,
|
||||
prefix_tree: Any,
|
||||
|
@ -66,6 +66,7 @@ from jax._src.tree_util import (
|
||||
DictKey as DictKey,
|
||||
GetAttrKey as GetAttrKey,
|
||||
FlattenedIndexKey as FlattenedIndexKey,
|
||||
register_static as register_static,
|
||||
# TODO(ivyzheng): Remove these old APIs after June 10 2023.
|
||||
register_keypaths,
|
||||
AttributeKeyPathEntry,
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import dataclasses
|
||||
import functools
|
||||
import pickle
|
||||
import re
|
||||
@ -128,6 +129,36 @@ class FlatCache:
|
||||
data, meta = tree_util.tree_flatten(tree_util.tree_unflatten(meta, data))
|
||||
return FlatCache(None, leaves=data, treedef=meta)
|
||||
|
||||
|
||||
@tree_util.register_static
|
||||
class StaticInt(int):
|
||||
pass
|
||||
|
||||
|
||||
@tree_util.register_static
|
||||
class StaticTuple(tuple):
|
||||
pass
|
||||
|
||||
|
||||
@tree_util.register_static
|
||||
class StaticDict(dict):
|
||||
pass
|
||||
|
||||
|
||||
@tree_util.register_static
|
||||
@dataclasses.dataclass
|
||||
class BlackBox:
|
||||
"""Stores a value but pretends to be equal to every other black box."""
|
||||
|
||||
value: int
|
||||
|
||||
def __hash__(self):
|
||||
return 0
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, BlackBox)
|
||||
|
||||
|
||||
TREES = (
|
||||
(None,),
|
||||
((None,),),
|
||||
@ -141,6 +172,9 @@ TREES = (
|
||||
([AnObject2(3, None, [4, "foo"])],),
|
||||
(Special(2, 3.),),
|
||||
({"a": 1, "b": 2},),
|
||||
(StaticInt(1),),
|
||||
(StaticTuple((2, 3)),),
|
||||
(StaticDict(foo=4, bar=5),),
|
||||
(collections.OrderedDict([("foo", 34), ("baz", 101), ("something", -42)]),),
|
||||
(collections.defaultdict(dict,
|
||||
[("foo", 34), ("baz", 101), ("something", -42)]),),
|
||||
@ -148,6 +182,7 @@ TREES = (
|
||||
(FlatCache(None),),
|
||||
(FlatCache(1),),
|
||||
(FlatCache({"a": [1, 2]}),),
|
||||
(BlackBox(value=2),),
|
||||
)
|
||||
|
||||
|
||||
@ -165,6 +200,9 @@ TREE_STRINGS = (
|
||||
"PyTreeDef([CustomNode(AnObject2[[4, 'foo']], [*, None])])",
|
||||
"PyTreeDef(CustomNode(Special[None], [*, *]))",
|
||||
"PyTreeDef({'a': *, 'b': *})",
|
||||
"PyTreeDef(CustomNode(StaticInt[1], []))",
|
||||
"PyTreeDef(CustomNode(StaticTuple[(2, 3)], []))",
|
||||
"PyTreeDef(CustomNode(StaticDict[{'foo': 4, 'bar': 5}], []))",
|
||||
)
|
||||
|
||||
# pytest expects "tree_util_test.ATuple"
|
||||
@ -199,6 +237,10 @@ TREES_WITH_KEYPATH = (
|
||||
(collections.defaultdict(dict,
|
||||
[("foo", 34), ("baz", 101), ("something", -42)]),),
|
||||
(ANamedTupleSubclass(foo="hello", bar=3.5),),
|
||||
(StaticInt(1),),
|
||||
(StaticTuple((2, 3)),),
|
||||
(StaticDict(foo=4, bar=5),),
|
||||
(BlackBox(value=2),),
|
||||
)
|
||||
|
||||
|
||||
@ -592,6 +634,66 @@ class TreeTest(jtu.JaxTestCase):
|
||||
self.assertLen(leaves, 1)
|
||||
|
||||
|
||||
class StaticTest(parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(
|
||||
(StaticInt(2),),
|
||||
(StaticTuple((2, None)),),
|
||||
(StaticDict(foo=2),),
|
||||
)
|
||||
def test_trace_just_once_with_same_static(self, y):
|
||||
num_called = 0
|
||||
|
||||
@jax.jit
|
||||
def fn(x: int, static_y: StaticInt):
|
||||
nonlocal num_called
|
||||
num_called += 1
|
||||
unstatic_y = type(static_y).__base__(static_y)
|
||||
[y] = tree_util.tree_leaves(unstatic_y)
|
||||
return x + y
|
||||
|
||||
fn(1, y)
|
||||
fn(3, y)
|
||||
self.assertEqual(num_called, 1)
|
||||
|
||||
def test_name(self):
|
||||
self.assertEqual(StaticInt.__name__, "StaticInt")
|
||||
self.assertEqual(BlackBox.__name__, "BlackBox")
|
||||
|
||||
@parameterized.parameters(
|
||||
(StaticInt(2), StaticInt(4)),
|
||||
(StaticTuple((2, None)), StaticTuple((4, None))),
|
||||
(StaticDict(foo=2), StaticDict(foo=4)),
|
||||
)
|
||||
def test_trace_twice_with_different_static(self, y1, y2):
|
||||
num_called = 0
|
||||
|
||||
@jax.jit
|
||||
def fn(x: int, static_y: StaticInt):
|
||||
nonlocal num_called
|
||||
num_called += 1
|
||||
unstatic_y = type(static_y).__base__(static_y)
|
||||
[y] = tree_util.tree_leaves(unstatic_y)
|
||||
return x + y
|
||||
|
||||
fn(1, y1)
|
||||
fn(3, y2)
|
||||
self.assertEqual(num_called, 2)
|
||||
|
||||
def test_trace_just_once_if_static_looks_constant(self):
|
||||
num_called = 0
|
||||
|
||||
@jax.jit
|
||||
def fn(x: int, static_y: BlackBox):
|
||||
nonlocal num_called
|
||||
num_called += 1
|
||||
return x + static_y.value
|
||||
|
||||
self.assertEqual(fn(1, BlackBox(2)), 3)
|
||||
self.assertEqual(fn(3, BlackBox(1)), 5)
|
||||
self.assertEqual(num_called, 1)
|
||||
|
||||
|
||||
class RavelUtilTest(jtu.JaxTestCase):
|
||||
|
||||
def testFloats(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user