8106 Commits

Author SHA1 Message Date
Mathew Odden
6b35155294
Fix invalid lowerings for ROCm in Pallas (#223)
popcount and clz were effectively broken on ROCm,
since math_dialect had incorrect lowerings.

Use the device intrinsics for these functions, as
well as for exp and absf, which fixes some accuracy issues in
the pallas tests.

Docs for OCML/OCKL

- https://github.com/ROCm/llvm-project/blob/amd-staging/amd/device-libs/doc/OCML.md
- https://github.com/ROCm/llvm-project/blob/amd-staging/amd/device-libs/doc/OCKL.md
2025-02-14 11:27:52 -06:00
jax authors
a527aba646 Reverts f1b894d14a28ac22a037fb79177b991275c75a18
PiperOrigin-RevId: 716653711
2025-01-17 07:00:31 -08:00
Yash Katariya
ce85b89884 [sharding_in_types] Error out for reshape for splits like this: (4, 6, 8) -> (4, 4, 2, 6)
PiperOrigin-RevId: 716653203
2025-01-17 06:58:29 -08:00
Sergei Lebedev
d34c40f6b6 [mosaic_gpu] Added a serialization pass
The pass adds versioning to the Mosaic GPU IR in the lowered custom calls
and can apply forward/backward migration rules. Currently, no rules are
necessary since we are at version 1.

PiperOrigin-RevId: 716596848
2025-01-17 03:12:51 -08:00
Yash Katariya
af667199db [sharding_in_types] Rename .at[...].get(out_spec) to .at[...].get(out_sharding).
PiperOrigin-RevId: 716466870
2025-01-16 18:56:52 -08:00
Yash Katariya
97cd748376 Rename out_type -> out_sharding parameter on einsum
PiperOrigin-RevId: 716454800
2025-01-16 18:16:52 -08:00
Yash Katariya
49224d6cdb Replace Auto/User/Collective AxisTypes names with Hidden/Visible/Collective.
Replace `with set_mesh(mesh):` with `with use_mesh(mesh):` context manager

Also expose `AxisTypes` and `use_mesh` into public API via `jax.sharding.AxisTypes` and `jax.sharding.use_mesh`.

PiperOrigin-RevId: 716446406
2025-01-16 17:55:54 -08:00
Parker Schuh
f2f552c108 Allow resharding between tokens on a single device
and multiple devices.

Whenever this happens we can essentially introduce an effects barrier
instead of doing the normal device -> host -> device transfer.

Fixes https://github.com/jax-ml/jax/issues/25671.

PiperOrigin-RevId: 716309978
2025-01-16 11:24:22 -08:00
Yash Katariya
b23c42372b [sharding_in_types] If an indexing operation hits into gather_p, error out saying to use .at[...].get(out_spec=...) instead.
This will basically drop the gather operation into full auto mode and add a sharding constraint on the output given by the user via `out_spec`.

Co-authored-by: Matthew Johnson <mattjj@google.com>
PiperOrigin-RevId: 716295953
2025-01-16 10:51:15 -08:00
Yash Katariya
0df4475aeb Make result_handler of _DeferredShardArg a method instead of a property. Also play some code golf.
PiperOrigin-RevId: 716273533
2025-01-16 09:53:48 -08:00
Yash Katariya
c6b5ac5c7b [sharding_in_types] Expand reshape's sharding rule to add support for the following cases:
* Split on 1 dimension only and the splitting dimension should be unsharded.

  `operand.shape = (4@x, 6@y, 8), new_shape = (4@x, 6@y, 2, 2, 2)`

* Merging into 1 dimension only and all the merging dimensions should be unsharded.

  `operand.shape = (4@y, 2, 3, 8), new_shape = (4@y, 6, 8)`

* Split into singleton dimensions i.e. adding extra dims of size 1

  `operand.shape = (4@x, 6@y, 8@z), new_shape = (1, 4@x, 1, 6@y, 1, 8@z, 1)`

* Merge singleton dimensions i.e. removing extra dims of size 1

  `operand.shape = (1, 4@x, 6, 1, 8, 1), new_shape = (1, 4@x, 6, 8)`

* Identity reshape

  `operand.shape = (4@(x,y), 6), new_shape = (4@(x,y), 6)`

These cases are unambiguous to handle. In all other cases, we error out and ask the user to provide the out_sharding.

PiperOrigin-RevId: 716216240
2025-01-16 06:47:26 -08:00
Sharad Vikram
0ac63157f5 [Pallas TPU] Add helpers file with copy_ref function
PiperOrigin-RevId: 716030813
2025-01-15 18:34:58 -08:00
Zachary Garrett
f7d097f7cc Make utils for reporting function name work with functools.partial by using the inner .func attribute if the object doesn't have a __name__ attribute. functools.partial objects do not have __name__ attributes by default.
PiperOrigin-RevId: 715881812
2025-01-15 11:40:59 -08:00
jax authors
ca012d7ad6 Merge pull request #25864 from jax-ml:yet-more-linearization-fixes
PiperOrigin-RevId: 715840148
2025-01-15 10:00:31 -08:00
Zac Mustin
2d72e8de84 Jax: Stop returning a list of cost-analyses.
As it stands, there is only ever one element in this list (see b/384741132) and only the 0th element is ever used so we can simplify.

This is a potentially breaking change for external users, but (as stated in the [documentation](https://jax.readthedocs.io/en/latest/aot.html#debug-information-and-analyses-when-available)) no guarantees are made on this type, which is intended for debugging purposes and not intended to be a reliable public API.

PiperOrigin-RevId: 715837855
2025-01-15 09:53:59 -08:00
jax authors
70c1ee5d9c Merge pull request #25876 from gnecula:debug_info_3
PiperOrigin-RevId: 715831527
2025-01-15 09:35:03 -08:00
jax authors
2e5e4799fd Merge pull request #25880 from jakevdp:fix-gather
PiperOrigin-RevId: 715804120
2025-01-15 08:10:44 -08:00
Dougal
9fe553ca49 More linearization fixes 2025-01-15 10:27:21 -05:00
Sergei Lebedev
afcb21ddf1 [pallas:mosaic_gpu] Fixed a crash in MLIR Python bindings
The error message produced by MLIR is not really clear, but AFAICT the crash
was caused by the "temporary module" hack we use in the lax.cond lowering
rule.

PiperOrigin-RevId: 715785632
2025-01-15 07:09:43 -08:00
George Necula
f9dfe7f646 [better_errors] More cleanup 2025-01-15 10:22:29 +00:00
jax authors
c4406d2759 [pallas] Fix bad rebase, deleted lowering for a print
PiperOrigin-RevId: 715694818
2025-01-15 01:18:30 -08:00
jax authors
c18492be65 [pallas][mosaic kernel export] Add initial support for exporting a dynamic shapes (placeholder bound) kernel out of mosaic, via pallas as both MLIR and jaxpr.
PiperOrigin-RevId: 715629439
2025-01-14 20:34:11 -08:00
Jevin Jiang
6851700ed4 [Mosaic TPU] Append dump id to timestamp to make dump list ordered
PiperOrigin-RevId: 715488504
2025-01-14 12:44:10 -08:00
Jake VanderPlas
54fbf0b3f2 Indexing: avoid dynamic_slice when mode='clip'
This causes issues in the backward pass, where effectively mode='promise_in_bounds'
2025-01-14 11:20:50 -08:00
George Necula
f1b894d14a Reverts 391bad8ff59c07c8fad7b8ce05cd0e29dee4cf1a
PiperOrigin-RevId: 715435319
2025-01-14 10:31:59 -08:00
Justin Fu
b6acb9cb7a Fix remat bug on primitives with multiple outputs.
Addresses https://github.com/jax-ml/jax/issues/25841

PiperOrigin-RevId: 715434084
2025-01-14 10:26:58 -08:00
Yash Katariya
b7e06f1937 Remove dead codepaths now that MemorySpaceDescription works in OSS
PiperOrigin-RevId: 715410774
2025-01-14 09:22:26 -08:00
jax authors
ee724565bf Merge pull request #25827 from gnecula:debug_info_2
PiperOrigin-RevId: 715407809
2025-01-14 09:12:37 -08:00
Yash Katariya
c72ed260fe [sharding_in_types] Handle ShapeDtypeStruct inputs with sharding_in_types by registering the sharding on the aval properly created by SDS in it's pytype_aval_mapping.
Also If we are running under full auto mode, don't error out if primitives don't have a sharding rule registered.

PiperOrigin-RevId: 715383866
2025-01-14 08:03:50 -08:00
Dougal
7d11d12bcd Mention expected tangent aval in error message, see #25517. 2025-01-14 08:51:12 -05:00
George Necula
b30df36d7d [better_errors] Add debug_info to DynamicJaxprTrace and JaxprStackFrame
This is part of a sequence of changes to ensure that the debugging information
is propagated properly.

Additional cleanup:
* Rename `result_paths` to `result_paths_thunk` in `TracingDebugInfo` to clarify the
  difference from the similar field in `JaxprDebugInfo`
* Added more type declarations
2025-01-14 13:49:18 +00:00
Bart Chrzaszcz
74e912c3c0 #sdy dynamically choose which custom_partitioning API to use based on the current
value of the `use_shardy_partitioner` feature flag.

Before the way the API works depends on the value of the flag when the partitioning is defined. But we should allow this to be dynamically swapped in and out when the function is actually called. This change allows for that.

PiperOrigin-RevId: 715293018
2025-01-14 02:11:55 -08:00
jax authors
4f2f5fa53a Merge pull request #25798 from gnecula:fix_fori_error
PiperOrigin-RevId: 715258789
2025-01-14 00:01:30 -08:00
Roy Frostig
a60ead6fd1 enable partitionable threefry by default
PiperOrigin-RevId: 715242560
2025-01-13 22:46:24 -08:00
Ayaka
9ba1fd2801 [Pallas TPU] Add vector support to pl.debug_print
PiperOrigin-RevId: 715085454
2025-01-13 13:22:21 -08:00
Justin Fu
f69592ae78 [Mosaic GPU] Fix layout API bugs.
PiperOrigin-RevId: 715077057
2025-01-13 12:59:30 -08:00
jax authors
dabe27bc1b Merge pull request #25833 from jax-ml:more-linearize-fixes
PiperOrigin-RevId: 715073762
2025-01-13 12:51:05 -08:00
Dougal
96769f96c2 Even more linearize fixes 2025-01-13 15:30:49 -05:00
Adam Paszke
391bad8ff5 [Mosaic TPU] Add support for arith.fptosi with non-32bit source and target types
This effectively moves some of the Pallas logic to the layer below.

PiperOrigin-RevId: 714965374
2025-01-13 07:49:13 -08:00
George Necula
3faff78ca8 [better_errors] Ensure that tracer errors in for_loop points to use code
Fixes: 23637
2025-01-13 15:33:30 +00:00
Yash Katariya
6b253b2f75 [shardy] Fix cases in shardy where you have a nullary function with partially specified out_shardings (i.e. some out_sharding's are None and others are NamedShardings).
In this case, the returned out_shardings should all be NamedSharding (because of NamedSharding's presence in some out_sharding's).

PiperOrigin-RevId: 714681941
2025-01-12 08:09:17 -08:00
Yash Katariya
aeac6b0383 Fix pmap with sharded typed prng key
PiperOrigin-RevId: 714293671
2025-01-10 18:20:09 -08:00
Yash Katariya
a817f532b4 [sharding_in_types] Introduce auto_mode, user_mode, auto_mode_ctx and user_mode_ctx as **private** APIs to make writing auto/user sharding in types code way easier and noise-free.
These can be made public in the future under different names.

PiperOrigin-RevId: 714169304
2025-01-10 14:14:25 -08:00
Justin Fu
8e86bede9f [Mosaic GPU] Allow multiple gmem indexers on copies.
This is implemented by merging multiple indexers into one.

PiperOrigin-RevId: 714150733
2025-01-10 13:12:50 -08:00
Justin Fu
73b64b8e56 [Mosaic GPU] Enable loop carries in the pipeline emitter.
PiperOrigin-RevId: 714141077
2025-01-10 12:40:42 -08:00
jax authors
a16fbffc13 [Mosaic][TPU] Add a compatibility mode to Mosaic's canonicalization pass, skipping over elementwise and matmul op insertions and/or type compat casts.
PiperOrigin-RevId: 714132282
2025-01-10 12:12:54 -08:00
Jake VanderPlas
1ee015674f [internal] add deprecation test utilities 2025-01-10 11:54:09 -08:00
jax authors
5d0ee43222 Merge pull request #25741 from jakevdp:solve-dep
PiperOrigin-RevId: 714124347
2025-01-10 11:46:33 -08:00
jax authors
aed79707e2 Merge pull request #25791 from mattjj:logsumexp-where-grad-nan
PiperOrigin-RevId: 714118085
2025-01-10 11:27:35 -08:00
Jake VanderPlas
051abafd6d jnp.linalg.solve: finalize deprecation of batched 1D solves 2025-01-10 10:42:32 -08:00