184 Commits

Author SHA1 Message Date
Jake VanderPlas
108376d792 Remove deprecated function jax.tree_util.tree_multimap 2022-07-26 09:37:27 -07:00
George Necula
66dc95e2de removes the jax.mask and jax.shapecheck APIs.
PiperOrigin-RevId: 463026577
2022-07-25 01:23:38 -07:00
George Necula
07fcf79324 jax.mask and jax.shapecheck are being deprecated
Issue: #11557
PiperOrigin-RevId: 462315754
2022-07-21 00:09:31 -07:00
Kuangyuan Chen
c0ec3b33e6 Introduce jax.experimental.clear_backends to delete all JAX runtime backends.
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.

PiperOrigin-RevId: 462239974
2022-07-20 15:10:27 -07:00
jax authors
9ee6cacdc8 Merge pull request #11540 from gnecula:ds_check_flag
PiperOrigin-RevId: 462061356
2022-07-19 23:07:14 -07:00
Yash Katariya
9f914a93d6 Replace input sharding_specs with in_shardings in InputsHandler
PiperOrigin-RevId: 461963206
2022-07-19 13:43:28 -07:00
George Necula
2106d65561 [dynamic-shapes] Add check that --jax_dynamic_shapes is set when using abstracted_axes.
abstracted_axes has no effect without the --jax_dynamic_shapes. Make this and
explicit error.
2022-07-19 19:48:45 +02:00
Yash Katariya
ea627b807b Replace out_specs with out_shardings and remove out_indices in ResultsHandler.
PiperOrigin-RevId: 461788039
2022-07-18 20:57:02 -07:00
jax authors
ed51c65576 Merge pull request #11405 from mattjj:djax-vmap
PiperOrigin-RevId: 459958155
2022-07-09 10:38:39 -07:00
Matthew Johnson
5b82ba787c [dynamic-shapes] start basic vmap compatibility 2022-07-09 10:03:40 -07:00
Peter Hawkins
0b4b0ba072 Update minimum jaxlib version to 0.3.14. 2022-07-08 00:36:02 +00:00
Matthew Johnson
12a56c3064 [dynamic-shapes] add basic abstracted_axes support to jit(f, ...).lower(...) 2022-07-07 12:48:29 -07:00
Sharad Vikram
6274b9ed39 Enable Python callbacks on TFRT TPU backend
PiperOrigin-RevId: 459415455
2022-07-06 20:52:50 -07:00
Jake VanderPlas
cb25a96d43 vmap: better errors for mismatched axis in keyword arguments 2022-06-29 14:31:03 -07:00
Yash Katariya
766c5ba0a2 Check sharding in pmap for jax.Array.
The checks are:

(1) Check if the in_axes given to pmap matches the sharding of Array.

(2) Check if devices in `array.sharding` is equal to the devices provided to pmap

(3) Check if devices for all array inputs are the same.

(4) If devices are not provided to pmap, use the devices on `Array` after checking point (3).

PiperOrigin-RevId: 456567562
2022-06-22 11:37:10 -07:00
Yash Katariya
dce8f64b40 Make device_put_sharded and device_put_replicated return Arrays.
PiperOrigin-RevId: 456525113
2022-06-22 08:51:29 -07:00
Jake VanderPlas
abcfaec6e3 DOC: clarify variable names 2022-06-21 13:20:53 -07:00
Neil Girdhar
7697616b85 Annotate vmap 2022-06-14 15:41:47 -04:00
Peter Hawkins
78312a7fff Add an undocumented method on jit() functions to clear the function cache. 2022-06-10 18:36:18 -07:00
Peter Hawkins
b32f83d84d Clarify that functions passed to jax.jit must be weakly referenceable. 2022-06-10 12:21:23 -07:00
Sharad Vikram
289610eb02 Add a public facing named_scope function to allow adding to the name stack. 2022-06-08 17:23:57 -07:00
Sharad Vikram
426c7356fb Guard has_explicit_device with xla_client version 2022-06-02 12:03:27 -07:00
jax authors
ea54754c49 Merge pull request #9118 from skye:device_context_manager
PiperOrigin-RevId: 452570041
2022-06-02 10:33:53 -07:00
jax authors
094e706498 Merge pull request #10823 from LenaMartens:changelist/450914215
PiperOrigin-RevId: 451427821
2022-05-27 10:39:24 -07:00
Matthew Johnson
ffa9328a68 Copybara import of the project:
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:

[djax] add support for dynamic-shape outputs

PiperOrigin-RevId: 451320477
2022-05-26 23:21:40 -07:00
Matthew Johnson
995220a739 Copybara import of the project:
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:

[djax] add support for dynamic-shape outputs

