22318 Commits

Author SHA1 Message Date
Ruturaj4
644ac10c92 [ROCm] improve gpu script rocm-jax-stable-2024_08_07 2024-08-08 09:04:02 -05:00
Ruturaj4
a2d79936df [ROCM] Fix BUILD.bazel library source paths 2024-08-07 09:18:20 -05:00
Yash Katariya
dd958adc39 Add mesh_shape to the lowering context. This is to allow custom partitioning to not depend on the mesh context manager to return NamedShardings even if the arguments have NamedShardings on them.
Since `shardy`, sharding in types work, world 2 dagger is going in a direction of making Mesh and PartitionSpec a first class sharding type, let's pull the trigger right now to start fixing these bad user interactions.

Some things that will break due to this change: Before passing NamedSharding and an equivalent PositionalSharding to the same jitted function one after another would lead to a lowering cache hit. But now we will cache miss. In other words: `f(ns); f(ps) # cache hit before`

In followup CLs, we will make the tracing cache aware of the mesh shape too to fix some other issues related to tracing and lowering cache misses

PiperOrigin-RevId: 660177423
2024-08-06 18:35:44 -07:00
Yue Sheng
7f44edc01e Change log level of clearing JAX backend caches from info to debug.
PiperOrigin-RevId: 660141868
2024-08-06 16:27:56 -07:00
jax authors
798297af98 Update XLA dependency to use revision
08b8d938eb.

PiperOrigin-RevId: 660133285
2024-08-06 16:05:14 -07:00
jax authors
53ab5eb24f Merge pull request #22900 from jakevdp:dep-bfloat16
PiperOrigin-RevId: 660102762
2024-08-06 14:42:43 -07:00
jax authors
9074e8544f Add test for zero-sized host memory parameter
PiperOrigin-RevId: 660097039
2024-08-06 14:31:41 -07:00
jax authors
aec6efb44b Merge pull request #22649 from ROCm:ci_jax_export_harness
PiperOrigin-RevId: 660096296
2024-08-06 14:27:13 -07:00
jax authors
cc9665749f Merge pull request #22901 from ROCm:ci_test_harness_vmap
PiperOrigin-RevId: 660089572
2024-08-06 14:04:57 -07:00
Jieying Luo
abe7982d65 Remove enable_gpu and xla_python_enable_gpu from jax .bazelrc.
The plugin is released and the flag is no longer needed.

Also set default value of enable_gpu to False. enable_gpu will be removed in the next change.

PiperOrigin-RevId: 660059432
2024-08-06 12:39:45 -07:00
Kanglan Tang
ae541203bc Skip flaky test_weight_offload_with_dp_on_output test on GPU backend.
PiperOrigin-RevId: 660057950
2024-08-06 12:35:53 -07:00
Ruturaj4
707cdd4706 [ROCM] Fix hipsolverSsyevd tests due to align with the rocm behavior. 2024-08-06 14:10:09 -05:00
Jake VanderPlas
a009e1cf50 deprecate jax.lib.xla_client.bfloat16 2024-08-06 11:22:27 -07:00
jax authors
f67f73c352 Merge pull request #22834 from rajasekharporeddy:testbranch1
PiperOrigin-RevId: 660014909
2024-08-06 10:46:56 -07:00
jax authors
799de71c41 Merge pull request #22896 from jakevdp:pin-sphinx
PiperOrigin-RevId: 660012176
2024-08-06 10:39:40 -07:00
Jake VanderPlas
4f8c5a335d CI: pin sphinx to avoid build errors on 8.0 2024-08-06 09:16:41 -07:00
Ruturaj4
35c70fd3ec [ROCM] Fix export harness tests 2024-08-06 10:12:31 -05:00
jax authors
8b9ceb598b Handle bool comparisons.
PiperOrigin-RevId: 659919931
2024-08-06 05:37:35 -07:00
Adam Paszke
209f6cd6c1 [Mosaic GPU] Profiler improvements
1. Each process now corresponds to an SM, showing how many blocks
   are executing concurrently.
