Merge pull request #24700 from jakevdp:register-dataclass

PiperOrigin-RevId: 694224696
This commit is contained in:
jax authors 2024-11-07 13:14:58 -08:00
commit 0b022aa741
3 changed files with 97 additions and 40 deletions

View File

@ -43,6 +43,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
passing compilation options to XLA. For the moment it's undocumented and
may be in flux.
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
declared inline via {func}`dataclasses.field`. See the function documentation
for examples.
## jax 0.4.35 (Oct 22, 2024)

View File

@ -927,8 +927,8 @@ def register_pytree_with_keys_class(cls: Typ) -> Typ:
@export
def register_dataclass(
nodetype: Typ,
data_fields: Sequence[str],
meta_fields: Sequence[str],
data_fields: Sequence[str] | None = None,
meta_fields: Sequence[str] | None = None,
drop_fields: Sequence[str] = (),
) -> Typ:
"""Extends the set of types that are considered internal nodes in pytrees.
@ -945,24 +945,33 @@ def register_dataclass(
attributes represent the whole of the object state, and can be passed
as keywords to the class constructor to create a copy of the object.
All defined attributes should be listed among ``meta_fields`` or ``data_fields``.
meta_fields: auxiliary data field names. These fields *must* contain static,
hashable, immutable objects, as these objects are used to generate JIT cache
keys. In particular, ``meta_fields`` cannot contain :class:`jax.Array` or
:class:`numpy.ndarray` objects.
data_fields: data field names. These fields *must* be JAX-compatible objects
such as arrays (:class:`jax.Array` or :class:`numpy.ndarray`), scalars, or
pytrees whose leaves are arrays or scalars. Note that ``data_fields`` may be
``None``, as this is recognized by JAX as an empty pytree.
meta_fields: metadata field names: these are attributes which will be treated as
{term}`static` when this pytree is passed to :func:`jax.jit`. ``meta_fields`` is
optional only if ``nodetype`` is a dataclass, in which case individual fields can
be marked static via :func:`dataclasses.field` (see examples below).
Metadata fields *must* be static, hashable, immutable objects, as these objects
are used to generate JIT cache keys. In particular, metadata fields cannot contain
:class:`jax.Array` or :class:`numpy.ndarray` objects.
data_fields: data field names: these are attributes which will be treated as non-static
when this pytree is passed to :func:`jax.jit`. ``data_fields`` is optional only if
``nodetype`` is a dataclass, in which case fields are assumed data fields unless
marked via :func:`dataclasses.field` (see examples below).
Data fields *must* be JAX-compatible objects such as arrays (:class:`jax.Array`
or :class:`numpy.ndarray`), scalars, or pytrees whose leaves are arrays or scalars.
Note that ``None`` is a valid data field, as JAX recognizes this as an empty pytree.
Returns:
The input class ``nodetype`` is returned unchanged after being added to JAX's
pytree registry. This return value allows ``register_dataclass`` to be partially
evaluated and used as a decorator as in the example below.
pytree registry, so that :func:`register_dataclass` can be used as a decorator.
Examples:
In JAX v0.4.35 or older, you must specify ``data_fields`` and ``meta_fields``
in order to use this decorator:
>>> import jax
>>> from dataclasses import dataclass
>>> from functools import partial
>>>
...
>>> @partial(jax.tree_util.register_dataclass,
... data_fields=['x', 'y'],
... meta_fields=['op'])
@ -976,7 +985,26 @@ def register_dataclass(
>>> m
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
Now that this class is registered, it can be used with functions in :mod:`jax.tree_util`:
Starting in JAX v0.4.36, the ``data_fields`` and ``meta_fields`` arguments are optional
for :func:`~dataclasses.dataclass` inputs, with fields defaulting to ``data_fields``
unless marked as static using `static` metadata in :func:`dataclasses.field`.
>>> import jax
>>> from dataclasses import dataclass, field
...
>>> @jax.tree_util.register_dataclass
... @dataclass
... class MyStruct:
... x: jax.Array # defaults to non-static data field
... y: jax.Array # defaults to non-static data field
... op: str = field(metadata=dict(static=True)) # marked as static meta field.
...
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
>>> m
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
Once this class is registered, it can be used with functions in :mod:`jax.tree` and
:mod:`jax.tree_util`:
>>> leaves, treedef = jax.tree.flatten(m)
>>> leaves
@ -987,7 +1015,8 @@ def register_dataclass(
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
In particular, this registration allows ``m`` to be passed seamlessly through code
wrapped in :func:`jax.jit` and other JAX transformations:
wrapped in :func:`jax.jit` and other JAX transformations, with ``data_fields`` being
treated as dynamic arguments, and ``meta_fields`` being treated as static arguments:
>>> @jax.jit
... def compiled_func(m):
@ -999,6 +1028,21 @@ def register_dataclass(
>>> compiled_func(m)
Array([1., 2., 3.], dtype=float32)
"""
if data_fields is None or meta_fields is None:
if (data_fields is None) != (meta_fields is None):
raise TypeError("register_dataclass: data_fields and meta_fields must both be specified"
f" when either is specified. Got {data_fields=} {meta_fields=}.")
if not dataclasses.is_dataclass(nodetype):
raise TypeError("register_dataclass: data_fields and meta_fields are required when"
f" nodetype is not a dataclass. Got {nodetype=}.")
data_fields = [f.name for f in dataclasses.fields(nodetype)
if not f.metadata.get('static', False)]
meta_fields = [f.name for f in dataclasses.fields(nodetype)
if f.metadata.get('static', False)]
assert meta_fields is not None
assert data_fields is not None
# Store inputs as immutable tuples in this scope, because we close over them
# for later evaluation. This prevents potentially confusing behavior if the
# caller were to pass in lists that are later mutated.

