9 Commits

Author SHA1 Message Date
Peter Hawkins
51b9fe3010 [JAX] Add a new jax_num_cpu_devices flag that allows the user to specify the number of CPU directly.
This subsumes (and ultimately will deprecate) overriding the number of CPU devices via XLA_FLAGS.

In addition, replace the test utility jtu.set_host_platform_device_count with jtu.request_cpu_devices(...), which sets or increases the flag's value. This both removes the need for an overly complicated context stack, and prepares for removing remaining uses of setUpModule as part of work parallelizing the test suite with threads.

PiperOrigin-RevId: 713272197
2025-01-08 06:37:44 -08:00
Hyeontaek Lim
0e0fc0ac03 [JAX] Add a test using inputs with different device orders for a single colocated Python call
PiperOrigin-RevId: 708461989
2024-12-20 16:55:29 -08:00
Peter Hawkins
7776982a8d Bump xla_extension_version after jaxlib release.
The new minimum version is 301.
2024-12-18 08:07:19 -05:00
Peter Hawkins
62e66b684b Don't monkey-patch functions in test_utils to count events for tests.
This has two problems:
* it's not thread-safe, which will become problematic if we run tests with thread-parallelism.
* it's not very maintainable.

Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
2024-12-12 09:58:14 -05:00
Hyeontaek Lim
296d1670bf [JAX] Add concurrent execution support in colocated Python
This change makes asynchronous execution run without holding a mutex. This
allows colocated Python executions from multiple Python threads to run
concurrently.

PiperOrigin-RevId: 704340663
2024-12-09 10:43:30 -08:00
Hyeontaek Lim
e20a483bef [JAX] Add end-to-end execution support in colocated Python API
This change adds a capability to run colocated Python function calls through
`PyLoadedExecutable`. This capability is not yet used for McJAX, but is tested
with a prototype of a colocated Python backend. The overall behavior remains
the same for McJAX (running the user code inline when colocated Python is
called); the new logic will be used once we introduce a colocated Python
backend for McJAX.

Key highlights:

* Colocated Python is compiled into `PyLoadedExeutable` and uses the JAX C++
dispatch path.

* `CustomCallProgram` for a colocated Python compilation nows includes
specialization (input/output specs, devices). This information allows a
colocated Python backend to transform input/outputs and validate
PyTree/dtype/shape/sharding.

* `out_specs_fn` now receives `jax.ShapeDTypeStruct`s instead of concrete values.

* Deserialization of devices now prefers the default backend. This improves the
compatibility with an environment using both multi-platform backend as well as
the standard "cpu" backend at the same time.

* Several bugs have been fixed (e.g., correctly using `{}` for kwargs).

PiperOrigin-RevId: 703172997
2024-12-05 10:52:40 -08:00
Hyeontaek Lim
bbaec6ea59 [JAX] Add Python binding for building a colocated Python program
This change adds a Python binding that makes `ifrt::CustomCallProgram` for a
colocated Python program. This Python binding will be used internally in the
colocated Python API implementation. The API does not yet compile the program
into an executable, which will be added separately.

PiperOrigin-RevId: 700443656
2024-11-26 13:31:15 -08:00
Vadym Matsishevskyi
e127053304 Fix ColocatedPythonTest. The test has been failing only on pytest nighlies because there are more GPU devices than CPU devices available, but the tests was making assumption that number of cpu devices is always bigger.
PiperOrigin-RevId: 692707314
2024-11-03 08:41:16 -08:00
Hyeontaek Lim
77797f434d [JAX] Add the function API of jax.experimental.colocated_python
This change adds an experimental API `jax.experimental.colocated_python`. The
ultimate goal of this API is to provide a runtime-agnostic way to wrap a Python
code that runs close to (or on) accelerator hosts. Multi-controller JAX can
trivially achieve this colocated Python code execution today, while
single-controller JAX needed its own solution for distributed Python code
execution, which creates fragmentation of the user code for these two runtime
architectures. `colocated_python` is an attempt to define a single device model
and portable API to allow the user to write a single code once that can run on
both runtime architectures.

This change includes an implementation of the function API portion of
`jax.experimental.colocated_python`. A (stateful) object API will be added
separately. Also there will be a separate change that expresses serialized
functions as an IFRT `CustomCallProgram`.

It is currently in an early development stage. Please proceed with a caution
when using the API.

PiperOrigin-RevId: 690705899
2024-10-28 12:18:48 -07:00