23755 Commits

Author SHA1 Message Date
Bart Chrzaszcz
801fe87da6 Do not allow None axis names in meshes.
PiperOrigin-RevId: 686557025
2024-10-16 10:32:25 -07:00
Phuong Nguyen
82113cd047 rm CmdBuffer traits
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
2024-10-16 10:27:09 -07:00
Phuong Nguyen
f3775aa233 added cudaGraph traits + use register_ffi_target()
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
2024-10-16 10:01:20 -07:00
Sergei Lebedev
bb271aaff8 [pallas:mosaic_gpu] Added FragmentedArray.to_layout
PiperOrigin-RevId: 686524192
2024-10-16 08:53:02 -07:00
Mantas Pajarskas
1222b4a571 [Pallas TPU] Add a better error message for rank 1 block mappings check.
Currently, the error message refers to "last two dimensions" which is confusing for a rank-1 case; furthermore, the error does not match the check in the code.

PiperOrigin-RevId: 686520781
2024-10-16 08:41:22 -07:00
Sergei Lebedev
4c0d82824f [pallas:mosaic_gpu] Added a few more operations necessary to port Flash Attention
PiperOrigin-RevId: 686451398
2024-10-16 04:05:36 -07:00
jax authors
56eea2b5bb Merge pull request #24312 from jakevdp:gather-doc
PiperOrigin-RevId: 686372450
2024-10-15 22:41:39 -07:00
Yash Katariya
66c6292e6a Make committed a public property of jax.Array.
Why?

Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.

PiperOrigin-RevId: 686329828
2024-10-15 19:46:10 -07:00
Jake VanderPlas
284ca8bc01 Improve docs for lax.gather & lax.scatter 2024-10-15 16:42:44 -07:00
selamw1
24b6f50938 tile_docstring_added 2024-10-15 15:34:01 -07:00
jax authors
ad99ab1ce3 Merge pull request #24315 from jakevdp:setitem-error
PiperOrigin-RevId: 686260067
2024-10-15 15:31:06 -07:00
Andrey Portnoy
0da52cd139 [Mosaic GPU] Skip OSS Flash Attention test unless running on sm90 2024-10-15 18:28:02 -04:00
jax authors
30ba7f37e0 Reverts 0788a0a589318b98ae2cb2169af109622ba26cf0
PiperOrigin-RevId: 686252325
2024-10-15 15:05:36 -07:00
jax authors
f0459907dc Merge pull request #24316 from ROCm:ci_build_fix
PiperOrigin-RevId: 686236716
2024-10-15 14:18:30 -07:00
jax authors
5e03a573bc Use Iota order for certain v5e with 8 devices.
PiperOrigin-RevId: 686227482
2024-10-15 13:53:01 -07:00
jax authors
1df58b1854 Update XLA dependency to use revision
9136df1d97.

PiperOrigin-RevId: 686213543
2024-10-15 13:13:28 -07:00
Ayaka
5ac2076fb7 [Pallas TPU] Fix boolean comparison
Fixes https://github.com/jax-ml/jax/issues/24030

Also added tests to cover all scalar comparison cases.

PiperOrigin-RevId: 686197357
2024-10-15 12:24:58 -07:00
Jevin Jiang
3a7d9137a4 [Pallas TPU] Support ref reshape.
Jaxpr example:
```
{ lambda ; a:MemRef<None>{int32[32,256]} b:MemRef<None>{int32[8,128]}. let
    c:i32[8,128] <- a[:16,:][bitcast(int16[32,256])][reshape(int16[2,16,256])][bitcast(float16[2,16,256])][1:,:,:][reshape(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
    b[:,:] <- c
  in () }
```

Tested:

- DMA with reshaped ref
- Load from reshaped ref
- Store to reshaped ref
- Multiple transforms
- Interpret Mode for ref transforms (updated discharge rules)

PiperOrigin-RevId: 686186426
2024-10-15 11:52:15 -07:00
Ruturaj4
a2824862f5 [ROCm] build script fix 2024-10-15 13:43:08 -05:00
Jake VanderPlas
63f4299e0e Improve array setitem error 2024-10-15 11:39:31 -07:00
Jevin Jiang
a47b755619 [Mosaic TPU] Support native int4 @ int4
PiperOrigin-RevId: 686179715
2024-10-15 11:35:23 -07:00
jax authors
94152e8dfc Merge pull request #24305 from jakevdp:signbit-doc
PiperOrigin-RevId: 686174793
2024-10-15 11:21:07 -07:00
jax authors
87d8f3817b Merge pull request #24294 from jakevdp:invert-doc
PiperOrigin-RevId: 686159879
2024-10-15 10:44:46 -07:00
Praveen Batra
3a3190fbce Fix typo in Pallas TPU matmul doc. I think the logical layout of the input array is non-transposed, rather than transposed?
PiperOrigin-RevId: 686151692
2024-10-15 10:23:39 -07:00
jax authors
e461c0496f Merge pull request #23684 from simonster:sjk/fix-prefix-error
PiperOrigin-RevId: 686133952
2024-10-15 09:32:30 -07:00
Vladimir Belitskiy
2f2fd8a334 Skip some Shardy-enabled tests if XLA < 292.
PiperOrigin-RevId: 686133374
2024-10-15 09:30:41 -07:00
Jake VanderPlas
dd4a0408a4 Improve docs for jnp.invert and related functions 2024-10-15 08:57:19 -07:00
jax authors
2c2c1eebc7 Merge pull request #24251 from dfm:dot-algorithm-jax2tf
PiperOrigin-RevId: 686116542
2024-10-15 08:35:38 -07:00
Sharad Vikram
cd78c653e7 [Pallas] Use core_map instead of shard_map for Shmallas
- core_map is like a shard_map but it takes in no inputs and outputs
- we can use it in Pallas to generalize mapping a function over the cores of a chip (e.g. TensorCores in a TPU or SMs in a GPU)
- we specify how the function will be mapped over the device with a `mesh` object. This is also a convenient mechanism for picking the backend for pallas to target

