21 Commits

Author SHA1 Message Date
Peter Hawkins
7782feeb6c [JAX] Drop the private _process_pytree method from tree_util.
Removing this tested but otherwise unused method makes it easier to merge https://github.com/tensorflow/tensorflow/pull/56202 which changes the API contract of (undocumented) method .walk().

Technically speaking changing the contract of .walk() breaks backward compatibility, but we've never advertised its existence and as far as I can tell it has no users in the code I have access to.

PiperOrigin-RevId: 455687311
2022-06-17 13:50:50 -07:00
Peter Hawkins
a30cfed875 Improve documentation for jax.tree_util.tree_map.
Add some examples.
2022-06-10 12:38:57 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
jax authors
375777f43c Merge pull request #9569 from GJBoth:tree_flatten_docs
PiperOrigin-RevId: 441878577
2022-04-14 16:09:47 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
Gert-Jan Both
add6c8254e Clarify tree_flatten docstring. 2022-02-21 07:50:49 +00:00
Matthew Johnson
d57990ecf9 improve pjit in/out_axis_resources pytree errors
This is an application of the utilities in #9372.
2022-02-08 16:23:15 -08:00
Matthew Johnson
e186aa3f1e add and test pytree utils for better errors 2022-02-03 17:04:38 -08:00
Peter Hawkins
042c9bd7a5 Ensure that tree_util.Partial's .func attribute is stable.
Fixes #9429.
2022-02-03 10:44:13 -05:00
Matthew Johnson
d9dcd1394a djax: let make_jaxpr build dyn shape jaxprs 2022-02-01 00:10:21 -08:00
Tom Hennigan
2f62574e8e Add is_leaf to tree_{leaves,structure}.
PiperOrigin-RevId: 417783880
2021-12-22 02:56:56 -08:00
Peter Hawkins
4e21922055 Use imports relative to the jax package consistently, rather than .-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
2021-11-24 07:48:29 -08:00
Neil Girdhar
832cf214e3 Fix jacfwd and jacrev for heterogeneous pytrees
Changed the behavior of `jacfwd`, `jacrev`, and `grad` when the input
pytree elements have heterogeneous dtypes, e.g., real and complex
elements:

* Changed the dtypes of the pytree elements of the Jacobian produced by
  jacfwd to be those of the input tangent basis.

* Changed the dtypes of the pytree elements of the Jacobian produced by
  jacrev to be those of the output tangent basis.

* Changed the dtypes of the pytree elements of the primals and tangents
  produced by jacfwd and jacrev to be the same as the corresponding
  elements in the input.

Changed the behavior of the flags to `jacfwd` and `jacrev`:

* Changed the allow_int flag to only allows integer and Boolean dtypes.
  Previously, this flag allowed all other types.
2021-10-12 19:41:47 -04:00
Peter Hawkins
29447ed261 Fixes for Python 3.10.
With these changes, the JAX test suite passes on Python 3.10.
2021-10-05 15:25:28 -04:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Jake VanderPlas
00f36173bd Specify weak_type in DeviceArray repr 2021-08-23 13:19:33 -07:00
Sergei Lebedev
af41a959d3 Most of JAX now uses concrete types for things defined in jaxlib.xla_client
Note that a few call sites in the diff got a ``# type: ignore``, because
the latest jaxlib does not have up-to-date signatures for the correpsonding
callables.
2021-08-16 20:33:36 +01:00
Roy Frostig
b8f9dd6269 unify tree_map and tree_multimap 2021-04-28 19:59:31 -07:00
Peter Hawkins
6ee6c59235 Move jax.tree_util implementation to jax._src.tree_util.
NFC intended.

PiperOrigin-RevId: 364857920
2021-03-24 12:00:38 -07:00