Yash Katariya
be53ee10b1
Set jax_enable_memories
flag to True
by default
...
PiperOrigin-RevId: 660579462
2024-08-07 16:25:25 -07:00
jax authors
7efca0490f
Merge pull request #22920 from jakevdp:fix-lint
...
PiperOrigin-RevId: 660570457
2024-08-07 16:01:09 -07:00
jax authors
a57d6591ee
Update XLA dependency to use revision
...
3bf7e1ae48
.
PiperOrigin-RevId: 660570144
2024-08-07 15:57:41 -07:00
Jake VanderPlas
53af0d4d90
CI: fix mypy errors
2024-08-07 15:15:45 -07:00
jax authors
de02988e94
Merge pull request #22909 from ROCm:ci_fix_solver_paths
...
PiperOrigin-RevId: 660515208
2024-08-07 13:26:17 -07:00
jax authors
cce725059a
Merge pull request #22830 from kaixih:support_vmap
...
PiperOrigin-RevId: 660509938
2024-08-07 13:12:59 -07:00
jax authors
d3b6066f91
Merge pull request #22820 from Rifur13:mha-faster
...
PiperOrigin-RevId: 660461104
2024-08-07 11:11:15 -07:00
jax authors
32131d0288
Merge pull request #22897 from jakevdp:bool-indexing
...
PiperOrigin-RevId: 660444193
2024-08-07 10:30:41 -07:00
Sergei Lebedev
6fc57c0eb6
Rolling forward #22836
...
This version, proposed by @dfm, does not have a custom JVP for the whole
logsumexp and instead fixes #22398 directly.
Reverts e416c6675acfd82866a6e83e8c221640c4d02f29
PiperOrigin-RevId: 660438802
2024-08-07 10:17:55 -07:00
jax authors
893ae6eb80
Merge pull request #22869 from dfm:custom-batching-polish
...
PiperOrigin-RevId: 660421503
2024-08-07 09:40:46 -07:00
jax authors
5cb9510f60
Merge pull request #22908 from gnecula:pallas_warn
...
PiperOrigin-RevId: 660421476
2024-08-07 09:37:15 -07:00
jax authors
930c8ca791
Merge pull request #22914 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 660421322
2024-08-07 09:33:43 -07:00
Ruturaj4
a2d79936df
[ROCM] Fix BUILD.bazel library source paths
2024-08-07 09:18:20 -05:00
Sergei Lebedev
3a1567f57a
Do not run nn_test under asan -- it times out
...
PiperOrigin-RevId: 660377176
2024-08-07 07:14:27 -07:00
rajasekharporeddy
3095c570b8
Better docs for jnp.fft.rfft2 and jnp.fft.irfft2
2024-08-07 17:59:53 +05:30
George Necula
3e5e947542
Move some backwards compatibility tests from jax_triton to jax/pallas.
...
While doing this I moved `matmul.py` to `jax/experimental/pallas/ops/tpu`
PiperOrigin-RevId: 660341331
2024-08-07 05:00:29 -07:00
Sergei Lebedev
28ca734d9b
Added another boxDim check to mosaic_gpu_init_tma_desc
...
PiperOrigin-RevId: 660314586
2024-08-07 03:16:54 -07:00
George Necula
64eb8e9639
[pallas] Add a warning message about experimental and incomplete status
2024-08-07 08:38:56 +03:00
Sharad Vikram
803453ed74
[Pallas TPU] Close over consts in while_loop lowering to avoid passing refs in/out of loop
...
PiperOrigin-RevId: 660238073
2024-08-06 22:33:15 -07:00
Yash Katariya
dd958adc39
Add mesh_shape
to the lowering context. This is to allow custom partitioning to not depend on the mesh context manager to return NamedShardings even if the arguments have NamedShardings on them.
...
Since `shardy`, sharding in types work, world 2 dagger is going in a direction of making Mesh and PartitionSpec a first class sharding type, let's pull the trigger right now to start fixing these bad user interactions.
Some things that will break due to this change: Before passing NamedSharding and an equivalent PositionalSharding to the same jitted function one after another would lead to a lowering cache hit. But now we will cache miss. In other words: `f(ns); f(ps) # cache hit before`
In followup CLs, we will make the tracing cache aware of the mesh shape too to fix some other issues related to tracing and lowering cache misses
PiperOrigin-RevId: 660177423
2024-08-06 18:35:44 -07:00
Yue Sheng
7f44edc01e
Change log level of clearing JAX backend caches from info to debug.
...
PiperOrigin-RevId: 660141868
2024-08-06 16:27:56 -07:00
jax authors
798297af98
Update XLA dependency to use revision
...
08b8d938eb
.
PiperOrigin-RevId: 660133285
2024-08-06 16:05:14 -07:00
jax authors
53ab5eb24f
Merge pull request #22900 from jakevdp:dep-bfloat16
...
PiperOrigin-RevId: 660102762
2024-08-06 14:42:43 -07:00
jax authors
9074e8544f
Add test for zero-sized host memory parameter
...
PiperOrigin-RevId: 660097039
2024-08-06 14:31:41 -07:00
jax authors
aec6efb44b
Merge pull request #22649 from ROCm:ci_jax_export_harness
...
PiperOrigin-RevId: 660096296
2024-08-06 14:27:13 -07:00
jax authors
cc9665749f
Merge pull request #22901 from ROCm:ci_test_harness_vmap
...
PiperOrigin-RevId: 660089572
2024-08-06 14:04:57 -07:00
Jieying Luo
abe7982d65
Remove enable_gpu and xla_python_enable_gpu from jax .bazelrc.
...
The plugin is released and the flag is no longer needed.
Also set default value of enable_gpu to False. enable_gpu will be removed in the next change.
PiperOrigin-RevId: 660059432
2024-08-06 12:39:45 -07:00
Kanglan Tang
ae541203bc
Skip flaky test_weight_offload_with_dp_on_output test on GPU backend.
...
PiperOrigin-RevId: 660057950
2024-08-06 12:35:53 -07:00
Ruturaj4
707cdd4706
[ROCM] Fix hipsolverSsyevd tests due to align with the rocm behavior.
2024-08-06 14:10:09 -05:00
Jake VanderPlas
a009e1cf50
deprecate jax.lib.xla_client.bfloat16
2024-08-06 11:22:27 -07:00
jax authors
f67f73c352
Merge pull request #22834 from rajasekharporeddy:testbranch1
...
PiperOrigin-RevId: 660014909
2024-08-06 10:46:56 -07:00
jax authors
799de71c41
Merge pull request #22896 from jakevdp:pin-sphinx
...
PiperOrigin-RevId: 660012176
2024-08-06 10:39:40 -07:00
Jake VanderPlas
b45f0fe50f
Support empty boolean indexing
2024-08-06 09:56:03 -07:00
Jake VanderPlas
4f8c5a335d
CI: pin sphinx to avoid build errors on 8.0
2024-08-06 09:16:41 -07:00
Ruturaj4
35c70fd3ec
[ROCM] Fix export harness tests
2024-08-06 10:12:31 -05:00
jax authors
8b9ceb598b
Handle bool comparisons.
...
PiperOrigin-RevId: 659919931
2024-08-06 05:37:35 -07:00
Adam Paszke
209f6cd6c1
[Mosaic GPU] Profiler improvements
...
1. Each process now corresponds to an SM, showing how many blocks
are executing concurrently.
2. The timeline now accounts for the start offset of each block,
instead of aligning them together. This makes a lot more sense in
the SM view.
3. We now use inline PTX to emit profiler events. This sometimes slightly
pessimizes code generation, but allows us to predicate out write on
all threads other than the leader of each warpgroup, improving the
trace quality.
4. We make sure each trace is monotonic. I can't explain why but the clocks
can behave very weirdly, potentially due to rescheduling on the SASS level.
We now fix up all backward movements and emit a warning if big shifts have
been detected.
PiperOrigin-RevId: 659911268
2024-08-06 05:02:59 -07:00
Dan Foreman-Mackey
23da11b609
Re-land FFI port of GPU LU decomposition after fixing XLA FFI memory leak.
...
PiperOrigin-RevId: 659867028
2024-08-06 02:13:21 -07:00
Yue Sheng
f255fb700a
Async dispatch expensive computations on the JAX CPU backend. By setting jax.config.update('jax_cpu_enable_async_dispatch', False)
, one could opt out of the change and recover the old behavior.
...
PiperOrigin-RevId: 659741822
2024-08-05 17:48:17 -07:00
jax authors
0ab4d68511
Merge pull request #22885 from jakevdp:dep-xla
...
PiperOrigin-RevId: 659724044
2024-08-05 16:43:46 -07:00
Jake VanderPlas
06f29bbb97
Deprecate jax.lib.xla_client._xla
...
This is an alias for jax.lib.xla_extension. Why the deprecation warning
for this when #22844 removed other APIs without any warning? This one
is relatively commonly used (I found a few dozen downstream references)
so I feld that a deprecation warning might be helpful.
2024-08-05 16:19:59 -07:00
Yash Katariya
489fbc0ed5
Add a test for streaming in closed over constants from host to device
...
PiperOrigin-RevId: 659711557
2024-08-05 16:00:45 -07:00
jax authors
a497fbc558
Update XLA dependency to use revision
...
b33429c33d
.
PiperOrigin-RevId: 659695240
2024-08-05 15:08:18 -07:00
Jake VanderPlas
3d857b02ac
export jax.lib.xla_extension.HloModule
...
Followup to #22844 , because the symbol is used downstream.
PiperOrigin-RevId: 659678623
2024-08-05 14:16:40 -07:00
Sergei Lebedev
e416c6675a
Reverts 0f103d33849ca017e6a199d0f79fa0d83b373995
...
PiperOrigin-RevId: 659670593
2024-08-05 13:52:04 -07:00
jax authors
c2c04e054e
Merge pull request #22608 from kaixih:fix_cuda_version_check
...
PiperOrigin-RevId: 659664879
2024-08-05 13:34:43 -07:00
kaixih
09b88430e9
Fix CUDA version checks
2024-08-05 20:09:17 +00:00
jax authors
0f103d3384
Merge pull request #22836 from superbobry:maint-2
...
PiperOrigin-RevId: 659644462
2024-08-05 12:30:51 -07:00
jax authors
af1a69edfd
Merge pull request #22870 from google:dependabot/github_actions/actions/upload-artifact-4.3.5
...
PiperOrigin-RevId: 659604731
2024-08-05 10:47:06 -07:00
dependabot[bot]
6d7cf3fcfd
Bump actions/upload-artifact from 4.3.3 to 4.3.5
...
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact ) from 4.3.3 to 4.3.5.
- [Release notes](https://github.com/actions/upload-artifact/releases )
- [Commits](65462800fd...89ef406dd8
)
---
updated-dependencies:
- dependency-name: actions/upload-artifact
dependency-type: direct:production
update-type: version-update:semver-patch
...
Signed-off-by: dependabot[bot] <support@github.com>
2024-08-05 17:09:11 +00:00