Currently `keystr` just calls `str` on the key entries, leading to quite
verbose output. For example:
>>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
... for path, _ in jax.tree_util.tree_leaves_with_path(params):
... print(jax.tree_util.keystr(path))
['foo']['bar']['bat'][0]
['foo']['bar']['bat'][1]
['foo']['bar']['baz']
This change allows for a new "simple" format where the string representation
of key entries are further simplified. Additionally we allow a custom
separator since it is very common to use `/` (for example to separate module
and parameter names):
... for path, _ in jax.tree_util.tree_leaves_with_path(params):
... print(jax.tree_util.keystr(path, simple=True, separator='/'))
foo/bar/bat/0
foo/bar/bat/1
foo/bar/baz
```
PiperOrigin-RevId: 717971583
* Make tree_util.tree_flatten_with_path and tree_map_with_path APIs to be C++-based, to speed up the pytree flattening.
* Moves all the key classes down to C++ level, while keeping the APIs unchanged.
* Known small caveats: they are no longer Python dataclasses, and pattern matching might make pytype unhappy.
* Registered defaultdict and ordereddict via the keypath API now.
PiperOrigin-RevId: 701613257
Because strings that got f-string-interpolated could contain curly
braces, `f_string_interpolated_str.format(name=name)` could fail with
difficult-to-understand KeyErrors.
For example, tree_map(..., None, [2, 3]) did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case.
PiperOrigin-RevId: 674019460
A well-behaved registration call must list all ``init=True`` fields in either ``data_fields`` or ``meta_fields``. Otherwise, ``flatten . unflatten`` could potentially *not* be an identity
PiperOrigin-RevId: 669244669
This PR is a follow up to #18881.
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
We have a number of potential use cases where we want different functions that interpret pytrees differently. By allowing multiple pytree registries the same tree node can be registered in registry but not another.
One motivating use case is the new opaque PRNG array type. We want `jit` to treat these objects as if they were pytrees, but we want other transformations to leave them alone or handle them specially.
PiperOrigin-RevId: 549301796
There were two bugs in the _generate_keypaths function underlying tree_flatten_with_path, leading to disagreement between `len(tree_flatten(x)[0])` and `len(tree_flatten_with_path(x)[0])` for some `x`
1. pytree nodes that weren't registered as pytree-nodes-with-keys were treated as leaves
2. namedtuples that were registered as pytree nodes were being flattened as generic namedtuples rather than using the explicitly registered flattener
Empty nodes like [] and {} have 1 node and 0 leaves. This does not make
them a leaf treedef.
Reproducer:
```
pjit.pjit(lambda x: x, None, (None, {}))((3, {'a': []}))
```
For very large trees of custom nodes this printing can be very verbose with a
lot or repetition. Our internal repository also encourages very deep package
names which exacerbates this issue.
Users encounter treedef printing when interacting with some staging APIs in JAX,
for example:
>>> params = { .. some params .. }
>>> f = jax.jit(..).lower(params).compile()
>>> f(params) # fine
>>> params['some_new_thing'] = something
>>> f(params)
TypeError: function compiled for {treedef}, called with {treedef}.
PiperOrigin-RevId: 461190971
Removing this tested but otherwise unused method makes it easier to merge https://github.com/tensorflow/tensorflow/pull/56202 which changes the API contract of (undocumented) method .walk().
Technically speaking changing the contract of .walk() breaks backward compatibility, but we've never advertised its existence and as far as I can tell it has no users in the code I have access to.
PiperOrigin-RevId: 455687311