Colocated Python adds `colocated_python_class`. This API wraps a user-defined
class for automatic remoting of object construction/destruction and method calls:
* An object will be initialized on the backend. At least for now,
initialization is deferred until the first method is called; at this point,
colocated Python knows what devices the objects should be accessible and thus
it can construct the object(s).
* When an object method is called, the method call runs as a colocated Python
function call on the backend.
* When the object is destroyed (either by reaching a zero reference count or
through Python GC), destruction also runs as a colocated Python function call
and destroys all objects from the backend.
This change provides an intial API implementation. Main limitations are as
follows:
* The methods of a colocated Python class does not support specialization.
Calling it requires at least one argument.
* Colocated Python objects cannot reference or interact with each other on the
controller or on the colocated Python backend.
These limitations will be lifted as the object API implementation is improved.
PiperOrigin-RevId: 729629265
- This refactor just moves code around and should have no impact on tests or public-facing APIs.
- `mlir.emit_python_callback` would eventually depend on `ffi.ffi_lowering`, which in turn depends on definitions in `mlir.py`. We break this circular dependency.
PiperOrigin-RevId: 729561359
Also, if all axes of an out_aval are auto, set the corresponding out_sharding to Unspecified during lowering, otherwise things go horribly wrong. This is actually a XLA bug but we can workaround it in JAX for now.
PiperOrigin-RevId: 729307115
When deserializing a colocated Python function or input/output sharding, we
often need to deserialize a device using a device id. This is done by looking
up a CPU device map; this lookup can fail if the device id was referring to a
non-CPU device. Unfortunately, we would see a simple error message like
`KeyError: np.int64(0)` that does not give a context of the problem.
This change adds a slightly more context to the exception so that the error is
more actionable.
PiperOrigin-RevId: 729172296
If a mesh axis is Explicit, we don't canonicalize closed over values yet since that make require shape changes. The workaround is for users to pass those arrays as arguments instead of closing over them in a shard_map.
PiperOrigin-RevId: 728956512
Some caveats of enabling sharding-in-types by default are that we'll see tracing cache misses which will lead to lowering cache miss and compilation cache misses in the **following cases**: (but persistent compilation cache is not affected so we'll see a cache hit there)
1. Call `jitted_f(arr_ns)` with an array on `NamedSharding` and again `jitted_f(arr_ps)` with an array of same shape and dtype but now with `PositionalSharding`
* This leads to a tracing cache miss because on the second call, the aval has no sharding since it's PositionalSharding. This applies to calling with any sharding other than NamedSharding
2. `jitted_f = jit(f, in_shardings=ns)`. Call `jitted_f(sharded_arr)` and then on the second call you pass a numpy array `jitted_f(numpy_arr)`
* This also leads to a cache miss because the avals currently don't look at in_shardings because the semantics of in_shardings is complicated and I don't think we should change the aval based on in_shardings.
**The solution in both cases is make sure to pass the array sharded on the same mesh during both calls to jit.**
PiperOrigin-RevId: 728361493
The atime file is only needed to implement the LRU eviction policy,
which is only needed if a max persistence compilation cache size is
set. Writing this file can cause network filesystem performace and
other issues, so only write it if users are opted-in.
In this PR, only jit and control flows are supported. Support for vmap and multi-device environments will be added in subsequent PRs.
PiperOrigin-RevId: 726920440
We don't support Windows GPU builds right now and skip all the tests,
but at the moment they can't even skip because of the import failure.
PiperOrigin-RevId: 726917651
The corresponding Triton op is restricted to `jnp.stack([x, y], axis=-1)`,
so the lowering only supports that case for now.
See #25321.
PiperOrigin-RevId: 726881284
The PTX guide talks about a few layouts by assigning them different
letters, which do not have an obvious meaning. We redefine the layout
by parameterizing it with a 2D tile size which, as far as I can tell,
is sufficient to represent all layouts we care about.
PiperOrigin-RevId: 726833412
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.
For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.
See https://github.com/jax-ml/jax/issues/26480 for more details.
PiperOrigin-RevId: 726770483
As discovered in https://github.com/jax-ml/jax/issues/26216, for non-standard dtypes, calling `np.array` on a JAX array will unnecessarily cache the constructed `_npy_value` even when a copy isn't required. This change updates the logic to only save the cached value when it is a copy.
This fixes https://github.com/jax-ml/jax/issues/26216 by making the behavior consistent across dtypes, but we probably also want to expose a mechanism for clearing this cached value regardless.
PiperOrigin-RevId: 726522955