2. The timeline now accounts for the start offset of each block,
   instead of aligning them together. This makes a lot more sense in
   the SM view.
3. We now use inline PTX to emit profiler events. This sometimes slightly
   pessimizes code generation, but allows us to predicate out write on
   all threads other than the leader of each warpgroup, improving the
   trace quality.
4. We make sure each trace is monotonic. I can't explain why but the clocks
   can behave very weirdly, potentially due to rescheduling on the SASS level.
   We now fix up all backward movements and emit a warning if big shifts have
   been detected.

PiperOrigin-RevId: 659911268
2024-08-06 05:02:59 -07:00
Dan Foreman-Mackey
23da11b609 Re-land FFI port of GPU LU decomposition after fixing XLA FFI memory leak.
PiperOrigin-RevId: 659867028
2024-08-06 02:13:21 -07:00
Yue Sheng
f255fb700a Async dispatch expensive computations on the JAX CPU backend. By setting jax.config.update('jax_cpu_enable_async_dispatch', False), one could opt out of the change and recover the old behavior.
PiperOrigin-RevId: 659741822
2024-08-05 17:48:17 -07:00
jax authors
0ab4d68511 Merge pull request #22885 from jakevdp:dep-xla
PiperOrigin-RevId: 659724044
2024-08-05 16:43:46 -07:00
Jake VanderPlas
06f29bbb97 Deprecate jax.lib.xla_client._xla
This is an alias for jax.lib.xla_extension. Why the deprecation warning
for this when #22844 removed other APIs without any warning? This one
is relatively commonly used (I found a few dozen downstream references)
so I feld that a deprecation warning might be helpful.
2024-08-05 16:19:59 -07:00
Yash Katariya
489fbc0ed5 Add a test for streaming in closed over constants from host to device
PiperOrigin-RevId: 659711557
2024-08-05 16:00:45 -07:00
jax authors
a497fbc558 Update XLA dependency to use revision
b33429c33d.

PiperOrigin-RevId: 659695240
2024-08-05 15:08:18 -07:00
Jake VanderPlas
3d857b02ac export jax.lib.xla_extension.HloModule
Followup to #22844, because the symbol is used downstream.

PiperOrigin-RevId: 659678623
2024-08-05 14:16:40 -07:00
Sergei Lebedev
e416c6675a Reverts 0f103d33849ca017e6a199d0f79fa0d83b373995
PiperOrigin-RevId: 659670593
2024-08-05 13:52:04 -07:00
jax authors
c2c04e054e Merge pull request #22608 from kaixih:fix_cuda_version_check
PiperOrigin-RevId: 659664879
2024-08-05 13:34:43 -07:00
kaixih
09b88430e9 Fix CUDA version checks 2024-08-05 20:09:17 +00:00
jax authors
0f103d3384 Merge pull request #22836 from superbobry:maint-2
PiperOrigin-RevId: 659644462
2024-08-05 12:30:51 -07:00
jax authors
af1a69edfd Merge pull request #22870 from google:dependabot/github_actions/actions/upload-artifact-4.3.5
PiperOrigin-RevId: 659604731
2024-08-05 10:47:06 -07:00
dependabot[bot]
6d7cf3fcfd
Bump actions/upload-artifact from 4.3.3 to 4.3.5
Bumps [actions/upload-artifact](https://github.com/actions/upload-artifact) from 4.3.3 to 4.3.5.
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](65462800fd...89ef406dd8)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-08-05 17:09:11 +00:00
jax authors
5f99c557db Merge pull request #22844 from jakevdp:xla-extension
PiperOrigin-RevId: 659586676
2024-08-05 09:55:24 -07:00
rajasekharporeddy
1acff9c739 Better docs for jnp.fft.hfft and jnp.fft.ihfft 2024-08-05 21:53:29 +05:30
Sergei Lebedev
0a48aca965 Added a custom JVP rule for jax.nn.logsumexp
Fixes #22398 where the Jacobian of jax.nn.logsumexp was wrong if b= contained
exact zeros.
2024-08-05 17:05:03 +01:00
jax authors
9762ac53c8 Move CostEstimate from pltu to pl
* Move CostEstimate from TPU-specific `compiler_params` to a platform-independent argument of `pallas_call`.
Passing a CostEstimate in `compiler_params` is now deprecated and will be removed in 3 months time.
* Update the CostEstimate when batching a kernel by scaling it by the size of the batch axis.

