mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove obsolete python key path registry.
PiperOrigin-RevId: 747613761
This commit is contained in:
parent
a88486ca70
commit
ab600c3e82
@ -21,7 +21,7 @@ import functools
|
||||
from functools import partial
|
||||
import operator as op
|
||||
import textwrap
|
||||
from typing import Any, NamedTuple, TypeVar, overload
|
||||
from typing import Any, TypeVar, overload
|
||||
|
||||
from jax._src import traceback_util
|
||||
from jax._src.lib import pytree
|
||||
@ -762,42 +762,6 @@ def _simple_entrystr(key: KeyEntry) -> str:
|
||||
return str(key)
|
||||
|
||||
|
||||
# TODO(ivyzheng): remove this after another jaxlib release.
|
||||
class _RegistryWithKeypathsEntry(NamedTuple):
|
||||
flatten_with_keys: Callable[..., Any]
|
||||
unflatten_func: Callable[..., Any]
|
||||
|
||||
|
||||
def _register_keypaths(
|
||||
ty: type[T], handler: Callable[[T], tuple[KeyEntry, ...]]
|
||||
) -> None:
|
||||
def flatten_with_keys(xs):
|
||||
children, treedef = _registry[ty].to_iter(xs)
|
||||
return list(zip(handler(xs), children)), treedef
|
||||
if ty in _registry:
|
||||
_registry_with_keypaths[ty] = _RegistryWithKeypathsEntry(
|
||||
flatten_with_keys, _registry[ty].from_iter
|
||||
)
|
||||
|
||||
_registry_with_keypaths: dict[type[Any], _RegistryWithKeypathsEntry] = {}
|
||||
|
||||
_register_keypaths(
|
||||
tuple, lambda xs: tuple(SequenceKey(i) for i in range(len(xs)))
|
||||
)
|
||||
_register_keypaths(
|
||||
list, lambda xs: tuple(SequenceKey(i) for i in range(len(xs)))
|
||||
)
|
||||
_register_keypaths(dict, lambda xs: tuple(DictKey(k) for k in sorted(xs)))
|
||||
|
||||
_register_keypaths(
|
||||
collections.defaultdict, lambda x: tuple(DictKey(k) for k in x.keys())
|
||||
)
|
||||
|
||||
_register_keypaths(
|
||||
collections.OrderedDict, lambda x: tuple(DictKey(k) for k in x.keys())
|
||||
)
|
||||
|
||||
|
||||
@export
|
||||
def register_pytree_with_keys(
|
||||
nodetype: type[T],
|
||||
@ -867,9 +831,6 @@ def register_pytree_with_keys(
|
||||
register_pytree_node(
|
||||
nodetype, flatten_func, unflatten_func, flatten_with_keys
|
||||
)
|
||||
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
|
||||
flatten_with_keys, unflatten_func
|
||||
)
|
||||
|
||||
|
||||
@export
|
||||
@ -1062,11 +1023,6 @@ def register_dataclass(
|
||||
msg += f" Unexpected fields: {unexpected}."
|
||||
raise ValueError(msg)
|
||||
|
||||
def flatten_with_keys(x):
|
||||
meta = tuple(getattr(x, name) for name in meta_fields)
|
||||
data = tuple((GetAttrKey(name), getattr(x, name)) for name in data_fields)
|
||||
return data, meta
|
||||
|
||||
def unflatten_func(meta, data):
|
||||
meta_args = tuple(zip(meta_fields, meta))
|
||||
data_args = tuple(zip(data_fields, data))
|
||||
@ -1082,9 +1038,6 @@ def register_dataclass(
|
||||
none_leaf_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
|
||||
dispatch_registry.register_dataclass_node(nodetype, list(data_fields), list(meta_fields))
|
||||
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
|
||||
_registry_with_keypaths[nodetype] = _RegistryWithKeypathsEntry(
|
||||
flatten_with_keys, unflatten_func
|
||||
)
|
||||
return nodetype
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user