8681 Commits

Author SHA1 Message Date
jax authors
4944dcb977 Merge pull request #26897 from jakevdp:cond-doc
PiperOrigin-RevId: 733077065
2025-03-03 15:13:23 -08:00
jax authors
07d1cd0290 Merge pull request #26876 from carlosgmartin:fix_matrix_norm_empty_matrix
PiperOrigin-RevId: 733077011
2025-03-03 15:11:31 -08:00
Jake VanderPlas
84ca80d215 doc: in lax.cond, note that both branches will be traced 2025-03-03 13:05:24 -08:00
Peter Hawkins
7f05b74bca Fix wrong results in multidimensional pad.
When there are multiple dimensions, NumPy's semantics are as if the padding is applied to each dimension in order.

We lacked test coverage for this case because constant values ((0, 2),) and (0, 2) were handled by different code paths.

Fixes https://github.com/jax-ml/jax/issues/26888
2025-03-03 15:25:08 -05:00
carlosgmartin
897e1a1310 Fix linalg.norm to return zero for proper norms of empty matrices. 2025-03-03 15:02:34 -05:00
Bart Chrzaszcz
ed4a7bbab1 #sdy Add JAX backwards compatibility test.
This tests saving a module with one set of axis names, but loading it with another set of axis names.

This does also test the custom calls:

- `@Sharding`
- `@xla.sdy.GlobalToLocalShape`
- `@xla.sdy.LocalToGlobalShape`

But note that there are a bunch of other custom calls that will be tested in the Shardy and XLA codebases. The way the testing utils is tested here doesn't allow me to set `out_shardings` for example. So JAX can rely on the existence of those tests as stability guarantees just like for StableHLO.

PiperOrigin-RevId: 732893432
2025-03-03 06:01:34 -08:00
Bart Chrzaszcz
ac493655bf #sdy support JAX export tests when Shardy is enabled.
This CL only supports lowering a module with the exact same mesh, and loading it with either the exact same mesh or different meshes.

Note that we will be introducing some restrictions under Shardy for JAX export:

- You can only lower/save the module with meshes all of the same shape, but different axis names (this PR is right now only allowing the same axis names, but this will be relaxed in a follow-up)
- When loading the module, just like with GSPMD, you can use a different mesh with a different mesh shape and axis names. However, like with the restriction in the previous point, all shardings must use the same axis shapes, but can use different axis names (again this will be relaxed in a follow-up)

We may remove the restriction of having to use the exact same mesh shapes during export saving time and exact same mesh shaped during export loading time in the future. But for now we will keep this restriction while no one is using Shardy with JAX export.

PiperOrigin-RevId: 732878916
2025-03-03 04:57:06 -08:00
George Necula
a6c47d6f36 Use the same name for aliased Vars when pretty-printing Jaxprs.
Add a mechanism for using the same Var names for Vars that
are aliased. In this PR, we use this for `pjit`, such that the
following `print(jax.make_jaxpr(lambda a: jax.jit(lambda a: a + 1)(a))(0.))`
prints:

```
{ lambda ; a:f32[]. let
    b:f32[] = pjit[
          name=<lambda>
          jaxpr={ lambda ; a:f32[]. let b:f32[] = add a 1.0 in (b,) }
          ] a
    in (b,) }
```

instead of the previous:

```
{ lambda ; a:f32[]. let
    b:f32[] = pjit[
          name=<lambda>
          jaxpr={ lambda ; c:f32[]. let d:f32[] = add c 1.0 in (d,) }
          ] a
    in (b,) }
```

The same mechanism could be used for other higher-order primitives,
e.g., cond, and others.

