16 Commits

Author SHA1 Message Date
Peter Hawkins
5527966b27 [JAX] Deprecate .to_py() property on arrays. Implement __array__ instead.
.to_py() was something of an accidental export from the JAX array classes. There are other mechanisms to turn a JAX array into a NumPy array, including `np.asarray(x)` and `jax.device_get(x)`. Deprecate this mechanism because it is redundant.

PiperOrigin-RevId: 469984029
2022-08-25 07:28:27 -07:00
Yash Katariya
a6ffa72caa Return Arrays if config.jax_array is enabled always.
PiperOrigin-RevId: 469780136
2022-08-24 11:27:32 -07:00
Yash Katariya
7cdb7e1471 Add checkpointing support for Array similar to GDA.
PiperOrigin-RevId: 469271107
2022-08-22 13:35:18 -07:00
Jake VanderPlas
a10f0377db Avoid top-level aliases of jax.tree_util.* 2022-07-07 11:41:02 -07:00
Yash Katariya
8a23605462 Add a limiter for in-flight bytes. Read a shard from TensorStore if there are enough bytes are available. This only works for deserialization right now.
PiperOrigin-RevId: 458586521
2022-07-01 19:26:59 -07:00
Yash Katariya
0574eb2141 Use JAX's distributed system for fully asynchronous checkpointing.
PiperOrigin-RevId: 449380175
2022-05-17 20:17:57 -07:00
Yash Katariya
de7a872e1b Take temp_checkpoint_dir and final_checkpoint_dir as the arguments to serialize instead of the __init__. THis is because this manager will be defined at the top where the directories may not yet be known.
PiperOrigin-RevId: 446104174
2022-05-02 21:22:43 -07:00
Yash Katariya
b7293d5683 Add fully asynchronous checkpointing. This will allow the training to proceed forward when the checkpoint is being committed.
PiperOrigin-RevId: 446083057
2022-05-02 18:43:54 -07:00
Yash Katariya
cf87e3a4a3 Add a dtypes option to cast host arrays when reloading from TS.
PiperOrigin-RevId: 443804229
2022-04-22 18:00:27 -07:00
Colin Gaffney
41b6e00141 Enable use of GlobalDeviceArray (GDA) in T5X Checkpointer. Add a separate unit test, gda_checkpoints_test, to cover this use case.
GDA is locked behind a `use_gda` bool in Checkpointer. The feature is currently not enabled anywhere.

Our follow-up plan is to add code which would enable GDA use throughout T5X, and to fix any remaining issues with Checkpointer.

PiperOrigin-RevId: 439358913
2022-04-04 10:56:07 -07:00
Yash Katariya
99a103723c Make mesh_axes on GDA strict by only allowing PartitionSpecs to be consistent with pjit.
PiperOrigin-RevId: 432957496
2022-03-07 08:59:23 -08:00
Yash Katariya
3290dd3a4d Make resharding of GDA work if the shape is larger than what it was serialized with.
For example: If you serialize with shape (8, 2) and want to deserialize with global shape (12, 2).

PiperOrigin-RevId: 429680502
2022-02-18 17:28:13 -08:00
Yash Katariya
b7dcc4ce01 Handle serialization of arrays with shape (0,). These arrays are usually empty lists (np.array([]))
PiperOrigin-RevId: 426172532
2022-02-03 10:00:01 -08:00
Yash Katariya
dcca99b052 Remove path from the serde API as tspec encompasses those things.
PiperOrigin-RevId: 425727733
2022-02-01 15:17:01 -08:00
Yash Katariya
47af596ccc Instead of getting users to run a tree_map over gdas, etc and run asyncio.run, absorb those APIs into the gda serde library.
PiperOrigin-RevId: 424923960
2022-01-28 11:59:05 -08:00
Yash Katariya
0bb7d204ab Move serialization/de-serialization of GDA into jax.
PiperOrigin-RevId: 414607092
2021-12-06 20:05:02 -08:00