12 Commits

Author SHA1 Message Date
Peter Hawkins
e7e5140dc9 Move implementation of jax.flatten_util to jax._src.flatten_util. Add a jax.flatten_util shim.
Change as part of cleaning up the jax.* namespace.

PiperOrigin-RevId: 395551093
2021-09-08 13:54:25 -07:00
Peter Hawkins
6a6f13e1b0 [JAX] Move contents of jax/dtypes.py to jax/_src/dtypes.py.
PiperOrigin-RevId: 367345623
2021-04-07 19:35:51 -07:00
Matthew Johnson
5c6ff67e4e generalize ravel_pytree to handle int types, add tests 2021-03-19 10:50:02 -07:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Matthew Johnson
159a61b2f7 deflake 2020-06-12 15:41:49 -07:00
Matthew Johnson
ae9df752de add docstring to ravel_pytree 2020-06-12 15:41:07 -07:00
Peter Hawkins
b1bc841ae5
Replace np -> jnp, onp -> np in more places. (#2973)
* Replace np -> jnp, onp -> np in more places.

Context: #2370

* Fix typo in random_test.py
2020-05-05 16:40:41 -04:00
Peter Hawkins
e60d5dd54c
Remove "from __future__" uses from JAX. (#2117)
The future (Python 3) has arrived; no need to request it explicitly.
2020-01-29 12:29:03 -05:00
George Necula
528a69f32e Added some more documentation to the linear_util module
Also cleaned up the inconsistent way of importing the module.
Prefer importing with qualified name 'lu.transformation' rather
than just 'transformation'.
2020-01-05 16:40:26 +01:00
Matthew Johnson
9c2e1c35b1 prevent jit from treating keyword args as static
fixes #523
2019-04-10 22:09:14 -07:00
Matthew Johnson
0f7c7c4eab generalize jacfwd and jacrev to handle pytrees 2019-01-06 12:49:41 -08:00
Matthew Johnson
ad4322c5da playing around with flattening functions 2019-01-06 12:49:35 -08:00