Also add some typing declarations and rename APIs to use "shared jaxpr"
in lieu of "top-level jaxpr" for those Jaxprs that are used multiple
times and are printed first. I presume that the term "top-level jaxpr"
was picked because these are printed first at top-level. But this is
confusing, because they are really subjaxprs. In fact, there was already
a function `core.pp_toplevel_jaxpr` for printing the top-level Jaxpr,
and there was also `core.pp_top_level_jaxpr` (which now is named
`core.pp_shared_jaxpr`.
2025-03-03 11:38:51 +01:00
Parker Schuh
b8b690e594 Add use_high_dynamic_range_gumbel flag which allows sampling gumbel such
that it more closely matches the CDF for low probably events (less than
2**-nmant).

Because -log(-log(x)) is more sensitive close to 1 than 0, we must use
-log(-logp1(-x)) instead to make better use of the extra range around 0.

PiperOrigin-RevId: 732757388
2025-03-02 19:42:40 -08:00
Yash Katariya
53494ade2d PRNGKeyArray.aval should have the correct logical sharding. This required refactoring code so that we don't hit recursion errors.
PiperOrigin-RevId: 732536521
2025-03-01 18:18:19 -08:00
jax authors
2a1eeb0ce8 Chnages for kernel export
PiperOrigin-RevId: 732383028
2025-03-01 00:32:39 -08:00
Anton Osokin
1f3176636d Reverts 10f6edeb496a2eec2a09c2c5cecbe4f8f02452ab
PiperOrigin-RevId: 732315349
2025-02-28 18:04:27 -08:00
Jake VanderPlas
b2c45b8eb9 Improved errors when indexing with floats 2025-02-28 15:04:07 -08:00
Jake VanderPlas
c56e794a66 doc: fix description of logsumexp axis 2025-02-28 12:53:33 -08:00
Yash Katariya
da1cc0a50e [sharding_in_types] out_sharding argument on einsum should only apply to the last einsum and not intermediate einsums.
For example: Consider this einsum: `jnp.einsum('bthD, bthi, bthj->ijD', dy, i, j, out_sharding=P('data', None, None))`

This will decompose into 2 einsums where the intermediate einsum output will be of rank `5`:
  * `'bthj,bthD->bthjD'`
  * `'bthjD,bthi->ijD'`

The out_sharding specified (`P('data', None, None)`) is not compatible with the intermediate einsum: `'bthj,bthD->bthjD'` since the `length of spec (3) != out_aval.ndim (5)`.

This change makes it so that out_sharding is only applied to the contraction that leads to the final output. **If there are conflicts in intermediate einsums, then the user has to reshard the input or split into multiple einsums (and maybe provide out_sharding) so that conflicts don't exist.**

Note: We won't drop into auto mode for intermediate einsums. The user will have to split the einsum if any conflict is detected.
PiperOrigin-RevId: 732205849
2025-02-28 11:39:14 -08:00
Dan Foreman-Mackey
bb9aed5eec Reimplement custom_vjp.optimize_remat using custom_dce. 2025-02-28 10:00:28 -05:00
Benjamin Chetioui
a9ab614123 [Pallas/Mosaic GPU] Add an abstraction to obtain a slice of dynamic shared memory when using waprgroup semantics.
Explicitly make the assumption that `runtime_smem` starts at `0` in the Pallas
module context---which should be enforced by Mosaic GPU.

This is in preparation of changes implementing transform inference.

PiperOrigin-RevId: 732091266
2025-02-28 04:38:25 -08:00
Yash Katariya
dda62f576f Make sure default layout is None for input and output layout in all codepaths
PiperOrigin-RevId: 731865511
2025-02-27 14:26:25 -08:00
jax authors
c7ca35fe32 Merge pull request #26345 from wenscarl:scaled_matmul
PiperOrigin-RevId: 731865430
2025-02-27 14:24:48 -08:00
jax authors
6a7736754f Reverts 0f0d5e90ef1c3d60f35020141710ea350d17816b
PiperOrigin-RevId: 731844119
2025-02-27 13:27:32 -08:00
Sharad Vikram
6f57410e12 [Pallas TPU] Use grid_env for pipeline body so we can query num_programs/program_id inside the block spec
PiperOrigin-RevId: 731831543
2025-02-27 12:53:02 -08:00
Yash Katariya
07f192cd48 Merge _check_mesh_resource_axis and _check_axis_type_consistency into 1 function.
PiperOrigin-RevId: 731830347
2025-02-27 12:51:25 -08:00
Yash Katariya
c265568530 Remove parsed_pspec from NamedSharding constructor
PiperOrigin-RevId: 731820173
2025-02-27 12:24:17 -08:00
Peter Hawkins
1e5d9a9158 Add an allow_negative_indices option to lax.dynamic_slice and lax.dynamic_update_slice.
The goal of this change is to avoid generating code to wrap negative indices back into range in cases where we know it doesn't matter. Change scan to pass allow_negative_indices=False to avoid emitting index wrapping code for each scan argument.

PiperOrigin-RevId: 731812827
2025-02-27 12:04:28 -08:00
Yash Katariya
c94ec0eb0d Use batched_device_put for token shard_arg handler
PiperOrigin-RevId: 731800613
2025-02-27 11:30:22 -08:00
jax authors
da39b6f3d4 Comment change
PiperOrigin-RevId: 731792151
2025-02-27 11:07:59 -08:00
Yash Katariya
d69da3b012 More cleanups around ParsedPartitionSpec. In a follow up CL, I can remove it from NamedSharding constructor. Deleting ParsedPartitionSpec is remaining but that's after 0.5.2 release.
PiperOrigin-RevId: 731785005
2025-02-27 10:51:04 -08:00
Yash Katariya
034a827a4d Remove _parsed_pspec from everywhere in JAX except for NamedSharding constructor. I'll do that in the next CL since that has a dependency on C++ so needs guards.
PiperOrigin-RevId: 731772222
2025-02-27 10:17:06 -08:00
Yash Katariya
177e1f6ed9 Canonicalize PartitionSpec so that we can delete ParsedPartitionSpec. We need to do this after sharding-in-types to speed up NamedSharding construction and remove a lot of tech debt and unnecessary complexity.
* `_partitions` is now canonicalized and only contains `tuples`, `singular strings`, `None` or `UNCONSTRAINED`. No more empty tuples (`P((), 'x')`) and singleton tuples.

* Cache the creating of sharding on ShapedArray since it's expensive to do it a lot of times

* Change the `__hash__` and `__eq__` of `NamedSharding` to depend on `self.spec` instead of `self._parsed_pspec`.

PiperOrigin-RevId: 731745062
2025-02-27 08:59:25 -08:00
Dan Foreman-Mackey
f93c2a1aa5 Add and test support for partitioning of batch dimensions in lax.linalg.
On CPU and GPU, almost all of the primitives in lax.linalg are backed by custom calls that support simple semantics when batch dimensions are sharded. Before this change, all linalg operations on CPU and GPU will insert an `all-gather` before being executed when called on sharded inputs, even when that shouldn't be necessary. This change adds support for this type of partitioning, to cover a wide range of use cases.

There are a few remaining GPU ops that don't support partitioning either because they are backed by HLO ops that don't partition properly (Cholesky factorization and triangular solves), or because they're still using descriptors with problem dimensions in kernel. I'm going to fix these in follow up changes.

PiperOrigin-RevId: 731732301
2025-02-27 08:16:16 -08:00
Adrian Kuegel
de4d047852 Change int4 packing from big-endian to little-endian
LLVM uses little-endian format for int4 packing. To avoid converting between
these formats, we should also use little-endian in XLA.

PiperOrigin-RevId: 731731530
2025-02-27 08:13:43 -08:00
jax authors
07f5d7a475 Reverts f3fade3b70443b6cf87f01f360e6a1cb85d4b1fb
PiperOrigin-RevId: 731658204
2025-02-27 03:26:37 -08:00
Chris Jones
d6752e9267 [pallas:triton] Generate more efficient code for loading contiguous slices of int4 values.
The existing `int4` loading code is very generic. When reading contiguous data, it will read with offsets like `0, 0, 1, 1, ...`. Triton doesn't consider these to be contiguous in memory and emits much less efficient code than when reading contiguous blocks.

PiperOrigin-RevId: 731635736
2025-02-27 01:57:47 -08:00
Tom Hennigan
1becb57ac9 Add jax.copy_to_host_async(tree).
A relatively common pattern I've observed is the following:

```python
_, metrics = some_jax_function()

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

We are missing an opportunity here to more eagerly begin the h2d copy of
the metrics (e.g. overlap it with closing the "compute_metrics" context
manager etc. The intention of `jax.copy_to_host_async(x)` is to make it
simple to begin h2d transfers as early as possible. Adapting the above code:

```python
_, metrics = some_jax_function()

# Begin D2H copies as early as we can.
jax.copy_to_host_async(metrics)

with profiler.Trace('compute_metrics'):
  jax.block_until_ready(metrics)

with profiler.Trace('copy_to_host'):
  metrics = jax.device_get(metrics)
```

PiperOrigin-RevId: 731626446
2025-02-27 01:22:15 -08:00
Sharad Vikram
2646b8d4ad [Pallas TPU] Add support for GridDimensionSemantics to pallas_call
PiperOrigin-RevId: 731543938
2025-02-26 19:34:36 -08:00
Sharad Vikram
b5fcffadd4 Add swap as method to TransformedRef
PiperOrigin-RevId: 731541165
2025-02-26 19:19:10 -08:00
Sharad Vikram
1ecbac9702 [Pallas] Add name parameter to core_map
PiperOrigin-RevId: 731536152
2025-02-26 18:59:01 -08:00
Sharad Vikram
0f0d5e90ef Add support for TPU v5 2x2 tray configuration
PiperOrigin-RevId: 731529917
2025-02-26 18:33:49 -08:00
Emily Fertig
82124da5cd Redefine is_fully_addressable in shardings to support zero local devices for McJAX.
PiperOrigin-RevId: 731526750
2025-02-26 18:17:35 -08:00
Emily Fertig
7f9e7473cf Rolling back a commit that caused a 50-90% performance regression in most MaxText workloads.
Reverts 9d421c9149a1db006444adeea87464bd6b8c0743

PiperOrigin-RevId: 731506280
2025-02-26 16:57:18 -08:00
carlosgmartin
ba428d8cda Extend random.orthogonal to semi-orthogonal matrices. Simplify initializers.orthogonal by using it. 2025-02-26 16:39:45 -05:00
Shu Wang
7f0a5bc83e
Add apache header. 2025-02-26 15:26:56 -06:00
Jake VanderPlas
7be7c48985 Implement jnp.ndarray.__contains__
Currently this falls back to a linear scan via __iter__, which is slow
and raises unclear error messages in unsupported cases.
2025-02-26 11:13:45 -08:00
Klaus Greff
5acfc88a00
fix Initializer protocol 2025-02-26 14:25:15 +01:00
Peter Hawkins
66293d8897 Remove code present to support jaxlib < 0.5.1.
The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
2025-02-26 07:40:40 -05:00
Adam Paszke
3251b55ef2 [Pallas:MGPU] Don't recreate single_thread_predicate at every rule
While the predicate helps us avoid branching, it can be created once per
block. Its creation uses `*.sync` instructions, which are not DCEd by
LLVM and end up polluting the final code.

PiperOrigin-RevId: 731253109
2025-02-26 04:02:21 -08:00
Benjamin Chetioui
7a34f1cedc [Pallas/Mosaic GPU][NFC] Move thread_semantics to ModuleContext.
This simplifies the propagation of the argument, and is the proper place to
put it.

PiperOrigin-RevId: 731239831
2025-02-26 03:08:42 -08:00
shuw
17088e9025 Improve after review # 2 2025-02-26 04:48:25 +00:00
Jacob Burnim
4c7140fa03 [Pallas] Add option for async DMAs in the new TPU interpret mode
When dma_execution_mode='on_wait', we wait to execute DMAs until we are interpreting a `dma_wait` instruction.  In particular, while a device is waiting on a DMA semaphore, we will (partially) execute DMAs that signal that semaphore until the wait operation can succeed.

PiperOrigin-RevId: 731103569
2025-02-25 18:19:20 -08:00
jax authors
7c26ab53f6 Use jax.Array as type annotation for pallas random keys
jax_prng.PRNGKeyArray is not exposed to the public jax API, resulting in type check errors when sampling outside of tests.

PiperOrigin-RevId: 731008883
2025-02-25 13:30:58 -08:00