22340 Commits

Author SHA1 Message Date
Yash Katariya
be53ee10b1 Set jax_enable_memories flag to True by default
PiperOrigin-RevId: 660579462
2024-08-07 16:25:25 -07:00
jax authors
7efca0490f Merge pull request #22920 from jakevdp:fix-lint
PiperOrigin-RevId: 660570457
2024-08-07 16:01:09 -07:00
jax authors
a57d6591ee Update XLA dependency to use revision
3bf7e1ae48.

PiperOrigin-RevId: 660570144
2024-08-07 15:57:41 -07:00
Jake VanderPlas
53af0d4d90 CI: fix mypy errors 2024-08-07 15:15:45 -07:00
jax authors
de02988e94 Merge pull request #22909 from ROCm:ci_fix_solver_paths
PiperOrigin-RevId: 660515208
2024-08-07 13:26:17 -07:00
jax authors
cce725059a Merge pull request #22830 from kaixih:support_vmap
PiperOrigin-RevId: 660509938
2024-08-07 13:12:59 -07:00
jax authors
d3b6066f91 Merge pull request #22820 from Rifur13:mha-faster
PiperOrigin-RevId: 660461104
2024-08-07 11:11:15 -07:00
jax authors
32131d0288 Merge pull request #22897 from jakevdp:bool-indexing
PiperOrigin-RevId: 660444193
2024-08-07 10:30:41 -07:00
Sergei Lebedev
6fc57c0eb6 Rolling forward #22836
This version, proposed by @dfm, does not have a custom JVP for the whole
logsumexp and instead fixes #22398 directly.

Reverts e416c6675acfd82866a6e83e8c221640c4d02f29

PiperOrigin-RevId: 660438802
2024-08-07 10:17:55 -07:00
jax authors
893ae6eb80 Merge pull request #22869 from dfm:custom-batching-polish
PiperOrigin-RevId: 660421503
2024-08-07 09:40:46 -07:00
jax authors
5cb9510f60 Merge pull request #22908 from gnecula:pallas_warn
PiperOrigin-RevId: 660421476
2024-08-07 09:37:15 -07:00
jax authors
930c8ca791 Merge pull request #22914 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 660421322
2024-08-07 09:33:43 -07:00
Ruturaj4
a2d79936df [ROCM] Fix BUILD.bazel library source paths 2024-08-07 09:18:20 -05:00
Sergei Lebedev
3a1567f57a Do not run nn_test under asan -- it times out
PiperOrigin-RevId: 660377176
2024-08-07 07:14:27 -07:00
rajasekharporeddy
3095c570b8 Better docs for jnp.fft.rfft2 and jnp.fft.irfft2 2024-08-07 17:59:53 +05:30
George Necula
3e5e947542 Move some backwards compatibility tests from jax_triton to jax/pallas.
While doing this I moved `matmul.py` to `jax/experimental/pallas/ops/tpu`

PiperOrigin-RevId: 660341331
2024-08-07 05:00:29 -07:00
Sergei Lebedev
28ca734d9b Added another boxDim check to mosaic_gpu_init_tma_desc
PiperOrigin-RevId: 660314586
2024-08-07 03:16:54 -07:00
George Necula
64eb8e9639 [pallas] Add a warning message about experimental and incomplete status 2024-08-07 08:38:56 +03:00
Sharad Vikram
803453ed74 [Pallas TPU] Close over consts in while_loop lowering to avoid passing refs in/out of loop
PiperOrigin-RevId: 660238073
2024-08-06 22:33:15 -07:00
Yash Katariya
dd958adc39 Add mesh_shape to the lowering context. This is to allow custom partitioning to not depend on the mesh context manager to return NamedShardings even if the arguments have NamedShardings on them.
Since `shardy`, sharding in types work, world 2 dagger is going in a direction of making Mesh and PartitionSpec a first class sharding type, let's pull the trigger right now to start fixing these bad user interactions.

Some things that will break due to this change: Before passing NamedSharding and an equivalent PositionalSharding to the same jitted function one after another would lead to a lowering cache hit. But now we will cache miss. In other words: `f(ns); f(ps) # cache hit before`

In followup CLs, we will make the tracing cache aware of the mesh shape too to fix some other issues related to tracing and lowering cache misses

PiperOrigin-RevId: 660177423
2024-08-06 18:35:44 -07:00
Yue Sheng
7f44edc01e Change log level of clearing JAX backend caches from info to debug.
PiperOrigin-RevId: 660141868
2024-08-06 16:27:56 -07:00
jax authors
798297af98 Update XLA dependency to use revision
08b8d938eb.

PiperOrigin-RevId: 660133285
2024-08-06 16:05:14 -07:00
jax authors
53ab5eb24f Merge pull request #22900 from jakevdp:dep-bfloat16
PiperOrigin-RevId: 660102762
2024-08-06 14:42:43 -07:00
jax authors
9074e8544f Add test for zero-sized host memory parameter
PiperOrigin-RevId: 660097039
2024-08-06 14:31:41 -07:00
jax authors
aec6efb44b Merge pull request #22649 from ROCm:ci_jax_export_harness
PiperOrigin-RevId: 660096296
2024-08-06 14:27:13 -07:00
jax authors
cc9665749f Merge pull request #22901 from ROCm:ci_test_harness_vmap
PiperOrigin-RevId: 660089572
2024-08-06 14:04:57 -07:00
Jieying Luo
abe7982d65 Remove enable_gpu and xla_python_enable_gpu from jax .bazelrc.
The plugin is released and the flag is no longer needed.

