Jake VanderPlas
108376d792
Remove deprecated function jax.tree_util.tree_multimap
2022-07-26 09:37:27 -07:00
George Necula
66dc95e2de
removes the jax.mask and jax.shapecheck APIs.
...
PiperOrigin-RevId: 463026577
2022-07-25 01:23:38 -07:00
George Necula
07fcf79324
jax.mask and jax.shapecheck are being deprecated
...
Issue: #11557
PiperOrigin-RevId: 462315754
2022-07-21 00:09:31 -07:00
Kuangyuan Chen
c0ec3b33e6
Introduce jax.experimental.clear_backends to delete all JAX runtime backends.
...
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.
PiperOrigin-RevId: 462239974
2022-07-20 15:10:27 -07:00
jax authors
9ee6cacdc8
Merge pull request #11540 from gnecula:ds_check_flag
...
PiperOrigin-RevId: 462061356
2022-07-19 23:07:14 -07:00
Yash Katariya
9f914a93d6
Replace input sharding_specs
with in_shardings
in InputsHandler
...
PiperOrigin-RevId: 461963206
2022-07-19 13:43:28 -07:00
George Necula
2106d65561
[dynamic-shapes] Add check that --jax_dynamic_shapes is set when using abstracted_axes.
...
abstracted_axes has no effect without the --jax_dynamic_shapes. Make this and
explicit error.
2022-07-19 19:48:45 +02:00
Yash Katariya
ea627b807b
Replace out_specs
with out_shardings
and remove out_indices
in ResultsHandler.
...
PiperOrigin-RevId: 461788039
2022-07-18 20:57:02 -07:00
jax authors
ed51c65576
Merge pull request #11405 from mattjj:djax-vmap
...
PiperOrigin-RevId: 459958155
2022-07-09 10:38:39 -07:00
Matthew Johnson
5b82ba787c
[dynamic-shapes] start basic vmap compatibility
2022-07-09 10:03:40 -07:00
Peter Hawkins
0b4b0ba072
Update minimum jaxlib version to 0.3.14.
2022-07-08 00:36:02 +00:00
Matthew Johnson
12a56c3064
[dynamic-shapes] add basic abstracted_axes support to jit(f, ...).lower(...)
2022-07-07 12:48:29 -07:00
Sharad Vikram
6274b9ed39
Enable Python callbacks on TFRT TPU backend
...
PiperOrigin-RevId: 459415455
2022-07-06 20:52:50 -07:00
Jake VanderPlas
cb25a96d43
vmap: better errors for mismatched axis in keyword arguments
2022-06-29 14:31:03 -07:00
Yash Katariya
766c5ba0a2
Check sharding in pmap for jax.Array
.
...
The checks are:
(1) Check if the in_axes given to pmap matches the sharding of Array.
(2) Check if devices in `array.sharding` is equal to the devices provided to pmap
(3) Check if devices for all array inputs are the same.
(4) If devices are not provided to pmap, use the devices on `Array` after checking point (3).
PiperOrigin-RevId: 456567562
2022-06-22 11:37:10 -07:00
Yash Katariya
dce8f64b40
Make device_put_sharded
and device_put_replicated
return Arrays.
...
PiperOrigin-RevId: 456525113
2022-06-22 08:51:29 -07:00
Jake VanderPlas
abcfaec6e3
DOC: clarify variable names
2022-06-21 13:20:53 -07:00
Neil Girdhar
7697616b85
Annotate vmap
2022-06-14 15:41:47 -04:00
Peter Hawkins
78312a7fff
Add an undocumented method on jit() functions to clear the function cache.
2022-06-10 18:36:18 -07:00
Peter Hawkins
b32f83d84d
Clarify that functions passed to jax.jit must be weakly referenceable.
2022-06-10 12:21:23 -07:00
Sharad Vikram
289610eb02
Add a public facing named_scope
function to allow adding to the name stack.
2022-06-08 17:23:57 -07:00
Sharad Vikram
426c7356fb
Guard has_explicit_device
with xla_client version
2022-06-02 12:03:27 -07:00
jax authors
ea54754c49
Merge pull request #9118 from skye:device_context_manager
...
PiperOrigin-RevId: 452570041
2022-06-02 10:33:53 -07:00
jax authors
094e706498
Merge pull request #10823 from LenaMartens:changelist/450914215
...
PiperOrigin-RevId: 451427821
2022-05-27 10:39:24 -07:00
Matthew Johnson
ffa9328a68
Copybara import of the project:
...
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:
[djax] add support for dynamic-shape outputs
PiperOrigin-RevId: 451320477
2022-05-26 23:21:40 -07:00
Matthew Johnson
995220a739
Copybara import of the project:
...
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:
[djax] add support for dynamic-shape outputs
PiperOrigin-RevId: 451268007
2022-05-26 16:26:49 -07:00
Matthew Johnson
9b724647d1
[djax] add support for dynamic-shape outputs
2022-05-26 13:22:06 -07:00
Lena Martens
f8f5a5dca3
Add notes in buffer donation FAQ about key-word args limitation.
2022-05-25 15:33:04 +01:00
Jeppe Klitgaard
838a05329d
feat: validate jit args
2022-05-18 21:54:47 +01:00
Jeppe Klitgaard
17de89b16a
feat: refactor code using pyupgrade
...
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```
a
2022-05-17 22:14:05 +01:00
Sharad Vikram
268b4be21b
Add output token for unordered effects
...
Currently we can't block on *unordered* effectful computations because
there are no runtime tokens for them. This change adds a per-device token
that is returned by effectful computations. This enables us
to block on them if we want. See the design note added in https://github.com/google/jax/pull/10657 .
PiperOrigin-RevId: 449106281
2022-05-16 18:56:33 -07:00
Skye Wanderman-Milne
f26b866e08
Add jax.default_device
context manager
...
This currently only supports setting a specific Device object, not a
platform like "cpu". That should be added in the future.
Bumps the minimum jaxlib version in order to include
https://github.com/tensorflow/tensorflow/pull/53656
2022-05-07 00:31:00 +00:00
Jean-Baptiste Lespiau
5c838d2f6f
Add an option when lowering to not remove unused arguments.
...
This way, code using the output xla executable does not need to also drop the unused arguments, simplifying downstream code.
PiperOrigin-RevId: 446391558
2022-05-04 01:22:14 -07:00
Sharad Vikram
ef982cfa8c
Attach keepalive to executable
2022-05-03 20:08:03 -07:00
Sharad Vikram
8031eee7ee
Add in runtime tokens for effectful jaxprs
2022-05-03 15:55:07 -07:00
Matthew Johnson
11ad045dfd
[remove-units] remove units from partial_eval.py
...
After last week's changes, units are no longer traced or introduced into jaxprs
in any way, so we don't need to use them in partial evaluation.
(Also there are some unrelated removals of dead code in maps.py.)
2022-05-02 13:43:27 -07:00
Matthew Johnson
65bff3c856
[remove-units] avoid unit-generating function in jax.linear_transpose
2022-04-29 16:37:43 -07:00
Matthew Johnson
9fd53bc6f7
[remove-units] prevent ad.py from introducing units
2022-04-26 13:01:01 -07:00
jax authors
5013bd2e3a
Merge pull request #10402 from froystig:aot-jit-avoid-trivial
...
PiperOrigin-RevId: 443533232
2022-04-21 18:13:10 -07:00
Roy Frostig
5c118071cb
always lower/compile computations on the AOT jit path
...
... even trivial ones.
2022-04-21 15:30:36 -07:00
Jake VanderPlas
5782210174
CI: fix flake8 ignore declarations
2022-04-21 13:44:12 -07:00
Peter Hawkins
ad8e6ada4e
[MHLO] Change jax.xla_computation() to use MHLO lowering internally.
...
Change in preparation for removing the non-MHLO lowering path.
PiperOrigin-RevId: 441460875
2022-04-13 06:28:38 -07:00
Matthew Johnson
4354f355a8
prototyping dynamic shapes
...
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Matthew Johnson
902fc0c3d2
Remove invertible_ad since it's not in use.
...
PiperOrigin-RevId: 440890949
2022-04-11 07:56:58 -07:00
Lucas Beyer
f7b749c99c
Explicit doc note about device_put* async
2022-04-04 23:38:51 +02:00
Jake VanderPlas
4949e78859
Re-land changes from https://github.com/google/jax/pull/10069
...
PiperOrigin-RevId: 439381161
2022-04-04 12:18:43 -07:00
Jake VanderPlas
df1ceaeeb1
Deprecate jax.tree_util.tree_multimap
2022-04-01 14:51:54 -07:00
jax authors
1555ba147c
Copybara import of the project:
...
--
de9a948d1ce407056de545b5717c3441298e2f36 by Jake VanderPlas <jakevdp@google.com>:
make device_array.copy() return a device array
PiperOrigin-RevId: 438308145
2022-03-30 08:30:18 -07:00
jax authors
ef2efec649
Merge pull request #10069 from jakevdp:devicearray-copy
...
PiperOrigin-RevId: 438292130
2022-03-30 07:01:19 -07:00
Jake VanderPlas
f4b64f48f4
doc: add examples of using partial with jit
2022-03-29 15:43:58 -07:00