8695 Commits

Author SHA1 Message Date
Ruturaj Vaidya
dceb5310fe
[ROCm] Implement RNN support (#217) 2025-02-07 11:08:12 -06:00
jax authors
9a2dd19a92 Merge pull request #21524 from andportnoy:aportnoy/unknown-platform-lowering-warning
PiperOrigin-RevId: 688630259
2024-10-22 11:40:39 -07:00
Ayaka
c60bafcc33 [Pallas TPU] Fix lowering for jnp.remainder
Fixes https://github.com/jax-ml/jax/issues/24027

PiperOrigin-RevId: 688614799
2024-10-22 11:01:58 -07:00
Andrey Portnoy
2aaa108f06 Raise an error when registering a lowering for an unknown platform 2024-10-22 13:29:48 -04:00
Ayaka
2b7b0742a4 [Pallas TPU] Add lowerings for bf16 jnp.ceil and jnp.floor in TPU v6+
This PR is similar to https://github.com/jax-ml/jax/pull/24284

Note that `np.testing.assert_allclose()` is changed to `self.assertAllClose()` because the latter is a wrapper with bfloat16 support.

PiperOrigin-RevId: 688581914
2024-10-22 09:33:53 -07:00
jax authors
a2e4aff897 Merge pull request #24425 from dfm:rename-vmap-methods
PiperOrigin-RevId: 688547393
2024-10-22 07:51:29 -07:00
Adam Paszke
2db03ba54b [Pallas:MGPU] Add support for grid dims in GPUMesh
Of course no communication can happen across grid dimensions (unlike over the WG dim),
but we need to be able to launch multiple blocks somehow.

PiperOrigin-RevId: 688488660
2024-10-22 04:10:46 -07:00
jax authors
0b3f0e11fb Reverts ebb75db8a523150c48376d15391f84380a2bb110
PiperOrigin-RevId: 688477769
2024-10-22 03:29:32 -07:00
Adam Paszke
84a303f32f [Pallas:MGPU] Allow allocating transformed refs in run_scoped
PiperOrigin-RevId: 688448592
2024-10-22 01:38:46 -07:00
Yash Katariya
ebb75db8a5 [sharding_in_types] Add out_type argument to einsum and dot_general to allow specifying for the output type. Right now, it only accept a NamedSharding but in the future we can allow a polymorphic type of: jax.ShapeDtypeStruct | Sharding | Layout.
PiperOrigin-RevId: 688399552
2024-10-21 22:23:53 -07:00
Praveen Narayanan
ad1aff098d Respect dot algorithm spec on TPU backends.
PiperOrigin-RevId: 688274131
2024-10-21 14:30:48 -07:00
Jake VanderPlas
66971a2869 Fix jnp.diff for boolean inputs 2024-10-21 13:35:13 -07:00
Dan Foreman-Mackey
61701af4a2 Rename vmap methods for callbacks. 2024-10-21 15:03:04 -04:00
jax authors
4a5ca2fd00 Merge pull request #24400 from jakevdp:subtract-ufunc
PiperOrigin-RevId: 688190106
2024-10-21 10:38:52 -07:00
Jake VanderPlas
6467d03925 Make jnp.subtract a ufunc 2024-10-21 10:11:51 -07:00
Dan Foreman-Mackey
0b651f0f45 Make ffi_call return a callable 2024-10-21 12:16:57 -04:00
Adam Paszke
f833891c87 [Pallas:MGPU] Add support for passing in WGMMA lhs from registers
PiperOrigin-RevId: 688117316
2024-10-21 06:42:18 -07:00
Adam Paszke
f08801b8d6 [Pallas:MGPU] Allow indexing to appear anywhere in the list of transforms
We only need to exchange the transforms preceding the indexer, while
the rest can remain unmodified.

PiperOrigin-RevId: 688112088
2024-10-21 06:22:16 -07:00
jax authors
f4b84e1c97 Merge pull request #24342 from gnecula:export_custom_types
PiperOrigin-RevId: 688093192
2024-10-21 05:08:04 -07:00
George Necula
2feea414ac [export] Add support for serialization for some custom PyTree nodes
See the added documentation for `jax._src.export.register_pytree_node_serialization`
and `jax._src.export.register_namedtuple_serialization`.

Serialization of PyTree nodes is needed to serialize the `in_tree` and
`out_tree` fields of `Exported` functions (not to serialize actual instances
of the custom types).

When writing this I have looked at how TensorFlow handles namedtuple. It does
so transparently, without requiring the user to register a serialization
handler for the namedtuple type. But this has the disadvantage that on
deserializaton a fresh distinct namedtuple type is created for
each input and output type of the serialized function. This means that
calling the deserialized function will return outputs of different types
than then function that was serialized. This can be confusing.

The Python pickle mode does a bit better: it attempts to look up the
namedtuple type as a module attribute in the deserializing code,
importing automatically the module whose name was saved during serialization.
This is too much magic for my taste, as it can result in strange import errors.

Hence I added an explicit step for the user to say how they want
the namedtuple to be serialized and deserialized.

Since I wanted to also add support for `collections.OrderedDict`, which
users are asking for, I added more general support for PyTree custom nodes.
Note that this registration mechanism works in conjunction with the
PyTree custom node registration mechanism. The burden is on the
user to decide how to serialize and deserialize the custom auxdata that
the PyTree custom registration mechanism uses. Not all custom types
will be serializable, but many commonly used ones, e.g., dataclasses,
can now be inputs and outputs of the serialized functions.
2024-10-21 11:38:13 +02:00
Adam Paszke
bbcc3eef3c [Pallas:MGPU] Fix the implementation of WGMMA with transposed RHS
It's not enough that we have the physical transpose between the order
of tiled dimensions, we also need the user to explicitly transpose the
logical dimensions. This fixes a shape error that was previously hidden
because the RHS was square.

PiperOrigin-RevId: 687350270
2024-10-18 10:31:42 -07:00
Yash Katariya
2153de4ce0 [sharding_in_types] If out_aval.sharding is not None and the user specified out_sharding is None, concretize it with the device assignment available and add it to the final out_shardings that's used for lowering and compilation.
This will allow us to return the exact sharding spec that sharding propagation rules figured out.

PiperOrigin-RevId: 687349015
2024-10-18 10:27:58 -07:00
Christos Perivolaropoulos
f8a3c0366b [pallas] run_scoped now supports partial discharge.
PiperOrigin-RevId: 687347284
2024-10-18 10:22:31 -07:00
jax authors
eba5748094 Disable breaking test-case
PiperOrigin-RevId: 687320199
2024-10-18 08:54:36 -07:00
Adam Paszke
0ee9531ef2 [Pallas:MGPU] Add support for indexed refs to WGMMA
PiperOrigin-RevId: 687258992
2024-10-18 04:55:34 -07:00
Adam Paszke
f2edc83af3 [Pallas:MGPU] Properly commute indexing with other transforms
Doing so requires us to modify the other transforms when we attempt to
move indexing before them.

PiperOrigin-RevId: 687240515
2024-10-18 03:39:51 -07:00
Yash Katariya
4db212d2c6 Add _sharding argument to broadcasted_iota as a private parameter which only works under sharding_in_types mode.
This is required because `jax.nn.one_hot` calls into `broascasted_iota`.

PiperOrigin-RevId: 687152343
2024-10-17 21:16:51 -07:00
jax authors
dd5426301a Allow simple host call that uses host tensor as parameter/result in
linear layout. This cl only handles very simple host call patterns.
A more thorough implementation of propagation of T(1)S(5) will be done
later.

This cl doesn't handle host call that passes/returns tensors that
live on device with linear layout either, which will also be impelmented
separately.

PiperOrigin-RevId: 687113203
2024-10-17 18:22:46 -07:00
Dan Foreman-Mackey
8361eb58e1 Activate the FFI implementation of SVD on GPU.
Alongside activating this new implementation, this change adds a new `algorithm` parameter to `jax.lax.svd`. Previously the choice of algorithm was made based on heuristics in the lowering rule, but it probably also makes sense to expose an option for users to specify the algorithm explicitly because our heuristics are not very carefully optimized.

This change updates the implementation of SVD in `lax` to use the FFI version which was added to jaxlib in https://github.com/jax-ml/jax/pull/23794. This comes with a few benefits:

1. When running on a CUDA platform, the 64-bit API will be used for the algorithm based on QR decomposition. (Note that it looks like the 64-bit API isn't available on ROCm.) This addresses part of the feature request in https://github.com/jax-ml/jax/issues/23413, although there's still work to do to port the rest of the GPU calls to the 64-bit API.

2. This implementation supports shape polymorphism in all dimensions with some caveats. By default, we do use some heuristics to based on the matrix sizes to select the algorithm that is used, and the three different algorithms (QR, Jacobi, and batched Jacobi) have sufficiently different behavior (QR returns V^H, whereas Jacobi returns V; batched Jacobi doesn't support `full_matrices=False`) that I couldn't work out a simple way to push this logic into the kernel. If the symbolic constraints are not sufficient to concretely determine the heuristics, we always use the QR algorithm. But, I've also exposed the algorithm selection in the user API, so it's possible to bypass the heuristics and get consistent behavior alongside shape polymorphism if needed.

Besides these core changes, I removed the forward compatibility checks from the CPU lowering, since we're well outside of the forward compatibility window now.

PiperOrigin-RevId: 687106965
2024-10-17 17:57:06 -07:00
Yash Katariya
3e634d9530 [sharding_in_types] Add lax.transpose sharding propagation rule
PiperOrigin-RevId: 687094297
2024-10-17 17:08:04 -07:00
Yash Katariya
57a95a77ff [sharding_in_types] Support jnp.array with sharding_in_types. When the input array has a sharding, propagate it through without dropping the sharding.
PiperOrigin-RevId: 687089357
2024-10-17 16:51:41 -07:00
Yash Katariya
5df4878ad0 [sharding_in_types] Add reduce max, integer_pow and standard_unop sharding rules
PiperOrigin-RevId: 687073144
2024-10-17 15:55:29 -07:00
Yash Katariya
e92e1191b3 [sharding_in_types] Add broadcast_in_dim rule.
PiperOrigin-RevId: 687054181
2024-10-17 14:55:10 -07:00
Adam Paszke
2d78b17226 [Pallas:MGPU] Add support for transforms in user-specified async copies
PiperOrigin-RevId: 687019020
2024-10-17 13:10:45 -07:00
jax authors
6c2649fdf2 Rewrite mosaic concat to support operand shapes that do not align with native shapes, Expand tests to cover multi operand, batch dim concat, etc.
PiperOrigin-RevId: 687003778
2024-10-17 12:24:51 -07:00
Ionel Gog
ec279f9c54 Add config option to log or fatal when jax.Arrays are GCed.
Introduces `jax.config.array_garbage_collection_guard`, which is a tristate config for setting up a `jax.Array` garbage collection guard. The possible configs are:
* allow: `jax.Array`s are allowed to be garbage collected. This is the default value.
* log: whenever a `jax.Array` is GCed a log entry is generated with the array's traceback.
* fatal: fatal crash when a `jax.Array` is GCed. This is meant to be used for mature code bases that do tight memory management, and are reference cycle free.

PiperOrigin-RevId: 687003464
2024-10-17 12:23:16 -07:00
jax authors
1b5cf5a494 Fix breaking test-case
PiperOrigin-RevId: 686932281
2024-10-17 08:57:15 -07:00
Sergei Lebedev
de7beb91a7 [pallas:mosaic_gpu] Added layout_cast
PiperOrigin-RevId: 686917796
2024-10-17 08:08:05 -07:00
Adam Paszke
0519db15ab [Pallas:MGPU] Add lowerings for more ops
PiperOrigin-RevId: 686910947
2024-10-17 07:42:56 -07:00
Adam Paszke
f72376ae0a [Pallas:MGPU] Add support for debug_print of arrays that use the WGMMA layout
PiperOrigin-RevId: 686885229
2024-10-17 06:06:16 -07:00
Adam Paszke
ef361f05a4 [Mosaic GPU] Add support for launching multiple warpgroups using core_map
PiperOrigin-RevId: 686876014
2024-10-17 05:30:48 -07:00
jax authors
96d5542aae Support single-process AutoPGLE usage.
PiperOrigin-RevId: 686819261
2024-10-17 01:43:58 -07:00
Bart Chrzaszcz
801fe87da6 Do not allow None axis names in meshes.
PiperOrigin-RevId: 686557025
2024-10-16 10:32:25 -07:00
Sergei Lebedev
bb271aaff8 [pallas:mosaic_gpu] Added FragmentedArray.to_layout
PiperOrigin-RevId: 686524192
2024-10-16 08:53:02 -07:00
Mantas Pajarskas
1222b4a571 [Pallas TPU] Add a better error message for rank 1 block mappings check.
Currently, the error message refers to "last two dimensions" which is confusing for a rank-1 case; furthermore, the error does not match the check in the code.

PiperOrigin-RevId: 686520781
2024-10-16 08:41:22 -07:00
Sergei Lebedev
4c0d82824f [pallas:mosaic_gpu] Added a few more operations necessary to port Flash Attention
PiperOrigin-RevId: 686451398
2024-10-16 04:05:36 -07:00
jax authors
5e03a573bc Use Iota order for certain v5e with 8 devices.
PiperOrigin-RevId: 686227482
2024-10-15 13:53:01 -07:00
Ayaka
5ac2076fb7 [Pallas TPU] Fix boolean comparison
Fixes https://github.com/jax-ml/jax/issues/24030

Also added tests to cover all scalar comparison cases.

PiperOrigin-RevId: 686197357
2024-10-15 12:24:58 -07:00
Jevin Jiang
3a7d9137a4 [Pallas TPU] Support ref reshape.
Jaxpr example:
```
{ lambda ; a:MemRef<None>{int32[32,256]} b:MemRef<None>{int32[8,128]}. let
    c:i32[8,128] <- a[:16,:][bitcast(int16[32,256])][reshape(int16[2,16,256])][bitcast(float16[2,16,256])][1:,:,:][reshape(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
    b[:,:] <- c
  in () }
```

Tested:

- DMA with reshaped ref
- Load from reshaped ref
- Store to reshaped ref
- Multiple transforms
- Interpret Mode for ref transforms (updated discharge rules)

PiperOrigin-RevId: 686186426
2024-10-15 11:52:15 -07:00
jax authors
87d8f3817b Merge pull request #24294 from jakevdp:invert-doc
PiperOrigin-RevId: 686159879
2024-10-15 10:44:46 -07:00