Add an option to simplify keystr output and use a custom separator.

Currently `keystr` just calls `str` on the key entries, leading to quite
verbose output. For example:

    >>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
    ... for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path))
    ['foo']['bar']['bat'][0]
    ['foo']['bar']['bat'][1]
    ['foo']['bar']['baz']

This change allows for a new "simple" format where the string representation
of key entries are further simplified. Additionally we allow a custom
separator since it is very common to use `/` (for example to separate module
and parameter names):

    ... for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path, simple=True, separator='/'))
    foo/bar/bat/0
    foo/bar/bat/1
    foo/bar/baz
```

PiperOrigin-RevId: 717971583
This commit is contained in:
Tom Hennigan 2025-01-21 10:18:04 -08:00 committed by jax authors
parent 96a3ed36c7
commit 7f43316e27
2 changed files with 45 additions and 5 deletions

View File

@ -722,22 +722,49 @@ FlattenedIndexKey: Any = pytree.FlattenedIndexKey # type: ignore
@export
def keystr(keys: KeyPath):
def keystr(keys: KeyPath, *, simple: bool = False, separator: str = '') -> str:
"""Helper to pretty-print a tuple of keys.
Args:
keys: A tuple of ``KeyEntry`` or any class that can be converted to string.
simple: If True, use a simplified string representation for keys. The
simple representation of keys will be more compact than the default, but
is ambiguous in some cases (for example "0" might refer to the first item
in a list or a dictionary key for the integer 0 or string "0").
separator: The separator to use to join string representations of the keys.
Returns:
A string that joins all string representations of the keys.
Examples:
>>> import jax
>>> keys = (0, 1, 'a', 'b')
>>> jax.tree_util.keystr(keys)
'01ab'
>>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
... print(jax.tree_util.keystr(path))
['foo']['bar']['bat'][0]
['foo']['bar']['bat'][1]
['foo']['bar']['baz']
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
... print(jax.tree_util.keystr(path, simple=True, separator='/'))
foo/bar/bat/0
foo/bar/bat/1
foo/bar/baz
"""
return ''.join(map(str, keys))
str_fn = _simple_entrystr if simple else str
return separator.join(map(str_fn, keys))
def _simple_entrystr(key: KeyEntry) -> str:
match key:
case (
SequenceKey(idx=key)
| DictKey(key=key)
| GetAttrKey(name=key)
| FlattenedIndexKey(key=key)
):
return str(key)
case _:
return str(key)
# TODO(ivyzheng): remove this after another jaxlib release.

View File

@ -728,6 +728,19 @@ class TreeTest(jtu.JaxTestCase):
],
)
strs = [f"{tree_util.keystr(kp, simple=True, separator='/')}: {x}"
for kp, x in flattened]
self.assertEqual(
strs,
[
"0/foo: 12",
"0/bar/cin/0: 1",
"0/bar/cin/1: 4",
"0/bar/cin/2: 10",
"1: [0 1 2 3 4]",
],
)
def testTreeMapWithPathWithIsLeafArgument(self):
x = ((1, 2), [3, 4, 5])
y = (([3], jnp.array(0)), ([0], 7, [5, 6]))