mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #24700 from jakevdp:register-dataclass
PiperOrigin-RevId: 694224696
This commit is contained in:
commit
0b022aa741
@ -43,6 +43,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* {func}`jax.jit` got a new `compiler_options: dict[str, Any]` argument, for
|
||||
passing compilation options to XLA. For the moment it's undocumented and
|
||||
may be in flux.
|
||||
* {func}`jax.tree_util.register_dataclass` now allows metadata fields to be
|
||||
declared inline via {func}`dataclasses.field`. See the function documentation
|
||||
for examples.
|
||||
|
||||
## jax 0.4.35 (Oct 22, 2024)
|
||||
|
||||
|
@ -927,8 +927,8 @@ def register_pytree_with_keys_class(cls: Typ) -> Typ:
|
||||
@export
|
||||
def register_dataclass(
|
||||
nodetype: Typ,
|
||||
data_fields: Sequence[str],
|
||||
meta_fields: Sequence[str],
|
||||
data_fields: Sequence[str] | None = None,
|
||||
meta_fields: Sequence[str] | None = None,
|
||||
drop_fields: Sequence[str] = (),
|
||||
) -> Typ:
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
@ -945,24 +945,33 @@ def register_dataclass(
|
||||
attributes represent the whole of the object state, and can be passed
|
||||
as keywords to the class constructor to create a copy of the object.
|
||||
All defined attributes should be listed among ``meta_fields`` or ``data_fields``.
|
||||
meta_fields: auxiliary data field names. These fields *must* contain static,
|
||||
hashable, immutable objects, as these objects are used to generate JIT cache
|
||||
keys. In particular, ``meta_fields`` cannot contain :class:`jax.Array` or
|
||||
:class:`numpy.ndarray` objects.
|
||||
data_fields: data field names. These fields *must* be JAX-compatible objects
|
||||
such as arrays (:class:`jax.Array` or :class:`numpy.ndarray`), scalars, or
|
||||
pytrees whose leaves are arrays or scalars. Note that ``data_fields`` may be
|
||||
``None``, as this is recognized by JAX as an empty pytree.
|
||||
meta_fields: metadata field names: these are attributes which will be treated as
|
||||
{term}`static` when this pytree is passed to :func:`jax.jit`. ``meta_fields`` is
|
||||
optional only if ``nodetype`` is a dataclass, in which case individual fields can
|
||||
be marked static via :func:`dataclasses.field` (see examples below).
|
||||
Metadata fields *must* be static, hashable, immutable objects, as these objects
|
||||
are used to generate JIT cache keys. In particular, metadata fields cannot contain
|
||||
:class:`jax.Array` or :class:`numpy.ndarray` objects.
|
||||
data_fields: data field names: these are attributes which will be treated as non-static
|
||||
when this pytree is passed to :func:`jax.jit`. ``data_fields`` is optional only if
|
||||
``nodetype`` is a dataclass, in which case fields are assumed data fields unless
|
||||
marked via :func:`dataclasses.field` (see examples below).
|
||||
Data fields *must* be JAX-compatible objects such as arrays (:class:`jax.Array`
|
||||
or :class:`numpy.ndarray`), scalars, or pytrees whose leaves are arrays or scalars.
|
||||
Note that ``None`` is a valid data field, as JAX recognizes this as an empty pytree.
|
||||
|
||||
Returns:
|
||||
The input class ``nodetype`` is returned unchanged after being added to JAX's
|
||||
pytree registry. This return value allows ``register_dataclass`` to be partially
|
||||
evaluated and used as a decorator as in the example below.
|
||||
pytree registry, so that :func:`register_dataclass` can be used as a decorator.
|
||||
|
||||
Examples:
|
||||
In JAX v0.4.35 or older, you must specify ``data_fields`` and ``meta_fields``
|
||||
in order to use this decorator:
|
||||
|
||||
>>> import jax
|
||||
>>> from dataclasses import dataclass
|
||||
>>> from functools import partial
|
||||
>>>
|
||||
...
|
||||
>>> @partial(jax.tree_util.register_dataclass,
|
||||
... data_fields=['x', 'y'],
|
||||
... meta_fields=['op'])
|
||||
@ -976,7 +985,26 @@ def register_dataclass(
|
||||
>>> m
|
||||
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
|
||||
|
||||
Now that this class is registered, it can be used with functions in :mod:`jax.tree_util`:
|
||||
Starting in JAX v0.4.36, the ``data_fields`` and ``meta_fields`` arguments are optional
|
||||
for :func:`~dataclasses.dataclass` inputs, with fields defaulting to ``data_fields``
|
||||
unless marked as static using `static` metadata in :func:`dataclasses.field`.
|
||||
|
||||
>>> import jax
|
||||
>>> from dataclasses import dataclass, field
|
||||
...
|
||||
>>> @jax.tree_util.register_dataclass
|
||||
... @dataclass
|
||||
... class MyStruct:
|
||||
... x: jax.Array # defaults to non-static data field
|
||||
... y: jax.Array # defaults to non-static data field
|
||||
... op: str = field(metadata=dict(static=True)) # marked as static meta field.
|
||||
...
|
||||
>>> m = MyStruct(x=jnp.ones(3), y=jnp.arange(3), op='add')
|
||||
>>> m
|
||||
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
|
||||
|
||||
Once this class is registered, it can be used with functions in :mod:`jax.tree` and
|
||||
:mod:`jax.tree_util`:
|
||||
|
||||
>>> leaves, treedef = jax.tree.flatten(m)
|
||||
>>> leaves
|
||||
@ -987,7 +1015,8 @@ def register_dataclass(
|
||||
MyStruct(x=Array([1., 1., 1.], dtype=float32), y=Array([0, 1, 2], dtype=int32), op='add')
|
||||
|
||||
In particular, this registration allows ``m`` to be passed seamlessly through code
|
||||
wrapped in :func:`jax.jit` and other JAX transformations:
|
||||
wrapped in :func:`jax.jit` and other JAX transformations, with ``data_fields`` being
|
||||
treated as dynamic arguments, and ``meta_fields`` being treated as static arguments:
|
||||
|
||||
>>> @jax.jit
|
||||
... def compiled_func(m):
|
||||
@ -999,6 +1028,21 @@ def register_dataclass(
|
||||
>>> compiled_func(m)
|
||||
Array([1., 2., 3.], dtype=float32)
|
||||
"""
|
||||
if data_fields is None or meta_fields is None:
|
||||
if (data_fields is None) != (meta_fields is None):
|
||||
raise TypeError("register_dataclass: data_fields and meta_fields must both be specified"
|
||||
f" when either is specified. Got {data_fields=} {meta_fields=}.")
|
||||
if not dataclasses.is_dataclass(nodetype):
|
||||
raise TypeError("register_dataclass: data_fields and meta_fields are required when"
|
||||
f" nodetype is not a dataclass. Got {nodetype=}.")
|
||||
data_fields = [f.name for f in dataclasses.fields(nodetype)
|
||||
if not f.metadata.get('static', False)]
|
||||
meta_fields = [f.name for f in dataclasses.fields(nodetype)
|
||||
if f.metadata.get('static', False)]
|
||||
|
||||
assert meta_fields is not None
|
||||
assert data_fields is not None
|
||||
|
||||
# Store inputs as immutable tuples in this scope, because we close over them
|
||||
# for later evaluation. This prevents potentially confusing behavior if the
|
||||
# caller were to pass in lists that are later mutated.
|
||||
|
@ -17,7 +17,6 @@ import dataclasses
|
||||
import functools
|
||||
import pickle
|
||||
import re
|
||||
from typing import TypeVar
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
@ -142,27 +141,6 @@ class FlatCache:
|
||||
data, meta = tree_util.tree_flatten(tree_util.tree_unflatten(meta, data))
|
||||
return FlatCache(None, leaves=data, treedef=meta)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
# Inspired by Flax.
|
||||
def pytree_node_dataclass(clz: _T, **kwargs) -> _T:
|
||||
data_clz = dataclasses.dataclass(**kwargs)(clz) # type: ignore
|
||||
meta_fields = []
|
||||
data_fields = []
|
||||
for field_info in dataclasses.fields(data_clz):
|
||||
is_pytree_node = field_info.metadata.get("pytree_node", True)
|
||||
if is_pytree_node:
|
||||
data_fields.append(field_info.name)
|
||||
else:
|
||||
meta_fields.append(field_info.name)
|
||||
|
||||
jax.tree_util.register_dataclass(
|
||||
data_clz, data_fields, meta_fields
|
||||
)
|
||||
|
||||
return data_clz
|
||||
|
||||
|
||||
@tree_util.register_static
|
||||
class StaticInt(int):
|
||||
@ -231,16 +209,18 @@ TREE_STRINGS = (
|
||||
"PyTreeDef(CustomNode(StaticDict[{'foo': 4, 'bar': 5}], []))",
|
||||
)
|
||||
|
||||
@pytree_node_dataclass
|
||||
@jax.tree_util.register_dataclass
|
||||
@dataclasses.dataclass
|
||||
class ADataclass:
|
||||
x: tuple[int, int]
|
||||
y: int
|
||||
|
||||
@pytree_node_dataclass
|
||||
@jax.tree_util.register_dataclass
|
||||
@dataclasses.dataclass
|
||||
class ADataclassWithMeta:
|
||||
x: tuple[int, int]
|
||||
y: int
|
||||
z: int = dataclasses.field(metadata={"pytree_node": False})
|
||||
z: int = dataclasses.field(metadata={"static": True})
|
||||
|
||||
TREES += (
|
||||
(ADataclass(x=(1, 2), y=3),),
|
||||
@ -1294,6 +1274,36 @@ class TreeAliasTest(jtu.JaxTestCase):
|
||||
|
||||
class RegistrationTest(jtu.JaxTestCase):
|
||||
|
||||
def test_register_dataclass_with_field_specifier(self):
|
||||
@tree_util.register_dataclass
|
||||
@dataclasses.dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
y: int = dataclasses.field(metadata=dict(static=True))
|
||||
|
||||
f = Foo(2, 3)
|
||||
self.assertLen(jax.tree.leaves(f), 1)
|
||||
|
||||
def test_register_dataclass_field_errors(self):
|
||||
class Foo: # not a dataclass
|
||||
x: int
|
||||
y: int
|
||||
|
||||
msg = ("register_dataclass: data_fields and meta_fields are required"
|
||||
" when nodetype is not a dataclass. Got nodetype=<class '.*Foo'>")
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
tree_util.register_dataclass(Foo)
|
||||
|
||||
msg = ("register_dataclass: data_fields and meta_fields must both be specified"\
|
||||
r" when either is specified. Got data_fields=\['x'\] meta_fields=None.")
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
tree_util.register_dataclass(Foo, data_fields=['x'])
|
||||
|
||||
msg = ("register_dataclass: data_fields and meta_fields must both be specified"\
|
||||
r" when either is specified. Got data_fields=None meta_fields=\['y'\].")
|
||||
with self.assertRaisesRegex(TypeError, msg):
|
||||
tree_util.register_dataclass(Foo, meta_fields=['y'])
|
||||
|
||||
def test_register_dataclass_missing_fields(self):
|
||||
@dataclasses.dataclass
|
||||
class Foo:
|
||||
|
Loading…
x
Reference in New Issue
Block a user