PiperOrigin-RevId: 451268007
2022-05-26 16:26:49 -07:00
Matthew Johnson
9b724647d1 [djax] add support for dynamic-shape outputs 2022-05-26 13:22:06 -07:00
Lena Martens
f8f5a5dca3 Add notes in buffer donation FAQ about key-word args limitation. 2022-05-25 15:33:04 +01:00
Jeppe Klitgaard
838a05329d feat: validate jit args 2022-05-18 21:54:47 +01:00
Jeppe Klitgaard
17de89b16a feat: refactor code using pyupgrade
This PR upgrades legacy Python code to 3.7+ code using pyupgrade:
```sh
pyupgrade --py37-plus --keep-runtime-typing **.py
```

a
2022-05-17 22:14:05 +01:00
Sharad Vikram
268b4be21b Add output token for unordered effects
Currently we can't block on *unordered* effectful computations because
there are no runtime tokens for them. This change adds a per-device token
that is returned by effectful computations. This enables us
to block on them if we want. See the design note added in https://github.com/google/jax/pull/10657.

PiperOrigin-RevId: 449106281
2022-05-16 18:56:33 -07:00
Skye Wanderman-Milne
f26b866e08 Add jax.default_device context manager
This currently only supports setting a specific Device object, not a
platform like "cpu". That should be added in the future.

Bumps the minimum jaxlib version in order to include
https://github.com/tensorflow/tensorflow/pull/53656
2022-05-07 00:31:00 +00:00
Jean-Baptiste Lespiau
5c838d2f6f Add an option when lowering to not remove unused arguments.
This way, code using the output xla executable does not need to also drop the unused arguments, simplifying downstream code.

PiperOrigin-RevId: 446391558
2022-05-04 01:22:14 -07:00
Sharad Vikram
ef982cfa8c Attach keepalive to executable 2022-05-03 20:08:03 -07:00
Sharad Vikram
8031eee7ee Add in runtime tokens for effectful jaxprs 2022-05-03 15:55:07 -07:00
Matthew Johnson
11ad045dfd [remove-units] remove units from partial_eval.py
After last week's changes, units are no longer traced or introduced into jaxprs
in any way, so we don't need to use them in partial evaluation.

(Also there are some unrelated removals of dead code in maps.py.)
2022-05-02 13:43:27 -07:00
Matthew Johnson
65bff3c856 [remove-units] avoid unit-generating function in jax.linear_transpose 2022-04-29 16:37:43 -07:00
Matthew Johnson
9fd53bc6f7 [remove-units] prevent ad.py from introducing units 2022-04-26 13:01:01 -07:00
jax authors
5013bd2e3a Merge pull request #10402 from froystig:aot-jit-avoid-trivial
PiperOrigin-RevId: 443533232
2022-04-21 18:13:10 -07:00
Roy Frostig
5c118071cb always lower/compile computations on the AOT jit path
... even trivial ones.
2022-04-21 15:30:36 -07:00
Jake VanderPlas
5782210174 CI: fix flake8 ignore declarations 2022-04-21 13:44:12 -07:00
Peter Hawkins
ad8e6ada4e [MHLO] Change jax.xla_computation() to use MHLO lowering internally.
Change in preparation for removing the non-MHLO lowering path.

PiperOrigin-RevId: 441460875
2022-04-13 06:28:38 -07:00
Matthew Johnson
4354f355a8 prototyping dynamic shapes
Co-authored-by: Dougal Maclaurin <dougalm@google.com>
2022-04-11 22:10:47 -07:00
Matthew Johnson
902fc0c3d2 Remove invertible_ad since it's not in use.
PiperOrigin-RevId: 440890949
2022-04-11 07:56:58 -07:00
Lucas Beyer
f7b749c99c
Explicit doc note about device_put* async 2022-04-04 23:38:51 +02:00
Jake VanderPlas
4949e78859 Re-land changes from https://github.com/google/jax/pull/10069
PiperOrigin-RevId: 439381161
2022-04-04 12:18:43 -07:00
Jake VanderPlas
df1ceaeeb1 Deprecate jax.tree_util.tree_multimap 2022-04-01 14:51:54 -07:00
jax authors
1555ba147c Copybara import of the project:
--
de9a948d1ce407056de545b5717c3441298e2f36 by Jake VanderPlas <jakevdp@google.com>:

make device_array.copy() return a device array

PiperOrigin-RevId: 438308145
2022-03-30 08:30:18 -07:00
jax authors
ef2efec649 Merge pull request #10069 from jakevdp:devicearray-copy
PiperOrigin-RevId: 438292130
2022-03-30 07:01:19 -07:00
Jake VanderPlas
f4b64f48f4 doc: add examples of using partial with jit 2022-03-29 15:43:58 -07:00