8695 Commits

Author SHA1 Message Date
jax authors
e461c0496f Merge pull request #23684 from simonster:sjk/fix-prefix-error
PiperOrigin-RevId: 686133952
2024-10-15 09:32:30 -07:00
Vladimir Belitskiy
2f2fd8a334 Skip some Shardy-enabled tests if XLA < 292.
PiperOrigin-RevId: 686133374
2024-10-15 09:30:41 -07:00
Jake VanderPlas
dd4a0408a4 Improve docs for jnp.invert and related functions 2024-10-15 08:57:19 -07:00
Sharad Vikram
cd78c653e7 [Pallas] Use core_map instead of shard_map for Shmallas
- core_map is like a shard_map but it takes in no inputs and outputs
- we can use it in Pallas to generalize mapping a function over the cores of a chip (e.g. TensorCores in a TPU or SMs in a GPU)
- we specify how the function will be mapped over the device with a `mesh` object. This is also a convenient mechanism for picking the backend for pallas to target

PiperOrigin-RevId: 686036101
2024-10-15 03:26:58 -07:00
Yash Katariya
2f6cb89ac0 Add a private property to NamedSharding called _logical_device_ids which allows you to pass a custom tile_assignment_devices() equivalent.
This is because for Shardy, GSPMDSharding doesn't work, so `device_put` on a mesh with different device order needs `NamedSharding` support. Bonus is that the logic is now simplified wrt the previous version in `_different_device_order_reshard`.

This will also allow us to remove OpSharding usage in other projects which require such kind of permutation capabilities.

PiperOrigin-RevId: 685925636
2024-10-14 20:08:54 -07:00
jax authors
dfee4d1549 Merge pull request #24258 from dfm:remove-jaxlib-version-checks
PiperOrigin-RevId: 685884478
2024-10-14 17:12:03 -07:00
Tongfei Guo
d621737f13 [XLA:Collective] Expose a factory for constructing HLOSharding with explicit device ordering.
PiperOrigin-RevId: 685858699
2024-10-14 15:41:23 -07:00
Ayaka
bfc3d3cd18 [Pallas TPU] Add lowerings for scalar sin, cos, tan and tanh
This PR is similar to https://github.com/jax-ml/jax/pull/24238

PiperOrigin-RevId: 685842905
2024-10-14 14:49:11 -07:00
jax authors
1f0b5728a4 Add a memory saving index rewrite step to vmap with ragged inputs over pallas_call.
The approach here is to add a new notion to jax, for ragged_prop. Ragged prop is useful for computing the dynamism/raggedness of an output, given a set of inputs. In the limit, if we decide that this is a useful property to have in jax as a first class citizen, we could fold the raggedness into the type system. At the moment, however, it is just a small set of rules implemented per op.

PiperOrigin-RevId: 685827096
2024-10-14 14:01:42 -07:00
Yash Katariya
824ccd7183 [Shardy] Inline meshes when using shardy and get rid of global meshes from the MLIR body.
Also do a couple of cleanups.

