27 Commits

Author SHA1 Message Date
jax authors
d0819ae67b remove unnecessary if statement
PiperOrigin-RevId: 617653292
2024-03-20 16:15:15 -07:00
Jian Li
b6e985ffe7 Add int4 test to ArrayImpl.
PiperOrigin-RevId: 614778550
2024-03-11 13:40:11 -07:00
Daniel Ng
e079bb9938 Add primary_host and replica_id parameters to async_serialize().
PiperOrigin-RevId: 611108421
2024-02-28 08:18:02 -08:00
Jake VanderPlas
e59a0506fe Deprecate jax.tree_map in favor of jax.tree.map 2024-02-22 11:35:39 -08:00
Sergei Lebedev
078bb00fdb Replaced most usages of abc.ABC with util.StrictABC
StrictABC does not allow registering virtual subclasses and can thus avoid
using relatively expensive __instancecheck__/__sublclasscheck__ defined in
abc.ABCMeta.

The only abc.ABC subclass left is jax.Array which *does* use virtual
subclasses for natively-defined array types.
2024-01-29 12:40:43 +00:00
Sergei Lebedev
f936613b06 Upgrade remaining sources to Python 3.9
This PR is a follow up to #18881.

The changes were generated by adding

    from __future__ import annotations

to the files which did not already have them and running

    pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Mark Sandler
569f06cda7 In python 3.11 async.run() always tries to convert repr of the result of a coroutine as integer while fetching sigint handler. This makes the test materialize the whole tensor in memory. This changes the test co-routine to return nothing to avoid triggering this bug.
https://github.com/python/cpython/issues/112559

PiperOrigin-RevId: 586756112
2023-11-30 12:37:12 -08:00
Yash Katariya
d6637da431 Disable test_memory_cosumption test
PiperOrigin-RevId: 586426753
2023-11-29 12:50:46 -08:00
Yash Katariya
0d57330fe0 Add mitigation techniques in the error message when a barrier timeout occurs.
PiperOrigin-RevId: 577214081
2023-10-27 08:57:39 -07:00
Jake VanderPlas
4a5bd9e046 Fix typos across the package 2023-09-22 14:54:31 -07:00
Yash Katariya
6b574708ee Abstract the array_serialization error message to a global variable so that it can be overridden.
PiperOrigin-RevId: 561224461
2023-08-29 21:46:26 -07:00
Peter Hawkins
2c32660a8f Replace references to DeviceArray with Array.
A number of stale references are lurking in our documentation.
2023-08-18 17:46:00 -04:00
Yash Katariya
1bd5fd2a52 Add serialize_with_paths and deserialize_with_paths API to GlobalAsyncCheckpointManager
PiperOrigin-RevId: 555050522
2023-08-08 22:47:27 -07:00
Peter Hawkins
319ab98980 Apply pyupgrade --py39-plus.
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Alexander Kolesnikov
16a3c1018f Improve error message and the relevant test.
PiperOrigin-RevId: 548058071
2023-07-14 01:45:43 -07:00
Parker Schuh
feced360f0 Make the default driver in serialization be a global constant.
PiperOrigin-RevId: 543008650
2023-06-23 18:40:31 -07:00
Peter Hawkins
816ba91263 Use lower-case PEP 585 names for types.
Issue https://github.com/google/jax/issues/16537

PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Peter Hawkins
9f4080ae2b Silence pytype errors under an upcoming pytype change.
PiperOrigin-RevId: 542652577
2023-06-22 13:32:15 -07:00
Yash Katariya
b44f8b4bf0 Fix get_tensorstore_spec for GCS paths if ocdbt is enabled
PiperOrigin-RevId: 538627415
2023-06-07 16:47:35 -07:00
Yash Katariya
9615a31a73 Add concurrent_gb to deserialize
PiperOrigin-RevId: 538537686
2023-06-07 11:03:14 -07:00
Colin Gaffney
88d3d826cc Add assume_metadata option to avoid duplicate read if the ts.Spec has already been fully constructed via ts.open.
PiperOrigin-RevId: 537986108
2023-06-05 14:45:02 -07:00
Yash Katariya
0dffdf4645 Allow GlobalAsyncCheckpointManager to work on a single process without initializing the distributed system.
PiperOrigin-RevId: 537937426
2023-06-05 11:41:54 -07:00
jax authors
c15f30f22e Instrument a new metric to measure the savings of async checkpoint in JAX.
Create the new metric '/jax/checkpoint/write/async/thread_duration_sec' to measure the savings from the async thread creation time.

PiperOrigin-RevId: 529227213
2023-05-03 16:36:54 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Mark Sandler
849e47f79a Makes deserializer put tensors on the device before releasing inflight memory, as well as avoids allocating memory before memory is available.
This makes in-flight memory limiter both reflective of the actual peak usage, as well as reduces peak usage since we no longer try to fully materialize sharded tensors on the host.

PiperOrigin-RevId: 524456216
2023-04-14 21:20:38 -07:00
Colin Gaffney
38f6338299 Switch to zstd for numpy array serialization (jax.Array serialization is handled by JAX library).
PiperOrigin-RevId: 522616067
2023-04-07 09:36:05 -07:00
Yash Katariya
d27a80dbfa Rename gda_serialization to array_serialization but keep gda_serialization around until it is included in a jax release so that OSS projects can be moved to array_serialization
PiperOrigin-RevId: 521055760
2023-03-31 18:07:51 -07:00