Bart Chrzaszcz
801fe87da6
Do not allow None
axis names in meshes.
...
PiperOrigin-RevId: 686557025
2024-10-16 10:32:25 -07:00
Phuong Nguyen
82113cd047
rm CmdBuffer traits
...
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
2024-10-16 10:27:09 -07:00
Phuong Nguyen
f3775aa233
added cudaGraph traits + use register_ffi_target()
...
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
2024-10-16 10:01:20 -07:00
Sergei Lebedev
bb271aaff8
[pallas:mosaic_gpu] Added FragmentedArray.to_layout
...
PiperOrigin-RevId: 686524192
2024-10-16 08:53:02 -07:00
Mantas Pajarskas
1222b4a571
[Pallas TPU] Add a better error message for rank 1 block mappings check.
...
Currently, the error message refers to "last two dimensions" which is confusing for a rank-1 case; furthermore, the error does not match the check in the code.
PiperOrigin-RevId: 686520781
2024-10-16 08:41:22 -07:00
Sergei Lebedev
4c0d82824f
[pallas:mosaic_gpu] Added a few more operations necessary to port Flash Attention
...
PiperOrigin-RevId: 686451398
2024-10-16 04:05:36 -07:00
jax authors
56eea2b5bb
Merge pull request #24312 from jakevdp:gather-doc
...
PiperOrigin-RevId: 686372450
2024-10-15 22:41:39 -07:00
Yash Katariya
66c6292e6a
Make committed a public property of jax.Array.
...
Why?
Because users need to know if an array is committed or not since JAX raises errors based on committedness of a jax.Array. JAX also makes decisions about dispatching based on committedness of a jax.Array.
But the placement of such arrays on devices is an internal implementation detail.
PiperOrigin-RevId: 686329828
2024-10-15 19:46:10 -07:00
Jake VanderPlas
284ca8bc01
Improve docs for lax.gather & lax.scatter
2024-10-15 16:42:44 -07:00
selamw1
24b6f50938
tile_docstring_added
2024-10-15 15:34:01 -07:00
jax authors
ad99ab1ce3
Merge pull request #24315 from jakevdp:setitem-error
...
PiperOrigin-RevId: 686260067
2024-10-15 15:31:06 -07:00
Andrey Portnoy
0da52cd139
[Mosaic GPU] Skip OSS Flash Attention test unless running on sm90
2024-10-15 18:28:02 -04:00
jax authors
30ba7f37e0
Reverts 0788a0a589318b98ae2cb2169af109622ba26cf0
...
PiperOrigin-RevId: 686252325
2024-10-15 15:05:36 -07:00
jax authors
f0459907dc
Merge pull request #24316 from ROCm:ci_build_fix
...
PiperOrigin-RevId: 686236716
2024-10-15 14:18:30 -07:00
jax authors
5e03a573bc
Use Iota order for certain v5e with 8 devices.
...
PiperOrigin-RevId: 686227482
2024-10-15 13:53:01 -07:00
jax authors
1df58b1854
Update XLA dependency to use revision
...
9136df1d97
.
PiperOrigin-RevId: 686213543
2024-10-15 13:13:28 -07:00
Ayaka
5ac2076fb7
[Pallas TPU] Fix boolean comparison
...
Fixes https://github.com/jax-ml/jax/issues/24030
Also added tests to cover all scalar comparison cases.
PiperOrigin-RevId: 686197357
2024-10-15 12:24:58 -07:00
Jevin Jiang
3a7d9137a4
[Pallas TPU] Support ref reshape.
...
Jaxpr example:
```
{ lambda ; a:MemRef<None>{int32[32,256]} b:MemRef<None>{int32[8,128]}. let
c:i32[8,128] <- a[:16,:][bitcast(int16[32,256])][reshape(int16[2,16,256])][bitcast(float16[2,16,256])][1:,:,:][reshape(float16[16,256])][:,:128][bitcast(int32[8,128])][:,:]
b[:,:] <- c
in () }
```
Tested:
- DMA with reshaped ref
- Load from reshaped ref
- Store to reshaped ref
- Multiple transforms
- Interpret Mode for ref transforms (updated discharge rules)
PiperOrigin-RevId: 686186426
2024-10-15 11:52:15 -07:00
Ruturaj4
a2824862f5
[ROCm] build script fix
2024-10-15 13:43:08 -05:00
Jake VanderPlas
63f4299e0e
Improve array setitem error
2024-10-15 11:39:31 -07:00
Jevin Jiang
a47b755619
[Mosaic TPU] Support native int4 @ int4
...
PiperOrigin-RevId: 686179715
2024-10-15 11:35:23 -07:00
jax authors
94152e8dfc
Merge pull request #24305 from jakevdp:signbit-doc
...
PiperOrigin-RevId: 686174793
2024-10-15 11:21:07 -07:00
jax authors
87d8f3817b
Merge pull request #24294 from jakevdp:invert-doc
...
PiperOrigin-RevId: 686159879
2024-10-15 10:44:46 -07:00
Praveen Batra
3a3190fbce
Fix typo in Pallas TPU matmul doc. I think the logical layout of the input array is non-transposed, rather than transposed?
...
PiperOrigin-RevId: 686151692
2024-10-15 10:23:39 -07:00
jax authors
e461c0496f
Merge pull request #23684 from simonster:sjk/fix-prefix-error
...
PiperOrigin-RevId: 686133952
2024-10-15 09:32:30 -07:00
Vladimir Belitskiy
2f2fd8a334
Skip some Shardy-enabled tests if XLA < 292.
...
PiperOrigin-RevId: 686133374
2024-10-15 09:30:41 -07:00
Jake VanderPlas
dd4a0408a4
Improve docs for jnp.invert and related functions
2024-10-15 08:57:19 -07:00
jax authors
2c2c1eebc7
Merge pull request #24251 from dfm:dot-algorithm-jax2tf
...
PiperOrigin-RevId: 686116542
2024-10-15 08:35:38 -07:00
Sharad Vikram
cd78c653e7
[Pallas] Use core_map instead of shard_map for Shmallas
...
- core_map is like a shard_map but it takes in no inputs and outputs
- we can use it in Pallas to generalize mapping a function over the cores of a chip (e.g. TensorCores in a TPU or SMs in a GPU)
- we specify how the function will be mapped over the device with a `mesh` object. This is also a convenient mechanism for picking the backend for pallas to target
PiperOrigin-RevId: 686036101
2024-10-15 03:26:58 -07:00
Parker Schuh
b0768906db
Add back missing manual axis to fallback.
...
PiperOrigin-RevId: 685951218
2024-10-14 21:50:17 -07:00
Yash Katariya
2f6cb89ac0
Add a private property to NamedSharding called _logical_device_ids
which allows you to pass a custom tile_assignment_devices()
equivalent.
...
This is because for Shardy, GSPMDSharding doesn't work, so `device_put` on a mesh with different device order needs `NamedSharding` support. Bonus is that the logic is now simplified wrt the previous version in `_different_device_order_reshard`.
This will also allow us to remove OpSharding usage in other projects which require such kind of permutation capabilities.
PiperOrigin-RevId: 685925636
2024-10-14 20:08:54 -07:00
Jake VanderPlas
a096986844
Better docs for jnp.signbit
2024-10-14 19:51:42 -07:00
Tzu-Wei Sung
0788a0a589
Internal change.
...
PiperOrigin-RevId: 685905230
2024-10-14 18:43:36 -07:00
jax authors
829e315963
Merge pull request #24299 from jakevdp:apply-doc
...
PiperOrigin-RevId: 685888268
2024-10-14 17:27:05 -07:00
jax authors
dfee4d1549
Merge pull request #24258 from dfm:remove-jaxlib-version-checks
...
PiperOrigin-RevId: 685884478
2024-10-14 17:12:03 -07:00
jax authors
2a828d5d6b
Merge pull request #23467 from abhinavgoel95:patch-4
...
PiperOrigin-RevId: 685872800
2024-10-14 16:29:24 -07:00
Tongfei Guo
d621737f13
[XLA:Collective] Expose a factory for constructing HLOSharding with explicit device ordering.
...
PiperOrigin-RevId: 685858699
2024-10-14 15:41:23 -07:00
Ayaka
bfc3d3cd18
[Pallas TPU] Add lowerings for scalar sin
, cos
, tan
and tanh
...
This PR is similar to https://github.com/jax-ml/jax/pull/24238
PiperOrigin-RevId: 685842905
2024-10-14 14:49:11 -07:00
Jake VanderPlas
0c307fe706
Better docs for jnp.apply_along_axis & apply_over_axes
2024-10-14 14:48:33 -07:00
jax authors
1f0b5728a4
Add a memory saving index rewrite step to vmap with ragged inputs over pallas_call.
...
The approach here is to add a new notion to jax, for ragged_prop. Ragged prop is useful for computing the dynamism/raggedness of an output, given a set of inputs. In the limit, if we decide that this is a useful property to have in jax as a first class citizen, we could fold the raggedness into the type system. At the moment, however, it is just a small set of rules implemented per op.
PiperOrigin-RevId: 685827096
2024-10-14 14:01:42 -07:00
jax authors
fff3b8747f
Update XLA dependency to use revision
...
867b9f6e80
.
PiperOrigin-RevId: 685797421
2024-10-14 12:30:37 -07:00
jax authors
90cd8a79dc
Merge pull request #24290 from jax-ml:dependabot/github_actions/actions/upload-artifact-4.4.3
...
PiperOrigin-RevId: 685764189
2024-10-14 11:00:05 -07:00
jax authors
13bc497836
Merge pull request #24289 from jax-ml:dependabot/github_actions/actions/cache-4.1.1
...
PiperOrigin-RevId: 685763882
2024-10-14 10:58:22 -07:00
dependabot[bot]
93adc0e931
Bump actions/upload-artifact from 4.4.1 to 4.4.3
...
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact ) from 4.4.1 to 4.4.3.
- [Release notes](https://github.com/actions/upload-artifact/releases )
- [Commits](604373da63...b4b15b8c7c
)
---
updated-dependencies:
- dependency-name: actions/upload-artifact
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
2024-10-14 17:42:00 +00:00
dependabot[bot]
0fdd653509
Bump actions/cache from 4.1.0 to 4.1.1
...
Bumps [actions/cache](https://github.com/actions/cache ) from 4.1.0 to 4.1.1.
- [Release notes](https://github.com/actions/cache/releases )
- [Changelog](https://github.com/actions/cache/blob/main/RELEASES.md )
- [Commits](2cdf405574...3624ceb22c
)
---
updated-dependencies:
- dependency-name: actions/cache
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
2024-10-14 17:41:55 +00:00
Yash Katariya
824ccd7183
[Shardy] Inline meshes when using shardy and get rid of global meshes from the MLIR body.
...
Also do a couple of cleanups.
PiperOrigin-RevId: 685746298
2024-10-14 10:08:04 -07:00
Bart Chrzaszcz
75e22f2ccd
#sdy Run inlined mesh lifter pass at the end of JAX lowering.
...
PiperOrigin-RevId: 685728692
2024-10-14 09:13:12 -07:00
jax authors
d15d70d67f
Merge pull request #24271 from jakevdp:hist-doc
...
PiperOrigin-RevId: 685717635
2024-10-14 08:35:30 -07:00
jax authors
1de9f25c2d
Merge pull request #24264 from ROCm:ci_apt_update
...
PiperOrigin-RevId: 685717538
2024-10-14 08:35:12 -07:00
jax authors
57ef7a4a59
Merge pull request #24274 from ROCm:ci_linalg_fix
...
PiperOrigin-RevId: 685717437
2024-10-14 08:33:33 -07:00