2023-08-15 06:38:56 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
# Interface to the compiler
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from collections.abc import Sequence
|
2025-03-05 15:32:25 +00:00
|
|
|
import copy
|
|
|
|
from functools import partial
|
2024-02-06 01:27:21 -08:00
|
|
|
import logging
|
2023-08-15 06:38:56 -07:00
|
|
|
import time
|
2024-08-12 14:41:58 -07:00
|
|
|
from typing import Any, Callable
|
2023-08-15 06:38:56 -07:00
|
|
|
import warnings
|
|
|
|
|
2024-11-26 04:05:35 -08:00
|
|
|
from jax._src import cache_key as cache_key_type
|
2023-08-15 06:38:56 -07:00
|
|
|
from jax._src import compilation_cache
|
2023-10-09 07:28:18 -07:00
|
|
|
from jax._src import config as config
|
2024-02-06 01:27:21 -08:00
|
|
|
from jax._src import distributed
|
|
|
|
from jax._src import lib
|
2023-08-15 06:38:56 -07:00
|
|
|
from jax._src import monitoring
|
2024-07-29 16:13:01 -07:00
|
|
|
from jax._src import path as pathlib
|
2023-08-15 06:38:56 -07:00
|
|
|
from jax._src import profiler
|
|
|
|
from jax._src import traceback_util
|
2023-12-18 21:24:59 -08:00
|
|
|
from jax._src.interpreters import mlir
|
2023-08-15 06:38:56 -07:00
|
|
|
from jax._src.lib import xla_client as xc
|
2024-02-06 01:27:21 -08:00
|
|
|
from jax._src.lib.mlir import ir
|
|
|
|
import numpy as np
|
2023-08-15 06:38:56 -07:00
|
|
|
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
_DISABLE_MOST_OPTIMIZATIONS = config.bool_flag(
|
2023-08-15 06:38:56 -07:00
|
|
|
'jax_disable_most_optimizations',
|
2023-10-09 07:28:18 -07:00
|
|
|
config.bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
|
2023-08-15 06:38:56 -07:00
|
|
|
'Try not to do much optimization work. This can be useful if the cost of '
|
|
|
|
'optimization is greater than that of running a less-optimized program.')
|
|
|
|
|
2024-04-15 10:35:50 +01:00
|
|
|
_COMPILER_DETAILED_LOGGING_MIN_OPS = config.int_flag(
|
2023-09-13 13:44:21 -07:00
|
|
|
"jax_compiler_detailed_logging_min_ops",
|
2023-10-09 07:28:18 -07:00
|
|
|
config.int_env("JAX_COMPILER_DETAILED_LOGGING_MIN_OPS", 10),
|
2023-09-13 13:44:21 -07:00
|
|
|
help=(
|
|
|
|
'How big should a module be in MLIR operations before JAX enables '
|
|
|
|
'detailed compiler logging? The intent of this flag is to suppress '
|
|
|
|
'detailed logging for small/uninteresting computations.'
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2023-11-20 15:51:27 -08:00
|
|
|
# The special XLA-AutoFDO profile version that indicates that a profile is not
|
|
|
|
# available and retrieval should not be attempted.
|
|
|
|
_NO_PROFILE_DONT_RETRIEVE = -1
|
2023-08-15 06:38:56 -07:00
|
|
|
|
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
|
|
|
|
CompileOptions = xc.CompileOptions
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
# Will be monkeypatched with the function that gets the XLA-AutoFDO profile
|
|
|
|
# version. The default (-1) takes care of errors.
|
2023-11-20 15:51:27 -08:00
|
|
|
# TODO(b/289098047): consider refactoring this interface.
|
|
|
|
def get_latest_profile_version(backend: xc.Client) -> int:
|
|
|
|
del backend
|
2023-08-15 06:38:56 -07:00
|
|
|
return -1
|
|
|
|
|
|
|
|
|
2023-09-13 13:44:21 -07:00
|
|
|
def _walk_operations(op, k):
|
|
|
|
k -= 1
|
|
|
|
if k < 0:
|
|
|
|
return k
|
|
|
|
for region in op.regions:
|
|
|
|
for block in region:
|
|
|
|
for child_op in block:
|
|
|
|
k = _walk_operations(child_op, k)
|
|
|
|
if k < 0:
|
|
|
|
return k
|
|
|
|
return k
|
|
|
|
|
|
|
|
|
|
|
|
def use_detailed_logging(module: ir.Module) -> bool:
|
|
|
|
"""Returns 'true' if detailed logging should be enabled for 'module'."""
|
|
|
|
bound = _COMPILER_DETAILED_LOGGING_MIN_OPS.value
|
|
|
|
return _walk_operations(module.operation, bound) < 0
|
|
|
|
|
|
|
|
|
2024-07-21 11:49:12 +03:00
|
|
|
def log_persistent_cache_hit(module_name: str, cache_key: str) -> None:
|
2024-07-11 01:11:18 +00:00
|
|
|
hit_log_priority = (logging.WARNING if config.log_compiles.value
|
|
|
|
else logging.DEBUG)
|
2024-07-21 11:49:12 +03:00
|
|
|
logger.log(hit_log_priority, "Persistent compilation cache hit for '%s' with key %r",
|
|
|
|
module_name, cache_key)
|
2024-07-11 01:11:18 +00:00
|
|
|
|
|
|
|
|
2024-07-21 11:49:12 +03:00
|
|
|
def log_persistent_cache_miss(module_name: str, cache_key: str) -> None:
|
2024-07-11 01:11:18 +00:00
|
|
|
miss_log_priority = (logging.WARNING
|
|
|
|
if config.explain_cache_misses.value
|
|
|
|
and compilation_cache.is_persistent_cache_enabled()
|
|
|
|
else logging.DEBUG)
|
|
|
|
# all caps to match the tracing cache "TRACING CACHE MISS"
|
2024-07-21 11:49:12 +03:00
|
|
|
logger.log(miss_log_priority, "PERSISTENT COMPILATION CACHE MISS for '%s' with key %r",
|
|
|
|
module_name, cache_key)
|
2024-07-11 01:11:18 +00:00
|
|
|
|
|
|
|
|
2023-08-15 06:38:56 -07:00
|
|
|
def get_compile_options(
|
|
|
|
num_replicas: int,
|
|
|
|
num_partitions: int,
|
|
|
|
device_assignment=None,
|
|
|
|
use_spmd_partitioning: bool = True,
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
use_shardy_partitioner: bool = False,
|
2023-08-15 06:38:56 -07:00
|
|
|
use_auto_spmd_partitioning: bool = False,
|
|
|
|
auto_spmd_partitioning_mesh_shape: list[int] | None = None,
|
|
|
|
auto_spmd_partitioning_mesh_ids: list[int] | None = None,
|
|
|
|
env_options_overrides: dict[str, str] | None = None,
|
|
|
|
fdo_profile: bytes | None = None,
|
2023-09-13 13:44:21 -07:00
|
|
|
detailed_logging: bool = True,
|
2023-11-20 15:51:27 -08:00
|
|
|
backend: xc.Client | None = None,
|
2023-08-15 06:38:56 -07:00
|
|
|
) -> xc.CompileOptions:
|
|
|
|
"""Returns the compile options to use, as derived from flag values.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
num_replicas: Number of replicas for which to compile.
|
|
|
|
num_partitions: Number of partitions for which to compile.
|
|
|
|
device_assignment: Optional ndarray of jax devices indicating the assignment
|
|
|
|
of logical replicas to physical devices (default inherited from
|
|
|
|
xla_client.CompileOptions). Must be consistent with `num_replicas` and
|
|
|
|
`num_partitions`.
|
|
|
|
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
|
|
|
|
partitioning in XLA.
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
use_shardy_partitioner: boolean indicating whether to use the Shardy
|
|
|
|
partitioner in XLA. Shardy is a new open sourced propagation framework for
|
|
|
|
MLIR. Currently Shardy is experimental in JAX. See
|
|
|
|
www.github.com/openxla/shardy.
|
2023-08-15 06:38:56 -07:00
|
|
|
use_auto_spmd_partitioning: boolean indicating whether to automatically
|
|
|
|
generate XLA shardings for SPMD partitioner.
|
|
|
|
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
|
|
|
|
auto_spmd_partitioning search space.
|
|
|
|
auto_spmd_partitioning_mesh_ids: device ids used to create
|
|
|
|
auto_spmd_partitioning search space.
|
|
|
|
env_options_overrides: dict of additional options parsed by the compiler
|
|
|
|
fdo_profile: Optional profile for feedback-directed optimization passed to
|
2023-09-13 13:44:21 -07:00
|
|
|
XLA.
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
detailed_logging: Is this an "interesting" computation about which XLA would
|
|
|
|
be wise to log compilation information?
|
2023-11-20 15:51:27 -08:00
|
|
|
backend: the client, if available.
|
2023-08-15 06:38:56 -07:00
|
|
|
"""
|
|
|
|
compile_options = xc.CompileOptions()
|
|
|
|
compile_options.num_replicas = num_replicas
|
|
|
|
compile_options.num_partitions = num_partitions
|
|
|
|
build_options = compile_options.executable_build_options
|
|
|
|
build_options.use_spmd_partitioning = use_spmd_partitioning
|
|
|
|
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
|
2024-09-16 14:29:21 -07:00
|
|
|
build_options.use_shardy_partitioner = use_shardy_partitioner
|
2023-08-15 06:38:56 -07:00
|
|
|
if fdo_profile is not None:
|
|
|
|
build_options.fdo_profile = fdo_profile
|
|
|
|
if use_auto_spmd_partitioning:
|
|
|
|
build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape or []
|
|
|
|
build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids or []
|
|
|
|
if device_assignment is not None:
|
|
|
|
logger.debug(
|
|
|
|
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
|
|
|
|
num_replicas, num_partitions, device_assignment)
|
|
|
|
device_assignment = np.array(device_assignment)
|
|
|
|
|
|
|
|
# Allow 1D device assignment if num_partitions is 1.
|
|
|
|
if (device_assignment.ndim == 1) and (num_partitions == 1):
|
|
|
|
device_assignment = device_assignment[:, None]
|
|
|
|
|
|
|
|
if num_replicas != device_assignment.shape[0]:
|
|
|
|
msg = 'device_assignment does not match num_replicas: {} vs {}.'
|
|
|
|
raise ValueError(msg.format(device_assignment, num_replicas))
|
|
|
|
|
|
|
|
if num_partitions != device_assignment.shape[1]:
|
|
|
|
msg = 'device_assignment does not match num_partitions: {} vs {}.'
|
|
|
|
raise ValueError(msg.format(device_assignment, num_partitions))
|
|
|
|
|
|
|
|
if device_assignment.dtype == object:
|
|
|
|
device_assignment = np.vectorize(lambda d: d.id, otypes=[int])(
|
|
|
|
device_assignment)
|
|
|
|
device_assignment = xc.DeviceAssignment.create(device_assignment)
|
|
|
|
assert device_assignment.replica_count() == num_replicas
|
|
|
|
assert device_assignment.computation_count() == num_partitions
|
|
|
|
compile_options.device_assignment = device_assignment
|
|
|
|
|
2024-12-09 07:34:26 -08:00
|
|
|
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
|
|
|
|
build_options.memory_fitting_effort = config.memory_fitting_effort.value
|
2025-02-24 17:45:19 -05:00
|
|
|
build_options.optimization_level = config.EffortLevel(
|
|
|
|
config.optimization_level.value
|
|
|
|
).value
|
|
|
|
build_options.memory_fitting_level = config.EffortLevel(
|
|
|
|
config.memory_fitting_level.value
|
|
|
|
).value
|
2024-11-27 13:54:33 -05:00
|
|
|
|
2023-08-15 06:38:56 -07:00
|
|
|
if env_options_overrides is not None:
|
2024-11-01 16:24:42 -07:00
|
|
|
# Some overrides are passed directly on build_options.
|
|
|
|
overrides_on_build_options = [
|
2025-02-14 14:45:25 -08:00
|
|
|
"exec_time_optimization_effort", "memory_fitting_effort"]
|
2025-02-24 17:45:19 -05:00
|
|
|
overrides_on_build_options.extend(
|
|
|
|
["optimization_level", "memory_fitting_level"]
|
|
|
|
)
|
2025-02-14 14:45:25 -08:00
|
|
|
|
2024-11-01 16:24:42 -07:00
|
|
|
env_options_overrides = dict(env_options_overrides)
|
|
|
|
for name in overrides_on_build_options:
|
|
|
|
if name in env_options_overrides:
|
|
|
|
setattr(build_options, name, env_options_overrides.pop(name))
|
2023-08-15 06:38:56 -07:00
|
|
|
compile_options.env_option_overrides = list(env_options_overrides.items())
|
|
|
|
|
|
|
|
debug_options = compile_options.executable_build_options.debug_options
|
|
|
|
if lib.cuda_path is not None:
|
|
|
|
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path
|
|
|
|
|
|
|
|
if _DISABLE_MOST_OPTIMIZATIONS.value:
|
|
|
|
debug_options.xla_backend_optimization_level = 0
|
|
|
|
debug_options.xla_llvm_disable_expensive_passes = True
|
|
|
|
debug_options.xla_test_all_input_layouts = False
|
2024-10-02 16:18:03 -04:00
|
|
|
|
2024-09-18 14:53:45 -07:00
|
|
|
if not config.enable_remat_opt_pass.value:
|
2024-09-18 14:07:47 -07:00
|
|
|
debug_options.xla_disable_hlo_passes = "rematerialization"
|
2023-08-15 06:38:56 -07:00
|
|
|
|
|
|
|
# XLA-AutoFDO profile version: precedence order is:
|
|
|
|
# 1. Whatever --jax_xla_profile_version is set to.
|
|
|
|
# 2. If --jax_xla_profile_version is not set (i.e., 0), call the function
|
|
|
|
# set in get_latest_profile_version and use the return value if non-zero.
|
|
|
|
# If the function returns 0, set -1; this is an error.
|
|
|
|
# -1 indicates that no attempt should be made to retrieve the latest profile
|
|
|
|
# later on.
|
2023-10-09 07:28:18 -07:00
|
|
|
jax_xla_profile_version = config.jax_xla_profile_version.value
|
2023-08-15 06:38:56 -07:00
|
|
|
if jax_xla_profile_version > 0:
|
|
|
|
compile_options.profile_version = jax_xla_profile_version
|
|
|
|
logger.debug("get_compile_options XLA-AutoFDO profile: " +
|
|
|
|
"using JAX XLA profile version %d from flag",
|
|
|
|
jax_xla_profile_version)
|
|
|
|
else:
|
2023-11-20 15:51:27 -08:00
|
|
|
compile_options.profile_version = _NO_PROFILE_DONT_RETRIEVE
|
|
|
|
if backend is None:
|
|
|
|
logging.info("get_compile_options: no backend supplied; "
|
|
|
|
"disabling XLA-AutoFDO profile")
|
2023-08-15 06:38:56 -07:00
|
|
|
else:
|
2023-11-20 15:51:27 -08:00
|
|
|
fdo_profile_version = get_latest_profile_version(backend)
|
|
|
|
if fdo_profile_version != 0:
|
|
|
|
compile_options.profile_version = fdo_profile_version
|
|
|
|
logger.debug("get_compile_options XLA-AutoFDO profile: " +
|
|
|
|
"using XLA-AutoFDO profile version %d",
|
|
|
|
fdo_profile_version)
|
|
|
|
else:
|
|
|
|
logger.error("get_compile_options XLA-AutoFDO profile: " +
|
|
|
|
"XLA-AutoFDO profile version is 0; this should not happen")
|
2023-08-15 06:38:56 -07:00
|
|
|
|
2023-11-17 09:37:45 -08:00
|
|
|
debug_options.xla_detailed_logging = detailed_logging
|
2023-10-03 06:59:07 -07:00
|
|
|
|
2024-07-29 16:13:01 -07:00
|
|
|
# If persistent cache is enabled, also enable additional XLA caching features.
|
2025-01-17 14:15:36 -05:00
|
|
|
if compilation_cache.is_persistent_cache_enabled():
|
2024-07-29 16:13:01 -07:00
|
|
|
# compilation_cache_dir can't be None here, but the type checker is a bit
|
|
|
|
# strict.
|
|
|
|
path = pathlib.Path(config.compilation_cache_dir.value or "")
|
|
|
|
enabled_flags = config.persistent_cache_enable_xla_caches.value or ""
|
|
|
|
|
|
|
|
if enabled_flags == "all" or "xla_gpu_kernel_cache_file" in enabled_flags:
|
|
|
|
kernel_cache_path = path / "xla_gpu_kernel_cache_file"
|
|
|
|
debug_options.xla_gpu_kernel_cache_file = str(kernel_cache_path)
|
|
|
|
# This option is required to use the kernel cache.
|
|
|
|
debug_options.xla_gpu_enable_llvm_module_compilation_parallelism = True
|
|
|
|
logger.debug("Enabling XLA kernel cache at '%s'", kernel_cache_path)
|
|
|
|
|
|
|
|
if enabled_flags == "all" or "xla_gpu_per_fusion_autotune_cache_dir" in enabled_flags:
|
|
|
|
autotune_cache_path = path / "xla_gpu_per_fusion_autotune_cache_dir"
|
|
|
|
debug_options.xla_gpu_per_fusion_autotune_cache_dir = str(autotune_cache_path)
|
|
|
|
logger.debug("Enabling XLA autotuning cache at '%s'", autotune_cache_path)
|
|
|
|
|
|
|
|
# Set caching mode so that only process 0 can write to the cache.
|
|
|
|
if distributed.global_state.process_id == 0:
|
|
|
|
debug_options.xla_gpu_experimental_autotune_cache_mode = xc.AutotuneCacheMode.UPDATE
|
|
|
|
else:
|
|
|
|
debug_options.xla_gpu_experimental_autotune_cache_mode = xc.AutotuneCacheMode.READ
|
|
|
|
|
2023-08-15 06:38:56 -07:00
|
|
|
return compile_options
|
|
|
|
|
#sdy Initial set of changes to allow for lowering to the Shardy dialect.
The OpenXLA project is working on an open source, MLIR, named-axis based propagation (and in the future SP<D partitioning) system that will be dialect agnostic (would work for any dialect - MHLO, StableHLO, YourDialect). We plan on having frontends like JAX and PyTorch target this when using XLA and wanting SPMD propagation/partitioning. See www.github.com/openxla/shardy for more info.
Currently Shardy is implemented inside the XLA compiler, requiring us to round-trip between StableHLO and HLO with `mhlo.sharding`s. But we will eventually make Shardy the first pass in the XLA pipeline while it's still working on StableHLO. Partitioning (the system that adds the collectives like all-gathers/all-reduces) will still be the GSPMD Partitioner, but next year the Shardy partitioner will be developed, allowing for propagation and partitioning to be completely in MLIR and the first pass in the pipeline. So then we'd have:
1. Traced jaxpr
2. Jaxpr -> StableHLO
3. StableHLO with Shardy propagation
4. StableHLO with Shardy partitioning
5. StableHLO -> HLO
6. XLA optimizations
The following test:
```py
def test_sdy_lowering(self):
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
s = jax.sharding.NamedSharding(mesh, P('x', 'y'))
arr = jax.device_put(np_inp, s)
@partial(jax.jit, out_shardings=s)
def f(x):
return x * 2
print(f.lower(arr).as_text())
```
outputs:
```
module @jit_f attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} {
sdy.mesh @mesh = <"x"=4, "y"=2>
func.func public @main(%arg0: tensor<8x2xi64> {mhlo.layout_mode = "{1,0}", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) -> (tensor<8x2xi64> {jax.result_info = "", mhlo.layout_mode = "default", sdy.sharding = #sdy.sharding<@mesh, [{"x"}, {"y"}]>}) {
%c = stablehlo.constant dense<2> : tensor<i64>
%0 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor<i64>) -> tensor<8x2xi64>
%1 = stablehlo.multiply %arg0, %0 : tensor<8x2xi64>
return %1 : tensor<8x2xi64>
}
}
```
Shardy will be hidden behind the `jax_use_shardy_partitioner` flag initially before becoming enabled by default in the future.
PiperOrigin-RevId: 655127611
2024-07-23 05:31:15 -07:00
|
|
|
|
2023-08-15 06:38:56 -07:00
|
|
|
@profiler.annotate_function
|
|
|
|
def backend_compile(
|
|
|
|
backend: xc.Client,
|
|
|
|
module: ir.Module,
|
|
|
|
options: xc.CompileOptions,
|
|
|
|
host_callbacks: Sequence[Any],
|
|
|
|
) -> xc.LoadedExecutable:
|
2025-03-05 15:32:25 +00:00
|
|
|
sym_name = module.operation.attributes['sym_name']
|
|
|
|
module_name = ir.StringAttr(sym_name).value
|
2024-05-16 15:10:01 +01:00
|
|
|
# Convert ir.Module to a string representation, unless the backend
|
|
|
|
# explicitly flags the ability to handle a module directly (avoiding the
|
|
|
|
# overhead of back and forth conversions).
|
|
|
|
# TODO(slebedev): Change the backend.compile() to accept ir.Module.
|
|
|
|
built_c: Any
|
2023-08-15 06:38:56 -07:00
|
|
|
if getattr(backend, "needs_str_ir", True):
|
2023-12-18 21:24:59 -08:00
|
|
|
built_c = mlir.module_to_bytecode(module)
|
2023-08-15 06:38:56 -07:00
|
|
|
else:
|
|
|
|
built_c = module
|
|
|
|
|
2025-03-05 15:32:25 +00:00
|
|
|
if (options.executable_build_options.fdo_profile is not None
|
|
|
|
and len(options.executable_build_options.fdo_profile)):
|
|
|
|
logger.debug(
|
|
|
|
"Compiling module %s with FDO profile of length %d",
|
|
|
|
module_name,
|
|
|
|
len(options.executable_build_options.fdo_profile),
|
|
|
|
)
|
|
|
|
|
2024-08-12 14:41:58 -07:00
|
|
|
try:
|
|
|
|
# we use a separate function call to ensure that XLA compilation appears
|
|
|
|
# separately in Python profiling results
|
|
|
|
if host_callbacks:
|
|
|
|
return backend.compile(
|
|
|
|
built_c, compile_options=options, host_callbacks=host_callbacks
|
|
|
|
)
|
|
|
|
# Some backends don't have `host_callbacks` option yet
|
|
|
|
# TODO(sharadmv): remove this fallback when all backends allow `compile`
|
|
|
|
# to take in `host_callbacks`
|
|
|
|
return backend.compile(built_c, compile_options=options)
|
|
|
|
except xc.XlaRuntimeError as e:
|
|
|
|
for error_handler in _XLA_RUNTIME_ERROR_HANDLERS:
|
|
|
|
handler_result = error_handler(e)
|
|
|
|
if handler_result is not None:
|
|
|
|
raise handler_result from e
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
|
|
_XLA_RUNTIME_ERROR_HANDLERS = []
|
|
|
|
|
|
|
|
|
|
|
|
def register_xla_runtime_error_handler(
|
|
|
|
handler_fn: Callable[[xc.XlaRuntimeError], Exception | None],
|
|
|
|
):
|
|
|
|
"""Registers a custom exception handler for XLA runtime errors.
|
|
|
|
|
|
|
|
Registering a custom handler allows re-raising a more informative exception
|
|
|
|
after encountering an XLARuntimeError.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
handler_fn: A function which returns a new exception to replace the original
|
|
|
|
XLA runtime error, or None if the original error should be propagated.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A new exception or None.
|
|
|
|
"""
|
|
|
|
_XLA_RUNTIME_ERROR_HANDLERS.append(handler_fn)
|
|
|
|
|
2023-08-15 06:38:56 -07:00
|
|
|
|
|
|
|
def compile_or_get_cached(
|
|
|
|
backend: xc.Client,
|
|
|
|
computation: ir.Module,
|
|
|
|
devices: np.ndarray,
|
|
|
|
compile_options: xc.CompileOptions,
|
|
|
|
host_callbacks: Sequence[Any],
|
2024-05-29 01:49:06 -07:00
|
|
|
pgle_profiler: profiler.PGLEProfiler | None = None,
|
2023-08-15 06:38:56 -07:00
|
|
|
) -> xc.LoadedExecutable:
|
|
|
|
sym_name = computation.operation.attributes['sym_name']
|
|
|
|
module_name = ir.StringAttr(sym_name).value
|
|
|
|
|
2023-12-18 21:24:59 -08:00
|
|
|
if dumped_to := mlir.dump_module_to_file(computation, "compile"):
|
|
|
|
logging.info("Dumped the module to %s.", dumped_to)
|
2023-08-15 06:38:56 -07:00
|
|
|
|
2024-12-10 10:23:07 -08:00
|
|
|
is_multi_process = (
|
|
|
|
len({device.process_index for device in devices.flatten()}) > 1
|
|
|
|
)
|
|
|
|
min_device_process_id = min(
|
|
|
|
devices.flatten(), key=lambda device: device.id
|
|
|
|
).process_index
|
|
|
|
|
2025-03-05 15:32:25 +00:00
|
|
|
# cache_key: may be None if compilation caching is disabled
|
|
|
|
cache_key, compile_options = _resolve_compilation_strategy(
|
|
|
|
computation,
|
|
|
|
devices,
|
|
|
|
compile_options,
|
|
|
|
backend,
|
|
|
|
pgle_profiler,
|
|
|
|
is_multi_process,
|
|
|
|
module_name,
|
|
|
|
min_device_process_id,
|
|
|
|
)
|
2024-12-10 10:23:07 -08:00
|
|
|
|
2025-03-05 15:32:25 +00:00
|
|
|
if cache_key is None:
|
2023-08-15 06:38:56 -07:00
|
|
|
return backend_compile(backend, computation, compile_options,
|
|
|
|
host_callbacks)
|
|
|
|
|
2023-08-31 15:04:43 -07:00
|
|
|
monitoring.record_event('/jax/compilation_cache/compile_requests_use_cache')
|
|
|
|
|
2023-08-15 06:38:56 -07:00
|
|
|
cache_retrieval_start = time.monotonic()
|
|
|
|
retrieved_executable, retrieved_compile_time = _cache_read(
|
|
|
|
module_name, cache_key, compile_options, backend)
|
|
|
|
cache_retrieval_time = time.monotonic() - cache_retrieval_start
|
|
|
|
|
|
|
|
if retrieved_executable is not None:
|
|
|
|
assert retrieved_compile_time is not None
|
2024-07-21 11:49:12 +03:00
|
|
|
log_persistent_cache_hit(module_name, cache_key)
|
2023-08-31 15:04:43 -07:00
|
|
|
|
2023-12-12 16:33:45 -08:00
|
|
|
monitoring.record_event('/jax/compilation_cache/cache_hits')
|
|
|
|
monitoring.record_event_duration_secs(
|
|
|
|
'/jax/compilation_cache/compile_time_saved_sec',
|
|
|
|
retrieved_compile_time - cache_retrieval_time)
|
2023-08-31 15:04:43 -07:00
|
|
|
|
2023-08-15 06:38:56 -07:00
|
|
|
monitoring.record_event_duration_secs(
|
|
|
|
"/jax/compilation_cache/cache_retrieval_time_sec", cache_retrieval_time)
|
2023-08-31 15:04:43 -07:00
|
|
|
|
2023-08-15 06:38:56 -07:00
|
|
|
return retrieved_executable
|
2024-01-11 23:37:22 -08:00
|
|
|
elif (
|
2024-05-14 03:45:05 -07:00
|
|
|
config.share_binary_between_hosts.value
|
|
|
|
and is_multi_process
|
2024-01-11 23:37:22 -08:00
|
|
|
and distributed.global_state.client is not None
|
2024-01-23 00:37:29 -08:00
|
|
|
# Host callbacks are currently baked into the HLO module so we cant share
|
|
|
|
# them.
|
|
|
|
and len(host_callbacks) == 0
|
2024-01-11 23:37:22 -08:00
|
|
|
):
|
2024-07-21 11:49:12 +03:00
|
|
|
log_persistent_cache_miss(module_name, cache_key)
|
2024-02-06 01:27:21 -08:00
|
|
|
return _compile_and_share_module(
|
|
|
|
backend,
|
|
|
|
computation,
|
|
|
|
compile_options,
|
|
|
|
host_callbacks,
|
|
|
|
distributed.global_state.client,
|
|
|
|
module_name,
|
|
|
|
cache_key,
|
2024-05-29 01:49:06 -07:00
|
|
|
min_device_process_id
|
2024-02-06 01:27:21 -08:00
|
|
|
)
|
2023-08-15 06:38:56 -07:00
|
|
|
else:
|
2024-07-21 11:49:12 +03:00
|
|
|
log_persistent_cache_miss(module_name, cache_key)
|
2024-02-06 01:27:21 -08:00
|
|
|
return _compile_and_write_cache(
|
|
|
|
backend,
|
|
|
|
computation,
|
|
|
|
compile_options,
|
|
|
|
host_callbacks,
|
|
|
|
module_name,
|
|
|
|
cache_key,
|
|
|
|
)
|
|
|
|
|
2024-12-10 09:36:44 -08:00
|
|
|
|
|
|
|
# When PGLE is enabled there might be 3 types of situations:
|
2024-11-14 10:38:53 +00:00
|
|
|
# 1. PGLE optimized module (the one which was recompiled with FDO profile) is
|
2024-12-10 09:36:44 -08:00
|
|
|
# in the persistent cache. In this case the module should be returned from
|
|
|
|
# cache and PGLE should be disabled for this module. Is module is stored in
|
2025-03-05 15:32:25 +00:00
|
|
|
# the persistent cache under the "pgle_optimized_cache_key", which is
|
|
|
|
# calculated by replacing the FDO profile with a sentinel value that identifies
|
|
|
|
# that the module was optimized with PGLE.
|
2024-12-10 09:36:44 -08:00
|
|
|
# 2. PGLE profiled module is not in the persistent cache and the module is
|
2025-03-05 15:32:25 +00:00
|
|
|
# getting built with an FDO profile. In this case we need to share the FDO
|
|
|
|
# profile with any other processes and store the result under the
|
|
|
|
# "pgle_optimized_cache_key" so later in case 1 we will be able to find the
|
2024-12-10 09:36:44 -08:00
|
|
|
# module.
|
|
|
|
# 3. PGLE profiled module is not in the persistent cache and the module is
|
|
|
|
# getting compiled to be PGLEd (FDO profile is empty). In this case we need to
|
2025-03-05 15:32:25 +00:00
|
|
|
# simply return the non-PGLE profiled module from the persistent cache if it
|
|
|
|
# exists, and otherwise compile it.
|
2024-11-14 10:38:53 +00:00
|
|
|
#
|
|
|
|
# If the compilation_cache_expect_pgle option is set then in case 1 the PGLE
|
|
|
|
# optimized module will be loaded even if PGLE is not enabled in the current
|
|
|
|
# process. This is useful if we want to combine the use of PGLE with other
|
|
|
|
# profiling tools (e.g. Nsight Systems) that cannot co-exist with PGLE due to
|
|
|
|
# contention for CUPTI resources.
|
2025-03-05 15:32:25 +00:00
|
|
|
def _resolve_compilation_strategy(
|
2024-12-10 09:36:44 -08:00
|
|
|
computation: ir.Module,
|
|
|
|
devices: np.ndarray,
|
|
|
|
compile_options: xc.CompileOptions,
|
|
|
|
backend: xc.Client,
|
|
|
|
pgle_profiler: profiler.PGLEProfiler | None,
|
|
|
|
is_multi_process: bool,
|
|
|
|
module_name: str,
|
|
|
|
min_device_process_id: int,
|
2025-03-05 15:32:25 +00:00
|
|
|
) -> tuple[str | None, xc.CompileOptions]:
|
|
|
|
is_auto_pgle_used = (
|
|
|
|
config.enable_pgle.value and config.pgle_profiling_runs.value > 0
|
2024-12-10 09:36:44 -08:00
|
|
|
)
|
2025-03-05 15:32:25 +00:00
|
|
|
|
|
|
|
get_cache_key = partial(_get_cache_key, backend=backend,
|
|
|
|
computation=computation, devices=devices)
|
|
|
|
|
|
|
|
if is_auto_pgle_used or config.compilation_cache_expect_pgle.value:
|
|
|
|
# This can be None if cache key generation fails.
|
|
|
|
pgle_optimized_cache_key = get_cache_key(compile_options,
|
|
|
|
override_fdo_profile=b"pgle profiled")
|
|
|
|
# TODO(b/376647494): remove the workaround when the bug is fixed; the JAX
|
|
|
|
# profiler cannot collect sufficiently detailed profile data for PGLE if
|
|
|
|
# command buffers / CUDA graphs are enabled. Therefore disable command
|
|
|
|
# buffers when compiling for PGLE data collection, but not if AutoPGLE is
|
|
|
|
# not enabled, and not when re-compiling using PGLE data. This condition
|
|
|
|
# includes `compilation_cache_expect_pgle` so that slow-to-compile modules
|
|
|
|
# that are not executed often enough to trigger re-compilation will still
|
|
|
|
# be cached between an "enable_pgle" run and an "expect_pgle" run.
|
|
|
|
first_pass_compile_options = copy.deepcopy(compile_options)
|
|
|
|
first_pass_compile_options.env_option_overrides += [
|
|
|
|
("xla_gpu_enable_command_buffer", ""),
|
|
|
|
]
|
2024-11-14 10:38:53 +00:00
|
|
|
else:
|
2025-03-05 15:32:25 +00:00
|
|
|
pgle_optimized_cache_key = None
|
|
|
|
first_pass_compile_options = compile_options
|
|
|
|
|
|
|
|
# This can be None if cache key generation fails or caching is disabled
|
|
|
|
cache_key = get_cache_key(first_pass_compile_options)
|
|
|
|
|
|
|
|
if cache_key is not None and pgle_optimized_cache_key is not None:
|
|
|
|
# The compilation cache is enabled and AutoPGLE is enabled/expected
|
|
|
|
if _is_executable_in_cache(backend, pgle_optimized_cache_key):
|
|
|
|
if config.compilation_cache_expect_pgle.value:
|
|
|
|
logging.info(f"PGLE-optimized {module_name} loaded from compilation cache")
|
|
|
|
# No need to record N profiles in this case
|
|
|
|
if pgle_profiler is not None:
|
|
|
|
pgle_profiler.disable()
|
|
|
|
return pgle_optimized_cache_key, compile_options
|
|
|
|
elif (config.compilation_cache_expect_pgle.value
|
|
|
|
and _is_executable_in_cache(backend, cache_key)):
|
|
|
|
# No PGLE-optimized module found in the persistent cache, and the user
|
|
|
|
# asserted (expect_pgle) that this miss was unexpected
|
2024-11-14 10:38:53 +00:00
|
|
|
warnings.warn(f"PERSISTENT CACHE MISS for PGLE-optimized {module_name} "
|
|
|
|
"despite non-PGLE hit; it may not have been executed "
|
|
|
|
"enough times when the cache was populated")
|
2025-03-05 15:32:25 +00:00
|
|
|
|
|
|
|
if (is_auto_pgle_used
|
|
|
|
and compile_options.executable_build_options.fdo_profile is not None
|
|
|
|
and len(compile_options.executable_build_options.fdo_profile)):
|
|
|
|
# Profile data are available to trigger a PGLE-optimized recompilation;
|
|
|
|
# store under `pgle_optimized_cache_key` if the cache is enabled
|
|
|
|
if is_multi_process and distributed.global_state.client is not None:
|
|
|
|
compile_options.executable_build_options.fdo_profile = (
|
|
|
|
_share_fdo_profiles(
|
|
|
|
computation,
|
|
|
|
devices,
|
|
|
|
compile_options,
|
|
|
|
backend,
|
|
|
|
distributed.global_state.client,
|
|
|
|
min_device_process_id,
|
2024-11-14 10:38:53 +00:00
|
|
|
)
|
2025-03-05 15:32:25 +00:00
|
|
|
)
|
|
|
|
return pgle_optimized_cache_key, compile_options
|
|
|
|
else:
|
|
|
|
# Compile for PGLE collection, store under `cache_key` if the cache is
|
|
|
|
# enabled. This is also the AutoPGLE-disabled path.
|
|
|
|
return cache_key, first_pass_compile_options
|
2024-12-10 09:36:44 -08:00
|
|
|
|
2025-03-05 15:32:25 +00:00
|
|
|
def _get_cache_key(
|
|
|
|
options: xc.CompileOptions,
|
|
|
|
backend: xc.Client,
|
|
|
|
computation: ir.Module,
|
|
|
|
devices: np.ndarray,
|
|
|
|
override_fdo_profile: bytes | None = None) -> str | None:
|
|
|
|
if not compilation_cache.is_cache_used(backend):
|
|
|
|
return None
|
|
|
|
if config.remove_custom_partitioning_ptr_from_cache_key.value:
|
|
|
|
ignore_callbacks = cache_key_type.IgnoreCallbacks.CUSTOM_PARTITIONING
|
|
|
|
else:
|
|
|
|
ignore_callbacks = cache_key_type.IgnoreCallbacks.NO
|
|
|
|
if override_fdo_profile is not None:
|
|
|
|
options = copy.deepcopy(options)
|
|
|
|
options.executable_build_options.fdo_profile = override_fdo_profile
|
|
|
|
try:
|
|
|
|
return compilation_cache.get_cache_key(
|
|
|
|
computation,
|
|
|
|
devices,
|
|
|
|
options,
|
|
|
|
backend,
|
|
|
|
ignore_callbacks,
|
|
|
|
)
|
|
|
|
except xc._xla.XlaRuntimeError as ex:
|
|
|
|
logger.error("compile_or_get_cached: unable to generate cache key, "
|
|
|
|
"skipping the cache: %s", ex)
|
|
|
|
return None
|
2024-12-10 09:36:44 -08:00
|
|
|
|
2024-05-29 01:49:06 -07:00
|
|
|
# The process that has the lowest device ID should share FDO profile before
|
|
|
|
# compilation with other processes.
|
|
|
|
def _share_fdo_profiles(
|
|
|
|
computation: ir.Module,
|
|
|
|
devices: np.ndarray,
|
|
|
|
compile_options: xc.CompileOptions,
|
|
|
|
backend: xc.Client,
|
|
|
|
global_client: lib.xla_extension.DistributedRuntimeClient,
|
|
|
|
min_process_id
|
2024-06-26 14:44:52 -04:00
|
|
|
) -> bytes | None:
|
2024-05-29 01:49:06 -07:00
|
|
|
sym_name = computation.operation.attributes['sym_name']
|
|
|
|
module_name = ir.StringAttr(sym_name).value
|
|
|
|
fdo_profile = compile_options.executable_build_options.fdo_profile
|
|
|
|
if fdo_profile is None or len(fdo_profile) == 0:
|
|
|
|
return fdo_profile
|
|
|
|
|
|
|
|
compile_options.executable_build_options.fdo_profile = b""
|
2024-12-10 09:36:44 -08:00
|
|
|
try:
|
|
|
|
profile_key = (
|
|
|
|
compilation_cache.get_cache_key(
|
|
|
|
computation,
|
|
|
|
devices,
|
|
|
|
compile_options,
|
|
|
|
backend,
|
|
|
|
cache_key_type.IgnoreCallbacks.ALL,
|
|
|
|
)
|
|
|
|
+ "_fdo_sync"
|
|
|
|
)
|
|
|
|
except xc._xla.XlaRuntimeError as ex:
|
|
|
|
logger.error(
|
|
|
|
"compile_or_get_cached: unable to generate cache key, "
|
|
|
|
"skipping the fdo profile sharing: %s",
|
|
|
|
ex,
|
|
|
|
)
|
|
|
|
return fdo_profile
|
|
|
|
|
2024-05-29 01:49:06 -07:00
|
|
|
if profile_key in _share_fdo_profiles.modules_profiles:
|
|
|
|
return _share_fdo_profiles.modules_profiles[profile_key]
|
|
|
|
|
|
|
|
share_timeout = config.share_binary_between_hosts_timeout_ms.value
|
|
|
|
if distributed.global_state.process_id == min_process_id:
|
|
|
|
logger.debug(
|
2024-12-10 09:36:44 -08:00
|
|
|
"Module %s. Sharing FDO profile. Process %d.",
|
2024-05-29 01:49:06 -07:00
|
|
|
module_name,
|
|
|
|
min_process_id,
|
|
|
|
)
|
|
|
|
global_client.key_value_set_bytes(profile_key, fdo_profile)
|
|
|
|
else:
|
|
|
|
logger.debug(
|
2024-12-10 09:36:44 -08:00
|
|
|
"Module %s. Waiting for FDO profile which should be set by process %d.",
|
2024-05-29 01:49:06 -07:00
|
|
|
module_name,
|
|
|
|
min_process_id,
|
|
|
|
)
|
|
|
|
fdo_profile = global_client.blocking_key_value_get_bytes(
|
|
|
|
profile_key, share_timeout
|
|
|
|
)
|
|
|
|
|
|
|
|
_share_fdo_profiles.modules_profiles[profile_key] = fdo_profile
|
|
|
|
return fdo_profile
|
|
|
|
|
|
|
|
|
|
|
|
_share_fdo_profiles.modules_profiles = {}
|
|
|
|
|
2024-05-24 00:21:36 -07:00
|
|
|
# The process with the first_process_id should compile the module and write it
|
|
|
|
# to the K-V storage.
|
2024-02-06 01:27:21 -08:00
|
|
|
def _compile_and_share_module(
|
|
|
|
backend: xc.Client,
|
|
|
|
computation: ir.Module,
|
|
|
|
compile_options: xc.CompileOptions,
|
|
|
|
host_callbacks: Sequence[Any],
|
|
|
|
global_client: lib.xla_extension.DistributedRuntimeClient,
|
|
|
|
module_name: str,
|
|
|
|
cache_key: str,
|
2024-05-14 03:45:05 -07:00
|
|
|
first_process_id: int
|
2024-02-06 01:27:21 -08:00
|
|
|
) -> xc.LoadedExecutable:
|
|
|
|
share_timeout = config.share_binary_between_hosts_timeout_ms.value
|
|
|
|
|
|
|
|
if cache_key in _compile_and_share_module.modules_cache:
|
|
|
|
return _compile_and_share_module.modules_cache[cache_key]
|
|
|
|
|
2024-05-14 03:45:05 -07:00
|
|
|
if distributed.global_state.process_id == first_process_id:
|
|
|
|
logger.debug("Process %d compiling and sharing module: %s",
|
|
|
|
first_process_id, module_name)
|
2024-02-06 01:27:21 -08:00
|
|
|
executable = _compile_and_write_cache(
|
|
|
|
backend,
|
|
|
|
computation,
|
|
|
|
compile_options,
|
|
|
|
host_callbacks,
|
|
|
|
module_name,
|
|
|
|
cache_key,
|
|
|
|
)
|
|
|
|
serialized_executable = backend.serialize_executable(executable)
|
|
|
|
serialized_executable = compilation_cache.compress_executable(
|
|
|
|
serialized_executable
|
|
|
|
)
|
|
|
|
global_client.key_value_set_bytes(cache_key, serialized_executable)
|
|
|
|
else:
|
2024-05-14 03:45:05 -07:00
|
|
|
logger.debug("Waiting for module: %s from process %d", module_name,
|
|
|
|
first_process_id)
|
2024-02-06 01:27:21 -08:00
|
|
|
serialized_executable = global_client.blocking_key_value_get_bytes(
|
|
|
|
cache_key, share_timeout
|
|
|
|
)
|
|
|
|
serialized_executable = compilation_cache.decompress_executable(
|
|
|
|
serialized_executable
|
|
|
|
)
|
|
|
|
executable = backend.deserialize_executable(
|
|
|
|
serialized_executable, compile_options
|
|
|
|
)
|
|
|
|
|
|
|
|
_compile_and_share_module.modules_cache[cache_key] = executable
|
|
|
|
return executable
|
|
|
|
|
|
|
|
_compile_and_share_module.modules_cache = {}
|
|
|
|
|
|
|
|
def _compile_and_write_cache(
|
|
|
|
backend: xc.Client,
|
|
|
|
computation: ir.Module,
|
|
|
|
compile_options: xc.CompileOptions,
|
|
|
|
host_callbacks: Sequence[Any],
|
|
|
|
module_name: str,
|
|
|
|
cache_key: str,
|
|
|
|
) -> xc.LoadedExecutable:
|
|
|
|
start_time = time.monotonic()
|
|
|
|
executable = backend_compile(
|
|
|
|
backend, computation, compile_options, host_callbacks
|
|
|
|
)
|
|
|
|
compile_time = time.monotonic() - start_time
|
|
|
|
_cache_write(
|
|
|
|
cache_key, compile_time, module_name, backend, executable, host_callbacks
|
|
|
|
)
|
|
|
|
return executable
|
2023-08-15 06:38:56 -07:00
|
|
|
|
2024-06-04 13:29:50 -07:00
|
|
|
def _is_executable_in_cache(backend, cache_key) -> bool:
|
2024-05-29 01:49:06 -07:00
|
|
|
"""Checks if executable is presented in cache on a given key
|
|
|
|
"""
|
|
|
|
try:
|
2024-06-04 13:29:50 -07:00
|
|
|
return compilation_cache.is_executable_in_cache(backend, cache_key)
|
2024-05-29 01:49:06 -07:00
|
|
|
except Exception as ex:
|
|
|
|
if config.raise_persistent_cache_errors.value:
|
|
|
|
raise
|
|
|
|
warnings.warn(
|
|
|
|
f"Error reading persistent compilation cache entry for "
|
|
|
|
f"'{cache_key}': {type(ex).__name__}: {ex}")
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
2023-08-15 06:38:56 -07:00
|
|
|
def _cache_read(
|
|
|
|
module_name: str, cache_key: str, compile_options: xc.CompileOptions,
|
|
|
|
backend: xc.Client
|
|
|
|
) -> tuple[xc.LoadedExecutable | None, int | None]:
|
|
|
|
"""Looks up the `computation` and it's compilation time in the persistent
|
|
|
|
compilation cache repository.
|
|
|
|
"""
|
|
|
|
try:
|
|
|
|
return compilation_cache.get_executable_and_time(
|
|
|
|
cache_key, compile_options, backend)
|
|
|
|
except Exception as ex:
|
2023-10-09 07:28:18 -07:00
|
|
|
if config.raise_persistent_cache_errors.value:
|
2023-08-15 06:38:56 -07:00
|
|
|
raise
|
|
|
|
warnings.warn(
|
|
|
|
f"Error reading persistent compilation cache entry for "
|
|
|
|
f"'{module_name}': {type(ex).__name__}: {ex}")
|
|
|
|
return None, None
|
|
|
|
|
|
|
|
|
|
|
|
def _cache_write(cache_key: str,
|
|
|
|
compile_time_secs: float,
|
|
|
|
module_name: str,
|
|
|
|
backend: xc.Client, executable: xc.LoadedExecutable,
|
|
|
|
host_callbacks: Sequence[Any]) -> None:
|
|
|
|
"""Writes the `serialized_computation` and its compilation time to the
|
|
|
|
persistent compilation cache repository.
|
|
|
|
"""
|
2023-10-09 13:54:32 -07:00
|
|
|
# Only write cache entries from the first process. Otherwise we create
|
|
|
|
# problems with contention for writes on some filesystems, e.g., GCS.
|
2024-07-11 01:11:18 +00:00
|
|
|
log_priority = (logging.WARNING
|
|
|
|
if config.explain_cache_misses.value
|
|
|
|
and compilation_cache.is_persistent_cache_enabled()
|
|
|
|
else logging.DEBUG)
|
2024-02-15 10:47:10 -08:00
|
|
|
if distributed.global_state.process_id != 0:
|
2024-07-11 01:11:18 +00:00
|
|
|
logger.log(log_priority,
|
|
|
|
"Not writing persistent cache entry since process_id != 0")
|
2023-10-09 13:54:32 -07:00
|
|
|
return
|
|
|
|
|
2023-08-15 06:38:56 -07:00
|
|
|
if host_callbacks:
|
2024-07-11 01:11:18 +00:00
|
|
|
logger.log(
|
|
|
|
log_priority,
|
2023-08-15 06:38:56 -07:00
|
|
|
"Not writing persistent cache entry for '%s' because it uses host "
|
|
|
|
"callbacks (e.g. from jax.debug.print or breakpoint)", module_name)
|
|
|
|
return
|
|
|
|
|
2023-10-09 07:28:18 -07:00
|
|
|
min_compile_time = config.persistent_cache_min_compile_time_secs.value
|
2023-11-06 14:03:48 -08:00
|
|
|
if compile_time_secs < min_compile_time:
|
2024-07-11 01:11:18 +00:00
|
|
|
logger.log(
|
|
|
|
log_priority,
|
2023-11-06 14:03:48 -08:00
|
|
|
"Not writing persistent cache entry for '%s' because it took < %.2f "
|
|
|
|
"seconds to compile (%.2fs)", module_name, min_compile_time,
|
|
|
|
compile_time_secs)
|
|
|
|
return
|
|
|
|
else:
|
|
|
|
logger.debug(
|
2024-01-08 14:02:54 -08:00
|
|
|
"'%s' took at least %.2f seconds to compile (%.2fs)",
|
|
|
|
module_name, min_compile_time, compile_time_secs)
|
2023-08-15 06:38:56 -07:00
|
|
|
|
|
|
|
try:
|
|
|
|
compilation_cache.put_executable_and_time(
|
|
|
|
cache_key, module_name, executable, backend, int(compile_time_secs))
|
|
|
|
except Exception as ex:
|
2023-10-09 07:28:18 -07:00
|
|
|
if config.raise_persistent_cache_errors.value:
|
2023-08-15 06:38:56 -07:00
|
|
|
raise
|
|
|
|
warnings.warn(
|
|
|
|
f"Error writing persistent compilation cache entry for "
|
|
|
|
f"'{module_name}': {type(ex).__name__}: {ex}")
|