[XLA:Python] Add a C++ implementation of flatten_one_level.

Also add a copy of the default registry that doesn't have None registered as a leaf, which is slightly faster than using an is_leaf function.

This is mostly just doing an old TODO.

PiperOrigin-RevId: 617988496
This commit is contained in:
Peter Hawkins 2024-03-21 15:56:42 -07:00 committed by jax authors
parent 05e61ed07d
commit 5532e5505b
2 changed files with 118 additions and 64 deletions

View File

@ -363,11 +363,11 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
user_specified_in_shardings = (in_shardings is not None and
not is_unspecified(in_shardings))
is_none = lambda x: x is None
in_shardings_leaves, in_shardings_treedef = tree_flatten(
in_shardings, is_leaf=is_none)
out_shardings_leaves, out_shardings_treedef = tree_flatten(
out_shardings, is_leaf=is_none)
none_leaf_registry = tree_util.none_leaf_registry
in_shardings_leaves, in_shardings_treedef = none_leaf_registry.flatten(
in_shardings)
out_shardings_leaves, out_shardings_treedef = none_leaf_registry.flatten(
out_shardings)
fun_sourceinfo = api_util.fun_sourceinfo(fun)
fun_signature = api_util.fun_signature(fun)

View File

@ -25,6 +25,7 @@ from typing import Any, Callable, NamedTuple, TypeVar, Union, overload
from jax._src import traceback_util
from jax._src.lib import pytree
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip
from jax._src.util import unzip2
@ -44,6 +45,13 @@ default_registry = pytree.default_registry()
default_registry.__module__ = __name__
default_registry.__name__ = "default_registry"
# A copy of the default registry, where None is a leaf.
none_leaf_registry = pytree.PyTreeRegistry(
enable_none=False, enable_tuple=True, enable_namedtuple=True,
enable_list=True, enable_dict=True)
none_leaf_registry.__module__ = __name__
none_leaf_registry.__name__ = "none_leaf_registry"
# A special, internal pytree registry that includes everything in
# `default_registry`, plus internal Python-defined types that we want
# to teach the fast dispatch path ("C++ dispatch") how to flatten and
@ -242,6 +250,7 @@ def register_pytree_node(nodetype: type[T],
``nodetype``.
"""
default_registry.register_node(nodetype, flatten_func, unflatten_func)
none_leaf_registry.register_node(nodetype, flatten_func, unflatten_func)
dispatch_registry.register_node(nodetype, flatten_func, unflatten_func)
_registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
@ -374,21 +383,9 @@ _registry = {
def _replace_nones(sentinel, tree):
"""Replaces ``None`` in ``tree`` with ``sentinel``."""
if tree is None:
return sentinel
else:
handler = _registry.get(type(tree))
if handler:
children, metadata = handler.to_iter(tree)
proc_children = [_replace_nones(sentinel, child) for child in children]
return handler.from_iter(metadata, proc_children)
elif isinstance(tree, tuple) and hasattr(tree, "_fields"):
# handle namedtuple as a special case, based on heuristic
children = iter(tree)
proc_children = [_replace_nones(sentinel, child) for child in children]
return type(tree)(*proc_children)
else:
return tree
leaves, treedef = none_leaf_registry.flatten(tree)
leaves = map(lambda x: sentinel if x is None else x, leaves)
return treedef.unflatten(leaves)
no_initializer = object()
@ -586,29 +583,50 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any,
tree_map(add_leaves, prefix_tree, full_tree, is_leaf=is_leaf)
return result
def flatten_one_level(pytree: Any) -> tuple[list[Any], Hashable]:
"""Flatten the given pytree node by one level.
if xla_extension_version >= 248:
def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
"""Flatten the given pytree node by one level.
Args:
pytree: A valid pytree node, either built-in or registered via
``register_pytree_node`` or ``register_pytree_with_keys``.
Args:
pytree: A valid pytree node, either built-in or registered via
``register_pytree_node`` or ``register_pytree_with_keys``.
Returns:
A pair of the pytree's flattened children and its hashable metadata.
Returns:
A pair of the pytree's flattened children and its hashable metadata.
Raises:
ValueError: If the given pytree is not a built-in or registered container
via ``register_pytree_node`` or ``register_pytree_with_keys``.
"""
handler = _registry.get(type(pytree))
if handler:
children, meta = handler.to_iter(pytree)
return list(children), meta
elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'):
# handle namedtuple as a special case, based on heuristic
return [getattr(pytree, s) for s in pytree._fields], None
else:
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
Raises:
ValueError: If the given pytree is not a built-in or registered container
via ``register_pytree_node`` or ``register_pytree_with_keys``.
"""
out = default_registry.flatten_one_level(pytree)
if out is None:
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
else:
return out
else:
def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
"""Flatten the given pytree node by one level.
Args:
pytree: A valid pytree node, either built-in or registered via
``register_pytree_node`` or ``register_pytree_with_keys``.
Returns:
A pair of the pytree's flattened children and its hashable metadata.
Raises:
ValueError: If the given pytree is not a built-in or registered container
via ``register_pytree_node`` or ``register_pytree_with_keys``.
"""
handler = _registry.get(type(pytree))
if handler:
children, meta = handler.to_iter(pytree)
return list(children), meta
elif isinstance(pytree, tuple) and hasattr(pytree, '_fields'):
# handle namedtuple as a special case, based on heuristic
return [getattr(pytree, s) for s in pytree._fields], None
else:
raise ValueError(f"can't tree-flatten type: {type(pytree)}")
def prefix_errors(prefix_tree: Any, full_tree: Any,
is_leaf: Callable[[Any], bool] | None = None,
@ -659,6 +677,8 @@ def _equality_errors(path, t1, t2, is_leaf):
return # no more errors to find
t1_children, t1_meta = flatten_one_level(t1)
t2_children, t2_meta = flatten_one_level(t2)
t1_children = tuple(t1_children)
t2_children = tuple(t2_children)
t1_keys, t2_keys = _child_keys(t1), _child_keys(t2)
try:
diff = ' '.join(repr(k.key) for k in
@ -905,32 +925,64 @@ _generate_key_paths = generate_key_paths # alias for backward compat
# The overall logic should be same as PyTreeDef::FlattenIntoImpl
def _generate_key_paths_(
key_path: KeyPath,
tree: Any,
is_leaf: Callable[[Any], bool] | None = None,
) -> Iterable[tuple[KeyPath, Any]]:
if is_leaf and is_leaf(tree):
yield key_path, tree
return
key_handler = _registry_with_keypaths.get(type(tree))
handler = _registry.get(type(tree))
if key_handler:
key_children, _ = key_handler.flatten_with_keys(tree)
for k, c in key_children:
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
elif handler:
children, _ = handler.to_iter(tree)
for i, c in enumerate(children):
if xla_extension_version >= 248:
def _generate_key_paths_(
key_path: KeyPath,
tree: Any,
is_leaf: Callable[[Any], bool] | None = None,
) -> Iterable[tuple[KeyPath, Any]]:
if is_leaf and is_leaf(tree):
yield key_path, tree
return
key_handler = _registry_with_keypaths.get(type(tree))
if key_handler:
key_children, _ = key_handler.flatten_with_keys(tree)
for k, c in key_children:
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
return
flat = default_registry.flatten_one_level(tree)
if flat is None:
yield key_path, tree # strict leaf type
return
if (isinstance(tree, tuple) and hasattr(tree, '_fields') and
flat[1] == type(tree)):
# handle namedtuple as a special case, based on heuristic
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
for k, c in key_children:
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
return
for i, c in enumerate(flat[0]):
k = FlattenedIndexKey(i)
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
# handle namedtuple as a special case, based on heuristic
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
for k, c in key_children:
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
else:
yield key_path, tree # strict leaf type
else:
def _generate_key_paths_(
key_path: KeyPath,
tree: Any,
is_leaf: Callable[[Any], bool] | None = None,
) -> Iterable[tuple[KeyPath, Any]]:
if is_leaf and is_leaf(tree):
yield key_path, tree
return
key_handler = _registry_with_keypaths.get(type(tree))
if key_handler:
key_children, _ = key_handler.flatten_with_keys(tree)
for k, c in key_children:
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
elif handler := _registry.get(type(tree)):
children, _ = handler.to_iter(tree)
for i, c in enumerate(children):
k = FlattenedIndexKey(i)
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
# handle namedtuple as a special case, based on heuristic
key_children = [(GetAttrKey(s), getattr(tree, s)) for s in tree._fields]
for k, c in key_children:
yield from _generate_key_paths_((*key_path, k), c, is_leaf)
else:
yield key_path, tree # strict leaf type
def tree_map_with_path(f: Callable[..., Any],
@ -1001,6 +1053,8 @@ def _prefix_error(
# point, and because prefix_tree is not a leaf, each can be flattened once:
prefix_tree_children, prefix_tree_meta = flatten_one_level(prefix_tree)
full_tree_children, full_tree_meta = flatten_one_level(full_tree)
prefix_tree_children = tuple(prefix_tree_children)
full_tree_children = tuple(full_tree_children)
prefix_tree_keys = _child_keys(prefix_tree)
full_tree_keys = _child_keys(full_tree)
# First we check special case types (list and tuple, though if they were