Colocated Python adds `colocated_python_class`. This API wraps a user-defined
class for automatic remoting of object construction/destruction and method calls:
* An object will be initialized on the backend. At least for now,
initialization is deferred until the first method is called; at this point,
colocated Python knows what devices the objects should be accessible and thus
it can construct the object(s).
* When an object method is called, the method call runs as a colocated Python
function call on the backend.
* When the object is destroyed (either by reaching a zero reference count or
through Python GC), destruction also runs as a colocated Python function call
and destroys all objects from the backend.
This change provides an intial API implementation. Main limitations are as
follows:
* The methods of a colocated Python class does not support specialization.
Calling it requires at least one argument.
* Colocated Python objects cannot reference or interact with each other on the
controller or on the colocated Python backend.
These limitations will be lifted as the object API implementation is improved.
PiperOrigin-RevId: 729629265
In this change, we update schur, triangular_solve, tridiagonal, and tridiagonal_solve. I batched these ones since they're all pretty straightforward.
PiperOrigin-RevId: 729572705
- This refactor just moves code around and should have no impact on tests or public-facing APIs.
- `mlir.emit_python_callback` would eventually depend on `ffi.ffi_lowering`, which in turn depends on definitions in `mlir.py`. We break this circular dependency.
PiperOrigin-RevId: 729561359
To be consistent with other rule registration helpers, `unop_dtype_rule` should pass through its kwargs to the `result_dtype` callable.
PiperOrigin-RevId: 729483613
As part of my efforts to simplify the primitive implementations in lax.linalg, I've found that all of the primitives share some common logic when it comes to impls, abstract_evals, and batching. This change adds some helper functions and starts the process of abstracting the primitive definitions to simplify and reduce duplication. I will continue with the rest of the primitives in lax.linalg, but I didn't want to overload the first diff.
PiperOrigin-RevId: 729471970
Also, if all axes of an out_aval are auto, set the corresponding out_sharding to Unspecified during lowering, otherwise things go horribly wrong. This is actually a XLA bug but we can workaround it in JAX for now.
PiperOrigin-RevId: 729307115
When deserializing a colocated Python function or input/output sharding, we
often need to deserialize a device using a device id. This is done by looking
up a CPU device map; this lookup can fail if the device id was referring to a
non-CPU device. Unfortunately, we would see a simple error message like
`KeyError: np.int64(0)` that does not give a context of the problem.
This change adds a slightly more context to the exception so that the error is
more actionable.
PiperOrigin-RevId: 729172296
If a mesh axis is Explicit, we don't canonicalize closed over values yet since that make require shape changes. The workaround is for users to pass those arrays as arguments instead of closing over them in a shard_map.
PiperOrigin-RevId: 728956512
Update to use new HLO to StableHLO API. Currently all users of this function have this flag set to true so should be a low impact change.
PiperOrigin-RevId: 728866498