22479 Commits

Author SHA1 Message Date
jax authors
87a5591db5 Merge pull request #23040 from jakevdp:isin-method
PiperOrigin-RevId: 662957551
2024-08-14 09:57:49 -07:00
Dan Foreman-Mackey
ad1bd38790 Move logic about when to dispatch to batched LU decomposition algorithm on GPU into the kernel.
This simplifies the lowering logic, and means that we don't get hit with a performance penalty when exporting with shape polymorphism.

PiperOrigin-RevId: 662945116
2024-08-14 09:20:40 -07:00
jax authors
bab70dda97 Reverts 734ebd570891ceaf8c7104e12256a1edfe942b14
PiperOrigin-RevId: 662942100
2024-08-14 09:12:03 -07:00
Yash Katariya
229cbae5ea Add num_devices to Sharding interface so that it works with NamedSharding containing AbstractMesh too.
PiperOrigin-RevId: 662938823
2024-08-14 09:03:17 -07:00
Benjamin Chetioui
df2e9c3836 [Mosaic] Fix lowering for _dot_general_lowering_rule to match the new vector.MultiDimReductionOp signature.
PiperOrigin-RevId: 662933072
2024-08-14 08:42:43 -07:00
Dan Foreman-Mackey
b0a144ae4b Don't export ir_attribute from interpreters.mlir.
PiperOrigin-RevId: 662918256
2024-08-14 07:53:18 -07:00
jax authors
807dcb5a06 Integrate LLVM at llvm/llvm-project@c8b5d30f70
Updates LLVM usage to match
[c8b5d30f7077](https://github.com/llvm/llvm-project/commit/c8b5d30f7077)

PiperOrigin-RevId: 662906261
2024-08-14 07:09:53 -07:00
Sergei Lebedev
6290cd77fc Added pl.program_id and pl.num_programs to Mosaic GPU lowering
PiperOrigin-RevId: 662836490
2024-08-14 02:23:38 -07:00
Adam Paszke
2ab7558425 [Mosaic GPU] Add support for grid tiling to improve L2 cache utilization
While CUDA technically does not guarantee anything about the order in
which blocks will be executed, in practice they are generally scheduled
in column-major order within the grid. We can use this property to launch
the blocks in a tiled way, which can lead to an improved rate of L2 hits
and a significant performance boost.

PiperOrigin-RevId: 662834982
2024-08-14 02:17:55 -07:00
Adam Paszke
f384497f68 [Mosaic GPU] Add support for cluster collective loads and barriers over multiple dimensions
This will be useful for an upcoming change to the matmul kernel that splits the N blocks
over two cluster dimensions.

PiperOrigin-RevId: 662825455
2024-08-14 01:47:12 -07:00
jax authors
4c4660a117 Merge pull request #23047 from froystig:docs
PiperOrigin-RevId: 662779345
2024-08-13 22:25:54 -07:00
George Necula
dbd6aeebb7 Disable some asan tests, times out
PiperOrigin-RevId: 662774152
2024-08-13 22:03:29 -07:00
Jake VanderPlas
25da7add37 Add method argument for jnp.isin 2024-08-13 19:04:14 -07:00
Roy Frostig
d17edb4c4d docs: fix shard_map guide headings
These were off by one level, causing section titles to be listed in
the guide index.
2024-08-13 18:29:42 -07:00
Yash Katariya
9a8f0a67f5 Add a devices property to AbstractMesh but raise an error in it. This is to make pytype happy
PiperOrigin-RevId: 662712450
2024-08-13 17:37:58 -07:00
Peter Hawkins
323e257f67 Fix test failures.
PiperOrigin-RevId: 662703221
2024-08-13 17:02:14 -07:00
jax authors
2b3ccce793 Merge pull request #23042 from jakevdp:dataclass-doc
PiperOrigin-RevId: 662701533
2024-08-13 16:55:51 -07:00
Jake VanderPlas
5903c772f4 doc: clarify data_fields & meta_fields in register_dataclass 2024-08-13 16:06:47 -07:00
jax authors
4e580d167e Update XLA dependency to use revision
55476059f6.

PiperOrigin-RevId: 662672660
2024-08-13 15:26:15 -07:00
Jevin Jiang
8f23392a8c [Mosaic:TPU] Refactor relayout helper functions to take ctx instead of only target shape.
PiperOrigin-RevId: 662672417
2024-08-13 15:22:46 -07:00
Yash Katariya
daa69da321 Introduce jax.sharding.AbstractMesh(shape_tuple: tuple[tuple[str, int], ...]) and allow with_sharding_constraint and shard_map to accept an abstract mesh as input (with_sharding_constraint is via NamedSharding(abstract_mesh, pspec)).
**Semantics**

Inside jit, we don't need to talk about concrete devices ever so the semantics stay the same as today i.e. we can lower a NamedSharding with abstract mesh with only mesh axis names and sizes and PartitionSpec. The only restriction is that the number of devices need to be consistent throughout the program when we are tracing.
During compilation, the order of devices throughout the program needs to be consistent (same as before this change).

Outside jit i.e. eager mode, if a `shard_map` or `with_sharding_constraint` contains AbstractMesh, then the input to those primitives should contain a concrete Mesh with the same shape and names as the abstract mesh.

**Why do this?**

There are cases, where you want the change the devices in the mesh but keep the mesh shape the same (axis names and axis sizes). But this leads to a device mismatch error if you have `with_sharding_constraint` or `shard_map` in your computation because they embed concrete devices in their signature.

So to fix the error, you need to change the mesh in `wsc` and `shmap` which will lead to a tracing cache miss (because function id is now different) and consequently a lowering to stableHLO cache miss. Explaining via an example:

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

@jax.jit
def f(x):
  y = with_sharding_constraint(x, NamedSharding(mesh1, P('x')))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # DEVICE MISMATCH ERROR!
```

The same problem exists for `shard_map` since it takes a mesh with concrete devices in it's signature.

**Okay, so how do you fix this?**

As mentioned above, we need the above program to work and get tracing and lowering cache hits (**cache hits is the most important** part here)

The approach in this change, allows `with_sharding_constraint` to accept a `NamedSharding(abstract_mesh, pspec)` as input. This leads to no errors downstream and we get tracing and lowering cache hits since we don't encode the concrete devices anymore. Just the axis_names and axis_size of the mesh.

**The important part is that the concrete device information should only come from the arguments. Inside `jax.jit`, you should never reference concrete devices ever.**

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)

@jax.jit
def f(x):
  y = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x')))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # tracing and lowering cache hit
```

**One caveat is that this only works with `jax.NamedSharding` but that's fine because `NamedSharding` is the most used `Sharding` in JAX.**

**What about `shard_map`?**

shard_map's signature will be: `shmap(f, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs)`.

```
mesh1 = Mesh(jax.devices()[:2], 'x')
mesh2 = Mesh(jax.devices()[2:4], 'x')

arr_mesh1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P()))
arr_mesh2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P()))

# Creating abstract mesh with mesh1 but since both meshes have the same shape (names
# and axis size), it should be ok.
abstract_mesh = jax.sharding.AbstractMesh(arr_mesh1.shape_tuple)

@jax.jit
def f(x):
  y = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), out_specs=P('x'))
  return y * 2

f(arr_mesh1)
f(arr_mesh2)  # tracing and lowering cache hit
```

This is a fully backwards change. So your current code will continue to work as is but you can opt-into this new behavior and get all the benefits!

PiperOrigin-RevId: 662670932
2024-08-13 15:18:08 -07:00
jax authors
98521ad35d Add todo for slow codegen in Pallas pipeline
PiperOrigin-RevId: 662661951
2024-08-13 14:53:23 -07:00
Sergei Lebedev
28dfe0d280 Import etils.epath lazily
This shaves off an extra 0.1-0.2s from JAX import times internally.

PiperOrigin-RevId: 662660356
2024-08-13 14:48:38 -07:00
Jevin Jiang
2dea3d6a0c [Mosaic:TPU] Add shuffled load and store.
we also emulate shuffled store using (store + shuffled load + store) for previous generations.

PiperOrigin-RevId: 662657663
2024-08-13 14:41:16 -07:00
jax authors
d2b85a48af Merge pull request #23036 from froystig:docs2
PiperOrigin-RevId: 662635310
2024-08-13 13:40:19 -07:00
jax authors
9f6857620b Disable TensorRT in TF, XLA and JAX.
This is needed for hermetic CUDA integration in Google ML projects since tensorRT is not distributed in the same free way as other CUDA/CUDNN distributives.

PiperOrigin-RevId: 662601190
2024-08-13 11:58:31 -07:00
Roy Frostig
3c223cd253 docs: tidy up titles and headings
This shortens some titles and makes them more consistent. It also
removes "JAX" from several titles ("in JAX", "for JAX", "JAX's",
etc.). Since these are JAX docs, that ought to be clear from context.
2024-08-13 11:53:57 -07:00
Sergei Lebedev
a755f1db83 Import from `mlir.dialects` lazily
These imports jointly account for ~0.3s of import time internally.

PiperOrigin-RevId: 662588167
2024-08-13 11:22:41 -07:00
John QiangZhang
1bba83894a Add logging the jax2tf mlir_module_serialized module size.
PiperOrigin-RevId: 662574156
2024-08-13 10:47:07 -07:00
jax authors
955699cc65 Merge pull request #23000 from dfm:dce-bug
PiperOrigin-RevId: 662565548
2024-08-13 10:29:30 -07:00
jax authors
3849e0e7d0 Merge pull request #23020 from jakevdp:setxor1d-size
PiperOrigin-RevId: 662565510
2024-08-13 10:25:46 -07:00
Loren Maggiore
7d2fbd5418 [pallas] enable lowering on an AbstractMesh.
PiperOrigin-RevId: 662533955
2024-08-13 08:52:13 -07:00
Adam Paszke
bab096e563 [Mosaic GPU] Add an autotuning harness to the matmul example
PiperOrigin-RevId: 662521895
2024-08-13 08:11:02 -07:00
Adam Paszke
f4c0b1feb0 [Mosaic GPU] Add control over the output format in the matmul example
PiperOrigin-RevId: 662478648
2024-08-13 05:33:12 -07:00
Jake VanderPlas
52c269c8cd jnp.setxor1d: add support for static size argument 2024-08-13 05:24:59 -07:00
Adam Paszke
5cf89b3f61 [Mosaic GPU] Add support for various swizzles in the matmul example
PiperOrigin-RevId: 662459766
2024-08-13 04:12:43 -07:00
Adam Paszke
ca6be2573b [Mosaic GPU] Move matmul tests to Hypothesis
We've been generating thousands of test cases and that's just not
scalable. Hypothesis should let us efficiently explore a large
number of configurations.

PiperOrigin-RevId: 662447113
2024-08-13 03:21:51 -07:00
Paweł Paruzel
354293da48 Activate Singular Value Decomposition to XLA's FFI
PiperOrigin-RevId: 662436635
2024-08-13 02:41:57 -07:00
George Necula
1a7c6aa186 [pallas] Fix test timeouts
PiperOrigin-RevId: 662420238
2024-08-13 01:42:41 -07:00
Paweł Paruzel
5fc992e5e1 Determine LAPACK workspaces during SVD kernel runtime
The SVD kernel implementation used to require workspace shapes to be determined prior to the custom call on the JAX's side. The new FFI kernels need not demand these shapes to be specified anymore. They are evaluated during kernel runtime.

PiperOrigin-RevId: 662413273
2024-08-13 01:17:44 -07:00
Dan Foreman-Mackey
850edee36e Fix bug in custom_vjp with optimize_remat and custom_vmap.
When used with a `custom_vmap` that introduces a new const the previous
implementation of `optimize_remat` would error in its DCE rule because
of unexpected consts when closing the fwd jaxpr. This shouldn't have
ever been hit, but there was a bug in the batching rule for
`remat_opt_p` where we weren't properly converting constvars to invars.
This fixes this bug and should unbreak internal users.
2024-08-13 09:06:57 +01:00
Dan Foreman-Mackey
69fc8bb419 Consolidate handling of input argument resolution in custom_* APIs.
This is a partial re-land of https://github.com/google/jax/pull/22869 with some updates to ensure that it doesn't break existing uses of `custom_vmap`.

Previously, using a `custom_jvp` or `custom_vjp` with a primal function that has keyword-only arguments would result in a type error, even if these arguments weren't passed by the caller. I believe that this check is actually slightly stricter than it needed to be, as discovered when adding a similar check to `custom_vmap`. Instead, I think that it is sufficient to check that the caller hasn't _passed_ any keyword-only arguments.

The previous behavior in `custom_vmap` was even harsher: it would error if any keyword arguments were passed.

In this change, I have moved `resolve_kwargs` into `api_utils` so that the same function can be used in both `custom_derivatives` and `custom_batching`. I've also updated the logic to only throw a `TypeError` if the caller passes a keyword only argument when calling a `custom_*`-decorated function. This changes the behavior of `custom_jvp` and `custom_vjp`, although users shouldn't see that effect, since previously having kwargs would have errored.

PiperOrigin-RevId: 662402158
2024-08-13 00:30:23 -07:00
jax authors
23effba503 Merge pull request #23027 from froystig:docs
PiperOrigin-RevId: 662392585
2024-08-12 23:59:39 -07:00
jax authors
ff15835797 Merge pull request #23017 from jakevdp:set-op-tests
PiperOrigin-RevId: 662392445
2024-08-12 23:59:21 -07:00
jax authors
69ba5f62c1 Merge pull request #22976 from jakevdp:extra-params-doc
PiperOrigin-RevId: 662392239
2024-08-12 23:54:49 -07:00
Roy Frostig
09e73118bf docs: more sentence case 2024-08-12 20:07:49 -07:00
Yash Katariya
4533aeaf26 Remove jax_enable_memories conditionals from JAX and remove it from tests too.
PiperOrigin-RevId: 662322241
2024-08-12 19:15:43 -07:00
jax authors
833560deb1 Merge pull request #23023 from froystig:docs3
PiperOrigin-RevId: 662318149
2024-08-12 19:03:12 -07:00
Roy Frostig
b8f8b7b07f docs: sentence case page titles, section headings, some content 2024-08-12 18:12:17 -07:00
Parker Schuh
734ebd5708 Support donating arrays with non-default layouts by setting up XLA donation
directly instead of defining aliasing for arrays with potentially incompatible
layouts.

PiperOrigin-RevId: 662258042
2024-08-12 15:58:52 -07:00