33 Commits

Author SHA1 Message Date
Adam Paszke
9c5e3f7ecc Verify that slices are trivial before discarding them in state primitives
At the moment, if `r` is a JAX ref then `r[0:1] = a` works, but it silently ignores the slices
and performs `r[:] = a` instead...

PiperOrigin-RevId: 529385973
2023-05-04 05:59:47 -07:00
Peter Hawkins
c1f65fc8b2 Avoid imports from the public jax.* namespace in more places internally.
This change is in preparation for more cycle breaking in the Bazel dependency graph.

PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
jax authors
3c1f3abba2 Merge pull request #15149 from sharadmv:runstate
PiperOrigin-RevId: 521809360
2023-04-04 10:56:25 -07:00
Sharad Vikram
5101184ad4 Add initial implementation of a run_state primitive 2023-04-03 21:32:32 -07:00
Peter Hawkins
6cc1bf54a1 Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.

PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Peter Hawkins
8fb1fd318d Replace jax._src.util.prod with math.prod.
math.prod() was added in Python 3.8, so we can assume it is always present.

PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Peter Hawkins
f66f6ec98a [JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
Limit jax._src.lib to shims around jaxlib and nothing else.

The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.

PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
jax authors
8ebfb0be48 Merge pull request #14614 from sharadmv:ref
PiperOrigin-RevId: 512315462
2023-02-25 11:12:00 -08:00
Sharad Vikram
4960e656af Refactor Ref abstract type to contain other AbstractValues 2023-02-23 17:02:40 -08:00
Sharad Vikram
a6c4c87f3e Add JaxprInputEffect and refactor StateEffects to use it 2023-02-21 16:30:06 -08:00
Sharad Vikram
af2306c0a8 Refactor effects system to use effect types, not objects 2023-02-17 17:40:08 -08:00
Roy Frostig
cb8dcce2fe migrate more internal dependencies from jax.core to jax._src.core
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Matthew Johnson
a964dc3b9a simpler pretty-print for pjit, tweak custom pp rule signature 2023-02-09 12:45:51 -08:00
jax authors
398aaaacc7 Add support for Ellipsis as an index for stateful operations.
PiperOrigin-RevId: 497466879
2022-12-23 22:46:50 -08:00
Jake VanderPlas
4a6bbde409 Move jax.linear_util to jax._src.linear_util 2022-12-20 14:49:27 -08:00
Jake VanderPlas
4389216d0c Remove typing_extensions dependency 2022-12-05 15:42:26 -08:00
Sharad Vikram
e1af93a9ba Enable state effect in cond_p (except in grad and vmap)
PiperOrigin-RevId: 485719926
2022-11-02 16:07:58 -07:00
Jake VanderPlas
97b17af5be [typing] add type annotations to the first several lax_numpy functions 2022-10-21 11:59:53 -07:00
Jake VanderPlas
7f89fd40a2 Cleanup: remove unused imports in private modules
Also improve our flake8 filter rules to avoid ignoring these.
2022-10-20 14:37:21 -07:00
Jake VanderPlas
5d15757741 [typing] annotate jax._src.util.safe_map 2022-10-20 10:15:04 -07:00
Jake VanderPlas
524745f322 TMP: annotate util.safe_zip 2022-10-19 10:29:53 -07:00
Sharad Vikram
bbf69d10cc Enable partially discharging state effects from jaxprs 2022-10-11 16:52:29 -07:00
Peter Hawkins
ba557d5e1b Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.

PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Sharad Vikram
f26f1e8afc Add support for closing over Refs in nested for loops 2022-09-13 13:32:44 -07:00
Sharad Vikram
6967c7ef51 Add sound loop invariance detection 2022-09-08 10:42:19 -07:00
Sharad Vikram
b6c3b9df19 Split State effect into Read/Write/Accum effects and tie them to Ref avals 2022-09-08 08:04:13 -07:00
Sharad Vikram
b2a5d2c3bb Add partial_eval_custom rule for for_loop 2022-09-06 11:00:26 -07:00
Matthew Johnson
bbb8048d2e Add batching rules for state primitives and for_loop
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-08-29 11:40:09 -07:00
jax authors
5d9dfbb142 Merge pull request #11900 from sharadmv:for-loop
PiperOrigin-RevId: 468608666
2022-08-18 20:14:17 -07:00
Sharad Vikram
49b7729f6b More tests for transpose 2022-08-18 18:06:21 -07:00
Neil Girdhar
ad38a6bb28 Fix common typo: Tuple[X] -> Tuple[X, ...] 2022-08-16 11:47:22 -04:00
Sharad Vikram
72dbe31172 Initial transpose implementation
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-15 10:23:04 -07:00
Sharad Vikram
8b7daa8095 Refactor state out of for_loop 2022-08-01 15:26:55 -07:00