Sharad Vikram
39ec5dacb4
[Pallas TPU] Add matrix multiplication tutorial
2024-07-16 18:12:19 -07:00
Sharad Vikram
7069a5a2e1
[Pallas/Mosaic] Refactor how we lower Pallas to a custom call
...
This avoids a round-trip through lower_fun which creates unnecessary HLO ops and name stack entries (which can add noise to trace viewers).
PiperOrigin-RevId: 653018101
2024-07-16 16:21:00 -07:00
jax authors
778477b62b
Update XLA dependency to use revision
...
d8e2b28463
.
PiperOrigin-RevId: 652932751
2024-07-16 12:11:43 -07:00
jax authors
4907c38742
Merge pull request #22386 from rdyro:rdyro/explain_persistent_compilation
...
PiperOrigin-RevId: 652928829
2024-07-16 12:00:37 -07:00
jax authors
84827bf247
Merge pull request #22418 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 652921968
2024-07-16 11:40:06 -07:00
jax authors
5ddec63a47
Merge pull request #22441 from gnecula:test_clean_hypothesis
...
PiperOrigin-RevId: 652919414
2024-07-16 11:32:46 -07:00
Chris Jones
15c542228a
[jax] Make the SupportsDType
protocol runtime checkable.
...
This allows `DTypeLike` to be used as a type annotation for type-checked functions without triggering a warning.
PiperOrigin-RevId: 652905699
2024-07-16 10:57:15 -07:00
jax authors
661ecdd83e
Merge pull request #22475 from dfm:fix-lint
...
PiperOrigin-RevId: 652865171
2024-07-16 09:01:31 -07:00
Dan Foreman-Mackey
556cc23fa5
Fix lint at head.
...
It looks like https://github.com/google/jax/pull/22330 introduced some
mypy lint. This PR fixes it.
2024-07-16 10:53:49 -04:00
Kaixi Hou
09531d2ff8
PR #22330 : [NVIDIA] Remove logic of combining bias and mask
...
Imported from GitHub PR https://github.com/google/jax/pull/22330
The cudnn API has already supported the combination of bias and mask from [this PR](https://github.com/google/jax/pull/22078 ). We are removing the logic from the public sdpa API and pass the mask directly.
cc. @Cjkkkk
Copybara import of the project:
--
0f75f58a9d81c0ae0a83701a71998c940318732a by kaixih <kaixih@nvidia.com>:
Remove logic of combining bias and mask
Merging this change closes #22330
COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/22330 from kaixih:remove_combine_bias_mask 0f75f58a9d81c0ae0a83701a71998c940318732a
PiperOrigin-RevId: 652830016
2024-07-16 07:19:01 -07:00
Sergei Lebedev
7a62b8dd18
Re-enabled PallasCallPrintTest on Cloud TPUs
...
PiperOrigin-RevId: 652823653
2024-07-16 06:55:20 -07:00
rajasekharporeddy
a00c94c5a2
Improved docs for jnp.median and nanmedian
2024-07-16 17:27:46 +05:30
George Necula
e7be205a39
[jax2tf] Fix jax2tf tolerance for Cholesky, needed for newer TPUs
...
PiperOrigin-RevId: 652751771
2024-07-16 02:13:11 -07:00
George Necula
d34a6e9ce2
[jax2tf] Deprecate jax2tf with native_serialization=False or enable_xla=False.
...
Also disable many of the non-native-serialization jax2tf tests.
In particular, I am disabling the thousands of primitives tests in
graph serialization mode.
I kept jax2tf_test running in both native and graph serialization mode.
PiperOrigin-RevId: 652749891
2024-07-16 02:05:43 -07:00
Christos Perivolaropoulos
28ffa25496
[MosaicGPU] Move parity computations to a separate function to allow the user to use wait_parity without duplicate code.
...
PiperOrigin-RevId: 652665738
2024-07-15 19:08:10 -07:00
Justin Fu
0690988626
[Pallas] Add limited boolean memref support for scalars.
...
PiperOrigin-RevId: 652653003
2024-07-15 17:59:05 -07:00
rdyro
c6d6207170
Unifying persistent cache messages
...
and moving them to WARNING logging when explain_cache_misses is true.
2024-07-16 00:47:53 +00:00
jax authors
21e9dabdad
Merge pull request #22465 from ROCm:ci_multiprocess_gpu_test
...
PiperOrigin-RevId: 652602616
2024-07-15 14:38:58 -07:00
Ruturaj4
fb5c516405
[ROCM] test_computation_follows_data fix for rocm and cuda
2024-07-15 16:28:22 -05:00
jax authors
3e09040cf7
Merge pull request #22463 from hawkinsp:tpuci2
...
PiperOrigin-RevId: 652597728
2024-07-15 14:25:15 -07:00
jax authors
fc3654792b
Merge pull request #22450 from kaixih:update_sdpa_doc
...
PiperOrigin-RevId: 652597691
2024-07-15 14:21:27 -07:00
Yash Katariya
bb7a6995f9
Remove the spmd_mode
check. It's disabled in OSS since a long time.
...
PiperOrigin-RevId: 652591122
2024-07-15 13:58:23 -07:00
Peter Hawkins
f488c4cc31
Disable some tests that fail on Cloud TPU.
2024-07-15 16:00:58 -04:00
Peter Hawkins
2f45cd725a
Bump some SVD test tolerances.
...
These just barely fail on recent TPUs.
PiperOrigin-RevId: 652571985
2024-07-15 12:54:00 -07:00
jax authors
5d3392927a
Update XLA dependency to use revision
...
83e38528e0
.
PiperOrigin-RevId: 652563753
2024-07-15 12:27:15 -07:00
Kevin Gleason
5e897c61f5
Integrate StableHLO at openxla/stablehlo@8817ff1d
...
PiperOrigin-RevId: 652528759
2024-07-15 10:38:09 -07:00
kaixih
0d387e0839
Update jax doc sdpa
2024-07-15 17:30:54 +00:00
jax authors
26ec43f9e5
Merge pull request #22445 from dfm:numpy-nightly-unique
...
PiperOrigin-RevId: 652520243
2024-07-15 10:16:03 -07:00
jax authors
cab1f85f09
Merge pull request #22448 from hawkinsp:tpuci
...
PiperOrigin-RevId: 652520107
2024-07-15 10:12:39 -07:00
jax authors
7255ab146b
Merge pull request #22440 from gnecula:pallas_test_clean
...
PiperOrigin-RevId: 652513116
2024-07-15 09:55:38 -07:00
jax authors
2b29a94255
Merge pull request #22375 from jakevdp:mypy-docs
...
PiperOrigin-RevId: 652511749
2024-07-15 09:52:07 -07:00
jax authors
5216719996
Merge pull request #22405 from gnecula:poly_pad
...
PiperOrigin-RevId: 652511693
2024-07-15 09:48:21 -07:00
Yash Katariya
b6264e99b1
Skip layout tests because they require xla_extension_version >= 274
...
PiperOrigin-RevId: 652510288
2024-07-15 09:43:03 -07:00
Peter Hawkins
3019467992
Update references to jaxlib_nightly_releases (which is a legacy index and no longer updated) to jax_nightly_releases.
...
Fixes CI jobs using old jaxlib nightly wheels.
2024-07-15 12:42:50 -04:00
Peter Hawkins
a1f69713f5
Disable Pallas vmap test that is very slow under tsan.
...
PiperOrigin-RevId: 652505878
2024-07-15 09:28:35 -07:00
jax authors
9c72e67711
Merge pull request #22169 from gspschmid:gschmid/async_serialize-overlap-shard-copies
...
PiperOrigin-RevId: 652502425
2024-07-15 09:15:21 -07:00
jax authors
86f4bb4346
Added more Mosaic bug reproducers.
...
PiperOrigin-RevId: 652498944
2024-07-15 09:02:48 -07:00
Dan Foreman-Mackey
7857bd3319
Fix compatibility of jnp.unique with numpy nightly
...
In https://github.com/numpy/numpy/pull/26914 , the behavior of the
`return_inverse` argument to `np.unique` was partially reverted to the
pre-v2.0 behavior. The PR brings JAX's implementation compatible with
the `numpy>2.0.0` behavior.
2024-07-15 11:41:00 -04:00
jax authors
2e7e700090
Merge pull request #22402 from superbobry:maint
...
PiperOrigin-RevId: 652487969
2024-07-15 08:20:56 -07:00
George Necula
d3454f374e
Add some hypothesis testing utilities and developer documentation.
...
Add a helper function for setting up hypothesis testing,
with support for selecting an interactive hypothesis profile
that speeds up interactive development.
2024-07-15 17:05:32 +02:00
George Necula
791743c296
[pallas] Some minor test housekeeping.
...
Add missing interpreter tests in tpu_pallas_test.
2024-07-15 17:05:15 +02:00
George Necula
7817b6785b
[shape_poly] Expand the support for shape polymorphism for jnp.pad
...
Handle several new padding modes: wrap, reflect, symmetric, linear_ramp, maximum.
Not all situations are handled; try to give a clear error for the unsupported
cases.
While implementing this, I needed to add shape polymorphism support
also for jnp.linspace.
And I discovered a bug in the implementation of `divmod(0, b)`.
2024-07-15 17:04:54 +02:00
jax authors
b966c74fae
Merge pull request #22443 from gnecula:docs_checkify
...
PiperOrigin-RevId: 652482390
2024-07-15 08:01:31 -07:00
George Necula
be8e83adc1
[docs] Fix docs building error
...
The checkify APIs were mentioned in the jax.experimental.rst and also
in jax.experimental.checkify.rst.
2024-07-15 15:42:33 +01:00
Georg Stefan Schmid
0428871c82
Adapt test case
2024-07-15 09:37:59 +00:00
Georg Stefan Schmid
cdd17cabba
Also overlap staged transfers (ba to hostseline case when jax memories or pinned_host is unavailable)
2024-07-15 09:37:53 +00:00
Georg Stefan Schmid
6b0136b829
async_serialize: Overlap shard copies
2024-07-15 09:37:47 +00:00
Georg Stefan Schmid
b8b9d2878c
[memories] Transfer to pinned_host fast path in async_serialize
2024-07-15 09:35:43 +00:00
jax authors
a8b425cac5
Update XLA dependency to use revision
...
c50e9f58f0
.
PiperOrigin-RevId: 652100151
2024-07-13 13:14:57 -07:00
jax authors
764ec92118
Add support for elementwise op canonicalization in fp32 for older hardware.
...
PiperOrigin-RevId: 651959463
2024-07-12 19:58:55 -07:00