diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 4991577c6..1ead7be9a 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -722,22 +722,49 @@ FlattenedIndexKey: Any = pytree.FlattenedIndexKey # type: ignore @export -def keystr(keys: KeyPath): +def keystr(keys: KeyPath, *, simple: bool = False, separator: str = '') -> str: """Helper to pretty-print a tuple of keys. Args: keys: A tuple of ``KeyEntry`` or any class that can be converted to string. + simple: If True, use a simplified string representation for keys. The + simple representation of keys will be more compact than the default, but + is ambiguous in some cases (for example "0" might refer to the first item + in a list or a dictionary key for the integer 0 or string "0"). + separator: The separator to use to join string representations of the keys. Returns: A string that joins all string representations of the keys. Examples: >>> import jax - >>> keys = (0, 1, 'a', 'b') - >>> jax.tree_util.keystr(keys) - '01ab' + >>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}} + >>> for path, _ in jax.tree_util.tree_leaves_with_path(params): + ... print(jax.tree_util.keystr(path)) + ['foo']['bar']['bat'][0] + ['foo']['bar']['bat'][1] + ['foo']['bar']['baz'] + >>> for path, _ in jax.tree_util.tree_leaves_with_path(params): + ... print(jax.tree_util.keystr(path, simple=True, separator='/')) + foo/bar/bat/0 + foo/bar/bat/1 + foo/bar/baz """ - return ''.join(map(str, keys)) + str_fn = _simple_entrystr if simple else str + return separator.join(map(str_fn, keys)) + + +def _simple_entrystr(key: KeyEntry) -> str: + match key: + case ( + SequenceKey(idx=key) + | DictKey(key=key) + | GetAttrKey(name=key) + | FlattenedIndexKey(key=key) + ): + return str(key) + case _: + return str(key) # TODO(ivyzheng): remove this after another jaxlib release. diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 834af9c5f..e5e649d43 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -728,6 +728,19 @@ class TreeTest(jtu.JaxTestCase): ], ) + strs = [f"{tree_util.keystr(kp, simple=True, separator='/')}: {x}" + for kp, x in flattened] + self.assertEqual( + strs, + [ + "0/foo: 12", + "0/bar/cin/0: 1", + "0/bar/cin/1: 4", + "0/bar/cin/2: 10", + "1: [0 1 2 3 4]", + ], + ) + def testTreeMapWithPathWithIsLeafArgument(self): x = ((1, 2), [3, 4, 5]) y = (([3], jnp.array(0)), ([0], 7, [5, 6]))