`jax.tree_util.register_dataclass now validates data_fields and meta_fields`

A well-behaved registration call must list all ``init=True`` fields in either ``data_fields`` or ``meta_fields``. Otherwise, ``flatten . unflatten`` could potentially *not* be an identity

PiperOrigin-RevId: 669244669
This commit is contained in:
Sergei Lebedev 2024-08-30 02:01:02 -07:00 committed by jax authors
parent 6a1adc842b
commit 02bb884357
3 changed files with 89 additions and 9 deletions

View File

@ -24,6 +24,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
`jax.host_ids()` function that was deprecated in JAX v0.2.13.
* To align with the behavior of `numpy.fabs`, `jax.numpy.fabs` has been
modified to no longer support `complex dtypes`.
* ``jax.tree_util.register_dataclass`` now checks that ``data_fields``
and ``meta_fields`` includes all dataclass fields with ``init=True``
and only them, if ``nodetype`` is a dataclass.
* Breaking changes
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the

View File

@ -15,6 +15,7 @@ from __future__ import annotations
import collections
from collections.abc import Callable, Hashable, Iterable, Sequence
import dataclasses
from dataclasses import dataclass
import difflib
import functools
@ -925,7 +926,10 @@ def register_pytree_with_keys_class(cls: Typ) -> Typ:
@export
def register_dataclass(
nodetype: Typ, data_fields: Sequence[str], meta_fields: Sequence[str]
nodetype: Typ,
data_fields: Sequence[str],
meta_fields: Sequence[str],
drop_fields: Sequence[str] = (),
) -> Typ:
"""Extends the set of types that are considered internal nodes in pytrees.
@ -941,16 +945,14 @@ def register_dataclass(
attributes represent the whole of the object state, and can be passed
as keywords to the class constructor to create a copy of the object.
All defined attributes should be listed among ``meta_fields`` or ``data_fields``.
meta_fields: auxiliary data field names. These fields *must* contain static,
hashable, immutable objects, as these objects are used to generate JIT cache
keys. In particular, ``meta_fields`` cannot contain :class:`jax.Array` or
:class:`numpy.ndarray` objects.
data_fields: data field names. These fields *must* be JAX-compatible objects
such as arrays (:class:`jax.Array` or :class:`numpy.ndarray`), scalars, or
pytrees whose leaves are arrays or scalars. Note that ``None`` is valid, as
this is recognized by JAX as an empty pytree.
meta_fields: auxiliary data field names. These fields will be considered static
within JAX transformations such as :func:`jax.jit`. The listed fields *must*
contain static, hashable, immutable objects, as these objects are used to
generate JIT cache keys: for example strings, Python scalars, or array shapes
and dtypes. In particular, ``meta_fields`` cannot contain :class:`jax.Array`
or :class:`numpy.ndarray` objects, as they are not hashable.
pytrees whose leaves are arrays or scalars. Note that ``data_fields`` may be
``None``, as this is recognized by JAX as an empty pytree.
Returns:
The input class ``nodetype`` is returned unchanged after being added to JAX's
@ -1003,6 +1005,23 @@ def register_dataclass(
meta_fields = tuple(meta_fields)
data_fields = tuple(data_fields)
if dataclasses.is_dataclass(nodetype):
init_fields = {f.name for f in dataclasses.fields(nodetype) if f.init}
init_fields.difference_update(*drop_fields)
if {*meta_fields, *data_fields} != init_fields:
msg = (
"data_fields and meta_fields must include all dataclass fields with"
" ``init=True`` and only them."
)
if missing := init_fields - {*meta_fields, *data_fields}:
msg += (
f" Missing fields: {missing}. Add them to drop_fields to suppress"
" this error."
)
if unexpected := {*meta_fields, *data_fields} - init_fields:
msg += f" Unexpected fields: {unexpected}."
raise ValueError(msg)
def flatten_with_keys(x):
meta = tuple(getattr(x, name) for name in meta_fields)
data = tuple((GetAttrKey(name), getattr(x, name)) for name in data_fields)

View File

@ -1257,5 +1257,63 @@ class TreeAliasTest(jtu.JaxTestCase):
)
class RegistrationTest(jtu.JaxTestCase):
def test_register_dataclass_missing_fields(self):
@dataclasses.dataclass
class Foo:
x: int
y: int
z: float = dataclasses.field(init=False)
with self.assertRaisesRegex(
ValueError,
"data_fields and meta_fields must include all dataclass fields.*"
"Missing fields: {'y'}",
):
tree_util.register_dataclass(Foo, data_fields=["x"], meta_fields=[])
# ``z`` is not required, because it's not included in ``__init__``.
tree_util.register_dataclass(Foo, data_fields=["x"], meta_fields=["y"])
def test_register_dataclass_unexpected_fields(self):
@dataclasses.dataclass
class Foo:
x: int
y: float
with self.assertRaisesRegex(
ValueError,
"data_fields and meta_fields must include all dataclass fields.*"
"Unexpected fields: {'z'}",
):
tree_util.register_dataclass(
Foo, data_fields=["x"], meta_fields=["y", "z"]
)
def test_register_dataclass_drop_fields(self):
@dataclasses.dataclass
class Foo:
x: int
y: int = dataclasses.field(default=42)
# ``y`` is explicitly excluded.
tree_util.register_dataclass(
Foo, data_fields=["x"], meta_fields=[], drop_fields=["y"]
)
def test_register_dataclass_invalid_plain_class(self):
class Foo:
x: int
y: int
def __init__(self, x, y):
self.x = x
self.y = y
# ``y`` is missing, but no validation is done for plain classes.
tree_util.register_dataclass(Foo, data_fields=["x"], meta_fields=[])
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())