24 Commits

Author SHA1 Message Date
Yash Katariya
be53ee10b1 Set jax_enable_memories flag to True by default
PiperOrigin-RevId: 660579462
2024-08-07 16:25:25 -07:00
Georg Stefan Schmid
0428871c82 Adapt test case 2024-07-15 09:37:59 +00:00
Georg Stefan Schmid
b8b9d2878c [memories] Transfer to pinned_host fast path in async_serialize 2024-07-15 09:35:43 +00:00
jax authors
d60e2201e7 Roll forward: Improve tensorstore I/O efficiency
Reverts 5462d2e3930c6202ffd66aea37d5876cc5f78dbb

PiperOrigin-RevId: 650332835
2024-07-08 12:12:59 -07:00
Yash Katariya
e1a496d3b6 Add concrete layout API to JAX. The API takes major_to_minor: tuple[int, ...] and tiling: tuple[tuple[int, ...], ...] as the arguments. Allows users to pass layouts to with_sharding_constraint to constrain the layout + sharding.
`sub_byte_element_size_in_bits` is a lowering only thing for now (since we know the dtype of the aval so JAX can add the appropriate value). We can expose it to the user API if required.

memory space is exposed via JAX memories API so it doesn't have to be in the layout API.

Also expose `_xla_layout` as a private API from `PJRTLayout` so that we can access fields to create JAX layouts.

Add construtors to `xla::Layout` so that JAX can create Layouts with minor_to_major and tiling information.

PiperOrigin-RevId: 647487510
2024-06-27 16:47:31 -07:00
jax authors
5462d2e393 Revert: Improve tensorstore I/O efficiency
Reverts 2f749dbe39589fe35d219e0966990e2b70818d92

PiperOrigin-RevId: 642755899
2024-06-12 15:22:05 -07:00
jax authors
2f749dbe39 Improve tensorstore I/O efficiency
Previously, when writing the OCDBT format, the manifest and root B+tree node could be redundantly written multiple times depending on timing.

With this change, the manifest and root B+tree node are always written only once.

Additionally, source data was previously redundantly copied into the TensorStore chunk cache.

PiperOrigin-RevId: 642345928
2024-06-11 12:07:42 -07:00
Jake VanderPlas
a861c55a28 test cleanup: use ExitStack to reduce test boilerplate 2024-06-06 14:18:27 -07:00
Jieying Luo
51e5951858 Use opaque layout PJRT_Layouts_MemoryLayout in PjRtCApiBuffer::layout() to keep all the layout information.
PjRtCApiBuffer::layout() was using PJRT_Buffer_GetMemoryLayout, which will be deprecated. PJRT_Buffer_GetMemoryLayout uses explicit PJRT_Buffer_MemoryLayout which does not contain all the layout information.
PiperOrigin-RevId: 638048293
2024-05-28 15:44:51 -07:00
Yash Katariya
4e7d0f1df2 Fix deserialization with int4 and layout interaction
Fixes: https://github.com/google/jax/issues/21339
PiperOrigin-RevId: 636336957
2024-05-22 16:45:31 -07:00
Daniel Ng
77988ead94 Move dtype settings out of metadata field into the root of Tensorstore spec
Before, dtype used to be in the metadata field of tensorstore spec because of it was the legacy way to config the dtype.  This setting doesn't understand the "str" name, hence, there was special logic to translate bfloat for example.

This CL moves it out of the metadata field and put the dtype directly into the Tensorstore spec to eliminate special dtype translation logic.  This will also add support of other quantized types such as int4.

PiperOrigin-RevId: 629845048
2024-05-01 14:48:55 -07:00
Yash Katariya
5ce7dca969 Add support for loading checkpoints with a given layout to the array serialization library
PiperOrigin-RevId: 624596358
2024-04-13 19:35:50 -07:00
Jake VanderPlas
f090074d86 Avoid 'from jax import config' imports
In some environments this appears to import the config module rather than
the config object.
2024-04-11 13:23:27 -07:00
Jian Li
b6e985ffe7 Add int4 test to ArrayImpl.
PiperOrigin-RevId: 614778550
2024-03-11 13:40:11 -07:00
Daniel Ng
e079bb9938 Add primary_host and replica_id parameters to async_serialize().
PiperOrigin-RevId: 611108421
2024-02-28 08:18:02 -08:00
Mark Sandler
569f06cda7 In python 3.11 async.run() always tries to convert repr of the result of a coroutine as integer while fetching sigint handler. This makes the test materialize the whole tensor in memory. This changes the test co-routine to return nothing to avoid triggering this bug.
https://github.com/python/cpython/issues/112559

PiperOrigin-RevId: 586756112
2023-11-30 12:37:12 -08:00
Yash Katariya
d6637da431 Disable test_memory_cosumption test
PiperOrigin-RevId: 586426753
2023-11-29 12:50:46 -08:00
Yash Katariya
1bd5fd2a52 Add serialize_with_paths and deserialize_with_paths API to GlobalAsyncCheckpointManager
PiperOrigin-RevId: 555050522
2023-08-08 22:47:27 -07:00
Parker Schuh
feced360f0 Make the default driver in serialization be a global constant.
PiperOrigin-RevId: 543008650
2023-06-23 18:40:31 -07:00
Yash Katariya
b44f8b4bf0 Fix get_tensorstore_spec for GCS paths if ocdbt is enabled
PiperOrigin-RevId: 538627415
2023-06-07 16:47:35 -07:00
Yash Katariya
0dffdf4645 Allow GlobalAsyncCheckpointManager to work on a single process without initializing the distributed system.
PiperOrigin-RevId: 537937426
2023-06-05 11:41:54 -07:00
Jake VanderPlas
fbe4f10403 Change to simpler import for jax.config 2023-04-21 11:51:22 -07:00
Mark Sandler
849e47f79a Makes deserializer put tensors on the device before releasing inflight memory, as well as avoids allocating memory before memory is available.
This makes in-flight memory limiter both reflective of the actual peak usage, as well as reduces peak usage since we no longer try to fully materialize sharded tensors on the host.

PiperOrigin-RevId: 524456216
2023-04-14 21:20:38 -07:00
Yash Katariya
d27a80dbfa Rename gda_serialization to array_serialization but keep gda_serialization around until it is included in a jax release so that OSS projects can be moved to array_serialization
PiperOrigin-RevId: 521055760
2023-03-31 18:07:51 -07:00