70 Commits

Author SHA1 Message Date
Peter Hawkins
66293d8897 Remove code present to support jaxlib < 0.5.1.
The new minimum xla_extension_version is 317 and the new mlir_api_version is 58.
2025-02-26 07:40:40 -05:00
jax authors
d3850e7fdd Support optimization_level and memory_fitting_level XLA compilation options.
PiperOrigin-RevId: 727070422
2025-02-14 14:46:11 -08:00
Olli Lupton
1bba1ea2e2 Add JAX_COMPILATION_CACHE_EXPECT_PGLE option
This allows using external profiling tools, such as Nsight Systems,
with the automatic PGLE workflow supported by JAX with a simple two-step
workflow:

export JAX_COMPILATION_CACHE_DIR=...
JAX_ENABLE_PGLE=yes python model.py
JAX_COMPILATION_CACHE_EXPECT_PGLE=yes nsys profile python model.py
2025-02-06 08:19:45 +00:00
Peter Hawkins
efab6945ca Remove code that supported jaxlib < 0.5.
The new xla_extension_version is 303 and the new mlir_api_version is 57.
2025-01-17 14:22:27 -05:00
jax authors
4023810565 [AutoPGLE] FIx PGLE kokoro test failures.
PiperOrigin-RevId: 712930537
2025-01-07 08:59:59 -08:00
jax authors
b6aead6f3a [AutoPGLE] Explicitly disable command buffers when profiler is used.
PiperOrigin-RevId: 709475833
2024-12-24 21:31:05 -08:00
jax authors
a123d4e39e Remove autotune sharing.
xla_gpu_shard_autotuning can be used now instead and it is enabled by default.

PiperOrigin-RevId: 705792463
2024-12-13 01:22:27 -08:00
jax authors
e6dfe8f380 [AutoPGLE] Share FDO profile even when compilation cache disabled.
PiperOrigin-RevId: 704757991
2024-12-10 10:23:42 -08:00
jax authors
8813973d96 [AutoPGLE] Cleanup compiler code.
PiperOrigin-RevId: 704741308
2024-12-10 09:37:35 -08:00
Peter Hawkins
79318a08cf Remove dead code after minimum jaxlib version bump to v0.4.36.
New minimum xla_extension_version is 299, and the new mlir_api_version is 57.

PiperOrigin-RevId: 704280856
2024-12-09 07:35:05 -08:00
Jed Borovik
c65ce4b093
Merge branch 'main' into add-optimization-effort-flags 2024-11-27 14:08:10 -05:00
Jed Borovik
83b54d97e7 Add version check for effort flags 2024-11-27 13:54:33 -05:00
labs-code-app[bot]
762301fc5d Add exec_time_optimization_effort and memory_fitting_effort flags.
These flags control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. They can be set via the command line, e.g. . Valid values are between -1.0 and 1.0, default is 0.0.
2024-11-26 13:57:47 +00:00
jax authors
231967fdb5 [AutoPGLE] Explicitly ignore host callback pointers
Before this change users had to specify remove_custom_partitioning_ptr_from_cache_key config flag when using AutoPGLE.

PiperOrigin-RevId: 700289965
2024-11-26 04:06:15 -08:00
Trevor Morris
a79d307ac7 When caching is enabled, also enable XLA caching features as well
Add unit test

Fix typechecker

Set caching mode depending on process id
2024-11-13 10:30:04 -08:00
Matthew Johnson
0f3ba4250d support exec_time_optimization_effort and memory_fitting_effort xla compilation
options

PiperOrigin-RevId: 692322944
2024-11-01 16:25:50 -07:00
jax authors
96d5542aae Support single-process AutoPGLE usage.
PiperOrigin-RevId: 686819261
2024-10-17 01:43:58 -07:00
Dan Foreman-Mackey
19313f4c0f Fix lint at HEAD. 2024-10-02 16:18:03 -04:00
Keshav
caf57495cf use bool_state instead of bool_flag 2024-09-18 14:53:45 -07:00
Keshav
efe69200e5 add config flag to enable/disable remat HLO pass 2024-09-18 14:07:47 -07:00
Peter Hawkins
940860625e Remove code that existed to support jaxlib < 0.4.32.
New minimum versions:
* jaxlib 0.4.32
* xla_extension_version 283
* mlir_api_version 57

