mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Implement flatten one level with keys in C++ and use it for the prefix/equality error printing.
With this, we should be able to safely delete the python with-path registry after a new jaxlib release. Also changed all `std::string_view` to `absl::string_view` per requirements of TF repository. PiperOrigin-RevId: 705669465
This commit is contained in:
parent
eb3ea985b7
commit
ef06607735
@ -25,6 +25,7 @@ from typing import Any, NamedTuple, TypeVar, overload
|
||||
|
||||
from jax._src import traceback_util
|
||||
from jax._src.lib import pytree
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.util import safe_zip, set_module
|
||||
from jax._src.util import unzip2
|
||||
|
||||
@ -607,6 +608,18 @@ def flatten_one_level(tree: Any) -> tuple[Iterable[Any], Hashable]:
|
||||
return out
|
||||
|
||||
|
||||
# flatten_one_level_with_keys is not exported.
|
||||
def flatten_one_level_with_keys(
|
||||
tree: Any,
|
||||
) -> tuple[Iterable[KeyLeafPair], Hashable]:
|
||||
"""Flatten the given pytree node by one level, with keys."""
|
||||
out = default_registry.flatten_one_level_with_keys(tree)
|
||||
if out is None:
|
||||
raise ValueError(f"can't tree-flatten type: {type(tree)}")
|
||||
else:
|
||||
return out
|
||||
|
||||
|
||||
# prefix_errors is not exported
|
||||
def prefix_errors(prefix_tree: Any, full_tree: Any,
|
||||
is_leaf: Callable[[Any], bool] | None = None,
|
||||
@ -728,7 +741,7 @@ def keystr(keys: KeyPath):
|
||||
return ''.join(map(str, keys))
|
||||
|
||||
|
||||
# TODO(ivyzheng): remove this after _child_keys() also moved to C++.
|
||||
# TODO(ivyzheng): remove this after another jaxlib release.
|
||||
class _RegistryWithKeypathsEntry(NamedTuple):
|
||||
flatten_with_keys: Callable[..., Any]
|
||||
unflatten_func: Callable[..., Any]
|
||||
@ -1146,6 +1159,8 @@ def tree_map_with_path(f: Callable[..., Any],
|
||||
|
||||
def _child_keys(pytree: Any) -> KeyPath:
|
||||
assert not treedef_is_strict_leaf(tree_structure(pytree))
|
||||
if xla_extension_version >= 301:
|
||||
return tuple(k for k, _ in flatten_one_level_with_keys(pytree)[0])
|
||||
handler = _registry_with_keypaths.get(type(pytree))
|
||||
if handler:
|
||||
return tuple(k for k, _ in handler.flatten_with_keys(pytree)[0])
|
||||
|
Loading…
x
Reference in New Issue
Block a user