mirror of
https://github.com/ROCm/jax.git
synced 2025-04-25 20:26:06 +00:00

This change adds an experimental API `jax.experimental.colocated_python`. The ultimate goal of this API is to provide a runtime-agnostic way to wrap a Python code that runs close to (or on) accelerator hosts. Multi-controller JAX can trivially achieve this colocated Python code execution today, while single-controller JAX needed its own solution for distributed Python code execution, which creates fragmentation of the user code for these two runtime architectures. `colocated_python` is an attempt to define a single device model and portable API to allow the user to write a single code once that can run on both runtime architectures. This change includes an implementation of the function API portion of `jax.experimental.colocated_python`. A (stateful) object API will be added separately. Also there will be a separate change that expresses serialized functions as an IFRT `CustomCallProgram`. It is currently in an early development stage. Please proceed with a caution when using the API. PiperOrigin-RevId: 690705899