PiperOrigin-RevId: 675291231
2024-09-16 14:30:00 -07:00
Justin Fu
aa66fb37c3 [Pallas][XLA:Mosaic] Add python stack traces to Mosaic errors that occur in Pallas.
PiperOrigin-RevId: 662232859
2024-08-12 14:42:48 -07:00
Bixia Zheng
c81f5cd2fc [xla] Replace debug option xla_use_shardy with execution option
use_shardy_partitioner.

Replace the use of xla_use_shardy with use_shardy_partitioner and remove
xla_use_shardy.

PiperOrigin-RevId: 657359119
2024-07-29 16:11:36 -07:00
jax authors
694c14bbe6 Merge pull request #22556 from cool-RR:log-cache-key
PiperOrigin-RevId: 656364840
2024-07-26 05:32:11 -07:00
Bart Chrzaszcz
864178d3a3 #sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.

Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations

The following test:

```py
def test_sdy_lowering(self):
  mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
  np_inp = np.arange(16).reshape(8, 2)
  s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
  arr = jax.device_put(np_inp, s)

  @partial(jax.jit, out_shardings=s)
  def f(x):
    return x * 2

  print(f.lower(arr).as_text())
```

outputs:

```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
  sdy.mesh @mesh = <"x"=4, "y"=2>
  func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
    %c = stablehlo.constant dense<2> : tensor<i64>
    %0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
    %1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
    return %1 : tensor<8x2xi64>
  }
}
```

Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.

PiperOrigin-RevId: 655127611
2024-07-23 05:32:06 -07:00
jax authors
44241eeab1 Add a beacon metric to report tasks disabled Jax compilation cache.
- Create metric '/jax/compilation_cache/task_disabled_cache' as a beacon metric to monitor tasks which have disabled compilation cache.
- Modified existing logic for reporting the '/jax/compilation_cache/tasks_using_cache' metric and make it easier to find the two adoption related metrics in the code.

PiperOrigin-RevId: 654970654
2024-07-22 18:44:18 -07:00
Ram Rachum
00bd6ddf95 Show cache_key when logging compilation cache hits/misses 2024-07-22 11:56:02 +03:00
rdyro
c6d6207170 Unifying persistent cache messages
and moving them to WARNING logging when explain_cache_misses is true.
2024-07-16 00:47:53 +00:00
Ayaka
6c05aa2f32 Clean up 2024-07-04 17:16:32 +04:00
Peter Hawkins
7f4ef63cd8 Run pyupgrade --py310-plus.
Also apply manual fixes to import sorting and unused imports.
2024-06-26 16:10:18 -04:00
Sergei Lebedev
ce0d9e9b9f Changed the naming of internal config APIs
The new naming highlights that we have two kinds of configuration options:
flags, set at most once, and states, which can be changed locally per thread
via a context manager.

The renames are

* FlagHolder -> Flag
* DEFINE_<type> -> <type>_flag
* _StateContextManager -> State
* define_<type>_state -> <type>_state
2024-06-18 11:48:57 +01:00
Parker Schuh
eba0564b70 Allow disabling compilation cache for particular runtime_types.
PiperOrigin-RevId: 640264856
2024-06-04 13:30:42 -07:00
jax authors
26f9820417 [JAX] Automatically share PGO data for GPU latency-hiding scheduler.
Overall the idea is to collect profile data for each module given amount of times (which can be configured) then recompile the module with the aggregated profile data.

1. We need to track how many times each module were profiled and collect profiling results. For this i added a ProfileSessionRunner class at profile.py. The class can track how many times an instance of it was called to profile a session and also can aggregate profile results.

2. We need associate profiling session to the module at the interpreter. To do this i added a dictionary to pjit.py which associates Jaxpr with profile session runner.

3. The profile session runner should be passed to pxla.py and then called.

4. We need to correctly deal with fast path at the interpreter level, so JAX won't use HLO directly if PGLE need to be collected, but also JAX will not recompiled the module only for PGLE. See changes in pjit.py and in lru_cache.h

5. Once FDO is collected we need to share it between hosts to keep deterministic compilation.

