15 Commits

Author SHA1 Message Date
Jieying Luo
3f1900e2e3 [PJRT C API] Add a util method to get the PJRT C API version of the backend.
Disable some memories tests which are not supported on plugin older than 0.32.

PiperOrigin-RevId: 581008059
2023-11-09 13:30:19 -08:00
Yash Katariya
cf3c041366 Disable jax memories flag.
PiperOrigin-RevId: 580961421
2023-11-09 10:54:02 -08:00
Yunlong Liu
b99958db37 Places the remat decorator on top of the body function.
PiperOrigin-RevId: 577320028
2023-10-27 15:27:19 -07:00
jax authors
74983770cb Add GetOpSharding to XLA/PjRt utils.
PiperOrigin-RevId: 574287268
2023-10-17 15:46:52 -07:00
Roy Frostig
3247db774e add tests for host offloading (plus operations) under a custom VJP
Co-authored-by: Yash Katariya <yashkatariya@google.com>
PiperOrigin-RevId: 569333314
2023-09-28 17:19:21 -07:00
Peter Hawkins
1885c4933c Add a new internal test utility test_device_matches() and use it instead of equality tests on device_under_test().
This change prepares for allowing more flexible tag matching. For example, we may want to write "gpu" in a test and have it match both "cuda" and "rocm" devices, which we cannot do under the current API but can easily do under this design.

Replace uses of device_under_test() in a context that performs an equality test with a call to test_device_matches().
Replace uses of if_device_under_test() with test_device_matches() and delete if_device_under_test().

PiperOrigin-RevId: 568923117
2023-09-27 12:10:43 -07:00
jax authors
d120dd682a Remove PjRt C API check in PjRtArray::Reshard since PjRtCApiBuffer::CopyToMemorySpace has been supported.
PiperOrigin-RevId: 568754787
2023-09-26 23:26:26 -07:00
Hyeontaek Lim
51589bbe70 Clarify that PjRtClient and PjRtDevice memory_spaces are not in particular order
PiperOrigin-RevId: 567630629
2023-09-22 08:37:07 -07:00
jax authors
68eddd16f3 Update the logic of PjRtArray::Reshard after PjRtBuffer::CopyToMemorySpace was introduced. Users should use PjRtBuffer::CopyToMemorySpace instead of PjRtBuffer::CopyToDevice when memories are supported, since the semantics of the latter one is to always copy to the default memory space of the device.
PiperOrigin-RevId: 567154400
2023-09-20 19:39:01 -07:00
jax authors
7bc01d9472 Add memory kind check in PjRtArray::Create.
PiperOrigin-RevId: 566851924
2023-09-19 22:58:37 -07:00
Yash Katariya
05729513fb Delete TransferPjRtBufferBetweenMemories and replace it with CopyToMemorySpace which is more robust and fully async and transfers between any memory space.
PiperOrigin-RevId: 566420233
2023-09-18 14:48:55 -07:00
Yash Katariya
ebc24c737b Pass sharded inputs to remat offloading tests. When we execute, these inputs will be interesting to validate against the correctness of the compiler passes.
PiperOrigin-RevId: 565180089
2023-09-13 15:43:40 -07:00
Yash Katariya
8340149336 Check if the input which is donated is actually deleted along with the AOT check.
PiperOrigin-RevId: 565098239
2023-09-13 10:50:16 -07:00
Yash Katariya
c41d271175 Add memories support to remat.
This PR adds basic support to remat to allow transferring intermediates (activations) to destination memory in the forward pass. Currently JAX only support host memory kind but the API allows to transfer to other memories too. Remat will automatically load the residuals back to the source memory in the backward pass.

Introduce two singletons called `Recompute`, `Saveable` and a NamedTuple (`Offloadable`) that each policy can return. Currently policies return a bool which if True means saveable else recompute on backward pass. This is a backwards compatible change i.e. policies can still return a bool.

A very basic offloadable policy can look like this:

```
def policy(prim, *avals, **params):
  return ad_checkpoint.Offloadable(src='tpu_hbm', dst='unpinned_host')
```

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 564914301
2023-09-12 20:50:05 -07:00
Yash Katariya
76a5dc3cac Move memories_test.py to JAX
PiperOrigin-RevId: 564551723
2023-09-11 17:41:55 -07:00