Remove obsolete python key path registry.

PiperOrigin-RevId: 747613761
This commit is contained in:
Ivy Zheng 2025-04-14 16:32:24 -07:00 committed by jax authors
parent a88486ca70
commit ab600c3e82

View File

@ -21,7 +21,7 @@ import functools
from functools import partial
import operator as op
import textwrap
from typing import Any, NamedTuple, TypeVar, overload
from typing import Any, TypeVar, overload
from jax._src import traceback_util
from jax._src.lib import pytree
@ -762,42 +762,6 @@ def _simple_entrystr(key: KeyEntry) -> str:
return str(key)
# TODO(ivyzheng): remove this after another jaxlib release.
class _RegistryWithKeypathsEntry(NamedTuple):
flatten_with_keys: Callable[..., Any]
unflatten_func: Callable[..., Any]
def _register_keypaths(
ty: type[T], handler: Callable[[T], tuple[KeyEntry, ...]]
) -> None:
def flatten_with_keys(xs):
children, treedef = _registry[ty].to_iter(xs)
return list(zip(handler(xs), children)), treedef
if ty in _registry:
_registry_with_keypaths[ty] = _RegistryWithKeypathsEntry(
flatten_with_keys, _registry[ty].from_iter
)
_registry_with_keypaths: dict[type[Any], _RegistryWithKeypathsEntry] = {}
_register_keypaths(
tuple, lambda xs: tuple(SequenceKey(i) for i in range(len(xs)))
)
_register_keypaths(
list, lambda xs: tuple(SequenceKey(i) for i in range(len(xs)))
)
_register_keypaths(dict, lambda xs: tuple(DictKey(k) for k in sorted(xs)))
_register_keypaths(
collections.defaultdict, lambda x: tuple(DictKey(k) for k in x.keys())
)
_register_keypaths(
collections.OrderedDict, lambda x: tuple(DictKey(k) for k in x.keys())
)
@export
def register_pytree_with_keys(
nodetype: type[T],
@ -867,9 +831,6 @@ def register_pytree_with_keys(
register_pytree_node(
nodetype, flatten_func, unflatten_func, flatten_with_keys
)
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
flatten_with_keys, unflatten_func
)
@export
@ -1062,11 +1023,6 @@ def register_dataclass(
msg += f" Unexpected fields: {unexpected}."
raise ValueError(msg)
def flatten_with_keys(x):
meta = tuple(getattr(x, name) for name in meta_fields)
data = tuple((GetAttrKey(name), getattr(x, name)) for name in data_fields)
return data, meta
def unflatten_func(meta, data):
meta_args = tuple(zip(meta_fields, meta))
data_args = tuple(zip(data_fields, data))
@ -1082,9 +1038,6 @@ def register_dataclass(
none_leaf_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
dispatch_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
flatten_with_keys, unflatten_func
)
return nodetype