PiperOrigin-RevId: 638197166
2024-05-29 01:50:03 -07:00
jax authors
e0bf783c3e [JAX] Fix multi-process data sharing comments.
PiperOrigin-RevId: 636817251
2024-05-24 00:22:28 -07:00
Sergei Lebedev
c3bc88d5e4 Bumped mypy to 1.10.0 and ruff to 0.4.4 2024-05-16 23:16:32 +01:00
jax authors
95c4ba961c [JAX] Use first process id instead of process 0 to share multi-host data.
PiperOrigin-RevId: 633526220
2024-05-14 03:45:41 -07:00
Yash Katariya
395d3cb79e Bump minimum jaxlib version to 0.4.27
xla_extension_version is 261 and mlir_api_version is 56

PiperOrigin-RevId: 631579739
2024-05-07 16:07:59 -07:00
jax authors
7681493760 Don't create temp directory when module is getting imported.
PiperOrigin-RevId: 630958402
2024-05-06 00:58:45 -07:00
jax authors
88dd29a0b5 Re-enable persistent cache on cpu.
CPU cache key now includes machine attributes, so there should no longer
be a problem with incompatible CPUs accessing the same cache entry.

PiperOrigin-RevId: 621341638
2024-04-02 17:30:52 -07:00
jax authors
98ad6ef057 Add debug logging for autotune profile sharing.
PiperOrigin-RevId: 614915339
2024-03-11 22:46:56 -07:00
jax authors
3708336f8f To avoid the inconsistency between process_index and process_id, replace backend.process_index with distributed.global_state.process_id in Jax compilation _cache_write function.
Testing: new unit test.
PiperOrigin-RevId: 607385112
2024-02-15 10:48:05 -08:00
Junwhan Ahn
b9c7296754 Add an interface for IFRT backends to specify executable serializability and use this information to automatically disable JAX compilation cache for backends that do not implement executable serialization
This CL adds `supports_executable_serialization` as an IFRT client attribute indicating whether the executables produced by a given IFRT implementation are serializable or not. This is based on a principle where `xla::ifrt::Client::attributes()` returns a set of attributes representing the "capabilities" of an IFRT implementation, so that the users of IFRT can act based on such capabilities without having to know the exact backend that they are using.

This change is backward compatible as IFRT backends that do not implement `supports_executable_serialization` are assumed to implement executable serialization.

PiperOrigin-RevId: 606799188
2024-02-13 17:27:58 -08:00
jax authors
9a098e922a Share autotune config between hosts.
PiperOrigin-RevId: 604569298
2024-02-06 01:28:18 -08:00
Eugene Zhulenev
28ef77dfb0 Disable JAX compilation cache for XLA:CPU
PiperOrigin-RevId: 603551262
2024-02-01 19:20:39 -08:00
jax authors
34d22fc498 Exclude modules with host callbacks from inter-host sharing.
PiperOrigin-RevId: 600697796
2024-01-23 00:38:13 -08:00
jax authors
b8b119d9b9 Cleanup deprecated compilation cache APIs.
Since the compilation cache is now initialized lazily,
existing APIs initialize_cache() and is_initialized()
are confusing. Deprecate these APIs.

Introduce a new API set_cache_dir() to explicitly set the
cache directory path in code.

Testing: revised unit tests, test workload.
PiperOrigin-RevId: 598073423
2024-01-12 22:44:48 -08:00
jax authors
adbbe69cc2 Add option to share compiled module between hosts.
PiperOrigin-RevId: 597754861
2024-01-11 23:38:02 -08:00
Eugene Zhulenev
ba4c2b1c7d [pjrt:cpu] Add CpuTopology to TfrtCpuClient and enable persistent compilation cache for cpu backend
PiperOrigin-RevId: 597327136
2024-01-10 12:40:57 -08:00
jax authors
da96633f11 Correct the cache miss metric instrumentation due to the new min cache entry size flag
Since introduction of the min cache entry size check for compilation cache, the cache miss metric overcounts the skipped caches whose sizes are smaller than the min cache entry size. After moving the metric instrumentation to compilation_cache.put_executable_and_time, the cache miss metric will be incremented if both compile time and cache entry size are greater than the minimum thresholds.

PiperOrigin-RevId: 596696013
2024-01-08 14:03:33 -08:00
jax authors
c28aa2cecc Remove CPU support from compilation cache.
CPU support was originally added to the compilation cache
in anticipation of the availability of CPU acceleration
compilation. Since this is not available and the
--xla_cpu_use_xla_runtime flag has been deprecated,
cleanup the code and test.

Testing: test workload, revised unit test.
PiperOrigin-RevId: 592962316
2023-12-21 15:25:56 -08:00