6 Commits

Author SHA1 Message Date
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Sharad Vikram
5d3f48204d Add stateful for loop primitives (#10982)
Adds a `get/swap/addupdate` primitive, along with impl, abstract_eval
and jvp rules.

Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-06-15 15:55:38 -07:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Matthew Johnson
9cd55a2bbd [remove-units] remove units 2022-05-04 10:58:56 -07:00
Matthew Johnson
e7acb82b14 [remove-units] remove units from api_util.py 2022-04-26 12:31:08 -07:00
Peter Hawkins
e9611eb090 Move jax.ad_util to jax._src.ad_util.
Expose ad_util.stop_gradient_p as jax.lax.stop_gradient_p. stop_gradient() is already under the external lax namespace.

PiperOrigin-RevId: 378011152
2021-06-07 14:51:34 -07:00