73 Commits

Author SHA1 Message Date
Tom Ward
2135b51109 Expose default PyTree registry.
This allows users to deserialize PyTree definitions using `PyTreeDef.deserialize_using_proto` with the default registry.

PiperOrigin-RevId: 606567659
2024-02-13 03:59:56 -08:00
Jake VanderPlas
e356d76913 Remove a number of deprecated APIs
All of these were deprecated prior to the JAX 0.4.16 release, on Sept 18 2023.
As of Monday Dec 18, we have met the 3 month deprecation period specified by the [API Compatiblity Policy](https://jax.readthedocs.io/en/latest/api_compatibility.html).

PiperOrigin-RevId: 591933493
2023-12-18 10:08:47 -08:00
Jake VanderPlas
3fd204ca0a fix typo in deprecation message 2023-10-03 14:04:09 -07:00
Jake VanderPlas
7c9cca8b53 jax.tree_util: use standard deprecation framework for deprecated items 2023-08-29 15:14:16 -07:00
jax authors
f2d7798a0c Add register_static decorator to tree_util to facilitate creating leafless classes.
PiperOrigin-RevId: 558937697
2023-08-21 16:54:29 -07:00
Tom Hennigan
ed073aa6c9 Add jax.tree_util.tree_leaves_with_path(tree).
PiperOrigin-RevId: 539609052
2023-06-12 04:13:37 -07:00
Ivy Zheng
db025df030 Stop importing old tree_util APIs conveniently and set explicit time for removal.
PiperOrigin-RevId: 521003611
2023-03-31 13:45:10 -07:00
jax authors
ad8c39ad7c Internal change
PiperOrigin-RevId: 513953876
2023-03-04 13:24:11 +00:00
Jake VanderPlas
26f2f97805 Document why 'import name as name' is used 2022-12-14 15:07:04 -08:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Jake VanderPlas
108376d792 Remove deprecated function jax.tree_util.tree_multimap 2022-07-26 09:37:27 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
Matthew Johnson
d57990ecf9 improve pjit in/out_axis_resources pytree errors
This is an application of the utilities in #9372.
2022-02-08 16:23:15 -08:00
Jake VanderPlas
245581411e Add PEP484-compatible export for jax and its subpackages 2021-09-13 14:08:48 -07:00
Sergei Lebedev
af41a959d3 Most of JAX now uses concrete types for things defined in jaxlib.xla_client
Note that a few call sites in the diff got a ``# type: ignore``, because
the latest jaxlib does not have up-to-date signatures for the correpsonding
callables.
2021-08-16 20:33:36 +01:00
Peter Hawkins
6ee6c59235 Move jax.tree_util implementation to jax._src.tree_util.
NFC intended.

PiperOrigin-RevId: 364857920
2021-03-24 12:00:38 -07:00
Jake VanderPlas
e159d67e7e Move flake8 & mypy checks to pre-commit 2021-02-10 13:44:40 -08:00
Jake VanderPlas
5e7be4a61f Cleanup: remove obsolete jaxlib version checks 2021-02-04 15:13:39 -08:00
Jake VanderPlas
2044ba126e Add example of Partial with no arguments 2021-02-04 10:12:26 -08:00
Jake VanderPlas
a921371dc7 DOC: add examples to jax.tree_util.Partial 2021-02-04 10:12:26 -08:00
Thomas Keck
28dd00b010 Adds support for is_leaf in tree_util.tree_map and tree_util.tree_multimap. 2021-01-29 12:05:16 +00:00
Jake VanderPlas
a0b12bba25 DOC: fix minor formatting issues 2021-01-20 14:38:19 -08:00
Peter Hawkins
1abf383fac Remove type overloads for tree_map and tree_multimap.
These overloads lead pytype to infer incorrect types for code like this:

```
from typing import Any
import jax.tree_util

def bar(x: Any) -> str:
  return repr(x)
reveal_type(jax.tree_util.tree_map(bar, [2, 3, 4]))
```

which deduces `str` when the output is `List[str]`.
2021-01-14 17:16:34 -05:00
Neil Girdhar
8dbb406e59 Improve type annotations
* Add py.typed, which makes type annotations available to users.
* Annotate register_pytree_node, tree_map, tree_multimap, and tree_reduce.
* Add a type annotation overload for vjp
* Annotate jax.scipy.special.
* Annotate lax.scan.
2021-01-13 10:26:35 -05:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Adam Paszke
f3bfdf8968 Expose is_leaf predicate for pytree.flatten
and add tests for it. The change has already been landed in the TF code,
where the C++ pytree components live. This is why I needed to bump the
commit.
2020-12-11 11:26:18 +00:00
Peter Hawkins
81b6cd29ff [JAX] Move traceback_util.py into jax._src.
traceback_util is a JAX-private API.

PiperOrigin-RevId: 340659195
2020-11-04 09:02:59 -08:00
Roy Frostig
5d50e19364 add path exclusion opt-in to filtered stack traces and use it throughout the codebase 2020-10-26 12:31:19 -07:00
Tom Hennigan
bf041fbdb1
Compare treedefs by num_leaves not traversal_ in tree_transpose. (#3659)
In general for a `kCustom` node it is not guaranteed that `a.compose(b)` will
have the same `traversal_` as some structure `c` (which is the composition of
`a+b`). We have a real world example in deepmind/dm-haiku with our FlatMapping
type and I've put a simpler example in `tree_util_tests.py`.

Since this test seems largely to be for input validation I've changed this to
compute the expected number of leaves (which is cheaper than using compose as
the previous implementation did) which will catch common errors and is
guaranteed to work for any well formed pytree (additionally I had to fix the
leaf and node count for composed pytrees which were wrong at HEAD).
2020-07-30 13:31:17 -04:00
George Necula
8f93607330
Fix broken links to deleted notebook (#3663)
Fixes #3662.
2020-07-06 09:04:02 +03:00
Matthew Johnson
ae9df752de add docstring to ravel_pytree 2020-06-12 15:41:07 -07:00
Joost Bastings
dc234b6f11
Expose functools.reduce initializer argument to tree_util.tree_reduce (#2935)
* Expose `functools.reduce` initializer argument to `tree_util.tree_reduce`.

`functools.reduce` takes an optional `initializer` argument (default=None) which is currently not exposed by `tree_reduce'. This can be useful e.g. for computing an L2 penalty, where you would initialize with 0., and then sum the L2 for each parameter.

Example:
```
def l2_sum(total, param):
  return total + jnp.sum(param**2)

tree_reduce(l2_sum, params, 0.)
```

* Only call functools.reduce with initializer when it is not None.

* Change logic to check for number of args to allow None value as initializer

* Rename seq to tree, and add tree_leaves

* Change reduce to functools.reduce.

* Make tree_reduce self-documenting

* Replace jax.tree_leaves with tree_leaves

* Update to use custom sentinel instead of optional position argument

* jax.tree_leaves -> tree_leaves
2020-05-05 11:11:10 +03:00
Tom Hennigan
ca23be63fb Add jax.tree_util.all_leaves(iterable).
In Haiku (https://github.com/deepmind/dm-haiku) we have `FlatMapping` which is
an immutable Mapping subclass maintaining a flat internal representation. Our
goal is to allow very cheap flatten/unflatten since these objects are used to
represent parameters/state and are often passed in and out of JAX functions that
flatten their inputs (e.g. jit/pmap).

One challenge we have is that on unflatten we need a fast way of testing whether
the list of leaves provided are flat or not (since we want to cache both the
flat structure and the leaves). Consider the following case:

```python
d = FlatMapping.from_mapping({"a": 1})  # Caches the result of jax.tree_flatten.
l, t = jax.tree_flatten(d)              # Fine, leaves are flat.
l = list(map(lambda x: (x, x), l))      # leaves are no longer flat.
d2 = jax.tree_unflatten(t, l)           # Needs to recompute structure.
jax.tree_leaves(d2)                     # Should return [1, 1] not [(1, 1)]
```

Actual implementation here: d37b486e09/haiku/_src/data_structures.py (L204-L208)

This function allows an efficient way to do this using the JAX public API.
2020-04-01 10:56:01 +03:00
Matthew Johnson
cfbdb65ad8
add register_pytree_node_class, fixes #2396 (#2400)
Co-authored-by: Stephan Hoyer <shoyer@google.com>

Co-authored-by: Stephan Hoyer <shoyer@google.com>
2020-03-10 15:01:18 -07:00
George Necula
a5c3468c93 Added the first draft of the Jaxpr documentation.
This replaces the previous Google Doc version, and is now
updated with the latest changes in Jaxpr.
2020-02-12 13:01:43 +01:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
Peter Hawkins
dcc882cf6b
Drop Python 2 support from JAX. (#1962)
Remove six dependency.
2020-01-08 13:17:55 -05:00
Tom Hennigan
0e552ff371 Register collections.defaultdict as a pytree. (#1908)
PiperOrigin-RevId: 286732270
2019-12-21 15:38:33 -08:00
George Necula
132102498b Minor edit 2019-11-25 09:08:00 +01:00
George Necula
4e89d43a75 Added JAX pytrees notebook
Also added docstrings to the tree_util module.
2019-11-24 20:29:07 +01:00
Matthew Johnson
979b38352f make vmap structured axes work for any pytree 2019-10-31 14:09:12 -07:00
Peter Hawkins
fffec81474 Register OrderedDict as a pytree type. 2019-10-10 10:19:43 -04:00
Matthew Johnson
c760b05f9b update jaxlib version in readme
fixes #1297
will update notebooks in #1260
2019-09-02 07:25:06 -07:00
Pavel Sountsov
3e0e269527 Address review comments. 2019-08-23 19:32:45 -07:00
Pavel Sountsov
b1604459ef Clarify the intended purpose of tree_util.
Most importantly, this removes the initial paragraph which was easy to
misinterpret to imply that this module was not JAX-specific.
2019-08-23 16:54:59 -07:00
Matthew Johnson
b702f8de3e De-tuplify the rest of the core
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2019-08-21 13:21:20 -07:00
Peter Hawkins
cb53ca876f Address review comments. 2019-08-01 16:48:18 -04:00
Peter Hawkins
38bffe9a8b Add a pytreedef.flatten_up_to() method that flattens a PyTree only up to the structure of a PyTreeDef.
Make the C++ version of tree_multimap accept tree suffixes of the primary tree. Document and test this behavior.
Remove unnecessary locking in custom node registry; we hold the GIL already so there's no point to the additional locking.
2019-08-01 12:17:00 -04:00
Peter Hawkins
3c3f01e6d3 Address review comments. 2019-07-30 10:15:37 -04:00