73 Commits

Author SHA1 Message Date
Tom Hennigan
7f43316e27 Add an option to simplify keystr output and use a custom separator.
Currently `keystr` just calls `str` on the key entries, leading to quite
verbose output. For example:

    >>> params = {'foo': {'bar': {'baz': 1, 'bat': [2, 3]}}}
    ... for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path))
    ['foo']['bar']['bat'][0]
    ['foo']['bar']['bat'][1]
    ['foo']['bar']['baz']

This change allows for a new "simple" format where the string representation
of key entries are further simplified. Additionally we allow a custom
separator since it is very common to use `/` (for example to separate module
and parameter names):

    ... for path, _ in jax.tree_util.tree_leaves_with_path(params):
    ...   print(jax.tree_util.keystr(path, simple=True, separator='/'))
    foo/bar/bat/0
    foo/bar/bat/1
    foo/bar/baz
```

PiperOrigin-RevId: 717971583
2025-01-21 10:18:42 -08:00
Ivy Zheng
26c40fadfd Add jax.tree shortcuts for .*_with_path calls, for convenience of users.
PiperOrigin-RevId: 705645570
2024-12-12 15:13:32 -08:00
Peter Hawkins
79318a08cf Remove dead code after minimum jaxlib version bump to v0.4.36.
New minimum xla_extension_version is 299, and the new mlir_api_version is 57.

PiperOrigin-RevId: 704280856
2024-12-09 07:35:05 -08:00
Ivy Zheng
a1dfdc1d61 C++ tree with path API
* Make tree_util.tree_flatten_with_path and tree_map_with_path APIs to be C++-based, to speed up the pytree flattening.

* Moves all the key classes down to C++ level, while keeping the APIs unchanged.
  * Known small caveats: they are no longer Python dataclasses, and pattern matching might make pytype unhappy.

* Registered defaultdict and ordereddict via the keypath API now.

PiperOrigin-RevId: 701613257
2024-11-30 21:26:48 -08:00
Jake VanderPlas
1af3b01c1c register_dataclass: allow marking static fields via field(static=True) 2024-11-06 11:18:11 -08:00
jax authors
e461c0496f Merge pull request #23684 from simonster:sjk/fix-prefix-error
PiperOrigin-RevId: 686133952
2024-10-15 09:32:30 -07:00
Peter Hawkins
d3f63a66b8 Remove code to support jaxlib <= 0.4.33. 2024-10-04 11:39:05 -04:00
Peter Hawkins
111f13e279 Reverts dffac29e63de6a51047fe77cf9d553ab762ef19b
PiperOrigin-RevId: 678748794
2024-09-25 10:14:45 -07:00
Michael Hudgins
d4d1518c3d Update references to the GitHub url in JAX codebase to reflect move from google/jax to jax-ml/jax
PiperOrigin-RevId: 676843138
2024-09-20 07:52:33 -07:00
Simon Kornblith
56a0bfa6af
Fix errors when prefix does not match pytree
Because strings that got f-string-interpolated could contain curly
braces, `f_string_interpolated_str.format(name=name)` could fail with
difficult-to-understand KeyErrors.
2024-09-16 22:25:28 -07:00
Peter Hawkins
dffac29e63 Reverts 255c30303d32e7473262b2e35348175c87e4348f
PiperOrigin-RevId: 674083626
2024-09-12 18:14:25 -07:00
Peter Hawkins
255c30303d Fix a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs.
For example, tree_map(..., None, [2, 3]) did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case.

PiperOrigin-RevId: 674019460
2024-09-12 14:49:18 -07:00
Sergei Lebedev
02bb884357 `jax.tree_util.register_dataclass now validates data_fields and meta_fields`
A well-behaved registration call must list all ``init=True`` fields in either ``data_fields`` or ``meta_fields``. Otherwise, ``flatten . unflatten`` could potentially *not* be an identity

PiperOrigin-RevId: 669244669
2024-08-30 02:01:50 -07:00
Jake VanderPlas
814b32a44b tree_all: add support for is_leaf 2024-06-10 09:46:15 -07:00
Jake VanderPlas
61ef800618 tree_util: improve tests of jax.tree aliases 2024-06-04 11:56:25 -07:00
Jake VanderPlas
6d5668db45 jax.tree_util: test serialize_using_proto 2024-05-21 12:37:37 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
Enrique Piqueras
cf9c08589e Add builtin cc dataclass pytree node for performance.
PiperOrigin-RevId: 627502102
2024-04-23 14:14:49 -07:00
Peter Hawkins
f759452219 [XLA:Python] Improve error checking for the return value of the to_iterable function of custom pytree nodes.
PiperOrigin-RevId: 617066587
2024-03-18 23:23:59 -07:00
Peter Hawkins
fdbee314d3 Make JAX tests that check for errors from dict key comparators in pytrees more relaxed, in preparation for https://github.com/openxla/xla/pull/9529.
PiperOrigin-RevId: 610819296
2024-02-27 11:30:10 -08:00
Jake VanderPlas
6ffea0ba1f tree_transpose: optionally infer inner_treedef 2024-02-15 12:01:21 -08:00
Jake VanderPlas
6934a4b76b Add jax.tree module with aliases of jax.tree_util 2024-02-12 13:07:59 -08:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Matthew Johnson
d2fcf27f93 must flatten defaultdict in key-sorted order, like regular dicts 2023-12-08 10:10:09 -08:00
Neil Girdhar
3c920c0120 Switch from flake8 to Ruff 2023-11-15 22:35:52 -05:00
jax authors
d45fa22424 Add tests to cover PyTreeDef.flatten_up_to error scenarios.
Also improve coverage of `PyTreeDef.flatten_up_to` success scenarios.

PiperOrigin-RevId: 570152827
2023-10-02 13:00:55 -07:00
jax authors
f2d7798a0c Add register_static decorator to tree_util to facilitate creating leafless classes.
PiperOrigin-RevId: 558937697
2023-08-21 16:54:29 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Peter Hawkins
cdb48134e5 [JAX] Add support for multiple pytree registries.
We have a number of potential use cases where we want different functions that interpret pytrees differently. By allowing multiple pytree registries the same tree node can be registered in registry but not another.

One motivating use case is the new opaque PRNG array type. We want `jit` to treat these objects as if they were pytrees, but we want other transformations to leave them alone or handle them specially.

PiperOrigin-RevId: 549301796
2023-07-19 06:48:21 -07:00
Tom Hennigan
ed073aa6c9 Add jax.tree_util.tree_leaves_with_path(tree).
PiperOrigin-RevId: 539609052
2023-06-12 04:13:37 -07:00
Matthew Johnson
42b2a80df2 add a test for tree_reduce with is_leaf argument 2023-05-16 15:37:52 -07:00
Matthew Johnson
da3799959a separate register_pytree_node and register_pytree_with_keys tests 2023-03-20 20:05:47 -07:00
Matthew Johnson
82c0035a50 [pytrees] fix function underlying tree-flattening with keys
There were two bugs in the _generate_keypaths function underlying tree_flatten_with_path, leading to disagreement between `len(tree_flatten(x)[0])` and `len(tree_flatten_with_path(x)[0])` for some `x`
1. pytree nodes that weren't registered as pytree-nodes-with-keys were treated as leaves
2. namedtuples that were registered as pytree nodes were being flattened as generic namedtuples rather than using the explicitly registered flattener
2023-03-17 19:12:32 -07:00
Ivy Zheng
08c83369be Add an optional flatten_func argument to custom node registration even when flatten_with_keys is given, for better perf for those in need.
Fixes #14844

PiperOrigin-RevId: 517308676
2023-03-16 21:35:10 -07:00
Matthew Johnson
a6d3ae1446 use Partial to make ravel_pytree unflatteners jit-friendly
Co-authored-by: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com>
2023-03-13 11:06:56 -07:00
jax authors
ad8c39ad7c Internal change
PiperOrigin-RevId: 513953876
2023-03-04 13:24:11 +00:00
Matthew Johnson
cea2b6b6f8 specialize tree prefix error message for list/tuple 2023-01-20 10:51:02 -08:00
Qiao Zhang
d58266eac7 Store sorted flattened dict keys in PyTree as a c++ vector instead of py::list to avoid creating new python object on every single dict flatten. For deeply nested dict, this avoids excessive gc pressure and avoids the slowdown whenever gc needs to sweep too many live python objects.
PiperOrigin-RevId: 502967020
2023-01-18 13:40:43 -08:00
Peter Hawkins
320d531521 Increase the minimum jaxlib version to 0.3.22.
The minimum xla_extension_version is now 98 and the minimum mlir_api_version is now 32.
2022-10-27 10:24:11 -04:00
Matthew Johnson
b27acedf1f add more info to pytree prefix key errors
fixes #12643
2022-10-11 12:34:03 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
jax authors
e1b31f82fd Make PyTreeDef pickleable
PiperOrigin-RevId: 465306184
2022-08-04 07:13:46 -07:00
Roy Frostig
8677d99267 promise to flatten trees in left-to-right order 2022-07-28 19:28:20 -07:00
Parker Schuh
6c4da65af4 Add treedef_is_strict_leaf to fix _prefix_error's semantics.
Empty nodes like [] and {} have 1 node and 0 leaves. This does not make
them a leaf treedef.

Reproducer:
```
pjit.pjit(lambda x: x, None, (None, {}))((3, {'a': []}))
```
2022-07-20 17:02:59 -07:00
Tom Hennigan
10720258ea Reduce the verbosity of treedef printing for custom nodes.
For very large trees of custom nodes this printing can be very verbose with a
lot or repetition. Our internal repository also encourages very deep package
names which exacerbates this issue.

Users encounter treedef printing when interacting with some staging APIs in JAX,
for example:

    >>> params = { .. some params .. }
    >>> f = jax.jit(..).lower(params).compile()
    >>> f(params)  # fine
    >>> params['some_new_thing'] = something
    >>> f(params)
    TypeError: function compiled for {treedef}, called with {treedef}.

PiperOrigin-RevId: 461190971
2022-07-15 07:14:28 -07:00
Peter Hawkins
7782feeb6c [JAX] Drop the private _process_pytree method from tree_util.
Removing this tested but otherwise unused method makes it easier to merge https://github.com/tensorflow/tensorflow/pull/56202 which changes the API contract of (undocumented) method .walk().

Technically speaking changing the contract of .walk() breaks backward compatibility, but we've never advertised its existence and as far as I can tell it has no users in the code I have access to.

PiperOrigin-RevId: 455687311
2022-06-17 13:50:50 -07:00
Jake VanderPlas
6efb03cf0d [x64] make tree_util_test compatible with strict dtype promotion 2022-06-14 15:14:44 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
Peter Hawkins
c978df5550 Increase minimum jaxlib version to 0.3.0. 2022-03-04 10:33:03 -05:00