PiperOrigin-RevId: 659560330
2024-08-05 08:18:01 -07:00
Philippe Hamel
ecf9f64240 Nit. Fix missing backticks documentation for jnp.where.
PiperOrigin-RevId: 659549362
2024-08-05 07:37:41 -07:00
jax authors
44a8c98912 Merge pull request #22141 from dfm:update-cuda-call-example-to-ffi-call
PiperOrigin-RevId: 659542133
2024-08-05 07:09:12 -07:00
jax authors
58927470dd Merge pull request #22849 from jakevdp:dlpack-doc
PiperOrigin-RevId: 659541072
2024-08-05 07:04:42 -07:00
George Necula
252032a368 [pallas] Improve error and debugging messages with source locations
Document the `name` argument to `pallas_call` and supplement it with source location information for the kernel function.
Pass all this as the `name_and_src_info` parameter to the `pallas_call_p` primitive.

Added some more information to the `if debug` prints.

Set the MLIR module names so that the debug dumps are named properly.

I changed `import pallas.core as pl_core` to `... as pallas_core` for consistency, in a couple of modules.

PiperOrigin-RevId: 659506675
2024-08-05 04:23:55 -07:00
Paweł Paruzel
b2a469b361 Port Eigenvalue Decompositions to XLA's FFI
This CL only contains the C++ changes. Python lowering code will be added after the forward compatibility window of 3 weeks.

PiperOrigin-RevId: 659492696
2024-08-05 03:18:13 -07:00
George Necula
9b35b760ce [pallas] Enable check for GPU lowering that tensor sizes are power of 2
Triton has a restriction that all operations have arguments and results
that are tensor whose size is a power of 2. Added a lowering check
for this. Without this, when we violate the condition we get an
unfriendly crash.

PiperOrigin-RevId: 659483450
2024-08-05 02:34:21 -07:00
jax authors
0b87bf48f9 Update XLA dependency to use revision
b94c84fa54.

PiperOrigin-RevId: 659354942
2024-08-04 14:53:18 -07:00
John Ryan
56ff247c2e Reverts 80560663d3fab4c0c3f87d7c8e52fb9931526dbb
PiperOrigin-RevId: 659334027
2024-08-04 12:11:30 -07:00
jax authors
83b5c7a0dd Merge pull request #22857 from mattjj:improve-while-loop-error
PiperOrigin-RevId: 659160966
2024-08-03 15:32:31 -07:00
jax authors
06c6a73236 Update XLA dependency to use revision
83ef35ce9e.

PiperOrigin-RevId: 659159129
2024-08-03 15:20:24 -07:00
Matthew Johnson
bdcd358b65 improve while_loop carry pytree/type mismatch errors
Now we call into the same error utility as we use in scan.
2024-08-03 21:57:29 +00:00
Jake VanderPlas
4d637c8feb Improve documentation for jnp.from_dlpack 2024-08-03 05:48:48 -07:00
Jake VanderPlas
521c94c6c6 Tighten the public API for jax.lib.xla_client & xla_extension 2024-08-03 05:26:22 -07:00
Yue Sheng
09beb33226 Don't call api.clean_up when there is no default backend.
PiperOrigin-RevId: 658936536
2024-08-02 16:14:29 -07:00