PiperOrigin-RevId: 685746298
2024-10-14 10:08:04 -07:00
Yash Katariya
4be1e332f7 [sharding_in_types] Add constraints during lowering for dot_general and reduce_sum so that we can enforce the sharding we choose during tracing
PiperOrigin-RevId: 685216047
2024-10-12 09:58:53 -07:00
Yash Katariya
8139c531a3 Fix repr of sharding in aval when a dimension is sharded on multiple mesh axes
PiperOrigin-RevId: 685215764
2024-10-12 09:56:02 -07:00
Yash Katariya
5b8775dc2f [sharding_in_types] Add sharding rule for reduce sum which is just drop the specs for the axis we are reducing over
PiperOrigin-RevId: 685069065
2024-10-11 21:31:25 -07:00
Yash Katariya
89fcd9f1f1 Better repr of aval when shardings are present
Example: (for array for shape (8, 2) with dtype float32

```
P('x', 'y') -- float32[8@x,2@y]

P('x', None) -- float32[8@x,2]

P(('x', 'y'), None) -- float32[8@xy,2]

P(None, None) -- float32[8, 2]
```

PiperOrigin-RevId: 684996577
2024-10-11 16:48:13 -07:00
Yash Katariya
18bc354305 [sharding_in_types] Add dot_general sharding rule. We only handle the simple cases and rely on xla to insert the collectives.
Cases where we error

* batch dimensions not having consistent sharding (ignore None)
* contracting dimensions not having consistent sharding (ignore None)
* lhs.mesh != rhs.mesh
* if batch dimension and tensor dimension sharding match -> Error

PiperOrigin-RevId: 684983567
2024-10-11 16:05:13 -07:00
Yash Katariya
a2973be051 Don't add mhlo.layout_mode = "default" since that is the default even in PJRT and will help reduce cruft in the IR
PiperOrigin-RevId: 684963359
2024-10-11 14:54:32 -07:00
Justin Fu
cff9e93824 [Pallas] Add runtime assert via checkify.check. This check will halt the TPU if triggered, meaning that we would need to restart the program to recover.
PiperOrigin-RevId: 684940271
2024-10-11 13:34:04 -07:00
Dan Foreman-Mackey
5ed2f4ef1c Remove checks for jaxlib v0.4.33 in tests 2024-10-11 15:39:24 -04:00
Bart Chrzaszcz
fb32841b1b #sdy add JAX Shardy support for memories.
PiperOrigin-RevId: 684867097
2024-10-11 09:44:24 -07:00
Sergei Lebedev
59ae2af699 [pallas:mosaic_gpu] Added a test doing manual in kernel pipelining
I think we have most of the primitives necessary, so the next step is to sketch
`emit_pipeline`.

PiperOrigin-RevId: 684840800
2024-10-11 08:08:04 -07:00
Sergei Lebedev
acd0e497af [pallas:mosaic_gpu] GPUBlockSpec no longer accepts swizzle
It was previously possible to pass `swizzle` both directly and via `transforms`.
This change eliminates the ambiguity at a slight downgrade to ergonomics.

PiperOrigin-RevId: 684797980
2024-10-11 05:11:26 -07:00
Ayaka
633cb31577 [Pallas TPU] Add lowering for scalar jnp.log1p
Fixes https://github.com/jax-ml/jax/issues/24239

PiperOrigin-RevId: 684650608
2024-10-10 18:44:41 -07:00
Ayaka
3bd8ca480a [Pallas] Add tests for scalar elementwise operations
On TPU, instructions differentiate between vectors and scalars, and the corresponding lowering paths are different. Existing Pallas tests only test vector version of operations, but not the scalar version of them. This PR adds tests for scalar elementwise operations.

The structure of the test is similar to the vector version of the tests above.

PiperOrigin-RevId: 684569107
2024-10-10 14:00:21 -07:00
Yash Katariya
8ef41a6e14 [sharding_in_types] Normalize partition specs when creating avals so that P(None, None) and P() are treated as replicated and equivalent. Shardings on avals are always normalized.
PiperOrigin-RevId: 684465123
2024-10-10 09:07:44 -07:00
Peter Hawkins
66f526894f Reenable some test cases that were disabled due to bugs that now seem fixed.
PiperOrigin-RevId: 684464642
2024-10-10 09:06:06 -07:00
Peter Hawkins
19dbff5326 Move additional CI enabled/disabled configurations into jax BUILD files.
PiperOrigin-RevId: 684457403
2024-10-10 08:41:45 -07:00
Sergei Lebedev
70ee8e1161 [pallas:mosaic_gpu] pl.run_scoped now supports scoped barriers
PiperOrigin-RevId: 684449776
2024-10-10 08:16:13 -07:00
Peter Hawkins
94abaf430e Add lax.FftType.
We had never provided a public name for the enum of FFT types; instead it was only known by a semi-private name (jax.lib.xla_client.FftType). Add a public name (jax.lax.FftType) and deprecate the private one.

We define a new FftType IntEnum rather than trying to expose the one in xla_client. The xla_client definition was useful when building classic HLO, but we no longer do that so there's no reason we need to couple our type to XLA's type.

PiperOrigin-RevId: 684447186
2024-10-10 08:07:35 -07:00
Yash Katariya
351187d9da [sharding_in_types] Add support for nary ops to propagate sharding when 1 input is sharded and all others are replicated.
PiperOrigin-RevId: 684289345
2024-10-09 21:24:37 -07:00
Justin Fu
73418427a8 [Pallas] Add lowering for threefry PRNG.
PiperOrigin-RevId: 684179182
2024-10-09 14:48:26 -07:00
Ayaka
3fc4ba29ea [Pallas TPU] Add lowerings for lax.population_count_p and lax.clz_p
PiperOrigin-RevId: 684158096
2024-10-09 13:46:29 -07:00
Andrey Portnoy
2c731320af Use py_deps("absl/testing") instead of //third_party/py/absl/testing 2024-10-09 15:24:40 -04:00
jax authors
bcb0f6466a Merge pull request #24172 from shuhand0:dev/shuhan/adopt_jaxlib0.4.34
PiperOrigin-RevId: 684104487
2024-10-09 11:15:59 -07:00
jax authors
db71965c56 Merge pull request #24167 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 684085868
2024-10-09 10:23:44 -07:00
Ayaka
77613f21aa [Pallas TPU] Fix comparison lowering for unsigned integers
Fixes https://github.com/jax-ml/jax/issues/23972.

In Pallas, we use `i32` for both `jnp.int32` and `jnp.uint32`, but we need to choose the correct operation (e.g. `arith.extUI` vs `arith.extSI`) or the correct attribute (e.g. `sle` vs `ule` for `arith::CmpIOp`).

In this particular issue, we need to use attributes like `ule` for `jnp.uint32`, but it's currently lowered to attributes for `jnp.int32` such as `sle`.

This PR fixes this issue by distinguishing the attributes to use for signed and unsigned types.

PiperOrigin-RevId: 684065893
2024-10-09 09:28:53 -07:00
rajasekharporeddy
ed028be7fb Better docs for jnp.left_shift 2024-10-09 12:09:33 +05:30
Justin Fu
9cf952a535 [Pallas] Add support for runtime checking of grid bounds using checkify.
PiperOrigin-RevId: 683791662
2024-10-08 15:48:16 -07:00
Yash Katariya
e5fa9656b2 Error out in donation if:
1) input layout is AUTO and output layout is not AUTO (i.e. default or concrete)

