16405 Commits

Author SHA1 Message Date
Yash Katariya
7c4fe2a7cc [sharding_in_types] Allow auto_axes and explicit_axes to take numpy arrays, python scalars.
PiperOrigin-RevId: 729729215
2025-02-21 18:49:02 -08:00
Yash Katariya
34077851d8 Reverts 6e83de5909b5dad6827c1682726c840c9688b32d
PiperOrigin-RevId: 729725907
2025-02-21 18:34:16 -08:00
Yash Katariya
80f18ded23 [sharding_in_types] Make slice and ellipsis work with .at[...].get(out_sharding=P(...))
PiperOrigin-RevId: 729723470
2025-02-21 18:25:11 -08:00
Yash Katariya
629426f89c Allow casting to the same axis type
PiperOrigin-RevId: 729667271
2025-02-21 14:53:28 -08:00
Daniel Suo
6e83de5909 Implement Jax CPU/GPU callbacks with XLA's FFI.
- Change 4 of 4 addressing #3 in https://github.com/jax-ml/jax/issues/25842.

PiperOrigin-RevId: 729641154
2025-02-21 13:34:48 -08:00
Matthias Kramm
b3fcba7c05 roofline: Handle ClosedJaxpr instances.
PiperOrigin-RevId: 729636113
2025-02-21 13:19:31 -08:00
Hyeontaek Lim
96b7dbabdc [JAX] Implement an initial object API for colocated Python
Colocated Python adds `colocated_python_class`. This API wraps a user-defined
class for automatic remoting of object construction/destruction and method calls:

* An object will be initialized on the backend. At least for now,
initialization is deferred until the first method is called; at this point,
colocated Python knows what devices the objects should be accessible and thus
it can construct the object(s).

* When an object method is called, the method call runs as a colocated Python
function call on the backend.

* When the object is destroyed (either by reaching a zero reference count or
through Python GC), destruction also runs as a colocated Python function call
and destroys all objects from the backend.

This change provides an intial API implementation. Main limitations are as
follows:

* The methods of a colocated Python class does not support specialization.
Calling it requires at least one argument.

* Colocated Python objects cannot reference or interact with each other on the
controller or on the colocated Python backend.

These limitations will be lifted as the object API implementation is improved.

PiperOrigin-RevId: 729629265
2025-02-21 12:58:25 -08:00
Kevin Gleason
6c83d43635 Reverts 655267609b5589fda8358ae7aaf2eb832036407a
PiperOrigin-RevId: 729608394
2025-02-21 11:59:44 -08:00
Dan Foreman-Mackey
c4418c1010 Update several remaining lax.linalg primitives to use registration helper functions.
In this change, we update schur, triangular_solve, tridiagonal, and tridiagonal_solve. I batched these ones since they're all pretty straightforward.

PiperOrigin-RevId: 729572705
2025-02-21 10:18:30 -08:00
Daniel Suo
2d1bc5c2a0 Refactor Jax FFI lowering to prepare for implementing CPU/GPU callbacks using XLA's FFI.
- This refactor just moves code around and should have no impact on tests or public-facing APIs.
- `mlir.emit_python_callback` would eventually depend on `ffi.ffi_lowering`, which in turn depends on definitions in `mlir.py`. We break this circular dependency.

PiperOrigin-RevId: 729561359
2025-02-21 09:45:59 -08:00
shuw
bfb9d3ca4b Improve based on comment # 1 2025-02-21 17:32:57 +00:00
Dan Foreman-Mackey
ed10003adc Update lax.linalg.qr primitives to use registration helper functions.
PiperOrigin-RevId: 729551997
2025-02-21 09:15:01 -08:00
Yash Katariya
66037d10e7 Set the mesh of the sharding during broadcast in vmap so that we don't hit an error during canonicalization. This is similar to bcd4048dd5
PiperOrigin-RevId: 729532213
2025-02-21 08:05:42 -08:00
jax authors
237b7941a8 Merge pull request #26649 from jakevdp:wrapped-fun
PiperOrigin-RevId: 729500034
2025-02-21 05:57:53 -08:00
Dan Foreman-Mackey
09325d925f Update internal unop primitive helper to pass kwargs to dtype rule.
To be consistent with other rule registration helpers, `unop_dtype_rule` should pass through its kwargs to the `result_dtype` callable.

