Jake VanderPlas
31abce1c80
register several deprecations in jax.numpy
...
This is in preparation for finalizing these deprecations. They include:
- complex->real casting in jnp.astype
- complex inputs to jnp.clip
- complex inputs to jnp.hypot
PiperOrigin-RevId: 656005670
2024-07-25 10:40:00 -07:00
Paweł Paruzel
ae40c87919
Activate Cholesky Factorization Kernel to XLA's FFI
...
PiperOrigin-RevId: 655990468
2024-07-25 09:59:28 -07:00
Jake VanderPlas
8c36f90c90
[array API] clean up some superseded definitions
2024-07-25 09:40:02 -07:00
jax authors
9ea79c61f4
Merge pull request #22653 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 655980588
2024-07-25 09:29:07 -07:00
Christos Perivolaropoulos
d54bf529cc
[mosaic_gpu] arrive_expect_tx() also accepts index typed values
...
PiperOrigin-RevId: 655965308
2024-07-25 08:33:55 -07:00
Vladimir Belitskiy
ba50e77407
Increase shard count for //third_party/py/jax/tests:lax_numpy_ufuncs_test_cpu.
...
PiperOrigin-RevId: 655946922
2024-07-25 07:26:06 -07:00
Sergei Lebedev
99a8b92a7f
Fixed Pallas Mosaic GPU tests
...
* Migrated to the new barrier APIs
* Fixed scratch view casting logic, it previously didn't work for >1 view
PiperOrigin-RevId: 655937541
2024-07-25 06:51:13 -07:00
Adam Paszke
2e6da35e97
[Mosaic GPU] Add support for clusters in the matmul example
...
With the collective async_copy API, the changes are quite minimal!
PiperOrigin-RevId: 655937185
2024-07-25 06:46:51 -07:00
Adam Paszke
e59303cf3e
[Mosaic GPU] Simplify the matmul example
...
Remove a bunch of WGMMAImpl classes. This is meant to be a simple forkable example,
not a complete kernel.
PiperOrigin-RevId: 655923069
2024-07-25 05:43:57 -07:00
Adam Paszke
be9cc807d9
[Mosaic GPU] Minor cleanups in the matmul example
...
We were incorrectly reporting the runtime performance.
PiperOrigin-RevId: 655915808
2024-07-25 05:10:41 -07:00
Bart Chrzaszcz
b00f978f70
#sdy Support with_sharding_constraint lowering through Shardy.
...
PiperOrigin-RevId: 655905063
2024-07-25 04:20:52 -07:00
jax authors
f15f9717c3
[Pallas/TPU] Fix bug with LocalMask grid shrinking
...
LocalMasks can trigger shrinking of the MaskInfo arrays and of the iteration space.
As a consequence, it is important that in the kernel body we use the `global_kv_index`. This is the kv_index in the "global" space without any shrinking of the iteration space.
PiperOrigin-RevId: 655901432
2024-07-25 04:05:57 -07:00
rajasekharporeddy
4ff69d3853
Fix jnp.triu_indices_from and jnp.tril_indices_from to emulate NumPy's behavior for arrays other than 2-D
2024-07-25 16:26:14 +05:30
jax authors
e14752c0ab
Merge pull request #22642 from ROCm:ci_jax_exp
...
PiperOrigin-RevId: 655894235
2024-07-25 03:36:36 -07:00
jax authors
2dadbd7eb6
Merge pull request #22605 from Cjkkkk:add_sm86_sm89_flash_attention
...
PiperOrigin-RevId: 655894058
2024-07-25 03:32:45 -07:00
George Necula
e6066e52a2
[pallas] Stop exporting jax.experimental.pallas.pallas
...
This was giving access to too many internal APIs.
PiperOrigin-RevId: 655887765
2024-07-25 03:03:58 -07:00
George Necula
4063373b22
Reverts 0d058ce86f04a44a51abba1261768fb46edf69d9
...
PiperOrigin-RevId: 655871052
2024-07-25 01:50:36 -07:00
Matthew Johnson
c8ea86c9c9
remove inlined jax.nn.initializers definitions, resolving TODO of levskaya et al
...
fixes breakage from cl/655766534 aka https://github.com/google/jax/pull/21069
PiperOrigin-RevId: 655806010
2024-07-24 20:55:36 -07:00
jax authors
76b4c70c23
Merge pull request #22628 from hawkinsp:broadcast2
...
PiperOrigin-RevId: 655779730
2024-07-24 19:17:25 -07:00
Yash Katariya
51e27923e8
Simplify pjit's batching rule now that xmap is deleted. Also do cleanup around adding manual axes under shard_map
...
PiperOrigin-RevId: 655776234
2024-07-24 19:02:13 -07:00
jax authors
086b500da6
Merge pull request #21069 from mattjj:remove-named-shapes
...
PiperOrigin-RevId: 655766534
2024-07-24 18:20:50 -07:00
Matthew Johnson
3f9eb404e4
remove named_shapes (since xmap is now gone)
2024-07-25 00:54:50 +00:00
Christos Perivolaropoulos
80a193d5db
[pallas] Use the same primitive run_scoped_p
for moth mosaic and mosaic_gpu
...
PiperOrigin-RevId: 655751205
2024-07-24 17:14:30 -07:00
Christos Perivolaropoulos
aa99113173
[MosaicGPU] Expose the BarrierRef type so we can type our variables outside of mosaic.
...
PiperOrigin-RevId: 655748632
2024-07-24 17:06:44 -07:00
jax authors
b5010a2dac
Merge pull request #22634 from jakevdp:array-api-cleanup
...
PiperOrigin-RevId: 655747905
2024-07-24 17:02:44 -07:00
Tomás Longeri
220ec2aa69
[Mosaic TPU] (8,128),-2 -> (8,128) for non-zero and replicated 2nd minor offset
...
Also fix bug where relayouts for fully replicated source assumed it was a no-op without checking implicit dims
PiperOrigin-RevId: 655746766
2024-07-24 16:58:35 -07:00
Peter Hawkins
52fa165d75
Simplify promote_shapes.
...
We can use lax.broadcast_to_rank instead of the considerably more complicated _broadcast_to.
Add a fast path to broadcast_to_rank and broadcast to avoid emitting an equation if the rank is already correct.
2024-07-24 19:42:16 -04:00
jax authors
f1cfd99fe8
Merge pull request #22625 from hawkinsp:broadcast
...
PiperOrigin-RevId: 655738756
2024-07-24 16:29:13 -07:00
Jake VanderPlas
b8bb869b7c
[array API] clean up some unused/unnecessary code
2024-07-24 15:52:18 -07:00
jax authors
422e033d98
Merge pull request #22640 from jakevdp:fix-assert-warns
...
PiperOrigin-RevId: 655724816
2024-07-24 15:41:45 -07:00
jax authors
02235eb5aa
Merge pull request #22629 from hawkinsp:broadcast3
...
PiperOrigin-RevId: 655722008
2024-07-24 15:33:25 -07:00
Pavel Sountsov
5ba26953be
Add canonical
arg to Rotation.as_quat() and switch .inv() to use the quaternion conjugate.
...
This matches scipy behavior as of 1.11.
I also went through the tests and enabled a bunch of disabled tests which appear to pass now(?).
PiperOrigin-RevId: 655719643
2024-07-24 15:29:19 -07:00
Parker Schuh
be6b77cc54
Update shard_map(jit) to properly set manual_axes on in_shardings and out_shardings of the nested jit. This avoids a problem where the jit returns {manaual} and then this gets passed to ShardToFull (manual is already considered a full sharding).
...
PiperOrigin-RevId: 655719254
2024-07-24 15:25:27 -07:00
Ruturaj4
f7039b9142
[ROCM] rocm plugin is no longer experiemental!
2024-07-24 17:11:03 -05:00
jax authors
f32b5d9e79
Merge pull request #22616 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 655708718
2024-07-24 14:54:21 -07:00
jax authors
b8aae435ee
Merge pull request #22555 from rajasekharporeddy:testbranch2
...
PiperOrigin-RevId: 655708129
2024-07-24 14:50:30 -07:00
jax authors
770d5169fc
Update XLA dependency to use revision
...
7ec84c52b0
.
PiperOrigin-RevId: 655704569
2024-07-24 14:39:18 -07:00
Adam Paszke
4f19af911c
[Mosaic GPU] Only split collective TMAs only (multiple) major dimensions
...
Each TMA only writes to a contiguous subset of SMEM, so skipping a major
dimension while splitting results in incorrect code. To work around the
loss of flexibility, we now allow splitting multiple leading dimensions
to handle larger clusters and tiled references.
PiperOrigin-RevId: 655700486
2024-07-24 14:26:07 -07:00
jax authors
06b199f4ca
Merge pull request #22636 from jakevdp:test-typos
...
PiperOrigin-RevId: 655697660
2024-07-24 14:18:18 -07:00
Jake VanderPlas
0f13646370
Fix incorrect usage of assertWarns in tests
2024-07-24 14:16:51 -07:00
cjkkkk
c8b474e1f2
add sm86/sm89
2024-07-24 21:12:37 +00:00
Adam Paszke
dbe8f56353
[Mosaic GPU] Strengthen cluster-related tests by covering more cluster shapes
...
In particular test trivial collectives (over singleton cluster axes), collectives
over more than 2 devices and clusters larger than 8 devices. This uncovered a few
more bugs in the implementation.
PiperOrigin-RevId: 655686102
2024-07-24 13:43:52 -07:00
Yash Katariya
b6e86c413a
Remove dead code now that xmap is deleted
...
PiperOrigin-RevId: 655664512
2024-07-24 12:40:20 -07:00
Jake VanderPlas
b0cd3b3ec5
lax_numpy_test: fix some typos
2024-07-24 10:47:27 -07:00
Yash Katariya
0d5dae09ff
Delete xmap
and the jax.experimental.maps
module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
...
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
Jaroslav Sevcik
fa2339c35e
Simplify and rename GPU mocking settings
2024-07-24 16:35:49 +00:00
Peter Hawkins
7527101672
Don't broadcast scalar conditions in the jnp.where implementation().
...
The underlying lax primitive is perfectly happy to accept scalar conditions with the other arguments being non-scalar.
2024-07-24 12:06:51 -04:00
Vladimir Belitskiy
d9a7cb4490
Skip pallas/gpu_attention_test.py on TPU.
...
PiperOrigin-RevId: 655575719
2024-07-24 08:24:57 -07:00
Peter Hawkins
34ce9f21db
Simplify implementation of _broadcast_to.
...
_broadcast_to needlessly squeezes away size 1 dimensions before passing its input to broadcast_in_dim. But broadcast_in_dim is perfectly happy to broadcast size 1 dimensions, so we don't need this squeeze.
2024-07-24 10:57:54 -04:00
Michal Kazmierski
61374c92ad
Fix error message in jax.nn.dot_product_attention when the inputs have different dtypes.
...
PiperOrigin-RevId: 655553414
2024-07-24 07:13:15 -07:00