Also set default value of enable_gpu to False. enable_gpu will be removed in the next change.

PiperOrigin-RevId: 660059432
2024-08-06 12:39:45 -07:00
Kanglan Tang
ae541203bc Skip flaky test_weight_offload_with_dp_on_output test on GPU backend.
PiperOrigin-RevId: 660057950
2024-08-06 12:35:53 -07:00
Ruturaj4
707cdd4706 [ROCM] Fix hipsolverSsyevd tests due to align with the rocm behavior. 2024-08-06 14:10:09 -05:00
Jake VanderPlas
a009e1cf50 deprecate jax.lib.xla_client.bfloat16 2024-08-06 11:22:27 -07:00
jax authors
f67f73c352 Merge pull request #22834 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 660014909
2024-08-06 10:46:56 -07:00
jax authors
799de71c41 Merge pull request #22896 from jakevdp:pin-sphinx
PiperOrigin-RevId: 660012176
2024-08-06 10:39:40 -07:00
Jake VanderPlas
b45f0fe50f Support empty boolean indexing 2024-08-06 09:56:03 -07:00
Jake VanderPlas
4f8c5a335d CI: pin sphinx to avoid build errors on 8.0 2024-08-06 09:16:41 -07:00
Ruturaj4
35c70fd3ec [ROCM] Fix export harness tests 2024-08-06 10:12:31 -05:00
jax authors
8b9ceb598b Handle bool comparisons.
PiperOrigin-RevId: 659919931
2024-08-06 05:37:35 -07:00
Adam Paszke
209f6cd6c1 [Mosaic GPU] Profiler improvements
1. Each process now corresponds to an SM, showing how many blocks
   are executing concurrently.
2. The timeline now accounts for the start offset of each block,
   instead of aligning them together. This makes a lot more sense in
   the SM view.
3. We now use inline PTX to emit profiler events. This sometimes slightly
   pessimizes code generation, but allows us to predicate out write on
   all threads other than the leader of each warpgroup, improving the
   trace quality.
4. We make sure each trace is monotonic. I can't explain why but the clocks
   can behave very weirdly, potentially due to rescheduling on the SASS level.
   We now fix up all backward movements and emit a warning if big shifts have
   been detected.

PiperOrigin-RevId: 659911268
2024-08-06 05:02:59 -07:00
Dan Foreman-Mackey
23da11b609 Re-land FFI port of GPU LU decomposition after fixing XLA FFI memory leak.
PiperOrigin-RevId: 659867028
2024-08-06 02:13:21 -07:00
Yue Sheng
f255fb700a Async dispatch expensive computations on the JAX CPU backend. By setting jax.config.update('jax_cpu_enable_async_dispatch', False), one could opt out of the change and recover the old behavior.
PiperOrigin-RevId: 659741822
2024-08-05 17:48:17 -07:00
jax authors
0ab4d68511 Merge pull request #22885 from jakevdp:dep-xla
PiperOrigin-RevId: 659724044
2024-08-05 16:43:46 -07:00
Jake VanderPlas
06f29bbb97 Deprecate jax.lib.xla_client._xla
This is an alias for jax.lib.xla_extension. Why the deprecation warning
for this when #22844 removed other APIs without any warning? This one
is relatively commonly used (I found a few dozen downstream references)
so I feld that a deprecation warning might be helpful.
2024-08-05 16:19:59 -07:00
Yash Katariya
489fbc0ed5 Add a test for streaming in closed over constants from host to device
PiperOrigin-RevId: 659711557
2024-08-05 16:00:45 -07:00
jax authors
a497fbc558 Update XLA dependency to use revision
b33429c33d.

PiperOrigin-RevId: 659695240
2024-08-05 15:08:18 -07:00
Jake VanderPlas
3d857b02ac export jax.lib.xla_extension.HloModule
Followup to #22844, because the symbol is used downstream.

PiperOrigin-RevId: 659678623
2024-08-05 14:16:40 -07:00
Sergei Lebedev
e416c6675a Reverts 0f103d33849ca017e6a199d0f79fa0d83b373995
PiperOrigin-RevId: 659670593
2024-08-05 13:52:04 -07:00
jax authors
c2c04e054e Merge pull request #22608 from kaixih:fix_cuda_version_check
PiperOrigin-RevId: 659664879
2024-08-05 13:34:43 -07:00
kaixih
09b88430e9 Fix CUDA version checks 2024-08-05 20:09:17 +00:00
jax authors
0f103d3384 Merge pull request #22836 from superbobry:maint-2
PiperOrigin-RevId: 659644462
2024-08-05 12:30:51 -07:00
jax authors
af1a69edfd Merge pull request #22870 from google:dependabot/github_actions/actions/upload-artifact-4.3.5
PiperOrigin-RevId: 659604731
2024-08-05 10:47:06 -07:00
dependabot[bot]
6d7cf3fcfd
Bump actions/upload-artifact from 4.3.3 to 4.3.5
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4.3.3 to 4.3.5.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](65462800fd...89ef406dd8)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-08-05 17:09:11 +00:00