PiperOrigin-RevId: 729483613
2025-02-21 04:52:51 -08:00
Dan Foreman-Mackey
126909b62a Update lax.linalg.lu primitive to use registration helper functions.
PiperOrigin-RevId: 729483456
2025-02-21 04:50:46 -08:00
Dan Foreman-Mackey
a981e1c4b9 Start adding primitive registration helper functions to lax.linalg.
As part of my efforts to simplify the primitive implementations in lax.linalg, I've found that all of the primitives share some common logic when it comes to impls, abstract_evals, and batching. This change adds some helper functions and starts the process of abstracting the primitive definitions to simplify and reduce duplication. I will continue with the rest of the primitives in lax.linalg, but I didn't want to overload the first diff.

PiperOrigin-RevId: 729471970
2025-02-21 04:05:34 -08:00
jax authors
655267609b Reverts 9ad9c3d3c2b6b5c3ea736af9b4d8c595537de93c
PiperOrigin-RevId: 729462063
2025-02-21 03:30:21 -08:00
Yash Katariya
bcd4048dd5 Set the mesh of tangent.aval when we are creating zeros_like_aval because when you close over an array which is unused, we error out during canonicalization
PiperOrigin-RevId: 729340808
2025-02-20 19:32:34 -08:00
Yash Katariya
250e2ee7da Use the mesh of out_aval when converting GSPMDSharding to NamedSharding. This makes sure that the axis types of the corresponding output is correct.
Also, if all axes of an out_aval are auto, set the corresponding out_sharding to Unspecified during lowering, otherwise things go horribly wrong. This is actually a XLA bug but we can workaround it in JAX for now.

PiperOrigin-RevId: 729307115
2025-02-20 17:13:24 -08:00
Jake VanderPlas
fe00aa0f65 Internal: improved type annotations for lu.WrappedFun 2025-02-20 15:12:39 -08:00
Sergei Lebedev
7438976e79 [pallas:mosaic_gpu] Added support for binary/comparison ops with WG semantics
PiperOrigin-RevId: 729266484
2025-02-20 15:06:27 -08:00
Robert David
08de0128b6 Fix head comment: was referring to nonexistent parameters.
PiperOrigin-RevId: 729231457
2025-02-20 13:29:40 -08:00
Hyeontaek Lim
71f9764edc [JAX] Generate more readable error for failed device deserialization in colocated Python
When deserializing a colocated Python function or input/output sharding, we
often need to deserialize a device using a device id. This is done by looking
up a CPU device map; this lookup can fail if the device id was referring to a
non-CPU device. Unfortunately, we would see a simple error message like
`KeyError: np.int64(0)` that does not give a context of the problem.

This change adds a slightly more context to the exception so that the error is
more actionable.

PiperOrigin-RevId: 729172296
2025-02-20 10:52:21 -08:00
jax authors
b7968474c2 [Pallas][Mosaic] Support float8_e4m3b11fnuz
PiperOrigin-RevId: 729169181
2025-02-20 10:44:33 -08:00
Michael Whittaker
ddcb7deeaf Added jax.experimental.multihost_utils.live_devices API.
This API is intended to enable fault tolerant multi-controller JAX programs.

PiperOrigin-RevId: 729153679
2025-02-20 10:04:17 -08:00
Yash Katariya
262aab74f0 canonicalize closed over values if **atleast** 1 mesh axis is Manual and **all other mesh axes** are Manual or Auto. This would make the canonicalization work properly with shmap partial-auto.
If a mesh axis is Explicit, we don't canonicalize closed over values yet since that make require shape changes. The workaround is for users to pass those arrays as arguments instead of closing over them in a shard_map.

PiperOrigin-RevId: 728956512
2025-02-19 22:18:56 -08:00
Yash Katariya
b6b319cd06 If cur_mesh is empty and AxisTypes of Mesh passed to shmap are Explicit, then treat the axes mentioned in auto as explicit too. In other words, "auto" really means "don't convert to manual", ie leave the listed mesh axes as they are, whether explicit or auto type
PiperOrigin-RevId: 728942780
2025-02-19 21:25:53 -08:00
Yash Katariya
8305803b76 [sharding_in_types] Initial support for partial-auto/explicit shard_map + sharding-in-types. If the axes in shmap(..., auto=...) is an explicit axes in the outer mesh context, then that axis is treated as Explicit instead of Auto.
PiperOrigin-RevId: 728920514
2025-02-19 20:04:54 -08:00
jax authors
cb0d326e16 Merge pull request #26591 from jakevdp:lax-docs
PiperOrigin-RevId: 728908919
2025-02-19 19:22:48 -08:00
Kevin Gleason
9ad9c3d3c2 [StableHLO] Only emit StableHLO from xla_computation_to_mlir_module
Update to use new HLO to StableHLO API. Currently all users of this function have this flag set to true so should be a low impact change.

