73 Commits

Author SHA1 Message Date
Matthew Johnson
d57990ecf9 improve pjit in/out_axis_resources pytree errors
This is an application of the utilities in #9372.
2022-02-08 16:23:15 -08:00
Matthew Johnson
e186aa3f1e add and test pytree utils for better errors 2022-02-03 17:04:38 -08:00
Peter Hawkins
042c9bd7a5 Ensure that tree_util.Partial's .func attribute is stable.
Fixes #9429.
2022-02-03 10:44:13 -05:00
Peter Hawkins
9bc6d1103e [JAX] Fix spurious inequality for two apparently equal PyTreeDefs.
When constructed via one path we were filling in the .custom field of nodes that weren't custom types.

Fixes https://github.com/google/jax/issues/9066

PiperOrigin-RevId: 420858917
2022-01-10 14:35:56 -08:00
Tom Hennigan
2f62574e8e Add is_leaf to tree_{leaves,structure}.
PiperOrigin-RevId: 417783880
2021-12-22 02:56:56 -08:00
Matthew Johnson
984e8d79f0 make ravel_pytree unraveler dtype-polymorphic
fixes #7809
2021-12-14 14:47:35 -08:00
Peter Hawkins
3fd3c46f20 Increase minimum jaxlib version to 0.1.74. 2021-11-18 15:06:58 -05:00
Peter Hawkins
29447ed261 Fixes for Python 3.10.
With these changes, the JAX test suite passes on Python 3.10.
2021-10-05 15:25:28 -04:00
Peter Hawkins
256e7220ff [JAX] Fix pylint errors.
* trailing-whitespace
* dangerous-default-value. None of these appear to be bugs in practice, but the potential for accidentally mutating the default value is there, and the cost of avoiding the problem is small.
* invalid-envvar-default. Pass strings as getenv() defaults.
* unnecessary-semicolon. Use tuples instead for this one-liner.
* invalid-hash-returned. Raise an exception rather than asserting false.
* pointless-string-statement. Use comments instead.
* unreachable. Use @unittest.skip() decorator rather than raising as first line in test.
* logging-not-lazy. Make the logging lazy.
* bad-format-string-type. Use f-string instead.
* subprocess-run-check. Pass check=...

PiperOrigin-RevId: 400858477
2021-10-04 17:54:46 -07:00
Peter Hawkins
db2e91eba2 Move jax.test_util to jax._src.test_util.
Add forwarding shims for names used by external clients of JAX in practice.

PiperOrigin-RevId: 398721725
2021-09-24 07:02:49 -07:00
Peter Hawkins
2c2f4033cc Move contents of jax.lib to jax._src.lib.
Add shim libraries for functions exported from jax.lib that other code seems to use in practice.

PiperOrigin-RevId: 398471863
2021-09-23 06:33:55 -07:00
Markus Kunesch
f0219d6a26 xla: fix the string representation of empty dict in PyTreeDef.
This commit fixes a bug in the string representation of empty dictionaries in a
PyTreeDef (the opening brace was missing).

PiperOrigin-RevId: 391083297
2021-08-16 10:43:36 -07:00
Roy Frostig
86c48ccb7c [jax] set leaf and node counts when creating a tuple pytree definition
PiperOrigin-RevId: 388479354
2021-08-03 09:54:28 -07:00
Peter Hawkins
3ddcec27f2 Update minimum jaxlib version to 0.1.69. 2021-07-15 17:00:13 -04:00
Peter Hawkins
75c9bf01f3 Fix most test failures under NumPy 1.21. 2021-06-22 16:31:44 -04:00
Qiao Zhang
99c142a0f3 Use assertRegex in TreeTest since pytest expects a different module prefix. 2021-04-20 14:10:30 -07:00
Markus Kunesch
f030e70e82 xla: improvement to string representation of PyTreeDef
The string representation of PyTreeDef was different to how the underlying
containers are represented in python. This sometimes made it harder to read
error messages. This commit modifies the representation of tuples, lists,
dicts, and None so that it matches the pythonic representation.

The representation of custom nodes and NamedTuples is left unchanged since
their structure is not easily accessible in C++. However, to avoid confusion
they are now labelled "CustomNode" instead of "PyTreeDef". The latter is now
only used to wrap the whole representation. See below for examples.

Tests that relied on a specific string representation of PyTreeDef in error
messages are modified to be agnostic to the representation. Instead, this
commit adds a separate test of the string representation in tree_util_test.

Examples:

```
OLD: PyTreeDef(dict[['a', 'b']], [*,*])
NEW: PyTreeDef({'a': *, 'b': *})

OLD: PyTreeDef(tuple, [PyTreeDef(tuple, [*,*]),PyTreeDef(list, [*,PyTreeDef(tuple, [*,PyTreeDef(None, []),*])])])
NEW: PyTreeDef(((*, *), [*, (*, None, *)]))

OLD: PyTreeDef(list, [PyTreeDef(<class '__main__.AnObject'>[[4, 'foo']], [*,PyTreeDef(None, [])])])
NEW: PyTreeDef([CustomNode(<class '__main__.AnObject'>[[4, 'foo']], [*, None])])
```
PiperOrigin-RevId: 369298267
2021-04-19 14:06:36 -07:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Peter Hawkins
6ee6c59235 Move jax.tree_util implementation to jax._src.tree_util.
NFC intended.

PiperOrigin-RevId: 364857920
2021-03-24 12:00:38 -07:00
Matthew Johnson
5c6ff67e4e generalize ravel_pytree to handle int types, add tests 2021-03-19 10:50:02 -07:00
Jake VanderPlas
5e7be4a61f Cleanup: remove obsolete jaxlib version checks 2021-02-04 15:13:39 -08:00
jax authors
10cff5f2bf Merge pull request #5621 from NathanHowell:enable-tree-util-tests
PiperOrigin-RevId: 355640225
2021-02-04 09:11:54 -08:00
Nathan Howell
36cbf0302f Rename {tree_util_tests.py->tree_util_test.py} 2021-02-03 16:27:24 -08:00