mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
`jax.tree_util.register_dataclass
now validates
data_fields
and
meta_fields
`
A well-behaved registration call must list all ``init=True`` fields in either ``data_fields`` or ``meta_fields``. Otherwise, ``flatten . unflatten`` could potentially *not* be an identity PiperOrigin-RevId: 669244669
This commit is contained in:
parent
6a1adc842b
commit
02bb884357
@ -24,6 +24,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
`jax.host_ids()` function that was deprecated in JAX v0.2.13.
|
||||
* To align with the behavior of `numpy.fabs`, `jax.numpy.fabs` has been
|
||||
modified to no longer support `complex dtypes`.
|
||||
* ``jax.tree_util.register_dataclass`` now checks that ``data_fields``
|
||||
and ``meta_fields`` includes all dataclass fields with ``init=True``
|
||||
and only them, if ``nodetype`` is a dataclass.
|
||||
|
||||
* Breaking changes
|
||||
* The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the
|
||||
|
@ -15,6 +15,7 @@ from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from collections.abc import Callable, Hashable, Iterable, Sequence
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
import difflib
|
||||
import functools
|
||||
@ -925,7 +926,10 @@ def register_pytree_with_keys_class(cls: Typ) -> Typ:
|
||||
|
||||
@export
|
||||
def register_dataclass(
|
||||
nodetype: Typ, data_fields: Sequence[str], meta_fields: Sequence[str]
|
||||
nodetype: Typ,
|
||||
data_fields: Sequence[str],
|
||||
meta_fields: Sequence[str],
|
||||
drop_fields: Sequence[str] = (),
|
||||
) -> Typ:
|
||||
"""Extends the set of types that are considered internal nodes in pytrees.
|
||||
|
||||
@ -941,16 +945,14 @@ def register_dataclass(
|
||||
attributes represent the whole of the object state, and can be passed
|
||||
as keywords to the class constructor to create a copy of the object.
|
||||
All defined attributes should be listed among ``meta_fields`` or ``data_fields``.
|
||||
meta_fields: auxiliary data field names. These fields *must* contain static,
|
||||
hashable, immutable objects, as these objects are used to generate JIT cache
|
||||
keys. In particular, ``meta_fields`` cannot contain :class:`jax.Array` or
|
||||
:class:`numpy.ndarray` objects.
|
||||
data_fields: data field names. These fields *must* be JAX-compatible objects
|
||||
such as arrays (:class:`jax.Array` or :class:`numpy.ndarray`), scalars, or
|
||||
pytrees whose leaves are arrays or scalars. Note that ``None`` is valid, as
|
||||
this is recognized by JAX as an empty pytree.
|
||||
meta_fields: auxiliary data field names. These fields will be considered static
|
||||
within JAX transformations such as :func:`jax.jit`. The listed fields *must*
|
||||
contain static, hashable, immutable objects, as these objects are used to
|
||||
generate JIT cache keys: for example strings, Python scalars, or array shapes
|
||||
and dtypes. In particular, ``meta_fields`` cannot contain :class:`jax.Array`
|
||||
or :class:`numpy.ndarray` objects, as they are not hashable.
|
||||
pytrees whose leaves are arrays or scalars. Note that ``data_fields`` may be
|
||||
``None``, as this is recognized by JAX as an empty pytree.
|
||||
|
||||
Returns:
|
||||
The input class ``nodetype`` is returned unchanged after being added to JAX's
|
||||
@ -1003,6 +1005,23 @@ def register_dataclass(
|
||||
meta_fields = tuple(meta_fields)
|
||||
data_fields = tuple(data_fields)
|
||||
|
||||
if dataclasses.is_dataclass(nodetype):
|
||||
init_fields = {f.name for f in dataclasses.fields(nodetype) if f.init}
|
||||
init_fields.difference_update(*drop_fields)
|
||||
if {*meta_fields, *data_fields} != init_fields:
|
||||
msg = (
|
||||
"data_fields and meta_fields must include all dataclass fields with"
|
||||
" ``init=True`` and only them."
|
||||
)
|
||||
if missing := init_fields - {*meta_fields, *data_fields}:
|
||||
msg += (
|
||||
f" Missing fields: {missing}. Add them to drop_fields to suppress"
|
||||
" this error."
|
||||
)
|
||||
if unexpected := {*meta_fields, *data_fields} - init_fields:
|
||||
msg += f" Unexpected fields: {unexpected}."
|
||||
raise ValueError(msg)
|
||||
|
||||
def flatten_with_keys(x):
|
||||
meta = tuple(getattr(x, name) for name in meta_fields)
|
||||
data = tuple((GetAttrKey(name), getattr(x, name)) for name in data_fields)
|
||||
|
@ -1257,5 +1257,63 @@ class TreeAliasTest(jtu.JaxTestCase):
|
||||
)
|
||||
|
||||
|
||||
class RegistrationTest(jtu.JaxTestCase):
|
||||
|
||||
def test_register_dataclass_missing_fields(self):
|
||||
@dataclasses.dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
y: int
|
||||
z: float = dataclasses.field(init=False)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"data_fields and meta_fields must include all dataclass fields.*"
|
||||
"Missing fields: {'y'}",
|
||||
):
|
||||
tree_util.register_dataclass(Foo, data_fields=["x"], meta_fields=[])
|
||||
|
||||
# ``z`` is not required, because it's not included in ``__init__``.
|
||||
tree_util.register_dataclass(Foo, data_fields=["x"], meta_fields=["y"])
|
||||
|
||||
def test_register_dataclass_unexpected_fields(self):
|
||||
@dataclasses.dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
y: float
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
"data_fields and meta_fields must include all dataclass fields.*"
|
||||
"Unexpected fields: {'z'}",
|
||||
):
|
||||
tree_util.register_dataclass(
|
||||
Foo, data_fields=["x"], meta_fields=["y", "z"]
|
||||
)
|
||||
|
||||
def test_register_dataclass_drop_fields(self):
|
||||
@dataclasses.dataclass
|
||||
class Foo:
|
||||
x: int
|
||||
y: int = dataclasses.field(default=42)
|
||||
|
||||
# ``y`` is explicitly excluded.
|
||||
tree_util.register_dataclass(
|
||||
Foo, data_fields=["x"], meta_fields=[], drop_fields=["y"]
|
||||
)
|
||||
|
||||
def test_register_dataclass_invalid_plain_class(self):
|
||||
class Foo:
|
||||
x: int
|
||||
y: int
|
||||
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
# ``y`` is missing, but no validation is done for plain classes.
|
||||
tree_util.register_dataclass(Foo, data_fields=["x"], meta_fields=[])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user