22142 Commits

Author SHA1 Message Date
Jake VanderPlas
31abce1c80 register several deprecations in jax.numpy
This is in preparation for finalizing these deprecations. They include:
- complex->real casting in jnp.astype
- complex inputs to jnp.clip
- complex inputs to jnp.hypot

PiperOrigin-RevId: 656005670
2024-07-25 10:40:00 -07:00
Paweł Paruzel
ae40c87919 Activate Cholesky Factorization Kernel to XLA's FFI
PiperOrigin-RevId: 655990468
2024-07-25 09:59:28 -07:00
Jake VanderPlas
8c36f90c90 [array API] clean up some superseded definitions 2024-07-25 09:40:02 -07:00
jax authors
9ea79c61f4 Merge pull request #22653 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 655980588
2024-07-25 09:29:07 -07:00
Christos Perivolaropoulos
d54bf529cc [mosaic_gpu] arrive_expect_tx() also accepts index typed values
PiperOrigin-RevId: 655965308
2024-07-25 08:33:55 -07:00
Vladimir Belitskiy
ba50e77407 Increase shard count for //third_party/py/jax/tests:lax_numpy_ufuncs_test_cpu.
PiperOrigin-RevId: 655946922
2024-07-25 07:26:06 -07:00
Sergei Lebedev
99a8b92a7f Fixed Pallas Mosaic GPU tests
* Migrated to the new barrier APIs
* Fixed scratch view casting logic, it previously didn't work for >1 view

PiperOrigin-RevId: 655937541
2024-07-25 06:51:13 -07:00
Adam Paszke
2e6da35e97 [Mosaic GPU] Add support for clusters in the matmul example
With the collective async_copy API, the changes are quite minimal!

PiperOrigin-RevId: 655937185
2024-07-25 06:46:51 -07:00
Adam Paszke
e59303cf3e [Mosaic GPU] Simplify the matmul example
Remove a bunch of WGMMAImpl classes. This is meant to be a simple forkable example,
not a complete kernel.

PiperOrigin-RevId: 655923069
2024-07-25 05:43:57 -07:00
Adam Paszke
be9cc807d9 [Mosaic GPU] Minor cleanups in the matmul example
We were incorrectly reporting the runtime performance.

PiperOrigin-RevId: 655915808
2024-07-25 05:10:41 -07:00
Bart Chrzaszcz
b00f978f70 #sdy Support with_sharding_constraint lowering through Shardy.
PiperOrigin-RevId: 655905063
2024-07-25 04:20:52 -07:00
jax authors
f15f9717c3 [Pallas/TPU] Fix bug with LocalMask grid shrinking
LocalMasks can trigger shrinking of the MaskInfo arrays and of the iteration space.
As a consequence, it is important that in the kernel body we use the `global_kv_index`. This is the kv_index in the "global" space without any shrinking of the iteration space.

PiperOrigin-RevId: 655901432
2024-07-25 04:05:57 -07:00
rajasekharporeddy
4ff69d3853 Fix jnp.triu_indices_from and jnp.tril_indices_from to emulate NumPy's behavior for arrays other than 2-D 2024-07-25 16:26:14 +05:30
jax authors
e14752c0ab Merge pull request #22642 from ROCm:ci_jax_exp
PiperOrigin-RevId: 655894235
2024-07-25 03:36:36 -07:00
jax authors
2dadbd7eb6 Merge pull request #22605 from Cjkkkk:add_sm86_sm89_flash_attention
PiperOrigin-RevId: 655894058
2024-07-25 03:32:45 -07:00
George Necula
e6066e52a2 [pallas] Stop exporting jax.experimental.pallas.pallas
This was giving access to too many internal APIs.

PiperOrigin-RevId: 655887765
2024-07-25 03:03:58 -07:00
George Necula
4063373b22 Reverts 0d058ce86f04a44a51abba1261768fb46edf69d9
PiperOrigin-RevId: 655871052
2024-07-25 01:50:36 -07:00
Matthew Johnson
c8ea86c9c9 remove inlined jax.nn.initializers definitions, resolving TODO of levskaya et al
fixes breakage from cl/655766534 aka https://github.com/google/jax/pull/21069

PiperOrigin-RevId: 655806010
2024-07-24 20:55:36 -07:00
jax authors
76b4c70c23 Merge pull request #22628 from hawkinsp:broadcast2
PiperOrigin-RevId: 655779730
2024-07-24 19:17:25 -07:00
Yash Katariya
51e27923e8 Simplify pjit's batching rule now that xmap is deleted. Also do cleanup around adding manual axes under shard_map
PiperOrigin-RevId: 655776234
2024-07-24 19:02:13 -07:00
jax authors
086b500da6 Merge pull request #21069 from mattjj:remove-named-shapes
PiperOrigin-RevId: 655766534
2024-07-24 18:20:50 -07:00
Matthew Johnson
3f9eb404e4 remove named_shapes (since xmap is now gone) 2024-07-25 00:54:50 +00:00
Christos Perivolaropoulos
80a193d5db [pallas] Use the same primitive run_scoped_p for moth mosaic and mosaic_gpu
PiperOrigin-RevId: 655751205
2024-07-24 17:14:30 -07:00
Christos Perivolaropoulos
aa99113173 [MosaicGPU] Expose the BarrierRef type so we can type our variables outside of mosaic.
PiperOrigin-RevId: 655748632
2024-07-24 17:06:44 -07:00
jax authors
b5010a2dac Merge pull request #22634 from jakevdp:array-api-cleanup
PiperOrigin-RevId: 655747905
2024-07-24 17:02:44 -07:00
Tomás Longeri
220ec2aa69 [Mosaic TPU] (8,128),-2 -> (8,128) for non-zero and replicated 2nd minor offset
Also fix bug where relayouts for fully replicated source assumed it was a no-op without checking implicit dims

