24 Commits

Author SHA1 Message Date
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Matthew Johnson
b7426b5ef9 rolling forward deletion of custom_jvp_call_jaxpr_p yet again...
PiperOrigin-RevId: 468541924
2022-08-18 14:02:40 -07:00
jax authors
03e2ca0ee7 roll-forward deletion of custom_jvp_call_jaxpr_p
PiperOrigin-RevId: 468522879
2022-08-18 12:39:21 -07:00
Matthew Johnson
3a20de1575 roll-forward deletion of custom_jvp_call_jaxpr_p
PiperOrigin-RevId: 468499658
2022-08-18 11:01:10 -07:00
jax authors
fe665b3a64 Copybara import of the project:
--
887b7ce2cb3d6d8aedac5cc273e137f1c876e3c7 by Matthew Johnson <mattjj@google.com>:

remove custom_jvp_call_jaxpr_p and its rules

They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).

This change languished until we could land #11830 / #11950 and friends. But now
we can!

PiperOrigin-RevId: 468373797
2022-08-17 22:40:58 -07:00
Matthew Johnson
887b7ce2cb remove custom_jvp_call_jaxpr_p and its rules
They were superfluous! Instead use the "new" mechanism for converting from
jaxpr params to bind params (in #9136).

This change languished until we could land #11830 / #11950 and friends. But now
we can!
2022-08-17 21:12:27 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Peter Hawkins
8b2123968a Switch internal users of jax.util.partial to use functools.partial. 2021-09-13 21:09:58 -04:00
Peter Hawkins
e869e5e0f8 Move contents of jax.api_util to jax._src.api_util and add a forwarding shim.
One of many changes to codify the set of exported symbols in the jax.* namespace.

PiperOrigin-RevId: 395484706
2021-09-08 09:00:56 -07:00
Peter Hawkins
a2a60844cb Use MainTrace payload mechanism in experimental callback tracer.
A change that adds jit() decorators on a number of standard library functions was triggering incorrect cache hits for these tests. This is because the payload fields of the MainTrace were not being included in __hash__() and __eq__().
2021-08-05 17:41:07 -04:00
Sharad Vikram
252bd6c0c8 Add support for custom derivatives in jax.experimental.callback 2021-03-22 12:57:33 -07:00
Sharad Vikram
ddaef193fe Add scan and while rule for jax.experimental.callback transformation 2021-03-09 19:46:16 -08:00
Peter Hawkins
3ac809ede3 [JAX] Move jax.util to jax._src_util.
PiperOrigin-RevId: 351234602
2021-01-11 14:21:07 -08:00
Roy Frostig
fa1d7ab5fa move wraps from api_util to util to avoid cyclic dependencies 2020-10-26 12:31:19 -07:00
Sharad Vikram
cae5334f1d fix lint error 2020-10-16 15:04:08 -07:00
Sharad Vikram
e8901d51af Add default implementation of process_custom_jvp_call and
process_custom_vjp_call to `jax.experimental.callback`
2020-10-16 14:53:23 -07:00
George Necula
634c6259df
More renaming of master to main in JAX internals (#4179) 2020-08-30 12:38:14 +03:00
Matthew Johnson
6b6789a53b
applied simple find+sed for 'master' -> 'main' (#4174)
* applied simple find+sed for 'master' -> 'main'

* Rename master->main in JAX API and internals (#4178)

* Started with #4174 
* Renamed Trace.master to Trace.main
* Renamed core.new_master and core.new_base_master

Co-authored-by: George Necula <gcnecula@gmail.com>
2020-08-30 11:16:51 +03:00
Matthew Johnson
4236eb2b59
omnistaging, under a flag and disabled by default (#3370)
This change, when enabled, stages out all primitive calls in the dynamic
scope of a jitted, pmapped, or control flow function, rather than only
staging out based on data dependence. One improvement is that jitted
functions can consume less memory, by avoiding instantiating large
constants at trace time, and cause less memory fragmentation as well. It
also simplifies several internals.

See https://github.com/google/jax/pull/3370 fo more information.
2020-07-30 12:59:36 -07:00
Jake Vanderplas
fb1717233a
Cleanup: deflake jax.experimental and jax.ops (#3329) 2020-06-05 19:00:04 -07:00
Peter Hawkins
b1bc841ae5
Replace np -> jnp, onp -> np in more places. (#2973)
* Replace np -> jnp, onp -> np in more places.

Context: #2370

* Fix typo in random_test.py
2020-05-05 16:40:41 -04:00
Peter Hawkins
1298e9e8c4
Fix some test failures. (#2713) 2020-04-14 18:23:19 -04:00
Sam Schoenholz
f5908cb5ee Fixed copyright, added type annotations, disabled mypy on one line. 2020-04-09 15:20:41 -07:00
Sam Schoenholz
25d797b1d9 Added callback_transform, examples, and tests. 2020-04-09 13:27:16 -07:00