PiperOrigin-RevId: 728866498
2025-02-19 16:45:56 -08:00
cjkkkk
3a80080392 fix unit tests to not use fmha rewriter 2025-02-20 00:41:04 +00:00
Parker Schuh
b7c66bd22e Only add new manual axes to residuals when adding axes with partial_auto.
PiperOrigin-RevId: 728839349
2025-02-19 15:27:32 -08:00
Shu Wang
ae111f7c97
Rename custom-call name. 2025-02-19 16:46:44 -06:00
Yash Katariya
dbb46e9214 Relax one more check in partial_eval_jaxpr_nounits
PiperOrigin-RevId: 728788472
2025-02-19 13:14:08 -08:00
Jacob Burnim
ac74857d27 [Pallas] Support dynamic grids in the new TPU interpret mode
PiperOrigin-RevId: 728786896
2025-02-19 13:09:23 -08:00
jax authors
eef829cbcb Merge pull request #26615 from froystig:scan-docfix
PiperOrigin-RevId: 728753989
2025-02-19 11:41:35 -08:00
Yash Katariya
1081c1f11a Relax the check in _mapped_axis_spec to allow () and None to be treated the same
PiperOrigin-RevId: 728746291
2025-02-19 11:23:17 -08:00
Roy Frostig
ae10f2da13 fix scan doc on the unroll argument.
Looks like a typo worth fixing.
2025-02-19 11:01:44 -08:00
Yash Katariya
401fa9019c Mark in_shardings and out_shardings as Any for typing reasons since they can take pytrees. Fixes https://github.com/jax-ml/jax/issues/26609
PiperOrigin-RevId: 728730349
2025-02-19 10:46:09 -08:00
Matthias Kramm
7eee2de703 roofline: Support computing flops for binary ops.
PiperOrigin-RevId: 728708058
2025-02-19 09:45:24 -08:00
Yash Katariya
66d04f85e6 Error out if going from Manual -> Auto/Explicit AxisTypes in the auto_axes and explicit_axes API that do mesh_cast implicitly.
Also, improve the error raised by canonicalize_sharding to include the api name and current source location.

PiperOrigin-RevId: 728701237
2025-02-19 09:21:53 -08:00
Yash Katariya
a3edfb43ef Now that sharding_in_types config flag is True, remove the config and all the conditionals
PiperOrigin-RevId: 728653433
2025-02-19 06:53:35 -08:00
Sebastian Bodenstein
d5e5b42de8 Use consistent dtype for forward and backwards in jax.nn.dot_product_attention.
Fixes https://github.com/jax-ml/jax/issues/24047

PiperOrigin-RevId: 728613700
2025-02-19 04:30:23 -08:00
Yash Katariya
b35083331c Expose get_ty aka get_aval from jax namespace
PiperOrigin-RevId: 728490205
2025-02-18 21:22:19 -08:00
Parker Schuh
c825241ccc Exclude auto axes whenever extending the axis_env via core.extend_axis_env_nd.
(Just pure refactoring, doesn't fix any bugs quite yet).

PiperOrigin-RevId: 728461916
2025-02-18 19:40:40 -08:00
jax authors
09491e2bef Merge pull request #26172 from ZacCranko:is-distributed-init
PiperOrigin-RevId: 728445236
2025-02-18 18:43:48 -08:00
Zac Cranko
5db78e7ae0 add distributed.is_initialized 2025-02-18 16:47:19 -08:00
Yash Katariya
1079dc4477 Let users pass in pspecs to with_sharding_constraint when use_mesh is set. This is in-line with other APIs which allow pspecs like einsum, reshape, etc
PiperOrigin-RevId: 728392216
2025-02-18 15:47:03 -08:00
Yash Katariya
8bcbf585df Make device_put resharding on single device array input work under use_mesh. Fixes https://github.com/jax-ml/jax/issues/26552
PiperOrigin-RevId: 728382461
2025-02-18 15:22:39 -08:00