Sergei Lebedev
65d3058944
Migrate a subset of internal modules to use state objects
...
The motivation here is to gradually replace all dynamic lookups on `jax.config`
with statically-typed state objects, which are more type checker/IDE friendly.
PiperOrigin-RevId: 571932143
2023-10-09 07:29:53 -07:00
Matthew Johnson
5715db4832
[run_state] add pjit run_state discharge rule and basic test
2023-10-05 21:14:00 -07:00
Sergei Lebedev
5ab05e42c9
MAINT Clean up leftover Array = Any
aliases in jax/_src/**.py
...
I had to revert to using `Any` for `RaggedAxis.ragged_axes` because pytype
found more latent type errors, which require the understanding of ragedness
and dynamic shapes internals to fix properly.
2023-10-01 12:19:21 +01:00
Jake VanderPlas
4a5bd9e046
Fix typos across the package
2023-09-22 14:54:31 -07:00
jax authors
256612bb80
Merge pull request #17720 from superbobry:tuple-list-comp
...
PiperOrigin-RevId: 567433086
2023-09-21 15:16:12 -07:00
Sergei Lebedev
df7f6a06c0
MAINT Use a generator expression in tuple([... for ... in ...])
...
In a few cases I also replaced tuple([*xs, *ys]) with (*xs, ys), because
tuple literals support unpacking as well.
2023-09-21 22:25:38 +01:00
Sharad Vikram
fdc2f9cab7
[Pallas] Add async_copy_to
and async_remote_copy_to
for doing DMAs.
...
Also add `.at` view syntax for `Ref`s
PiperOrigin-RevId: 565478936
2023-09-14 14:32:08 -07:00
Yash Katariya
c41d271175
Add memories support to remat.
...
This PR adds basic support to remat to allow transferring intermediates (activations) to destination memory in the forward pass. Currently JAX only support host memory kind but the API allows to transfer to other memories too. Remat will automatically load the residuals back to the source memory in the backward pass.
Introduce two singletons called `Recompute`, `Saveable` and a NamedTuple (`Offloadable`) that each policy can return. Currently policies return a bool which if True means saveable else recompute on backward pass. This is a backwards compatible change i.e. policies can still return a bool.
A very basic offloadable policy can look like this:
```
def policy(prim, *avals, **params):
return ad_checkpoint.Offloadable(src='tpu_hbm', dst='unpinned_host')
```
Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 564914301
2023-09-12 20:50:05 -07:00
Peter Hawkins
319ab98980
Apply pyupgrade --py39-plus.
...
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Sharad Vikram
c446b42522
Add discharge rules for scan/while
2023-07-06 22:30:35 +00:00
Peter Hawkins
816ba91263
Use lower-case PEP 585 names for types.
...
Issue https://github.com/google/jax/issues/16537
PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Sharad Vikram
907782d2e5
Deduplicate references closed over across branches of a lax.cond.
...
This fixes a correctness issue that could crop up when doing `run_state(cond)`.
PiperOrigin-RevId: 540795172
2023-06-15 23:58:14 -07:00
Adam Paszke
9c5e3f7ecc
Verify that slices are trivial before discarding them in state primitives
...
At the moment, if `r` is a JAX ref then `r[0:1] = a` works, but it silently ignores the slices
and performs `r[:] = a` instead...
PiperOrigin-RevId: 529385973
2023-05-04 05:59:47 -07:00
Peter Hawkins
c1f65fc8b2
Avoid imports from the public jax.* namespace in more places internally.
...
This change is in preparation for more cycle breaking in the Bazel dependency graph.
PiperOrigin-RevId: 521822756
2023-04-04 11:41:40 -07:00
jax authors
3c1f3abba2
Merge pull request #15149 from sharadmv:runstate
...
PiperOrigin-RevId: 521809360
2023-04-04 10:56:25 -07:00
Sharad Vikram
5101184ad4
Add initial implementation of a run_state
primitive
2023-04-03 21:32:32 -07:00
Peter Hawkins
6cc1bf54a1
Move jax.interpreters.partial_eval to jax._src.interpreters.partial_eval.
...
Also fix up some other internal imports of jax.interpreters.* to use jax._src.interpreters.
PiperOrigin-RevId: 519813664
2023-03-27 13:30:47 -07:00
Peter Hawkins
8fb1fd318d
Replace jax._src.util.prod with math.prod.
...
math.prod() was added in Python 3.8, so we can assume it is always present.
PiperOrigin-RevId: 513011144
2023-02-28 12:41:00 -08:00
Peter Hawkins
f66f6ec98a
[JAX] Move jax._src.lib.xla_bridge to jax._src.xla_bridge.
...
Limit jax._src.lib to shims around jaxlib and nothing else.
The goal of this change is to avoid a dependency cycle between the rest of jax and jax._src.lib in a Bazel build. This allows the types for jax._src.lib to be inferred by pytype in isolation without referring to the rest of JAX.
PiperOrigin-RevId: 512922397
2023-02-28 07:01:57 -08:00
jax authors
8ebfb0be48
Merge pull request #14614 from sharadmv:ref
...
PiperOrigin-RevId: 512315462
2023-02-25 11:12:00 -08:00
Sharad Vikram
4960e656af
Refactor Ref
abstract type to contain other AbstractValue
s
2023-02-23 17:02:40 -08:00
Sharad Vikram
a6c4c87f3e
Add JaxprInputEffect
and refactor StateEffect
s to use it
2023-02-21 16:30:06 -08:00
Sharad Vikram
af2306c0a8
Refactor effects system to use effect types, not objects
2023-02-17 17:40:08 -08:00
Roy Frostig
cb8dcce2fe
migrate more internal dependencies from jax.core
to jax._src.core
...
PiperOrigin-RevId: 509736368
2023-02-14 23:01:11 -08:00
Matthew Johnson
a964dc3b9a
simpler pretty-print for pjit, tweak custom pp rule signature
2023-02-09 12:45:51 -08:00
jax authors
398aaaacc7
Add support for Ellipsis as an index for stateful operations.
...
PiperOrigin-RevId: 497466879
2022-12-23 22:46:50 -08:00
Jake VanderPlas
4a6bbde409
Move jax.linear_util to jax._src.linear_util
2022-12-20 14:49:27 -08:00
Jake VanderPlas
4389216d0c
Remove typing_extensions dependency
2022-12-05 15:42:26 -08:00
Sharad Vikram
e1af93a9ba
Enable state effect in cond_p
(except in grad
and vmap
)
...
PiperOrigin-RevId: 485719926
2022-11-02 16:07:58 -07:00
Jake VanderPlas
97b17af5be
[typing] add type annotations to the first several lax_numpy functions
2022-10-21 11:59:53 -07:00
Jake VanderPlas
7f89fd40a2
Cleanup: remove unused imports in private modules
...
Also improve our flake8 filter rules to avoid ignoring these.
2022-10-20 14:37:21 -07:00
Jake VanderPlas
5d15757741
[typing] annotate jax._src.util.safe_map
2022-10-20 10:15:04 -07:00
Jake VanderPlas
524745f322
TMP: annotate util.safe_zip
2022-10-19 10:29:53 -07:00
Sharad Vikram
bbf69d10cc
Enable partially discharging state effects from jaxprs
2022-10-11 16:52:29 -07:00
Peter Hawkins
ba557d5e1b
Change JAX's copyright attribution from "Google LLC" to "The JAX Authors.".
...
See https://opensource.google/documentation/reference/releasing/contributions#copyright for more details.
PiperOrigin-RevId: 476167538
2022-09-22 12:27:19 -07:00
Sharad Vikram
f26f1e8afc
Add support for closing over Ref
s in nested for loops
2022-09-13 13:32:44 -07:00
Sharad Vikram
6967c7ef51
Add sound loop invariance detection
2022-09-08 10:42:19 -07:00
Sharad Vikram
b6c3b9df19
Split State effect into Read/Write/Accum effects and tie them to Ref avals
2022-09-08 08:04:13 -07:00
Sharad Vikram
b2a5d2c3bb
Add partial_eval_custom rule for for_loop
2022-09-06 11:00:26 -07:00
Matthew Johnson
bbb8048d2e
Add batching rules for state primitives and for_loop
...
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
2022-08-29 11:40:09 -07:00
jax authors
5d9dfbb142
Merge pull request #11900 from sharadmv:for-loop
...
PiperOrigin-RevId: 468608666
2022-08-18 20:14:17 -07:00
Sharad Vikram
49b7729f6b
More tests for transpose
2022-08-18 18:06:21 -07:00
Neil Girdhar
ad38a6bb28
Fix common typo: Tuple[X] -> Tuple[X, ...]
2022-08-16 11:47:22 -04:00
Sharad Vikram
72dbe31172
Initial transpose implementation
...
Co-authored-by: Matthew Johnson <mattjj@google.com>
2022-08-15 10:23:04 -07:00
Sharad Vikram
8b7daa8095
Refactor state out of for_loop
2022-08-01 15:26:55 -07:00