mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add an option to simplify keystr
output and use a custom separator.
Currently `keystr` just calls `str` on the key entries, leading to quite verbose output. For example: >>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}} ... for path, _ in jax.tree_util.tree_leaves_with_path(params): ... print(jax.tree_util.keystr(path)) ['foo']['bar']['bat'][0] ['foo']['bar']['bat'][1] ['foo']['bar']['baz'] This change allows for a new "simple" format where the string representation of key entries are further simplified. Additionally we allow a custom separator since it is very common to use `/` (for example to separate module and parameter names): ... for path, _ in jax.tree_util.tree_leaves_with_path(params): ... print(jax.tree_util.keystr(path, simple=True, separator='/')) foo/bar/bat/0 foo/bar/bat/1 foo/bar/baz ``` PiperOrigin-RevId: 717971583
This commit is contained in:
parent
96a3ed36c7
commit
7f43316e27
@ -722,22 +722,49 @@ FlattenedIndexKey: Any = pytree.FlattenedIndexKey # type: ignore
|
||||
|
||||
|
||||
@export
|
||||
def keystr(keys: KeyPath):
|
||||
def keystr(keys: KeyPath, *, simple: bool = False, separator: str = '') -> str:
|
||||
"""Helper to pretty-print a tuple of keys.
|
||||
|
||||
Args:
|
||||
keys: A tuple of ``KeyEntry`` or any class that can be converted to string.
|
||||
simple: If True, use a simplified string representation for keys. The
|
||||
simple representation of keys will be more compact than the default, but
|
||||
is ambiguous in some cases (for example "0" might refer to the first item
|
||||
in a list or a dictionary key for the integer 0 or string "0").
|
||||
separator: The separator to use to join string representations of the keys.
|
||||
|
||||
Returns:
|
||||
A string that joins all string representations of the keys.
|
||||
|
||||
Examples:
|
||||
>>> import jax
|
||||
>>> keys = (0, 1, 'a', 'b')
|
||||
>>> jax.tree_util.keystr(keys)
|
||||
'01ab'
|
||||
>>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
|
||||
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
|
||||
... print(jax.tree_util.keystr(path))
|
||||
['foo']['bar']['bat'][0]
|
||||
['foo']['bar']['bat'][1]
|
||||
['foo']['bar']['baz']
|
||||
>>> for path, _ in jax.tree_util.tree_leaves_with_path(params):
|
||||
... print(jax.tree_util.keystr(path, simple=True, separator='/'))
|
||||
foo/bar/bat/0
|
||||
foo/bar/bat/1
|
||||
foo/bar/baz
|
||||
"""
|
||||
return ''.join(map(str, keys))
|
||||
str_fn = _simple_entrystr if simple else str
|
||||
return separator.join(map(str_fn, keys))
|
||||
|
||||
|
||||
def _simple_entrystr(key: KeyEntry) -> str:
|
||||
match key:
|
||||
case (
|
||||
SequenceKey(idx=key)
|
||||
| DictKey(key=key)
|
||||
| GetAttrKey(name=key)
|
||||
| FlattenedIndexKey(key=key)
|
||||
):
|
||||
return str(key)
|
||||
case _:
|
||||
return str(key)
|
||||
|
||||
|
||||
# TODO(ivyzheng): remove this after another jaxlib release.
|
||||
|
@ -728,6 +728,19 @@ class TreeTest(jtu.JaxTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
strs = [f"{tree_util.keystr(kp, simple=True, separator='/')}: {x}"
|
||||
for kp, x in flattened]
|
||||
self.assertEqual(
|
||||
strs,
|
||||
[
|
||||
"0/foo: 12",
|
||||
"0/bar/cin/0: 1",
|
||||
"0/bar/cin/1: 4",
|
||||
"0/bar/cin/2: 10",
|
||||
"1: [0 1 2 3 4]",
|
||||
],
|
||||
)
|
||||
|
||||
def testTreeMapWithPathWithIsLeafArgument(self):
|
||||
x = ((1, 2), [3, 4, 5])
|
||||
y = (([3], jnp.array(0)), ([0], 7, [5, 6]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user