mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[XLA:Python] Add a C++ implementation of flatten_one_level.
Also add a copy of the default registry that doesn't have None registered as a leaf, which is slightly faster than using an is_leaf function. This is mostly just doing an old TODO. PiperOrigin-RevId: 617988496
This commit is contained in:
parent
05e61ed07d
commit
5532e5505b
@ -363,11 +363,11 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
|
||||
user_specified_in_shardings = (in_shardings is not None and
|
||||
not is_unspecified(in_shardings))
|
||||
is_none = lambda x: x is None
|
||||
in_shardings_leaves, in_shardings_treedef = tree_flatten(
|
||||
in_shardings, is_leaf=is_none)
|
||||
out_shardings_leaves, out_shardings_treedef = tree_flatten(
|
||||
out_shardings, is_leaf=is_none)
|
||||
none_leaf_registry = tree_util.none_leaf_registry
|
||||
in_shardings_leaves, in_shardings_treedef = none_leaf_registry.flatten(
|
||||
in_shardings)
|
||||
out_shardings_leaves, out_shardings_treedef = none_leaf_registry.flatten(
|
||||
out_shardings)
|
||||
|
||||
fun_sourceinfo = api_util.fun_sourceinfo(fun)
|
||||
fun_signature = api_util.fun_signature(fun)
|
||||
|
@ -25,6 +25,7 @@ from typing import Any, Callable, NamedTuple, TypeVar, Union, overload
|
||||
|
||||
from jax._src import traceback_util
|
||||
from jax._src.lib import pytree
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.util import unzip2
|
||||
|
||||
@ -44,6 +45,13 @@ default_registry = pytree.default_registry()
|
||||
default_registry.__module__ = __name__
|
||||
default_registry.__name__ = "default_registry"
|
||||
|
||||
# A copy of the default registry, where None is a leaf.
|
||||
none_leaf_registry = pytree.PyTreeRegistry(
|
||||
enable_none=False, enable_tuple=True, enable_namedtuple=True,
|
||||
enable_list=True, enable_dict=True)
|
||||
none_leaf_registry.__module__ = __name__
|
||||
none_leaf_registry.__name__ = "none_leaf_registry"
|
||||
|
||||
# A special, internal pytree registry that includes everything in
|
||||
# `default_registry`, plus internal Python-defined types that we want
|
||||
# to teach the fast dispatch path ("C++ dispatch") how to flatten and
|
||||
@ -242,6 +250,7 @@ def register_pytree_node(nodetype: type[T],
|
||||
``nodetype``.
|
||||
"""
|
||||
default_registry.register_node(nodetype, flatten_func, unflatten_func)
|
||||
none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func)
|
||||
dispatch_registry.register_node(nodetype, flatten_func, unflatten_func)
|
||||
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
|
||||
|
||||
@ -374,21 +383,9 @@ _registry = {
|
||||
|
||||
def _replace_nones(sentinel, tree):
|
||||
"""Replaces ``None`` in ``tree`` with ``sentinel``."""
|
||||
if tree is None:
|
||||
return sentinel
|
||||
else:
|
||||
handler = _registry.get(type(tree))
|
||||
if handler:
|
||||
children, metadata = handler.to_iter(tree)
|
||||
proc_children = [_replace_nones(sentinel, child) for child in children]
|
||||
return handler.from_iter(metadata, proc_children)
|
||||
elif isinstance(tree, tuple) and hasattr(tree, "_fields"):
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
children = iter(tree)
|
||||
proc_children = [_replace_nones(sentinel, child) for child in children]
|
||||
return type(tree)(*proc_children)
|
||||
else:
|
||||
return tree
|
||||
leaves, treedef = none_leaf_registry.flatten(tree)
|
||||
leaves = map(lambda x: sentinel if x is None else x, leaves)
|
||||
return treedef.unflatten(leaves)
|
||||
|
||||
|
||||
no_initializer = object()
|
||||
@ -586,29 +583,50 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any,
|
||||
tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
|
||||
return result
|
||||
|
||||
def flatten_one_level(pytree: Any) -> tuple[list[Any], Hashable]:
|
||||
"""Flatten the given pytree node by one level.
|
||||
if xla_extension_version >= 248:
|
||||
def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
|
||||
"""Flatten the given pytree node by one level.
|
||||
|
||||
Args:
|
||||
pytree: A valid pytree node, either built-in or registered via
|
||||
``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
Args:
|
||||
pytree: A valid pytree node, either built-in or registered via
|
||||
``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
|
||||
Returns:
|
||||
A pair of the pytree's flattened children and its hashable metadata.
|
||||
Returns:
|
||||
A pair of the pytree's flattened children and its hashable metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If the given pytree is not a built-in or registered container
|
||||
via ``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
"""
|
||||
handler = _registry.get(type(pytree))
|
||||
if handler:
|
||||
children, meta = handler.to_iter(pytree)
|
||||
return list(children), meta
|
||||
elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'):
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
return [getattr(pytree, s) for s in pytree._fields], None
|
||||
else:
|
||||
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
|
||||
Raises:
|
||||
ValueError: If the given pytree is not a built-in or registered container
|
||||
via ``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
"""
|
||||
out = default_registry.flatten_one_level(pytree)
|
||||
if out is None:
|
||||
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
|
||||
else:
|
||||
return out
|
||||
else:
|
||||
def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
|
||||
"""Flatten the given pytree node by one level.
|
||||
|
||||
Args:
|
||||
pytree: A valid pytree node, either built-in or registered via
|
||||
``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
|
||||
Returns:
|
||||
A pair of the pytree's flattened children and its hashable metadata.
|
||||
|
||||
Raises:
|
||||
ValueError: If the given pytree is not a built-in or registered container
|
||||
via ``register_pytree_node`` or ``register_pytree_with_keys``.
|
||||
"""
|
||||
handler = _registry.get(type(pytree))
|
||||
if handler:
|
||||
children, meta = handler.to_iter(pytree)
|
||||
return list(children), meta
|
||||
elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'):
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
return [getattr(pytree, s) for s in pytree._fields], None
|
||||
else:
|
||||
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
|
||||
|
||||
def prefix_errors(prefix_tree: Any, full_tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None,
|
||||
@ -659,6 +677,8 @@ def _equality_errors(path, t1, t2, is_leaf):
|
||||
return # no more errors to find
|
||||
t1_children, t1_meta = flatten_one_level(t1)
|
||||
t2_children, t2_meta = flatten_one_level(t2)
|
||||
t1_children = tuple(t1_children)
|
||||
t2_children = tuple(t2_children)
|
||||
t1_keys, t2_keys = _child_keys(t1), _child_keys(t2)
|
||||
try:
|
||||
diff = ' '.join(repr(k.key) for k in
|
||||
@ -905,32 +925,64 @@ _generate_key_paths = generate_key_paths # alias for backward compat
|
||||
|
||||
|
||||
# The overall logic should be same as PyTreeDef::FlattenIntoImpl
|
||||
def _generate_key_paths_(
|
||||
key_path: KeyPath,
|
||||
tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None,
|
||||
) -> Iterable[tuple[KeyPath, Any]]:
|
||||
if is_leaf and is_leaf(tree):
|
||||
yield key_path, tree
|
||||
return
|
||||
key_handler = _registry_with_keypaths.get(type(tree))
|
||||
handler = _registry.get(type(tree))
|
||||
if key_handler:
|
||||
key_children, _ = key_handler.flatten_with_keys(tree)
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
elif handler:
|
||||
children, _ = handler.to_iter(tree)
|
||||
for i, c in enumerate(children):
|
||||
if xla_extension_version >= 248:
|
||||
def _generate_key_paths_(
|
||||
key_path: KeyPath,
|
||||
tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None,
|
||||
) -> Iterable[tuple[KeyPath, Any]]:
|
||||
if is_leaf and is_leaf(tree):
|
||||
yield key_path, tree
|
||||
return
|
||||
key_handler = _registry_with_keypaths.get(type(tree))
|
||||
if key_handler:
|
||||
key_children, _ = key_handler.flatten_with_keys(tree)
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
return
|
||||
|
||||
flat = default_registry.flatten_one_level(tree)
|
||||
if flat is None:
|
||||
yield key_path, tree # strict leaf type
|
||||
return
|
||||
|
||||
if (isinstance(tree, tuple) and hasattr(tree, '_fields') and
|
||||
flat[1] == type(tree)):
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
return
|
||||
|
||||
for i, c in enumerate(flat[0]):
|
||||
k = FlattenedIndexKey(i)
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
else:
|
||||
yield key_path, tree # strict leaf type
|
||||
else:
|
||||
def _generate_key_paths_(
|
||||
key_path: KeyPath,
|
||||
tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None,
|
||||
) -> Iterable[tuple[KeyPath, Any]]:
|
||||
if is_leaf and is_leaf(tree):
|
||||
yield key_path, tree
|
||||
return
|
||||
key_handler = _registry_with_keypaths.get(type(tree))
|
||||
if key_handler:
|
||||
key_children, _ = key_handler.flatten_with_keys(tree)
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
elif handler := _registry.get(type(tree)):
|
||||
children, _ = handler.to_iter(tree)
|
||||
for i, c in enumerate(children):
|
||||
k = FlattenedIndexKey(i)
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
|
||||
# handle namedtuple as a special case, based on heuristic
|
||||
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
|
||||
for k, c in key_children:
|
||||
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
|
||||
else:
|
||||
yield key_path, tree # strict leaf type
|
||||
|
||||
|
||||
def tree_map_with_path(f: Callable[..., Any],
|
||||
@ -1001,6 +1053,8 @@ def _prefix_error(
|
||||
# point, and because prefix_tree is not a leaf, each can be flattened once:
|
||||
prefix_tree_children, prefix_tree_meta = flatten_one_level(prefix_tree)
|
||||
full_tree_children, full_tree_meta = flatten_one_level(full_tree)
|
||||
prefix_tree_children = tuple(prefix_tree_children)
|
||||
full_tree_children = tuple(full_tree_children)
|
||||
prefix_tree_keys = _child_keys(prefix_tree)
|
||||
full_tree_keys = _child_keys(full_tree)
|
||||
# First we check special case types (list and tuple, though if they were
|
||||
|
Loading…
x
Reference in New Issue
Block a user