2) input layout is not AUTO (i.e. default or concrete) and output layout is AUTO

This is because there is a conflict in such cases and almost always leads to the wrong layout being chosen by the compiler. For example, let's talk about (1): since input layout is AUTO and output layout is default and since they are aliased, XLA will end up choose default layout for input which is not what you want in majority of the cases.
Erroring is best in such cases and the user can mark the input layout to be default if they want to do that.

The correct choice is to always make both of them AUTO since you want the compiler to choose the best possible layout instead of choosing the input or output layout if the other one is AUTO.

PiperOrigin-RevId: 683688470
2024-10-08 11:02:17 -07:00
Christos Perivolaropoulos
28b0934272 [mosaic_gpu] Memref reshape by means of folding/unfolding
A reshape function that does fold/unfold by touching minimal number of
dimensions to potentially circumvent issues with strided memrefs.

PiperOrigin-RevId: 683663541
2024-10-08 09:59:54 -07:00
Adam Paszke
25c1519a84 [Pallas/MGPU] Allow delaying the release of pipelined buffers
This is useful so that we don't have to block on the WGMMA immediately after it runs.
`delay_release=n` means that the input/output buffers will not be mutated by the system
for at least `n` sequential steps following the one when they were kernel arguments.

PiperOrigin-RevId: 683629935
2024-10-08 08:17:58 -07:00
George Necula
023f2a78be Remove remaining implementations of jax.experimental.host_callback.call.
The host_callback module has been deprecated since March 2024, and we are now removing the implementation. We keep the functions so that we can give a nicer error message than AttributeError, and because removing those now break internal pytype checking. We will remove those in the near future.

See https://github.com/google/jax/issues/20385.

PiperOrigin-RevId: 683564340
2024-10-08 04:22:20 -07:00
Adam Paszke
7102c7adbf Bump the shard_count of FFT tests to avoid timeouts
PiperOrigin-RevId: 683537643
2024-10-08 02:44:41 -07:00
Ayaka
6a958b90b3 [Pallas] Simplify OpsTest by skipping 64-bit tests on 32-bit environments
This PR is similar to https://github.com/jax-ml/jax/pull/23814.

