Implement flatten one level with keys in C++ and use it for the prefix/equality error printing.

With this, we should be able to safely delete the python with-path registry after a new jaxlib release.

Also changed all `std::string_view` to `absl::string_view` per requirements of TF repository.

PiperOrigin-RevId: 705669465
This commit is contained in:
Ivy Zheng 2024-12-12 16:36:32 -08:00 committed by jax authors
parent eb3ea985b7
commit ef06607735

View File

@ -25,6 +25,7 @@ from typing import Any, NamedTuple, TypeVar, overload
from jax._src import traceback_util
from jax._src.lib import pytree
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip, set_module
from jax._src.util import unzip2
@ -607,6 +608,18 @@ def flatten_one_level(tree: Any) -> tuple[Iterable[Any], Hashable]:
return out
# flatten_one_level_with_keys is not exported.
def flatten_one_level_with_keys(
tree: Any,
) -> tuple[Iterable[KeyLeafPair], Hashable]:
"""Flatten the given pytree node by one level, with keys."""
out = default_registry.flatten_one_level_with_keys(tree)
if out is None:
raise ValueError(f"can't tree-flatten type: {type(tree)}")
else:
return out
# prefix_errors is not exported
def prefix_errors(prefix_tree: Any, full_tree: Any,
is_leaf: Callable[[Any], bool] | None = None,
@ -728,7 +741,7 @@ def keystr(keys: KeyPath):
return ''.join(map(str, keys))
# TODO(ivyzheng): remove this after _child_keys() also moved to C++.
# TODO(ivyzheng): remove this after another jaxlib release.
class _RegistryWithKeypathsEntry(NamedTuple):
flatten_with_keys: Callable[..., Any]
unflatten_func: Callable[..., Any]
@ -1146,6 +1159,8 @@ def tree_map_with_path(f: Callable[..., Any],
def _child_keys(pytree: Any) -> KeyPath:
assert not treedef_is_strict_leaf(tree_structure(pytree))
if xla_extension_version >= 301:
return tuple(k for k, _ in flatten_one_level_with_keys(pytree)[0])
handler = _registry_with_keypaths.get(type(pytree))
if handler:
return tuple(k for k, _ in handler.flatten_with_keys(pytree)[0])