Note that a few call sites in the diff got a ``# type: ignore``, because
the latest jaxlib does not have up-to-date signatures for the correpsonding
callables.
These overloads lead pytype to infer incorrect types for code like this:
```
from typing import Any
import jax.tree_util
def bar(x: Any) -> str:
return repr(x)
reveal_type(jax.tree_util.tree_map(bar, [2, 3, 4]))
```
which deduces `str` when the output is `List[str]`.
* Add py.typed, which makes type annotations available to users.
* Annotate register_pytree_node, tree_map, tree_multimap, and tree_reduce.
* Add a type annotation overload for vjp
* Annotate jax.scipy.special.
* Annotate lax.scan.
and add tests for it. The change has already been landed in the TF code,
where the C++ pytree components live. This is why I needed to bump the
commit.
In general for a `kCustom` node it is not guaranteed that `a.compose(b)` will
have the same `traversal_` as some structure `c` (which is the composition of
`a+b`). We have a real world example in deepmind/dm-haiku with our FlatMapping
type and I've put a simpler example in `tree_util_tests.py`.
Since this test seems largely to be for input validation I've changed this to
compute the expected number of leaves (which is cheaper than using compose as
the previous implementation did) which will catch common errors and is
guaranteed to work for any well formed pytree (additionally I had to fix the
leaf and node count for composed pytrees which were wrong at HEAD).
* Expose `functools.reduce` initializer argument to `tree_util.tree_reduce`.
`functools.reduce` takes an optional `initializer` argument (default=None) which is currently not exposed by `tree_reduce'. This can be useful e.g. for computing an L2 penalty, where you would initialize with 0., and then sum the L2 for each parameter.
Example:
```
def l2_sum(total, param):
return total + jnp.sum(param**2)
tree_reduce(l2_sum, params, 0.)
```
* Only call functools.reduce with initializer when it is not None.
* Change logic to check for number of args to allow None value as initializer
* Rename seq to tree, and add tree_leaves
* Change reduce to functools.reduce.
* Make tree_reduce self-documenting
* Replace jax.tree_leaves with tree_leaves
* Update to use custom sentinel instead of optional position argument
* jax.tree_leaves -> tree_leaves
In Haiku (https://github.com/deepmind/dm-haiku) we have `FlatMapping` which is
an immutable Mapping subclass maintaining a flat internal representation. Our
goal is to allow very cheap flatten/unflatten since these objects are used to
represent parameters/state and are often passed in and out of JAX functions that
flatten their inputs (e.g. jit/pmap).
One challenge we have is that on unflatten we need a fast way of testing whether
the list of leaves provided are flat or not (since we want to cache both the
flat structure and the leaves). Consider the following case:
```python
d = FlatMapping.from_mapping({"a": 1}) # Caches the result of jax.tree_flatten.
l, t = jax.tree_flatten(d) # Fine, leaves are flat.
l = list(map(lambda x: (x, x), l)) # leaves are no longer flat.
d2 = jax.tree_unflatten(t, l) # Needs to recompute structure.
jax.tree_leaves(d2) # Should return [1, 1] not [(1, 1)]
```
Actual implementation here: d37b486e09/haiku/_src/data_structures.py (L204-L208)
This function allows an efficient way to do this using the JAX public API.
Make the C++ version of tree_multimap accept tree suffixes of the primary tree. Document and test this behavior.
Remove unnecessary locking in custom node registry; we hold the GIL already so there's no point to the additional locking.
Move jaxlib version test into jax/lib/__init__.py. Make jax/lib mirror the structure of jaxlib; e.g., xla_client is now available as jax.lib.xla_client.