Background: We run tests on both 32-bit and 64-bit environments. Currently, when the tests encounters 64-bit dtypes on 32-bit environments, it enters into a local 64-bit environment using `stack.enter_context(config.enable_x64(True))`. This is not necessary since we also run the same tests on 64-bit environments. This PR makes those test skipped on 32-bit environments.
PiperOrigin-RevId: 683405197
2024-10-07 18:41:14 -07:00
Yash Katariya
ce2b49787f Skip test_ragged_copy_on_host if xla_extension_version < 290
PiperOrigin-RevId: 683326972
2024-10-07 14:30:34 -07:00
Dan Foreman-Mackey
28bbbf894f Simplify and consolidate dot algorithm control in lax.
In https://github.com/jax-ml/jax/pull/23574, we added a new `algorithm` parameter to `lax.dot_general` with the goal of giving users explicit control over the specific algorithm used to control dot product accumulation. When using this feature in real use cases, we have found that the API is both too conservative (it required the user to pass the appropriate input types) and too restrictive for common use cases. In this change, I simplify the API to bring it more in line with user expectations, and generalize it to support a broader range of use cases.

The core change is to update the dot_general lowering rule to add explicit type casts to the inputs, making sure that they always have the appropriate storage types going into the `DotGeneral` StableHLO op. Before this change, some backends would implicitly cast for some algorithms (e.g. f32 -> bf16), but error for others. It seems more user friendly to include automatic casts in all cases where a specific algorithm is requested.

Another change in behavior is to (if needed) cast the result of the `DotGeneral` op (which is defined by the algorithm's `accumulation_type`) to match the input types. This means that, regardless of the algorithm choice, the output type will match the value that a user would expect from past use of `lax.dot_general`. The `preferred_element_type` parameter can now be used to control the output type, even when an algorithm is selected.

To summarize, the updated version of `dot_general` accepts _any_ input dtypes, and the output will always match the inputs (under the existing promotion rules if the LHS and RHS don't match) unless `preferred_element_type` is used to select a specific output type. The specified "algorithm" is now more of an implementation detail, rather than the defining feature of the API, and JAX will do whatever it can to satisfy the user's request. (If an algorithm is not supported on the current device, we will still get a compile time error.)

With the above changes in mind, it's no longer really necessary to have a `transpose_algorithm` parameter, because we can now use the same algorithm for the backwards pass. For users who need to customize the algorithm on the backwards pass, that is still possible using `custom_vjp`.

Given the above changes, @sbodenstein made the excellent point that we don't really need the `algorithm` parameter anymore: just accept `DotAlgorithm` inputs to `precision`. I think this is a really nice suggestion, so I have updated the interface to implement this.

One minor negative of this approach is that `preferred_element_type` isn't a great name for what that parameter does when it is used in conjunction with an algorithm. In the long run, I'd like to rename this parameter, but keeping it as is for now seems like the best short term approach.

PiperOrigin-RevId: 683302687
2024-10-07 13:21:34 -07:00
Shuhan Ding
1aa32f51ee adopt jax.extend.backend 2024-10-07 12:27:35 -07:00
Christos Perivolaropoulos
9ac6723561 [pallas:mosaic_gpu] Dereferencing the accumulator now supports slicing
PiperOrigin-RevId: 683235013
2024-10-07 10:33:08 -07:00
Peter Hawkins
a9926f0f01 Remove classic HLO lowering rule support from JAX.
(JAX uses StableHLO always, now, with the exception of one use case in jax2tf.)

PiperOrigin-RevId: 683205145
2024-10-07 09:06:20 -07:00
George Necula
5fabd34e7e [jax2tf] Remove non-native serialization test from jax_to_ir_test
PiperOrigin-RevId: 683124315
2024-10-07 04:21:38 -07:00
Sergei Lebedev
95631a7d92 Added jax.experimental.pallas.mosaic_gpu
I also deprecated `jax.experimental.pallas.gpu` in favor of
`jax.experimental.pallas.triton` to avoid confusion with the Mosaic GPU
backend.

PiperOrigin-RevId: 683119193
2024-10-07 04:05:08 -07:00