17 Commits

Author SHA1 Message Date
jax authors
cd7f03f272 Updates the Colocated Python's serialization (and deserialization) implementation to utilize the recently added support for string arrays.
Currently the serialized data and its length are being carried in two separate arrays, a fixed-with bytes array (with a hard-coded max size) and a unit32 array respectively.

PiperOrigin-RevId: 734299259
2025-03-06 14:57:52 -08:00
Peter Hawkins
66293d8897 Remove code present to support jaxlib < 0.5.1.
The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
2025-02-26 07:40:40 -05:00
Adam Paszke
7fb6788d4f Skip tests using StringDType when NumPy version is below 2.0
PiperOrigin-RevId: 730795666
2025-02-25 02:25:47 -08:00
Hyeontaek Lim
96b7dbabdc [JAX] Implement an initial object API for colocated Python
Colocated Python adds `colocated_python_class`. This API wraps a user-defined
class for automatic remoting of object construction/destruction and method calls:

* An object will be initialized on the backend. At least for now,
initialization is deferred until the first method is called; at this point,
colocated Python knows what devices the objects should be accessible and thus
it can construct the object(s).

* When an object method is called, the method call runs as a colocated Python
function call on the backend.

* When the object is destroyed (either by reaching a zero reference count or
through Python GC), destruction also runs as a colocated Python function call
and destroys all objects from the backend.

This change provides an intial API implementation. Main limitations are as
follows:

* The methods of a colocated Python class does not support specialization.
Calling it requires at least one argument.

* Colocated Python objects cannot reference or interact with each other on the
controller or on the colocated Python backend.

These limitations will be lifted as the object API implementation is improved.

PiperOrigin-RevId: 729629265
2025-02-21 12:58:25 -08:00
Hyeontaek Lim
71f9764edc [JAX] Generate more readable error for failed device deserialization in colocated Python
When deserializing a colocated Python function or input/output sharding, we
often need to deserialize a device using a device id. This is done by looking
up a CPU device map; this lookup can fail if the device id was referring to a
non-CPU device. Unfortunately, we would see a simple error message like
`KeyError: np.int64(0)` that does not give a context of the problem.

This change adds a slightly more context to the exception so that the error is
more actionable.

PiperOrigin-RevId: 729172296
2025-02-20 10:52:21 -08:00
jax authors
9b6b569f3c Adds support for string and binary data processing in Colocated Python.
PiperOrigin-RevId: 727048049
2025-02-14 13:39:20 -08:00
Hyeontaek Lim
f43d2b68d9 [JAX] Add a test verifying the behavior of module-level state accessed by colocated Python
A new test verifies that
* Python module-level variables can be created/set and read from a colocated Python function
* Python module-level variables are not pickled on the controller (JAX) or sent to executors via pickling

An API for defining user-defined state and accessing it from multiple colocated
Python functions (i.e., object support) will be added later. That will be a
recommended way to express user-defined state. The capability of accessing
Python module variables is still crucial because a lot of Python code
(including JAX) requires this behavior to implement caching.

PiperOrigin-RevId: 723595727
2025-02-05 11:49:07 -08:00
Matthew Johnson
1ae02bc069 skip tests with extra requirements 2025-02-05 01:48:28 +00:00
Peter Hawkins
51b9fe3010 [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00
Hyeontaek Lim
0e0fc0ac03 [JAX] Add a test using inputs with different device orders for a single colocated Python call
PiperOrigin-RevId: 708461989
2024-12-20 16:55:29 -08:00
Peter Hawkins
7776982a8d Bump xla_extension_version after jaxlib release.
The new minimum version is 301.
2024-12-18 08:07:19 -05:00
Peter Hawkins
62e66b684b Don't monkey-patch functions in test_utils to count events for tests.
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Hyeontaek Lim
296d1670bf [JAX] Add concurrent execution support in colocated Python
This change makes asynchronous execution run without holding a mutex. This
allows colocated Python executions from multiple Python threads to run
concurrently.

PiperOrigin-RevId: 704340663
2024-12-09 10:43:30 -08:00
Hyeontaek Lim
e20a483bef [JAX] Add end-to-end execution support in colocated Python API
This change adds a capability to run colocated Python function calls through
`PyLoadedExecutable`. This capability is not yet used for McJAX, but is tested
with a prototype of a colocated Python backend. The overall behavior remains
the same for McJAX (running the user code inline when colocated Python is
called); the new logic will be used once we introduce a colocated Python
backend for McJAX.

Key highlights:

* Colocated Python is compiled into `PyLoadedExeutable` and uses the JAX C++
dispatch path.

* `CustomCallProgram` for a colocated Python compilation nows includes
specialization (input/output specs, devices). This information allows a
colocated Python backend to transform input/outputs and validate
PyTree/dtype/shape/sharding.

* `out_specs_fn` now receives `jax.ShapeDTypeStruct`s instead of concrete values.

* Deserialization of devices now prefers the default backend. This improves the
compatibility with an environment using both multi-platform backend as well as
the standard "cpu" backend at the same time.

* Several bugs have been fixed (e.g., correctly using `{}` for kwargs).

PiperOrigin-RevId: 703172997
2024-12-05 10:52:40 -08:00
Hyeontaek Lim
bbaec6ea59 [JAX] Add Python binding for building a colocated Python program
This change adds a Python binding that makes `ifrt::CustomCallProgram` for a
colocated Python program. This Python binding will be used internally in the
colocated Python API implementation. The API does not yet compile the program
into an executable, which will be added separately.

PiperOrigin-RevId: 700443656
2024-11-26 13:31:15 -08:00
Vadym Matsishevskyi
e127053304 Fix ColocatedPythonTest. The test has been failing only on pytest nighlies because there are more GPU devices than CPU devices available, but the tests was making assumption that number of cpu devices is always bigger.
PiperOrigin-RevId: 692707314
2024-11-03 08:41:16 -08:00
Hyeontaek Lim
77797f434d [JAX] Add the function API of jax.experimental.colocated_python
This change adds an experimental API `jax.experimental.colocated_python`. The
ultimate goal of this API is to provide a runtime-agnostic way to wrap a Python
code that runs close to (or on) accelerator hosts. Multi-controller JAX can
trivially achieve this colocated Python code execution today, while
single-controller JAX needed its own solution for distributed Python code
execution, which creates fragmentation of the user code for these two runtime
architectures. `colocated_python` is an attempt to define a single device model
and portable API to allow the user to write a single code once that can run on
both runtime architectures.

This change includes an implementation of the function API portion of
`jax.experimental.colocated_python`. A (stateful) object API will be added
separately. Also there will be a separate change that expresses serialized
functions as an IFRT `CustomCallProgram`.

It is currently in an early development stage. Please proceed with a caution
when using the API.

PiperOrigin-RevId: 690705899
2024-10-28 12:18:48 -07:00