Adam Paszke
9c5e3f7ecc
Verify that slices are trivial before discarding them in state primitives
...
At the moment, if `r` is a JAX ref then `r[0:1] = a` works, but it silently ignores the slices
and performs `r[:] = a` instead...
PiperOrigin-RevId: 529385973
2023-05-04 05:59:47 -07:00
Peter Hawkins
c1f65fc8b2
Avoid imports from the public jax.* namespace in more places internally.
...
This change is in preparation for more cycle breaking in the Bazel dependency graph.
PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
jax authors
3c1f3abba2
Merge pull request #15149 from sharadmv:runstate
...
PiperOrigin-RevId: 521809360
2023-04-04 10:56:25 -07:00
Sharad Vikram
5101184ad4
Add initial implementation of a run_state
primitive
2023-04-03 21:32:32 -07:00
Peter Hawkins
6cc1bf54a1
Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
...
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.
PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Peter Hawkins
8fb1fd318d
Replace jax._src.util.prod with math.prod.
...
math.prod() was added in Python 3.8, so we can assume it is always present.
PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Peter Hawkins
f66f6ec98a
[JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
...
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
jax authors
8ebfb0be48
Merge pull request #14614 from sharadmv:ref
...
PiperOrigin-RevId: 512315462
2023-02-25 11:12:00 -08:00
Sharad Vikram
4960e656af
Refactor Ref
abstract type to contain other AbstractValue
s
2023-02-23 17:02:40 -08:00
Sharad Vikram
a6c4c87f3e
Add JaxprInputEffect
and refactor StateEffect
s to use it
2023-02-21 16:30:06 -08:00
Sharad Vikram
af2306c0a8
Refactor effects system to use effect types, not objects
2023-02-17 17:40:08 -08:00
Roy Frostig
cb8dcce2fe
migrate more internal dependencies from jax.core
to jax._src.core
...
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Matthew Johnson
a964dc3b9a
simpler pretty-print for pjit, tweak custom pp rule signature
2023-02-09 12:45:51 -08:00
jax authors
398aaaacc7
Add support for Ellipsis as an index for stateful operations.
...
PiperOrigin-RevId: 497466879
2022-12-23 22:46:50 -08:00
Jake VanderPlas
4a6bbde409
Move jax.linear_util to jax._src.linear_util
2022-12-20 14:49:27 -08:00
Jake VanderPlas
4389216d0c
Remove typing_extensions dependency
2022-12-05 15:42:26 -08:00
Sharad Vikram
e1af93a9ba
Enable state effect in cond_p
(except in grad
and vmap
)
...
PiperOrigin-RevId: 485719926
2022-11-02 16:07:58 -07:00
Jake VanderPlas
97b17af5be
[typing] add type annotations to the first several lax_numpy functions
2022-10-21 11:59:53 -07:00
Jake VanderPlas
7f89fd40a2
Cleanup: remove unused imports in private modules
...
Also improve our flake8 filter rules to avoid ignoring these.
2022-10-20 14:37:21 -07:00
Jake VanderPlas
5d15757741
[typing] annotate jax._src.util.safe_map
2022-10-20 10:15:04 -07:00
Jake VanderPlas
524745f322
TMP: annotate util.safe_zip
2022-10-19 10:29:53 -07:00
Sharad Vikram
bbf69d10cc
Enable partially discharging state effects from jaxprs
2022-10-11 16:52:29 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Sharad Vikram
f26f1e8afc
Add support for closing over Ref
s in nested for loops
2022-09-13 13:32:44 -07:00
Sharad Vikram
6967c7ef51
Add sound loop invariance detection
2022-09-08 10:42:19 -07:00
Sharad Vikram
b6c3b9df19
Split State effect into Read/Write/Accum effects and tie them to Ref avals
2022-09-08 08:04:13 -07:00
Sharad Vikram
b2a5d2c3bb
Add partial_eval_custom rule for for_loop
2022-09-06 11:00:26 -07:00
Matthew Johnson
bbb8048d2e
Add batching rules for state primitives and for_loop
...
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-08-29 11:40:09 -07:00
jax authors
5d9dfbb142
Merge pull request #11900 from sharadmv:for-loop
...
PiperOrigin-RevId: 468608666
2022-08-18 20:14:17 -07:00
Sharad Vikram
49b7729f6b
More tests for transpose
2022-08-18 18:06:21 -07:00
Neil Girdhar
ad38a6bb28
Fix common typo: Tuple[X] -> Tuple[X, ...]
2022-08-16 11:47:22 -04:00
Sharad Vikram
72dbe31172
Initial transpose implementation
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-15 10:23:04 -07:00
Sharad Vikram
8b7daa8095
Refactor state out of for_loop
2022-08-01 15:26:55 -07:00