Move compiler APIs out of dispatch.py and xla_bridge.py into a new jax._src.compiler module.

Refactoring only, no user-visible changes intended.

PiperOrigin-RevId: 557116160
This commit is contained in:
Peter Hawkins 2023-08-15 06:38:56 -07:00 committed by jax authors
parent 423c8d8d4f
commit a259df0d76
13 changed files with 408 additions and 327 deletions

View File

@ -190,6 +190,7 @@ py_library_providing_imports_info(
":basearray",
":cloud_tpu_init",
":compilation_cache_internal",
":compiler",
":config",
":core",
":custom_api_util",
@ -320,6 +321,20 @@ pytype_strict_library(
srcs = ["_src/logging_config.py"],
)
pytype_strict_library(
name = "compiler",
srcs = ["_src/compiler.py"],
deps = [
":compilation_cache_internal",
":config",
":monitoring",
":path",
":profiler",
":traceback_util",
"//jax/_src/lib",
] + py_deps("numpy"),
)
pytype_strict_library(
name = "core",
srcs = [
@ -727,7 +742,7 @@ pytype_strict_library(
":traceback_util",
":util",
"//jax/_src/lib",
] + py_deps("numpy"),
],
)
# Public JAX libraries below this point.

342
jax/_src/compiler.py Normal file
View File

@ -0,0 +1,342 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Interface to the compiler
from __future__ import annotations
from collections.abc import Sequence
import io
import itertools
import time
from typing import Any
import logging
import os
import re
import warnings
import numpy as np
from jax._src import lib
from jax._src import compilation_cache
from jax._src import config as jax_config
from jax._src import monitoring
from jax._src import path
from jax._src import profiler
from jax._src import traceback_util
from jax._src.config import config
from jax._src.lib.mlir import ir
from jax._src.lib import xla_client as xc
_DISABLE_MOST_OPTIMIZATIONS = jax_config.DEFINE_bool(
'jax_disable_most_optimizations',
jax_config.bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
'Try not to do much optimization work. This can be useful if the cost of '
'optimization is greater than that of running a less-optimized program.')
_DUMP_IR_TO = jax_config.DEFINE_string(
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
help="Path to which the IR that is emitted by JAX as input to the "
"compiler should be dumped as text files. Optional. If omitted, JAX "
"will not dump IR.")
traceback_util.register_exclusion(__file__)
CompileOptions = xc.CompileOptions
logger = logging.getLogger(__name__)
# Will be monkeypatched with the function that gets the XLA-AutoFDO profile
# version. The default (-1) takes care of errors.
def get_latest_profile_version() -> int:
return -1
def get_compile_options(
num_replicas: int,
num_partitions: int,
device_assignment=None,
use_spmd_partitioning: bool = True,
use_auto_spmd_partitioning: bool = False,
auto_spmd_partitioning_mesh_shape: list[int] | None = None,
auto_spmd_partitioning_mesh_ids: list[int] | None = None,
env_options_overrides: dict[str, str] | None = None,
fdo_profile: bytes | None = None,
) -> xc.CompileOptions:
"""Returns the compile options to use, as derived from flag values.
Args:
num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile.
device_assignment: Optional ndarray of jax devices indicating the assignment
of logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA.
use_auto_spmd_partitioning: boolean indicating whether to automatically
generate XLA shardings for SPMD partitioner.
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
auto_spmd_partitioning search space.
auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space.
env_options_overrides: dict of additional options parsed by the compiler
fdo_profile: Optional profile for feedback-directed optimization passed to
XLA.
"""
compile_options = xc.CompileOptions()
compile_options.num_replicas = num_replicas
compile_options.num_partitions = num_partitions
build_options = compile_options.executable_build_options
build_options.use_spmd_partitioning = use_spmd_partitioning
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
if fdo_profile is not None:
build_options.fdo_profile = fdo_profile
if use_auto_spmd_partitioning:
build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape or []
build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids or []
if device_assignment is not None:
logger.debug(
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
num_replicas, num_partitions, device_assignment)
device_assignment = np.array(device_assignment)
# Allow 1D device assignment if num_partitions is 1.
if (device_assignment.ndim == 1) and (num_partitions == 1):
device_assignment = device_assignment[:, None]
if num_replicas != device_assignment.shape[0]:
msg = 'device_assignment does not match num_replicas: {} vs {}.'
raise ValueError(msg.format(device_assignment, num_replicas))
if num_partitions != device_assignment.shape[1]:
msg = 'device_assignment does not match num_partitions: {} vs {}.'
raise ValueError(msg.format(device_assignment, num_partitions))
if device_assignment.dtype == object:
device_assignment = np.vectorize(lambda d: d.id, otypes=[int])(
device_assignment)
device_assignment = xc.DeviceAssignment.create(device_assignment)
assert device_assignment.replica_count() == num_replicas
assert device_assignment.computation_count() == num_partitions
compile_options.device_assignment = device_assignment
if env_options_overrides is not None:
compile_options.env_option_overrides = list(env_options_overrides.items())
debug_options = compile_options.executable_build_options.debug_options
if lib.cuda_path is not None:
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path
if _DISABLE_MOST_OPTIMIZATIONS.value:
debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False
# XLA-AutoFDO profile version: precedence order is:
# 1. Whatever --jax_xla_profile_version is set to.
# 2. If --jax_xla_profile_version is not set (i.e., 0), call the function
# set in get_latest_profile_version and use the return value if non-zero.
# If the function returns 0, set -1; this is an error.
# -1 indicates that no attempt should be made to retrieve the latest profile
# later on.
jax_xla_profile_version = config.jax_xla_profile_version
if jax_xla_profile_version > 0:
compile_options.profile_version = jax_xla_profile_version
logger.debug("get_compile_options XLA-AutoFDO profile: " +
"using JAX XLA profile version %d from flag",
jax_xla_profile_version)
else:
fdo_profile_version = get_latest_profile_version()
if fdo_profile_version != 0:
compile_options.profile_version = fdo_profile_version
logger.debug("get_compile_options XLA-AutoFDO profile: " +
"using XLA-AutoFDO profile version %d",
fdo_profile_version)
else:
no_profile_dont_retrieve = -1
compile_options.profile_version = no_profile_dont_retrieve
logger.error("get_compile_options XLA-AutoFDO profile: " +
"XLA-AutoFDO profile version is 0; this should not happen")
return compile_options
def _module_to_string(module: ir.Module) -> str:
output = io.StringIO()
module.operation.print(file=output, enable_debug_info=True)
return output.getvalue()
def _module_to_bytecode(module: ir.Module) -> bytes:
output = io.BytesIO()
module.operation.write_bytecode(file=output)
return output.getvalue()
@profiler.annotate_function
def backend_compile(
backend: xc.Client,
module: ir.Module,
options: xc.CompileOptions,
host_callbacks: Sequence[Any],
) -> xc.LoadedExecutable:
# Convert ir.Module to a string representation, unless the
# back-end expliclity flags the ability to handle a module directly
# (avoiding the overhead of back and forth conversions)
if getattr(backend, "needs_str_ir", True):
built_c = _module_to_bytecode(module)
else:
built_c = module
# we use a separate function call to ensure that XLA compilation appears
# separately in Python profiling results
if host_callbacks:
return backend.compile(built_c, compile_options=options,
host_callbacks=host_callbacks)
# Some backends don't have `host_callbacks` option yet
# TODO(sharadmv): remove this fallback when all backends allow `compile`
# to take in `host_callbacks`
return backend.compile(built_c, compile_options=options)
_ir_dump_counter = itertools.count()
def _make_string_safe_for_filename(s: str) -> str:
return re.sub(r'[^\w.)( -]', '', s)
def _dump_ir_to_file(name: str, ir: str):
id = next(_ir_dump_counter)
name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir"
name = path.Path(_DUMP_IR_TO.value) / name
name.write_text(ir)
def compile_or_get_cached(
backend: xc.Client,
computation: ir.Module,
devices: np.ndarray,
compile_options: xc.CompileOptions,
host_callbacks: Sequence[Any],
) -> xc.LoadedExecutable:
sym_name = computation.operation.attributes['sym_name']
module_name = ir.StringAttr(sym_name).value
if _DUMP_IR_TO.value:
_dump_ir_to_file(module_name, _module_to_string(computation))
# Persistent compilation cache only implemented on TPU and GPU.
# TODO(skye): add warning when initializing cache on unsupported default platform
supported_platforms = ["tpu", "gpu"]
# (b/233850967) CPU caching can be enabled if XLA Runtime is enabled.
if "--xla_cpu_use_xla_runtime=true" in os.environ.get("XLA_FLAGS", ""):
supported_platforms.append("cpu")
use_compilation_cache = (compilation_cache.is_initialized() and
backend.platform in supported_platforms)
if not use_compilation_cache:
return backend_compile(backend, computation, compile_options,
host_callbacks)
cache_key = compilation_cache.get_cache_key(
computation, devices, compile_options, backend,
jax_config.config.jax_use_original_compilation_cache_key_generation,
)
cache_retrieval_start = time.monotonic()
retrieved_executable, retrieved_compile_time = _cache_read(
module_name, cache_key, compile_options, backend)
cache_retrieval_time = time.monotonic() - cache_retrieval_start
if retrieved_executable is not None:
assert retrieved_compile_time is not None
logger.info("Persistent compilation cache hit for '%s'", module_name)
monitoring.record_event_duration_secs(
"/jax/compilation_cache/cache_retrieval_time_sec", cache_retrieval_time)
# TODO(b/293308239) Instrument a metric for new cache savings once the
# enabling flag is added.
# TODO(b/293308239) Remove the metric for original cache savings after the
# new compilation cache key implementation is fully rolled out.
monitoring.record_event_duration_secs(
"/jax/compilation_cache/original_compile_time_saved_sec",
retrieved_compile_time - cache_retrieval_time)
return retrieved_executable
else:
start_time = time.monotonic()
executable = backend_compile(backend, computation,
compile_options, host_callbacks)
compile_time = time.monotonic() - start_time
_cache_write(cache_key, compile_time, module_name, backend, executable,
host_callbacks)
return executable
def _cache_read(
module_name: str, cache_key: str, compile_options: xc.CompileOptions,
backend: xc.Client
) -> tuple[xc.LoadedExecutable | None, int | None]:
"""Looks up the `computation` and it's compilation time in the persistent
compilation cache repository.
"""
try:
return compilation_cache.get_executable_and_time(
cache_key, compile_options, backend)
except Exception as ex:
if config.jax_raise_persistent_cache_errors:
raise
warnings.warn(
f"Error reading persistent compilation cache entry for "
f"'{module_name}': {type(ex).__name__}: {ex}")
return None, None
def _cache_write(cache_key: str,
compile_time_secs: float,
module_name: str,
backend: xc.Client, executable: xc.LoadedExecutable,
host_callbacks: Sequence[Any]) -> None:
"""Writes the `serialized_computation` and its compilation time to the
persistent compilation cache repository.
"""
if host_callbacks:
logger.info(
"Not writing persistent cache entry for '%s' because it uses host "
"callbacks (e.g. from jax.debug.print or breakpoint)", module_name)
return
min_compile_time = config.jax_persistent_cache_min_compile_time_secs
if min_compile_time:
if compile_time_secs < min_compile_time:
logger.info(
"Not writing persistent cache entry for '%s' because it took < %.2f "
"seconds to compile (%.2fs)", module_name, min_compile_time,
compile_time_secs)
return
else:
logger.info(
"'%s' took at least %.2f seconds to compile (%.2fs), writing "
"persistent cache entry", module_name, min_compile_time,
compile_time_secs)
try:
compilation_cache.put_executable_and_time(
cache_key, module_name, executable, backend, int(compile_time_secs))
except Exception as ex:
if config.jax_raise_persistent_cache_errors:
raise
warnings.warn(
f"Error writing persistent compilation cache entry for "
f"'{module_name}': {type(ex).__name__}: {ex}")

View File

@ -24,22 +24,15 @@ import itertools
import time
from typing import Any, Callable, NamedTuple
import logging
import os
import re
import threading
import warnings
import numpy as np
from jax._src import basearray
from jax._src import compilation_cache
from jax._src import config as jax_config
from jax._src import core
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import api_util
from jax._src import path
from jax._src import profiler
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
@ -50,7 +43,6 @@ from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax._src.lib.mlir import ir
from jax._src.lib import xla_client as xc
from jax._src.monitoring import record_event_duration_secs
from jax._src.partition_spec import PartitionSpec
@ -64,13 +56,6 @@ JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
_DUMP_IR_TO = jax_config.DEFINE_string(
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
help="Path to which the IR that is emitted by JAX as input to the "
"compiler should be dumped as text files. Optional. If omitted, JAX "
"will not dump IR.")
traceback_util.register_exclusion(__file__)
xe = xc._xla
@ -404,162 +389,6 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
if config.jax_debug_infs and np.any(np.isinf(np.asarray(buf))):
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
@profiler.annotate_function
def backend_compile(
backend: Backend,
module: ir.Module,
options: xc.CompileOptions,
host_callbacks: Sequence[Any],
) -> xc.LoadedExecutable:
# Convert ir.Module to a string representation, unless the
# back-end expliclity flags the ability to handle a module directly
# (avoiding the overhead of back and forth conversions)
if getattr(backend, "needs_str_ir", True):
built_c = mlir.module_to_bytecode(module)
else:
built_c = module
# we use a separate function call to ensure that XLA compilation appears
# separately in Python profiling results
if host_callbacks:
return backend.compile(built_c, compile_options=options,
host_callbacks=host_callbacks)
# Some backends don't have `host_callbacks` option yet
# TODO(sharadmv): remove this fallback when all backends allow `compile`
# to take in `host_callbacks`
return backend.compile(built_c, compile_options=options)
_ir_dump_counter = itertools.count()
def _make_string_safe_for_filename(s: str) -> str:
return re.sub(r'[^\w.)( -]', '', s)
def _dump_ir_to_file(name: str, ir: str):
id = next(_ir_dump_counter)
name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir"
name = path.Path(_DUMP_IR_TO.value) / name
name.write_text(ir)
def compile_or_get_cached(
backend: Backend,
computation: ir.Module,
devices: np.ndarray,
compile_options: xc.CompileOptions,
host_callbacks: Sequence[Any],
) -> xc.LoadedExecutable:
sym_name = computation.operation.attributes['sym_name']
module_name = ir.StringAttr(sym_name).value
if _DUMP_IR_TO.value:
_dump_ir_to_file(module_name, mlir.module_to_string(computation))
# Persistent compilation cache only implemented on TPU and GPU.
# TODO(skye): add warning when initializing cache on unsupported default platform
supported_platforms = ["tpu", "gpu"]
# (b/233850967) CPU caching can be enabled if XLA Runtime is enabled.
if "--xla_cpu_use_xla_runtime=true" in os.environ.get("XLA_FLAGS", ""):
supported_platforms.append("cpu")
use_compilation_cache = (compilation_cache.is_initialized() and
backend.platform in supported_platforms)
if not use_compilation_cache:
return backend_compile(backend, computation, compile_options,
host_callbacks)
cache_key = compilation_cache.get_cache_key(
computation, devices, compile_options, backend,
jax_config.config.jax_use_original_compilation_cache_key_generation,
)
cache_retrieval_start = time.monotonic()
retrieved_executable, retrieved_compile_time = _cache_read(
module_name, cache_key, compile_options, backend)
cache_retrieval_time = time.monotonic() - cache_retrieval_start
if retrieved_executable is not None:
assert retrieved_compile_time is not None
logger.info("Persistent compilation cache hit for '%s'", module_name)
record_event_duration_secs(
"/jax/compilation_cache/cache_retrieval_time_sec", cache_retrieval_time)
# TODO(b/293308239) Instrument a metric for new cache savings once the
# enabling flag is added.
# TODO(b/293308239) Remove the metric for original cache savings after the
# new compilation cache key implementation is fully rolled out.
record_event_duration_secs(
"/jax/compilation_cache/original_compile_time_saved_sec",
retrieved_compile_time - cache_retrieval_time)
return retrieved_executable
else:
start_time = time.monotonic()
executable = backend_compile(backend, computation,
compile_options, host_callbacks)
compile_time = time.monotonic() - start_time
_cache_write(cache_key, compile_time, module_name, backend, executable,
host_callbacks)
return executable
def _cache_read(
module_name: str, cache_key: str, compile_options: xc.CompileOptions,
backend: Backend
) -> tuple[xc.LoadedExecutable | None, int | None]:
"""Looks up the `computation` and it's compilation time in the persistent
compilation cache repository.
"""
try:
return compilation_cache.get_executable_and_time(
cache_key, compile_options, backend)
except Exception as ex:
if config.jax_raise_persistent_cache_errors:
raise
warnings.warn(
f"Error reading persistent compilation cache entry for "
f"'{module_name}': {type(ex).__name__}: {ex}")
return None, None
def _cache_write(cache_key: str,
compile_time_secs: float,
module_name: str,
backend: Backend, executable: xc.LoadedExecutable,
host_callbacks: Sequence[Any]) -> None:
"""Writes the `serialized_computation` and its compilation time to the
persistent compilation cache repository.
"""
if host_callbacks:
logger.info(
"Not writing persistent cache entry for '%s' because it uses host "
"callbacks (e.g. from jax.debug.print or breakpoint)", module_name)
return
min_compile_time = config.jax_persistent_cache_min_compile_time_secs
if min_compile_time:
if compile_time_secs < min_compile_time:
logger.info(
"Not writing persistent cache entry for '%s' because it took < %.2f "
"seconds to compile (%.2fs)", module_name, min_compile_time,
compile_time_secs)
return
else:
logger.info(
"'%s' took at least %.2f seconds to compile (%.2fs), writing "
"persistent cache entry", module_name, min_compile_time,
compile_time_secs)
try:
compilation_cache.put_executable_and_time(
cache_key, module_name, executable, backend, int(compile_time_secs))
except Exception as ex:
if config.jax_raise_persistent_cache_errors:
raise
warnings.warn(
f"Error writing persistent compilation cache entry for "
f"'{module_name}': {type(ex).__name__}: {ex}")
# TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that
# to check if shardings are compatible with the input.
def _check_sharding(aval: core.AbstractValue, s: Sharding):

View File

@ -34,6 +34,7 @@ from jax.errors import JAXTypeError
from jax._src import api_util
from jax._src import core
from jax._src import compiler
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
@ -903,7 +904,7 @@ class UnloadedPmapExecutable:
num_partitions = 1
device_assignment: np.ndarray = np.array(devices).reshape(
(replicas.num_global_replicas, num_partitions))
compile_options = xb.get_compile_options(
compile_options = compiler.get_compile_options(
num_replicas=replicas.num_global_replicas,
num_partitions=num_partitions,
device_assignment=device_assignment,
@ -953,7 +954,7 @@ class UnloadedPmapExecutable:
with dispatch.log_elapsed_time(
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT):
compiled = dispatch.compile_or_get_cached(
compiled = compiler.compile_or_get_cached(
pci.backend, hlo, device_assignment, compile_options,
host_callbacks)
@ -2482,7 +2483,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
fdo_profile = (None if compiler_options is None else
compiler_options.pop("fdo_profile", None))
compile_options = xb.get_compile_options(
compile_options = compiler.get_compile_options(
num_replicas=num_replicas,
num_partitions=num_partitions,
device_assignment=xla_device_assignment,
@ -2508,7 +2509,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
with dispatch.log_elapsed_time(
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
xla_executable = dispatch.compile_or_get_cached(
xla_executable = compiler.compile_or_get_cached(
backend, computation, dev, compile_options, host_callbacks)
return xla_executable, compile_options

View File

@ -33,12 +33,9 @@ import threading
from typing import Any, Callable, Optional, Union
import warnings
import numpy as np
from jax._src import lib
from jax._src import distributed
from jax._src import config as jax_config
from jax._src.config import bool_env, config
from jax._src.config import config
from jax._src.lib import xla_client
from jax._src import traceback_util
from jax._src import util
@ -73,11 +70,6 @@ _PLATFORM_NAME = jax_config.DEFINE_string(
'jax_platform_name',
os.getenv('JAX_PLATFORM_NAME', '').lower(),
'Deprecated, please use --jax_platforms instead.')
_DISABLE_MOST_OPTIMIZATIONS = jax_config.DEFINE_bool(
'jax_disable_most_optimizations',
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
'Try not to do much optimization work. This can be useful if the cost of '
'optimization is greater than that of running a less-optimized program.')
CUDA_VISIBLE_DEVICES = jax_config.DEFINE_string(
'jax_cuda_visible_devices', 'all',
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
@ -88,122 +80,6 @@ _ROCM_VISIBLE_DEVICES = jax_config.DEFINE_string(
'comma-separate list of integer device IDs.')
# Will be monkeypatched with the function that gets the XLA-AutoFDO profile
# version. The default (-1) takes care of errors.
def get_latest_profile_version() -> int:
return -1
def get_compile_options(
num_replicas: int,
num_partitions: int,
device_assignment=None,
use_spmd_partitioning: bool = True,
use_auto_spmd_partitioning: bool = False,
auto_spmd_partitioning_mesh_shape: Optional[list[int]] = None,
auto_spmd_partitioning_mesh_ids: Optional[list[int]] = None,
env_options_overrides: Optional[dict[str, str]] = None,
fdo_profile: Optional[bytes] = None,
) -> xla_client.CompileOptions:
"""Returns the compile options to use, as derived from flag values.
Args:
num_replicas: Number of replicas for which to compile.
num_partitions: Number of partitions for which to compile.
device_assignment: Optional ndarray of jax devices indicating the assignment
of logical replicas to physical devices (default inherited from
xla_client.CompileOptions). Must be consistent with `num_replicas` and
`num_partitions`.
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
partitioning in XLA.
use_auto_spmd_partitioning: boolean indicating whether to automatically
generate XLA shardings for SPMD partitioner.
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
auto_spmd_partitioning search space.
auto_spmd_partitioning_mesh_ids: device ids used to create
auto_spmd_partitioning search space.
env_options_overrides: dict of additional options parsed by the compiler
fdo_profile: Optional profile for feedback-directed optimization passed to
XLA.
"""
compile_options = xla_client.CompileOptions()
compile_options.num_replicas = num_replicas
compile_options.num_partitions = num_partitions
build_options = compile_options.executable_build_options
build_options.use_spmd_partitioning = use_spmd_partitioning
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
if fdo_profile is not None:
build_options.fdo_profile = fdo_profile
if use_auto_spmd_partitioning:
build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape or []
build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids or []
if device_assignment is not None:
logger.debug(
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
num_replicas, num_partitions, device_assignment)
device_assignment = np.array(device_assignment)
# Allow 1D device assignment if num_partitions is 1.
if (device_assignment.ndim == 1) and (num_partitions == 1):
device_assignment = device_assignment[:, None]
if num_replicas != device_assignment.shape[0]:
msg = 'device_assignment does not match num_replicas: {} vs {}.'
raise ValueError(msg.format(device_assignment, num_replicas))
if num_partitions != device_assignment.shape[1]:
msg = 'device_assignment does not match num_partitions: {} vs {}.'
raise ValueError(msg.format(device_assignment, num_partitions))
if device_assignment.dtype == object:
device_assignment = np.vectorize(lambda d: d.id, otypes=[int])(
device_assignment)
device_assignment = xla_client.DeviceAssignment.create(device_assignment)
assert device_assignment.replica_count() == num_replicas
assert device_assignment.computation_count() == num_partitions
compile_options.device_assignment = device_assignment
if env_options_overrides is not None:
compile_options.env_option_overrides = list(env_options_overrides.items())
debug_options = compile_options.executable_build_options.debug_options
if lib.cuda_path is not None:
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path
if _DISABLE_MOST_OPTIMIZATIONS.value:
debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False
# XLA-AutoFDO profile version: precedence order is:
# 1. Whatever --jax_xla_profile_version is set to.
# 2. If --jax_xla_profile_version is not set (i.e., 0), call the function
# set in get_latest_profile_version and use the return value if non-zero.
# If the function returns 0, set -1; this is an error.
# -1 indicates that no attempt should be made to retrieve the latest profile
# later on.
jax_xla_profile_version = config.jax_xla_profile_version
if jax_xla_profile_version > 0:
compile_options.profile_version = jax_xla_profile_version
logger.debug("get_compile_options XLA-AutoFDO profile: " +
"using JAX XLA profile version %d from flag",
jax_xla_profile_version)
else:
fdo_profile_version = get_latest_profile_version()
if fdo_profile_version != 0:
compile_options.profile_version = fdo_profile_version
logger.debug("get_compile_options XLA-AutoFDO profile: " +
"using XLA-AutoFDO profile version %d",
fdo_profile_version)
else:
no_profile_dont_retrieve = -1
compile_options.profile_version = no_profile_dont_retrieve
logger.error("get_compile_options XLA-AutoFDO profile: " +
"XLA-AutoFDO profile version is 0; this should not happen")
return compile_options
# Backends
def tpu_client_timer_callback(timer_secs: float) -> Optional[xla_client.Client]:

View File

@ -517,6 +517,7 @@ from jax._src.interpreters import mlir
from jax._src.interpreters import partial_eval as pe
from jax._src.interpreters import xla
from jax._src import ad_checkpoint
from jax._src import compiler
from jax._src import dispatch
from jax._src import pretty_printer as pp
from jax._src import sharding_impls
@ -2014,7 +2015,7 @@ def _initialize_outfeed_receiver(
_callback_handler_data.receiver = outfeed_receiver_module.start(
_callback_input_received, tuple(clients_with_outfeed),
max_callback_queue_size_bytes,
xb.get_compile_options(1, 1).executable_build_options) # type:ignore
compiler.get_compile_options(1, 1).executable_build_options) # type:ignore
def exit_handler():
# Prevent logging usage during compilation, gives errors under pytest

View File

@ -40,6 +40,7 @@ from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P
import jax.numpy as jnp
from jax._src import compiler
from jax._src import xla_bridge
import numpy as np
@ -108,7 +109,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
device_assignment = np.arange(num_partitions * num_replicas)
device_assignment = np.reshape(device_assignment, (-1, num_partitions))
use_spmd_partitioning = num_partitions > 1
compile_options = xla_bridge.get_compile_options(
compile_options = compiler.get_compile_options(
num_replicas=num_replicas,
num_partitions=num_partitions,
device_assignment=device_assignment,

View File

@ -35,9 +35,12 @@ from jax._src.core import (
)
# TODO(phawkins): update users.
from jax._src.compiler import (
backend_compile as backend_compile,
)
from jax._src.dispatch import (
apply_primitive as apply_primitive,
backend_compile as backend_compile,
)
from jax._src.sharding_impls import (

View File

@ -16,7 +16,10 @@
from jax._src.xla_bridge import (
default_backend as default_backend,
get_backend as get_backend,
get_compile_options as get_compile_options,
xla_client as xla_client,
_backends as _backends,
)
from jax._src.compiler import (
get_compile_options as get_compile_options,
)

View File

@ -911,6 +911,7 @@ py_test(
data = ["testdata/example_pjrt_plugin_config.json"],
deps = [
"//jax",
"//jax:compiler",
"//jax:test_util",
] + py_deps("absl/logging"),
)
@ -928,13 +929,19 @@ py_test(
jax_test(
name = "compilation_cache_test",
srcs = ["compilation_cache_test.py"],
deps = ["//jax:compilation_cache_internal"],
deps = [
"//jax:compilation_cache_internal",
"//jax:compiler",
],
)
jax_test(
name = "cache_key_test",
srcs = ["cache_key_test.py"],
deps = ["//jax:cache_key"],
deps = [
"//jax:cache_key",
"//jax:compiler",
],
)
jax_test(

View File

@ -24,6 +24,7 @@ import jax
from jax import config
from jax import lax
from jax._src import cache_key
from jax._src import compiler
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.config import compilation_cache_include_metadata_in_key
@ -38,7 +39,7 @@ FLAGS = config.FLAGS
class CacheKeyTest(jtu.JaxTestCase):
def test_compile_options(self):
compile_options_not_filled = xla_bridge.get_compile_options(
compile_options_not_filled = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
compile_options_filled = self.filled_compile_options()
@ -55,7 +56,7 @@ class CacheKeyTest(jtu.JaxTestCase):
self.assertNotEqual(filled_hash1, not_filled_hash3)
def test_executable_build_options(self):
compile_options_not_filled = xla_bridge.get_compile_options(
compile_options_not_filled = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
compile_options_filled = self.filled_compile_options()
@ -75,7 +76,7 @@ class CacheKeyTest(jtu.JaxTestCase):
self.assertNotEqual(filled_hash1, not_filled_hash3)
def test_debug_options(self):
compile_options = xla_bridge.get_compile_options(
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
hash1 = self.get_hashed_value(
@ -141,7 +142,7 @@ class CacheKeyTest(jtu.JaxTestCase):
def test_same_key(self):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = xla_bridge.get_compile_options(
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
@ -153,7 +154,7 @@ class CacheKeyTest(jtu.JaxTestCase):
def test_different_key(self):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options_not_filled = xla_bridge.get_compile_options(
compile_options_not_filled = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
compile_options_filled = self.filled_compile_options()
@ -169,7 +170,7 @@ class CacheKeyTest(jtu.JaxTestCase):
computation1 = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
computation2 = jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = xla_bridge.get_compile_options(
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
@ -186,7 +187,7 @@ class CacheKeyTest(jtu.JaxTestCase):
computation1 = jax.jit(f).lower(1, 1).compiler_ir()
computation2 = jax.jit(g).lower(2, 3).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = xla_bridge.get_compile_options(
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
@ -201,7 +202,7 @@ class CacheKeyTest(jtu.JaxTestCase):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = xla_bridge.get_compile_options(
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()

View File

@ -28,6 +28,7 @@ from jax import jit
from jax import lax
from jax import pmap
from jax._src import compilation_cache as cc
from jax._src import compiler
from jax._src import monitoring
from jax._src import test_util as jtu
from jax._src import xla_bridge
@ -78,7 +79,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
cc.initialize_cache(tmpdir)
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
compile_options = xla_bridge.get_compile_options(
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
@ -93,7 +94,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
cc.initialize_cache(tmpdir)
computation1 = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir())
computation2 = str(jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir())
compile_options = xla_bridge.get_compile_options(
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()
@ -117,7 +118,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
.compiler_ir()
)
devices = np.array([[jax.local_devices()[0]]])
compile_options = xla_bridge.get_compile_options(
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1
)
backend = xla_bridge.get_backend()

View File

@ -19,6 +19,7 @@ import warnings
from absl import logging
from absl.testing import absltest
from jax._src import compiler
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
@ -35,7 +36,7 @@ mock = absltest.mock
class XlaBridgeTest(jtu.JaxTestCase):
def test_set_device_assignment_no_partition(self):
compile_options = xb.get_compile_options(
compile_options = compiler.get_compile_options(
num_replicas=4, num_partitions=1, device_assignment=[0, 1, 2, 3])
expected_device_assignment = ("Computations: 1 Replicas: 4\nComputation 0: "
"0 1 2 3 \n")
@ -43,7 +44,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
expected_device_assignment)
def test_set_device_assignment_with_partition(self):
compile_options = xb.get_compile_options(
compile_options = compiler.get_compile_options(
num_replicas=2, num_partitions=2, device_assignment=[[0, 1], [2, 3]])
expected_device_assignment = ("Computations: 2 Replicas: 2\nComputation 0: "
"0 2 \nComputation 1: 1 3 \n")
@ -51,7 +52,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
expected_device_assignment)
def test_set_fdo_profile(self):
compile_options = xb.get_compile_options(
compile_options = compiler.get_compile_options(
num_replicas=1, num_partitions=1, fdo_profile=b"test_profile"
)
self.assertEqual(
@ -63,10 +64,10 @@ class XlaBridgeTest(jtu.JaxTestCase):
jax_flag_profile = 1
another_profile = 2
with jax_config.jax_xla_profile_version(jax_flag_profile):
with mock.patch.object(xb, "get_latest_profile_version",
with mock.patch.object(compiler, "get_latest_profile_version",
side_effect=lambda: another_profile):
self.assertEqual(
xb.get_compile_options(
compiler.get_compile_options(
num_replicas=3, num_partitions=4
).profile_version,
jax_flag_profile,
@ -75,10 +76,10 @@ class XlaBridgeTest(jtu.JaxTestCase):
# Use whatever non-zero value the function get_latest_profile_version
# returns if --jax_xla_profile_version is not set.
profile_version = 1
with mock.patch.object(xb, "get_latest_profile_version",
with mock.patch.object(compiler, "get_latest_profile_version",
side_effect=lambda: profile_version):
self.assertEqual(
xb.get_compile_options(
compiler.get_compile_options(
num_replicas=3, num_partitions=4
).profile_version,
profile_version,
@ -89,10 +90,10 @@ class XlaBridgeTest(jtu.JaxTestCase):
# retrieve the latest profile later.
error_return = 0
no_profile_dont_retrieve = -1
with mock.patch.object(xb, "get_latest_profile_version",
with mock.patch.object(compiler, "get_latest_profile_version",
side_effect=lambda: error_return):
self.assertEqual(
xb.get_compile_options(
compiler.get_compile_options(
num_replicas=3, num_partitions=4
).profile_version,
no_profile_dont_retrieve,