18 Commits

Author SHA1 Message Date
Yash Katariya
0d57330fe0 Add mitigation techniques in the error message when a barrier timeout occurs.
PiperOrigin-RevId: 577214081
2023-10-27 08:57:39 -07:00
Jake VanderPlas
4a5bd9e046 Fix typos across the package 2023-09-22 14:54:31 -07:00
Yash Katariya
6b574708ee Abstract the array_serialization error message to a global variable so that it can be overridden.
PiperOrigin-RevId: 561224461
2023-08-29 21:46:26 -07:00
Peter Hawkins
2c32660a8f Replace references to DeviceArray with Array.
A number of stale references are lurking in our documentation.
2023-08-18 17:46:00 -04:00
Yash Katariya
1bd5fd2a52 Add serialize_with_paths and deserialize_with_paths API to GlobalAsyncCheckpointManager
PiperOrigin-RevId: 555050522
2023-08-08 22:47:27 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Alexander Kolesnikov
16a3c1018f Improve error message and the relevant test.
PiperOrigin-RevId: 548058071
2023-07-14 01:45:43 -07:00
Parker Schuh
feced360f0 Make the default driver in serialization be a global constant.
PiperOrigin-RevId: 543008650
2023-06-23 18:40:31 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Peter Hawkins
9f4080ae2b Silence pytype errors under an upcoming pytype change.
PiperOrigin-RevId: 542652577
2023-06-22 13:32:15 -07:00
Yash Katariya
b44f8b4bf0 Fix get_tensorstore_spec for GCS paths if ocdbt is enabled
PiperOrigin-RevId: 538627415
2023-06-07 16:47:35 -07:00
Yash Katariya
9615a31a73 Add concurrent_gb to deserialize
PiperOrigin-RevId: 538537686
2023-06-07 11:03:14 -07:00
Colin Gaffney
88d3d826cc Add assume_metadata option to avoid duplicate read if the ts.Spec has already been fully constructed via ts.open.
PiperOrigin-RevId: 537986108
2023-06-05 14:45:02 -07:00
Yash Katariya
0dffdf4645 Allow GlobalAsyncCheckpointManager to work on a single process without initializing the distributed system.
PiperOrigin-RevId: 537937426
2023-06-05 11:41:54 -07:00
jax authors
c15f30f22e Instrument a new metric to measure the savings of async checkpoint in JAX.
Create the new metric '/jax/checkpoint/write/async/thread_duration_sec' to measure the savings from the async thread creation time.

PiperOrigin-RevId: 529227213
2023-05-03 16:36:54 -07:00
Mark Sandler
849e47f79a Makes deserializer put tensors on the device before releasing inflight memory, as well as avoids allocating memory before memory is available.
This makes in-flight memory limiter both reflective of the actual peak usage, as well as reduces peak usage since we no longer try to fully materialize sharded tensors on the host.

PiperOrigin-RevId: 524456216
2023-04-14 21:20:38 -07:00
Colin Gaffney
38f6338299 Switch to zstd for numpy array serialization (jax.Array serialization is handled by JAX library).
PiperOrigin-RevId: 522616067
2023-04-07 09:36:05 -07:00
Yash Katariya
d27a80dbfa Rename gda_serialization to array_serialization but keep gda_serialization around until it is included in a jax release so that OSS projects can be moved to array_serialization
PiperOrigin-RevId: 521055760
2023-03-31 18:07:51 -07:00