View File

@ -17,7 +17,6 @@ import dataclasses
import functools
import pickle
import re
from typing import TypeVar
from absl.testing import absltest
from absl.testing import parameterized
@ -142,27 +141,6 @@ class FlatCache:
data, meta = tree_util.tree_flatten(tree_util.tree_unflatten(meta, data))
return FlatCache(None, leaves=data, treedef=meta)
_T = TypeVar("_T")
# Inspired by Flax.
def pytree_node_dataclass(clz: _T, **kwargs) -> _T:
data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore
meta_fields = []
data_fields = []
for field_info in dataclasses.fields(data_clz):
is_pytree_node = field_info.metadata.get("pytree_node", True)
if is_pytree_node:
data_fields.append(field_info.name)
else:
meta_fields.append(field_info.name)
jax.tree_util.register_dataclass(
data_clz, data_fields, meta_fields
)
return data_clz
@tree_util.register_static
class StaticInt(int):
@ -231,16 +209,18 @@ TREE_STRINGS = (
"PyTreeDef(CustomNode(StaticDict[{'foo': 4, 'bar': 5}], []))",
)
@pytree_node_dataclass
@jax.tree_util.register_dataclass
@dataclasses.dataclass
class ADataclass:
x: tuple[int, int]
y: int
@pytree_node_dataclass
@jax.tree_util.register_dataclass
@dataclasses.dataclass
class ADataclassWithMeta:
x: tuple[int, int]
y: int
z: int = dataclasses.field(metadata={"pytree_node": False})
z: int = dataclasses.field(metadata={"static": True})
TREES += (
(ADataclass(x=(1, 2), y=3),),
@ -1294,6 +1274,36 @@ class TreeAliasTest(jtu.JaxTestCase):
class RegistrationTest(jtu.JaxTestCase):
def test_register_dataclass_with_field_specifier(self):
@tree_util.register_dataclass
@dataclasses.dataclass
class Foo:
x: int
y: int = dataclasses.field(metadata=dict(static=True))
f = Foo(2, 3)
self.assertLen(jax.tree.leaves(f), 1)
def test_register_dataclass_field_errors(self):
class Foo: # not a dataclass
x: int
y: int
msg = ("register_dataclass: data_fields and meta_fields are required"
" when nodetype is not a dataclass. Got nodetype=<class '.*Foo'>")
with self.assertRaisesRegex(TypeError, msg):
tree_util.register_dataclass(Foo)
msg = ("register_dataclass: data_fields and meta_fields must both be specified"\
r" when either is specified. Got data_fields=\['x'\] meta_fields=None.")
with self.assertRaisesRegex(TypeError, msg):
tree_util.register_dataclass(Foo, data_fields=['x'])
msg = ("register_dataclass: data_fields and meta_fields must both be specified"\
r" when either is specified. Got data_fields=None meta_fields=\['y'\].")
with self.assertRaisesRegex(TypeError, msg):
tree_util.register_dataclass(Foo, meta_fields=['y'])
def test_register_dataclass_missing_fields(self):
@dataclasses.dataclass
class Foo: