9761 Commits

Author SHA1 Message Date
Hyeontaek Lim
96b7dbabdc [JAX] Implement an initial object API for colocated Python
Colocated Python adds `colocated_python_class`. This API wraps a user-defined
class for automatic remoting of object construction/destruction and method calls:

* An object will be initialized on the backend. At least for now,
initialization is deferred until the first method is called; at this point,
colocated Python knows what devices the objects should be accessible and thus
it can construct the object(s).

* When an object method is called, the method call runs as a colocated Python
function call on the backend.

* When the object is destroyed (either by reaching a zero reference count or
through Python GC), destruction also runs as a colocated Python function call
and destroys all objects from the backend.

This change provides an intial API implementation. Main limitations are as
follows:

* The methods of a colocated Python class does not support specialization.
Calling it requires at least one argument.

* Colocated Python objects cannot reference or interact with each other on the
controller or on the colocated Python backend.

These limitations will be lifted as the object API implementation is improved.

PiperOrigin-RevId: 729629265
2025-02-21 12:58:25 -08:00
Daniel Suo
2d1bc5c2a0 Refactor Jax FFI lowering to prepare for implementing CPU/GPU callbacks using XLA's FFI.
- This refactor just moves code around and should have no impact on tests or public-facing APIs.
- `mlir.emit_python_callback` would eventually depend on `ffi.ffi_lowering`, which in turn depends on definitions in `mlir.py`. We break this circular dependency.

PiperOrigin-RevId: 729561359
2025-02-21 09:45:59 -08:00
shuw
bfb9d3ca4b Improve based on comment # 1 2025-02-21 17:32:57 +00:00
Peter Hawkins
673a02d614 Don't set PYTHONWARNINGS=error for tests that use TensorFlow.
Protobuf, which is used by TF, sometimes emits a warning under Python 3.12.

PiperOrigin-RevId: 729554740
2025-02-21 09:22:59 -08:00
Daniel Suo
87a7158f43 Disable tests/debug_info_test.py:test_vjp_of_jit test. Currently failing Python 3.13 GitHub CI tests.
PiperOrigin-RevId: 729544807
2025-02-21 08:50:47 -08:00
Yash Katariya
66037d10e7 Set the mesh of the sharding during broadcast in vmap so that we don't hit an error during canonicalization. This is similar to bcd4048dd5
PiperOrigin-RevId: 729532213
2025-02-21 08:05:42 -08:00
George Necula
edf401d775 [better_errors] Fix a debug_info test, and expand the docstring for the helper function 2025-02-21 13:03:18 +02:00
Yash Katariya
bcd4048dd5 Set the mesh of tangent.aval when we are creating zeros_like_aval because when you close over an array which is unused, we error out during canonicalization
PiperOrigin-RevId: 729340808
2025-02-20 19:32:34 -08:00
Yash Katariya
250e2ee7da Use the mesh of out_aval when converting GSPMDSharding to NamedSharding. This makes sure that the axis types of the corresponding output is correct.
Also, if all axes of an out_aval are auto, set the corresponding out_sharding to Unspecified during lowering, otherwise things go horribly wrong. This is actually a XLA bug but we can workaround it in JAX for now.

PiperOrigin-RevId: 729307115
2025-02-20 17:13:24 -08:00
Sergei Lebedev
7438976e79 [pallas:mosaic_gpu] Added support for binary/comparison ops with WG semantics
PiperOrigin-RevId: 729266484
2025-02-20 15:06:27 -08:00
Hyeontaek Lim
71f9764edc [JAX] Generate more readable error for failed device deserialization in colocated Python
When deserializing a colocated Python function or input/output sharding, we
often need to deserialize a device using a device id. This is done by looking
up a CPU device map; this lookup can fail if the device id was referring to a
non-CPU device. Unfortunately, we would see a simple error message like
`KeyError: np.int64(0)` that does not give a context of the problem.

This change adds a slightly more context to the exception so that the error is
more actionable.

PiperOrigin-RevId: 729172296
2025-02-20 10:52:21 -08:00
Yash Katariya
262aab74f0 canonicalize closed over values if **atleast** 1 mesh axis is Manual and **all other mesh axes** are Manual or Auto. This would make the canonicalization work properly with shmap partial-auto.
If a mesh axis is Explicit, we don't canonicalize closed over values yet since that make require shape changes. The workaround is for users to pass those arrays as arguments instead of closing over them in a shard_map.

PiperOrigin-RevId: 728956512
2025-02-19 22:18:56 -08:00
Yash Katariya
b6b319cd06 If cur_mesh is empty and AxisTypes of Mesh passed to shmap are Explicit, then treat the axes mentioned in auto as explicit too. In other words, "auto" really means "don't convert to manual", ie leave the listed mesh axes as they are, whether explicit or auto type
PiperOrigin-RevId: 728942780
2025-02-19 21:25:53 -08:00
Yash Katariya
8305803b76 [sharding_in_types] Initial support for partial-auto/explicit shard_map + sharding-in-types. If the axes in shmap(..., auto=...) is an explicit axes in the outer mesh context, then that axis is treated as Explicit instead of Auto.
PiperOrigin-RevId: 728920514
2025-02-19 20:04:54 -08:00
cjkkkk
3a80080392 fix unit tests to not use fmha rewriter 2025-02-20 00:41:04 +00:00
Parker Schuh
b7c66bd22e Only add new manual axes to residuals when adding axes with partial_auto.
PiperOrigin-RevId: 728839349
2025-02-19 15:27:32 -08:00
Jacob Burnim
ac74857d27 [Pallas] Support dynamic grids in the new TPU interpret mode
PiperOrigin-RevId: 728786896
2025-02-19 13:09:23 -08:00
Matthias Kramm
7eee2de703 roofline: Support computing flops for binary ops.
PiperOrigin-RevId: 728708058
2025-02-19 09:45:24 -08:00
Yash Katariya
66d04f85e6 Error out if going from Manual -> Auto/Explicit AxisTypes in the auto_axes and explicit_axes API that do mesh_cast implicitly.
Also, improve the error raised by canonicalize_sharding to include the api name and current source location.

PiperOrigin-RevId: 728701237
2025-02-19 09:21:53 -08:00
Yash Katariya
b35083331c Expose get_ty aka get_aval from jax namespace
PiperOrigin-RevId: 728490205
2025-02-18 21:22:19 -08:00
jax authors
09491e2bef Merge pull request #26172 from ZacCranko:is-distributed-init
PiperOrigin-RevId: 728445236
2025-02-18 18:43:48 -08:00
Zac Cranko
5db78e7ae0 add distributed.is_initialized 2025-02-18 16:47:19 -08:00
Yash Katariya
1079dc4477 Let users pass in pspecs to with_sharding_constraint when use_mesh is set. This is in-line with other APIs which allow pspecs like einsum, reshape, etc
PiperOrigin-RevId: 728392216
2025-02-18 15:47:03 -08:00
Yash Katariya
8bcbf585df Make device_put resharding on single device array input work under use_mesh. Fixes https://github.com/jax-ml/jax/issues/26552
PiperOrigin-RevId: 728382461
2025-02-18 15:22:39 -08:00
Yash Katariya
00d8297071 [sharding_in_types] Set the sharding_in_types config to True. This is a purely internal change and shouldn't affect any public APIs.
Some caveats of enabling sharding-in-types by default are that we'll see tracing cache misses which will lead to lowering cache miss and compilation cache misses in the **following cases**: (but persistent compilation cache is not affected so we'll see a cache hit there)

1. Call `jitted_f(arr_ns)` with an array on `NamedSharding` and again `jitted_f(arr_ps)` with an array of same shape and dtype but now with `PositionalSharding`
    * This leads to a tracing cache miss because on the second call, the aval has no sharding since it's PositionalSharding. This applies to calling with any sharding other than NamedSharding

2. `jitted_f = jit(f, in_shardings=ns)`. Call `jitted_f(sharded_arr)` and then on the second call you pass a numpy array `jitted_f(numpy_arr)`
   * This also leads to a cache miss because the avals currently don't look at in_shardings because the semantics of in_shardings is complicated and I don't think we should change the aval based on in_shardings.

**The solution in both cases is make sure to pass the array sharded on the same mesh during both calls to jit.**

PiperOrigin-RevId: 728361493
2025-02-18 14:35:14 -08:00
Jevin Jiang
bb68124c33 [Mosaic TPU] Support mask concat
PiperOrigin-RevId: 728349788
2025-02-18 14:03:46 -08:00
Yash Katariya
1dc58b79bf Error unconditionally for jit, pjit and with_sharding_constraint if use_mesh and with mesh are used together.
PiperOrigin-RevId: 728310200
2025-02-18 12:16:25 -08:00
Sergei Lebedev
d4559ba404 [pallas] Skip OpsTest.test_concat_constant if TPU is not available
PiperOrigin-RevId: 728269127
2025-02-18 10:38:35 -08:00
Peter Hawkins
b3ed528f7d Fix test failure in PGLE tests.
We weren't completely resetting the compilation cache.
2025-02-18 09:44:16 +00:00
Dimitar (Mitko) Asenov
52f8fbeee0 [Mosaic GPU] Implement lowerings for Tile and Transpose transforms from the MLIR dialect.
PiperOrigin-RevId: 727762334
2025-02-17 01:29:47 -08:00
jax authors
eaceac3bf9 [Pallas] Reductions with replicated axes.
PiperOrigin-RevId: 727292293
2025-02-15 07:41:16 -08:00
Ayaka
b6361b3e76 Minor format cleanup
Remove 2 redundant whitespaces mentioned in https://github.com/jax-ml/jax/pull/25056#pullrequestreview-2615387492.