PiperOrigin-RevId: 686036101
2024-10-15 03:26:58 -07:00
Parker Schuh
b0768906db Add back missing manual axis to fallback.
PiperOrigin-RevId: 685951218
2024-10-14 21:50:17 -07:00
Yash Katariya
2f6cb89ac0 Add a private property to NamedSharding called _logical_device_ids which allows you to pass a custom tile_assignment_devices() equivalent.
This is because for Shardy, GSPMDSharding doesn't work, so `device_put` on a mesh with different device order needs `NamedSharding` support. Bonus is that the logic is now simplified wrt the previous version in `_different_device_order_reshard`.

This will also allow us to remove OpSharding usage in other projects which require such kind of permutation capabilities.

PiperOrigin-RevId: 685925636
2024-10-14 20:08:54 -07:00
Jake VanderPlas
a096986844 Better docs for jnp.signbit 2024-10-14 19:51:42 -07:00
Tzu-Wei Sung
0788a0a589 Internal change.
PiperOrigin-RevId: 685905230
2024-10-14 18:43:36 -07:00
jax authors
829e315963 Merge pull request #24299 from jakevdp:apply-doc
PiperOrigin-RevId: 685888268
2024-10-14 17:27:05 -07:00
jax authors
dfee4d1549 Merge pull request #24258 from dfm:remove-jaxlib-version-checks
PiperOrigin-RevId: 685884478
2024-10-14 17:12:03 -07:00
jax authors
2a828d5d6b Merge pull request #23467 from abhinavgoel95:patch-4
PiperOrigin-RevId: 685872800
2024-10-14 16:29:24 -07:00
Tongfei Guo
d621737f13 [XLA:Collective] Expose a factory for constructing HLOSharding with explicit device ordering.
PiperOrigin-RevId: 685858699
2024-10-14 15:41:23 -07:00
Ayaka
bfc3d3cd18 [Pallas TPU] Add lowerings for scalar sin, cos, tan and tanh
This PR is similar to https://github.com/jax-ml/jax/pull/24238

PiperOrigin-RevId: 685842905
2024-10-14 14:49:11 -07:00
Jake VanderPlas
0c307fe706 Better docs for jnp.apply_along_axis & apply_over_axes 2024-10-14 14:48:33 -07:00
jax authors
1f0b5728a4 Add a memory saving index rewrite step to vmap with ragged inputs over pallas_call.
The approach here is to add a new notion to jax, for ragged_prop. Ragged prop is useful for computing the dynamism/raggedness of an output, given a set of inputs. In the limit, if we decide that this is a useful property to have in jax as a first class citizen, we could fold the raggedness into the type system. At the moment, however, it is just a small set of rules implemented per op.

PiperOrigin-RevId: 685827096
2024-10-14 14:01:42 -07:00
jax authors
fff3b8747f Update XLA dependency to use revision
867b9f6e80.

PiperOrigin-RevId: 685797421
2024-10-14 12:30:37 -07:00
jax authors
90cd8a79dc Merge pull request #24290 from jax-ml:dependabot/github_actions/actions/upload-artifact-4.4.3
PiperOrigin-RevId: 685764189
2024-10-14 11:00:05 -07:00
jax authors
13bc497836 Merge pull request #24289 from jax-ml:dependabot/github_actions/actions/cache-4.1.1
PiperOrigin-RevId: 685763882
2024-10-14 10:58:22 -07:00
dependabot[bot]
93adc0e931
Bump actions/upload-artifact from 4.4.1 to 4.4.3
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4.4.1 to 4.4.3.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](604373da63...b4b15b8c7c)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-10-14 17:42:00 +00:00
dependabot[bot]
0fdd653509
Bump actions/cache from 4.1.0 to 4.1.1
Bumps [actions/cache](https://github.com/actions/cache) from 4.1.0 to 4.1.1.
- [Release notes](https://github.com/actions/cache/releases)
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md)
- [Commits](2cdf405574...3624ceb22c)

---
updated-dependencies:
- dependency-name: actions/cache
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-10-14 17:41:55 +00:00
Yash Katariya
824ccd7183 [Shardy] Inline meshes when using shardy and get rid of global meshes from the MLIR body.
Also do a couple of cleanups.

PiperOrigin-RevId: 685746298
2024-10-14 10:08:04 -07:00
Bart Chrzaszcz
75e22f2ccd #sdy Run inlined mesh lifter pass at the end of JAX lowering.
PiperOrigin-RevId: 685728692
2024-10-14 09:13:12 -07:00
jax authors
d15d70d67f Merge pull request #24271 from jakevdp:hist-doc
PiperOrigin-RevId: 685717635
2024-10-14 08:35:30 -07:00
jax authors
1de9f25c2d Merge pull request #24264 from ROCm:ci_apt_update
PiperOrigin-RevId: 685717538
2024-10-14 08:35:12 -07:00
jax authors
57ef7a4a59 Merge pull request #24274 from ROCm:ci_linalg_fix
PiperOrigin-RevId: 685717437
2024-10-14 08:33:33 -07:00