Also add a config option to switch to the new checkpoint implementation
globally (default False for now), as the first step in replacing and then
deleting old remat.
In cases like unit tests, users may want to clean up all the backends along with the resources used in the end of the test, and reinitialize them in the next test.
PiperOrigin-RevId: 462239974
The checks are:
(1) Check if the in_axes given to pmap matches the sharding of Array.
(2) Check if devices in `array.sharding` is equal to the devices provided to pmap
(3) Check if devices for all array inputs are the same.
(4) If devices are not provided to pmap, use the devices on `Array` after checking point (3).
PiperOrigin-RevId: 456567562
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:
[djax] add support for dynamic-shape outputs
PiperOrigin-RevId: 451320477
--
9b724647d169a73ffae08610741676cb9b182d26 by Matthew Johnson <mattjj@google.com>:
[djax] add support for dynamic-shape outputs
PiperOrigin-RevId: 451268007
Currently we can't block on *unordered* effectful computations because
there are no runtime tokens for them. This change adds a per-device token
that is returned by effectful computations. This enables us
to block on them if we want. See the design note added in https://github.com/google/jax/pull/10657.
PiperOrigin-RevId: 449106281
This currently only supports setting a specific Device object, not a
platform like "cpu". That should be added in the future.
Bumps the minimum jaxlib version in order to include
https://github.com/tensorflow/tensorflow/pull/53656
This way, code using the output xla executable does not need to also drop the unused arguments, simplifying downstream code.
PiperOrigin-RevId: 446391558
After last week's changes, units are no longer traced or introduced into jaxprs
in any way, so we don't need to use them in partial evaluation.
(Also there are some unrelated removals of dead code in maps.py.)
--
de9a948d1ce407056de545b5717c3441298e2f36 by Jake VanderPlas <jakevdp@google.com>:
make device_array.copy() return a device array
PiperOrigin-RevId: 438308145