PiperOrigin-RevId: 727264168
2025-02-15 04:56:27 -08:00
Marcello Maggioni
9a8c9a56cf [JAX] Allow pallas to accept scalar shape semaphores.
PiperOrigin-RevId: 727198066
2025-02-14 23:20:47 -08:00
Matthew Johnson
3681960427 [shard_map] fix debug_print with partial auto shmap
Co-authored-by: Parker Schuh <parkers@google.com>
2025-02-15 00:23:59 +00:00
jax authors
d3850e7fdd Support optimization_level and memory_fitting_level XLA compilation options.
PiperOrigin-RevId: 727070422
2025-02-14 14:46:11 -08:00
jax authors
531d80dc72 Merge pull request #26529 from skye:atime
PiperOrigin-RevId: 727061966
2025-02-14 14:19:22 -08:00
jax authors
9b6b569f3c Adds support for string and binary data processing in Colocated Python.
PiperOrigin-RevId: 727048049
2025-02-14 13:39:20 -08:00
Skye Wanderman-Milne
d5d43fc46e Don't write atime file if JAX_COMPILATIION_CACHE_MAX_SIZE == -1
The atime file is only needed to implement the LRU eviction policy,
which is only needed if a max persistence compilation cache size is
set. Writing this file can cause network filesystem performace and
other issues, so only write it if users are opted-in.
2025-02-14 12:01:55 -08:00
Ayaka
6addf02add Add JAX error checking support
In this PR, only jit and control flows are supported. Support for vmap and multi-device environments will be added in subsequent PRs.

PiperOrigin-RevId: 726920440
2025-02-14 07:28:21 -08:00
Adam Paszke
b287c3924a Ignore ImportError for Triton on Windows
We don't support Windows GPU builds right now and skip all the tests,
but at the moment they can't even skip because of the import failure.

PiperOrigin-RevId: 726917651
2025-02-14 07:17:49 -08:00
jax authors
12d533f635 Merge pull request #26522 from andportnoy:aportnoy/mosaic-gpu-test-sm90a
PiperOrigin-RevId: 726899717
2025-02-14 06:13:53 -08:00
Adam Paszke
5ab8c5a8a4 Make sure that tests don't change the state of the compilation cache
If it was initialized before the test, it should stay so after. And the other
way around too.

PiperOrigin-RevId: 726899671
2025-02-14 06:12:02 -08:00
Christos Perivolaropoulos
49ad24152c [pallas:mgpu] Change FA3 kernel bc lax.div doesn't like mixed types anymore.
PiperOrigin-RevId: 726883573
2025-02-14 05:10:49 -08:00
Sergei Lebedev
3162cc4d0d [pallas:triton] Added basic support for lax.concatenate
The corresponding Triton op is restricted to `jnp.stack([x, y], axis=-1)`,
so the lowering only supports that case for now.

See #25321.

PiperOrigin-RevId: 726881284
2025-02-14 05:02:53 -08:00
Adam Paszke
4a8023fe1e [Mosaic GPU] Define TMEMLayout without referring to the PTX guide
The PTX guide talks about a few layouts by assigning them different
letters, which do not have an obvious meaning. We redefine the layout
by parameterizing it with a 2D tile size which, as far as I can tell,
is sufficient to represent all layouts we care about.

PiperOrigin-RevId: 726833412
2025-02-14 02:06:17 -08:00
George Necula
a0812cd57e [better_errors] Make it explicit that debug_info is not None.
Now all internal uses of lu.wrap_init and core.Jaxpr are with actual
debug info. This enables us to clean up the type declarations and
to remove the checks whether debug_info is present.

For usage outside of the JAX internals, we change
`jax.extend.linear_util.wrap_init` to be usable without debug_info,
for temporary backwards compatibility. We emit a deprecation
warning and fill-in some fake debugging info.

See https://github.com/jax-ml/jax/issues/26480 for more details.

PiperOrigin-RevId: 726770483
2025-02-13 22:07:04 -08:00
jax authors
f0cd1686ec Merge pull request #26509 from andportnoy:aportnoy/pallas-mosaic-gpu-test-sm90a
PiperOrigin-RevId: 726624339
2025-02-13 13:52:31 -08:00
Andrey Portnoy
ae9389dc0f [Mosaic GPU] Factor out Mosaic GPU dialect arch-specific tests 2025-02-13 15:04:34 -05:00
Dan Foreman-Mackey
ea4e324fe4 Fix some busted batching rules in lax.linalg.
PiperOrigin-RevId: 726543703
2025-02-13 10:28:39 -08:00
Dan Foreman-Mackey
7f999298ac Only cache jax.Array._npy_value when a copy is required.
As discovered in https://github.com/jax-ml/jax/issues/26216, for non-standard dtypes, calling `np.array` on a JAX array will unnecessarily cache the constructed `_npy_value` even when a copy isn't required. This change updates the logic to only save the cached value when it is a copy.

This fixes https://github.com/jax-ml/jax/issues/26216 by making the behavior consistent across dtypes, but we probably also want to expose a mechanism for clearing this cached value regardless.

PiperOrigin-RevId: 726522955
2025-02-13 09:36:55 -08:00