8 Commits

Author SHA1 Message Date
Ivy Zheng
a1dfdc1d61 C++ tree with path API
* Make tree_util.tree_flatten_with_path and tree_map_with_path APIs to be C++-based, to speed up the pytree flattening.

* Moves all the key classes down to C++ level, while keeping the APIs unchanged.
  * Known small caveats: they are no longer Python dataclasses, and pattern matching might make pytype unhappy.

* Registered defaultdict and ordereddict via the keypath API now.

PiperOrigin-RevId: 701613257
2024-11-30 21:26:48 -08:00
Jake VanderPlas
621e39de27 Set __module__ attribute of jax.numpy.linalg APIs 2024-11-20 10:47:23 -08:00
Jake VanderPlas
f652b6ad6a Set __module__ attribute for objects in jax.numpy 2024-11-15 06:03:54 -08:00
Peter Hawkins
7b53c2f39d Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class.
Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API.

PiperOrigin-RevId: 679163106
2024-09-26 08:39:30 -07:00
jax authors
a207fe9b77 Export KeyPath and related types to jax.tree_util
These types lie on the APIs in `jax.tree_util`, so it makes sense to export them.

PiperOrigin-RevId: 657987755
2024-07-31 06:41:33 -07:00
Jake VanderPlas
1327143d46 Better documentation for jax.tree_util 2024-05-20 19:56:47 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Jake VanderPlas
1c7f8efce6 Add test framework for module attribute 2023-04-21 13:20:16 -07:00