PiperOrigin-RevId: 655746766
2024-07-24 16:58:35 -07:00
Peter Hawkins
52fa165d75 Simplify promote_shapes.
We can use lax.broadcast_to_rank instead of the considerably more complicated _broadcast_to.

Add a fast path to broadcast_to_rank and broadcast to avoid emitting an equation if the rank is already correct.
2024-07-24 19:42:16 -04:00
jax authors
f1cfd99fe8 Merge pull request #22625 from hawkinsp:broadcast
PiperOrigin-RevId: 655738756
2024-07-24 16:29:13 -07:00
Jake VanderPlas
b8bb869b7c [array API] clean up some unused/unnecessary code 2024-07-24 15:52:18 -07:00
jax authors
422e033d98 Merge pull request #22640 from jakevdp:fix-assert-warns
PiperOrigin-RevId: 655724816
2024-07-24 15:41:45 -07:00
jax authors
02235eb5aa Merge pull request #22629 from hawkinsp:broadcast3
PiperOrigin-RevId: 655722008
2024-07-24 15:33:25 -07:00
Pavel Sountsov
5ba26953be Add canonical arg to Rotation.as_quat() and switch .inv() to use the quaternion conjugate.
This matches scipy behavior as of 1.11.

I also went through the tests and enabled a bunch of disabled tests which appear to pass now(?).

PiperOrigin-RevId: 655719643
2024-07-24 15:29:19 -07:00
Parker Schuh
be6b77cc54 Update shard_map(jit) to properly set manual_axes on in_shardings and out_shardings of the nested jit. This avoids a problem where the jit returns {manaual} and then this gets passed to ShardToFull (manual is already considered a full sharding).
PiperOrigin-RevId: 655719254
2024-07-24 15:25:27 -07:00
Ruturaj4
f7039b9142 [ROCM] rocm plugin is no longer experiemental! 2024-07-24 17:11:03 -05:00
jax authors
f32b5d9e79 Merge pull request #22616 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 655708718
2024-07-24 14:54:21 -07:00
jax authors
b8aae435ee Merge pull request #22555 from rajasekharporeddy:testbranch2
PiperOrigin-RevId: 655708129
2024-07-24 14:50:30 -07:00
jax authors
770d5169fc Update XLA dependency to use revision
7ec84c52b0.

PiperOrigin-RevId: 655704569
2024-07-24 14:39:18 -07:00
Adam Paszke
4f19af911c [Mosaic GPU] Only split collective TMAs only (multiple) major dimensions
Each TMA only writes to a contiguous subset of SMEM, so skipping a major
dimension while splitting results in incorrect code. To work around the
loss of flexibility, we now allow splitting multiple leading dimensions
to handle larger clusters and tiled references.

PiperOrigin-RevId: 655700486
2024-07-24 14:26:07 -07:00
jax authors
06b199f4ca Merge pull request #22636 from jakevdp:test-typos
PiperOrigin-RevId: 655697660
2024-07-24 14:18:18 -07:00
Jake VanderPlas
0f13646370 Fix incorrect usage of assertWarns in tests 2024-07-24 14:16:51 -07:00
cjkkkk
c8b474e1f2 add sm86/sm89 2024-07-24 21:12:37 +00:00
Adam Paszke
dbe8f56353 [Mosaic GPU] Strengthen cluster-related tests by covering more cluster shapes
In particular test trivial collectives (over singleton cluster axes), collectives
over more than 2 devices and clusters larger than 8 devices. This uncovered a few
more bugs in the implementation.

PiperOrigin-RevId: 655686102
2024-07-24 13:43:52 -07:00
Yash Katariya
b6e86c413a Remove dead code now that xmap is deleted
PiperOrigin-RevId: 655664512
2024-07-24 12:40:20 -07:00
Jake VanderPlas
b0cd3b3ec5 lax_numpy_test: fix some typos 2024-07-24 10:47:27 -07:00
Yash Katariya
0d5dae09ff Delete xmap and the jax.experimental.maps module. It's been 5 months since its deprecation (more than the standard 3 months deprecation period).
PiperOrigin-RevId: 655614395
2024-07-24 10:24:09 -07:00
Jaroslav Sevcik
fa2339c35e Simplify and rename GPU mocking settings 2024-07-24 16:35:49 +00:00
Peter Hawkins
7527101672 Don't broadcast scalar conditions in the jnp.where implementation().
The underlying lax primitive is perfectly happy to accept scalar conditions with the other arguments being non-scalar.
2024-07-24 12:06:51 -04:00
Vladimir Belitskiy
d9a7cb4490 Skip pallas/gpu_attention_test.py on TPU.
PiperOrigin-RevId: 655575719
2024-07-24 08:24:57 -07:00
Peter Hawkins
34ce9f21db Simplify implementation of _broadcast_to.
_broadcast_to needlessly squeezes away size 1 dimensions before passing its input to broadcast_in_dim. But broadcast_in_dim is perfectly happy to broadcast size 1 dimensions, so we don't need this squeeze.
2024-07-24 10:57:54 -04:00
Michal Kazmierski
61374c92ad Fix error message in jax.nn.dot_product_attention when the inputs have different dtypes.
PiperOrigin-RevId: 655553414
2024-07-24 07:13:15 -07:00