jax authors
d0819ae67b
remove unnecessary if statement
...
PiperOrigin-RevId: 617653292
2024-03-20 16:15:15 -07:00
Jian Li
b6e985ffe7
Add int4 test to ArrayImpl.
...
PiperOrigin-RevId: 614778550
2024-03-11 13:40:11 -07:00
Daniel Ng
e079bb9938
Add primary_host and replica_id parameters to async_serialize().
...
PiperOrigin-RevId: 611108421
2024-02-28 08:18:02 -08:00
Jake VanderPlas
e59a0506fe
Deprecate jax.tree_map in favor of jax.tree.map
2024-02-22 11:35:39 -08:00
Sergei Lebedev
078bb00fdb
Replaced most usages of abc.ABC with util.StrictABC
...
StrictABC does not allow registering virtual subclasses and can thus avoid
using relatively expensive __instancecheck__/__sublclasscheck__ defined in
abc.ABCMeta.
The only abc.ABC subclass left is jax.Array which *does* use virtual
subclasses for natively-defined array types.
2024-01-29 12:40:43 +00:00
Sergei Lebedev
f936613b06
Upgrade remaining sources to Python 3.9
...
This PR is a follow up to #18881 .
The changes were generated by adding
from __future__ import annotations
to the files which did not already have them and running
pyupgrade --py39-plus --keep-percent-format {jax,tests,jaxlib,examples,benchmarks}/**/*.py
2023-12-13 10:29:45 +00:00
Mark Sandler
569f06cda7
In python 3.11 async.run() always tries to convert repr of the result of a coroutine as integer while fetching sigint handler. This makes the test materialize the whole tensor in memory. This changes the test co-routine to return nothing to avoid triggering this bug.
...
https://github.com/python/cpython/issues/112559
PiperOrigin-RevId: 586756112
2023-11-30 12:37:12 -08:00
Yash Katariya
d6637da431
Disable test_memory_cosumption test
...
PiperOrigin-RevId: 586426753
2023-11-29 12:50:46 -08:00
Yash Katariya
0d57330fe0
Add mitigation techniques in the error message when a barrier timeout occurs.
...
PiperOrigin-RevId: 577214081
2023-10-27 08:57:39 -07:00
Jake VanderPlas
4a5bd9e046
Fix typos across the package
2023-09-22 14:54:31 -07:00
Yash Katariya
6b574708ee
Abstract the array_serialization error message to a global variable so that it can be overridden.
...
PiperOrigin-RevId: 561224461
2023-08-29 21:46:26 -07:00
Peter Hawkins
2c32660a8f
Replace references to DeviceArray with Array.
...
A number of stale references are lurking in our documentation.
2023-08-18 17:46:00 -04:00
Yash Katariya
1bd5fd2a52
Add serialize_with_paths
and deserialize_with_paths
API to GlobalAsyncCheckpointManager
...
PiperOrigin-RevId: 555050522
2023-08-08 22:47:27 -07:00
Peter Hawkins
319ab98980
Apply pyupgrade --py39-plus.
...
Notable changes:
* use PEP 585 type names
* use PEP 604 type union syntax where `from __future__ import annotations` is present.
* use f-strings in more places.
* remove redundant arguments to open().
2023-07-21 14:49:44 -04:00
Alexander Kolesnikov
16a3c1018f
Improve error message and the relevant test.
...
PiperOrigin-RevId: 548058071
2023-07-14 01:45:43 -07:00
Parker Schuh
feced360f0
Make the default driver in serialization be a global constant.
...
PiperOrigin-RevId: 543008650
2023-06-23 18:40:31 -07:00
Peter Hawkins
816ba91263
Use lower-case PEP 585 names for types.
...
Issue https://github.com/google/jax/issues/16537
PiperOrigin-RevId: 542969282
2023-06-23 15:12:14 -07:00
Peter Hawkins
9f4080ae2b
Silence pytype errors under an upcoming pytype change.
...
PiperOrigin-RevId: 542652577
2023-06-22 13:32:15 -07:00
Yash Katariya
b44f8b4bf0
Fix get_tensorstore_spec
for GCS paths if ocdbt is enabled
...
PiperOrigin-RevId: 538627415
2023-06-07 16:47:35 -07:00
Yash Katariya
9615a31a73
Add concurrent_gb to deserialize
...
PiperOrigin-RevId: 538537686
2023-06-07 11:03:14 -07:00
Colin Gaffney
88d3d826cc
Add assume_metadata
option to avoid duplicate read if the ts.Spec
has already been fully constructed via ts.open
.
...
PiperOrigin-RevId: 537986108
2023-06-05 14:45:02 -07:00
Yash Katariya
0dffdf4645
Allow GlobalAsyncCheckpointManager to work on a single process without initializing the distributed system.
...
PiperOrigin-RevId: 537937426
2023-06-05 11:41:54 -07:00
jax authors
c15f30f22e
Instrument a new metric to measure the savings of async checkpoint in JAX.
...
Create the new metric '/jax/checkpoint/write/async/thread_duration_sec' to measure the savings from the async thread creation time.
PiperOrigin-RevId: 529227213
2023-05-03 16:36:54 -07:00
Jake VanderPlas
fbe4f10403
Change to simpler import for jax.config
2023-04-21 11:51:22 -07:00
Mark Sandler
849e47f79a
Makes deserializer put tensors on the device before releasing inflight memory, as well as avoids allocating memory before memory is available.
...
This makes in-flight memory limiter both reflective of the actual peak usage, as well as reduces peak usage since we no longer try to fully materialize sharded tensors on the host.
PiperOrigin-RevId: 524456216
2023-04-14 21:20:38 -07:00
Colin Gaffney
38f6338299
Switch to zstd for numpy array serialization (jax.Array serialization is handled by JAX library).
...
PiperOrigin-RevId: 522616067
2023-04-07 09:36:05 -07:00
Yash Katariya
d27a80dbfa
Rename gda_serialization to array_serialization but keep gda_serialization around until it is included in a jax release so that OSS projects can be moved to array_serialization
...
PiperOrigin-RevId: 521055760
2023-03-31 18:07:51 -07:00