mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Move compiler APIs out of dispatch.py and xla_bridge.py into a new jax._src.compiler module.
Refactoring only, no user-visible changes intended. PiperOrigin-RevId: 557116160
This commit is contained in:
parent
423c8d8d4f
commit
a259df0d76
17
jax/BUILD
17
jax/BUILD
@ -190,6 +190,7 @@ py_library_providing_imports_info(
|
||||
":basearray",
|
||||
":cloud_tpu_init",
|
||||
":compilation_cache_internal",
|
||||
":compiler",
|
||||
":config",
|
||||
":core",
|
||||
":custom_api_util",
|
||||
@ -320,6 +321,20 @@ pytype_strict_library(
|
||||
srcs = ["_src/logging_config.py"],
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "compiler",
|
||||
srcs = ["_src/compiler.py"],
|
||||
deps = [
|
||||
":compilation_cache_internal",
|
||||
":config",
|
||||
":monitoring",
|
||||
":path",
|
||||
":profiler",
|
||||
":traceback_util",
|
||||
"//jax/_src/lib",
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
pytype_strict_library(
|
||||
name = "core",
|
||||
srcs = [
|
||||
@ -727,7 +742,7 @@ pytype_strict_library(
|
||||
":traceback_util",
|
||||
":util",
|
||||
"//jax/_src/lib",
|
||||
] + py_deps("numpy"),
|
||||
],
|
||||
)
|
||||
|
||||
# Public JAX libraries below this point.
|
||||
|
342
jax/_src/compiler.py
Normal file
342
jax/_src/compiler.py
Normal file
@ -0,0 +1,342 @@
|
||||
# Copyright 2018 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Interface to the compiler
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import io
|
||||
import itertools
|
||||
import time
|
||||
from typing import Any
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lib
|
||||
from jax._src import compilation_cache
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import monitoring
|
||||
from jax._src import path
|
||||
from jax._src import profiler
|
||||
from jax._src import traceback_util
|
||||
from jax._src.config import config
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
|
||||
_DISABLE_MOST_OPTIMIZATIONS = jax_config.DEFINE_bool(
|
||||
'jax_disable_most_optimizations',
|
||||
jax_config.bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
|
||||
'Try not to do much optimization work. This can be useful if the cost of '
|
||||
'optimization is greater than that of running a less-optimized program.')
|
||||
|
||||
_DUMP_IR_TO = jax_config.DEFINE_string(
|
||||
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
|
||||
help="Path to which the IR that is emitted by JAX as input to the "
|
||||
"compiler should be dumped as text files. Optional. If omitted, JAX "
|
||||
"will not dump IR.")
|
||||
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
CompileOptions = xc.CompileOptions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Will be monkeypatched with the function that gets the XLA-AutoFDO profile
|
||||
# version. The default (-1) takes care of errors.
|
||||
def get_latest_profile_version() -> int:
|
||||
return -1
|
||||
|
||||
|
||||
def get_compile_options(
|
||||
num_replicas: int,
|
||||
num_partitions: int,
|
||||
device_assignment=None,
|
||||
use_spmd_partitioning: bool = True,
|
||||
use_auto_spmd_partitioning: bool = False,
|
||||
auto_spmd_partitioning_mesh_shape: list[int] | None = None,
|
||||
auto_spmd_partitioning_mesh_ids: list[int] | None = None,
|
||||
env_options_overrides: dict[str, str] | None = None,
|
||||
fdo_profile: bytes | None = None,
|
||||
) -> xc.CompileOptions:
|
||||
"""Returns the compile options to use, as derived from flag values.
|
||||
|
||||
Args:
|
||||
num_replicas: Number of replicas for which to compile.
|
||||
num_partitions: Number of partitions for which to compile.
|
||||
device_assignment: Optional ndarray of jax devices indicating the assignment
|
||||
of logical replicas to physical devices (default inherited from
|
||||
xla_client.CompileOptions). Must be consistent with `num_replicas` and
|
||||
`num_partitions`.
|
||||
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
|
||||
partitioning in XLA.
|
||||
use_auto_spmd_partitioning: boolean indicating whether to automatically
|
||||
generate XLA shardings for SPMD partitioner.
|
||||
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
|
||||
auto_spmd_partitioning search space.
|
||||
auto_spmd_partitioning_mesh_ids: device ids used to create
|
||||
auto_spmd_partitioning search space.
|
||||
env_options_overrides: dict of additional options parsed by the compiler
|
||||
fdo_profile: Optional profile for feedback-directed optimization passed to
|
||||
XLA.
|
||||
"""
|
||||
compile_options = xc.CompileOptions()
|
||||
compile_options.num_replicas = num_replicas
|
||||
compile_options.num_partitions = num_partitions
|
||||
build_options = compile_options.executable_build_options
|
||||
build_options.use_spmd_partitioning = use_spmd_partitioning
|
||||
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
|
||||
if fdo_profile is not None:
|
||||
build_options.fdo_profile = fdo_profile
|
||||
if use_auto_spmd_partitioning:
|
||||
build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape or []
|
||||
build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids or []
|
||||
if device_assignment is not None:
|
||||
logger.debug(
|
||||
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
|
||||
num_replicas, num_partitions, device_assignment)
|
||||
device_assignment = np.array(device_assignment)
|
||||
|
||||
# Allow 1D device assignment if num_partitions is 1.
|
||||
if (device_assignment.ndim == 1) and (num_partitions == 1):
|
||||
device_assignment = device_assignment[:, None]
|
||||
|
||||
if num_replicas != device_assignment.shape[0]:
|
||||
msg = 'device_assignment does not match num_replicas: {} vs {}.'
|
||||
raise ValueError(msg.format(device_assignment, num_replicas))
|
||||
|
||||
if num_partitions != device_assignment.shape[1]:
|
||||
msg = 'device_assignment does not match num_partitions: {} vs {}.'
|
||||
raise ValueError(msg.format(device_assignment, num_partitions))
|
||||
|
||||
if device_assignment.dtype == object:
|
||||
device_assignment = np.vectorize(lambda d: d.id, otypes=[int])(
|
||||
device_assignment)
|
||||
device_assignment = xc.DeviceAssignment.create(device_assignment)
|
||||
assert device_assignment.replica_count() == num_replicas
|
||||
assert device_assignment.computation_count() == num_partitions
|
||||
compile_options.device_assignment = device_assignment
|
||||
|
||||
if env_options_overrides is not None:
|
||||
compile_options.env_option_overrides = list(env_options_overrides.items())
|
||||
|
||||
debug_options = compile_options.executable_build_options.debug_options
|
||||
if lib.cuda_path is not None:
|
||||
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path
|
||||
|
||||
if _DISABLE_MOST_OPTIMIZATIONS.value:
|
||||
debug_options.xla_backend_optimization_level = 0
|
||||
debug_options.xla_llvm_disable_expensive_passes = True
|
||||
debug_options.xla_test_all_input_layouts = False
|
||||
|
||||
# XLA-AutoFDO profile version: precedence order is:
|
||||
# 1. Whatever --jax_xla_profile_version is set to.
|
||||
# 2. If --jax_xla_profile_version is not set (i.e., 0), call the function
|
||||
# set in get_latest_profile_version and use the return value if non-zero.
|
||||
# If the function returns 0, set -1; this is an error.
|
||||
# -1 indicates that no attempt should be made to retrieve the latest profile
|
||||
# later on.
|
||||
jax_xla_profile_version = config.jax_xla_profile_version
|
||||
if jax_xla_profile_version > 0:
|
||||
compile_options.profile_version = jax_xla_profile_version
|
||||
logger.debug("get_compile_options XLA-AutoFDO profile: " +
|
||||
"using JAX XLA profile version %d from flag",
|
||||
jax_xla_profile_version)
|
||||
else:
|
||||
fdo_profile_version = get_latest_profile_version()
|
||||
if fdo_profile_version != 0:
|
||||
compile_options.profile_version = fdo_profile_version
|
||||
logger.debug("get_compile_options XLA-AutoFDO profile: " +
|
||||
"using XLA-AutoFDO profile version %d",
|
||||
fdo_profile_version)
|
||||
else:
|
||||
no_profile_dont_retrieve = -1
|
||||
compile_options.profile_version = no_profile_dont_retrieve
|
||||
logger.error("get_compile_options XLA-AutoFDO profile: " +
|
||||
"XLA-AutoFDO profile version is 0; this should not happen")
|
||||
|
||||
return compile_options
|
||||
|
||||
|
||||
def _module_to_string(module: ir.Module) -> str:
|
||||
output = io.StringIO()
|
||||
module.operation.print(file=output, enable_debug_info=True)
|
||||
return output.getvalue()
|
||||
|
||||
def _module_to_bytecode(module: ir.Module) -> bytes:
|
||||
output = io.BytesIO()
|
||||
module.operation.write_bytecode(file=output)
|
||||
return output.getvalue()
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def backend_compile(
|
||||
backend: xc.Client,
|
||||
module: ir.Module,
|
||||
options: xc.CompileOptions,
|
||||
host_callbacks: Sequence[Any],
|
||||
) -> xc.LoadedExecutable:
|
||||
# Convert ir.Module to a string representation, unless the
|
||||
# back-end expliclity flags the ability to handle a module directly
|
||||
# (avoiding the overhead of back and forth conversions)
|
||||
if getattr(backend, "needs_str_ir", True):
|
||||
built_c = _module_to_bytecode(module)
|
||||
else:
|
||||
built_c = module
|
||||
|
||||
# we use a separate function call to ensure that XLA compilation appears
|
||||
# separately in Python profiling results
|
||||
if host_callbacks:
|
||||
return backend.compile(built_c, compile_options=options,
|
||||
host_callbacks=host_callbacks)
|
||||
# Some backends don't have `host_callbacks` option yet
|
||||
# TODO(sharadmv): remove this fallback when all backends allow `compile`
|
||||
# to take in `host_callbacks`
|
||||
return backend.compile(built_c, compile_options=options)
|
||||
|
||||
|
||||
_ir_dump_counter = itertools.count()
|
||||
|
||||
def _make_string_safe_for_filename(s: str) -> str:
|
||||
return re.sub(r'[^\w.)( -]', '', s)
|
||||
|
||||
def _dump_ir_to_file(name: str, ir: str):
|
||||
id = next(_ir_dump_counter)
|
||||
name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir"
|
||||
name = path.Path(_DUMP_IR_TO.value) / name
|
||||
name.write_text(ir)
|
||||
|
||||
|
||||
def compile_or_get_cached(
|
||||
backend: xc.Client,
|
||||
computation: ir.Module,
|
||||
devices: np.ndarray,
|
||||
compile_options: xc.CompileOptions,
|
||||
host_callbacks: Sequence[Any],
|
||||
) -> xc.LoadedExecutable:
|
||||
sym_name = computation.operation.attributes['sym_name']
|
||||
module_name = ir.StringAttr(sym_name).value
|
||||
|
||||
if _DUMP_IR_TO.value:
|
||||
_dump_ir_to_file(module_name, _module_to_string(computation))
|
||||
|
||||
# Persistent compilation cache only implemented on TPU and GPU.
|
||||
# TODO(skye): add warning when initializing cache on unsupported default platform
|
||||
supported_platforms = ["tpu", "gpu"]
|
||||
# (b/233850967) CPU caching can be enabled if XLA Runtime is enabled.
|
||||
if "--xla_cpu_use_xla_runtime=true" in os.environ.get("XLA_FLAGS", ""):
|
||||
supported_platforms.append("cpu")
|
||||
use_compilation_cache = (compilation_cache.is_initialized() and
|
||||
backend.platform in supported_platforms)
|
||||
|
||||
if not use_compilation_cache:
|
||||
return backend_compile(backend, computation, compile_options,
|
||||
host_callbacks)
|
||||
|
||||
cache_key = compilation_cache.get_cache_key(
|
||||
computation, devices, compile_options, backend,
|
||||
jax_config.config.jax_use_original_compilation_cache_key_generation,
|
||||
)
|
||||
|
||||
cache_retrieval_start = time.monotonic()
|
||||
retrieved_executable, retrieved_compile_time = _cache_read(
|
||||
module_name, cache_key, compile_options, backend)
|
||||
cache_retrieval_time = time.monotonic() - cache_retrieval_start
|
||||
|
||||
if retrieved_executable is not None:
|
||||
assert retrieved_compile_time is not None
|
||||
logger.info("Persistent compilation cache hit for '%s'", module_name)
|
||||
monitoring.record_event_duration_secs(
|
||||
"/jax/compilation_cache/cache_retrieval_time_sec", cache_retrieval_time)
|
||||
# TODO(b/293308239) Instrument a metric for new cache savings once the
|
||||
# enabling flag is added.
|
||||
# TODO(b/293308239) Remove the metric for original cache savings after the
|
||||
# new compilation cache key implementation is fully rolled out.
|
||||
monitoring.record_event_duration_secs(
|
||||
"/jax/compilation_cache/original_compile_time_saved_sec",
|
||||
retrieved_compile_time - cache_retrieval_time)
|
||||
return retrieved_executable
|
||||
else:
|
||||
start_time = time.monotonic()
|
||||
executable = backend_compile(backend, computation,
|
||||
compile_options, host_callbacks)
|
||||
compile_time = time.monotonic() - start_time
|
||||
_cache_write(cache_key, compile_time, module_name, backend, executable,
|
||||
host_callbacks)
|
||||
return executable
|
||||
|
||||
|
||||
def _cache_read(
|
||||
module_name: str, cache_key: str, compile_options: xc.CompileOptions,
|
||||
backend: xc.Client
|
||||
) -> tuple[xc.LoadedExecutable | None, int | None]:
|
||||
"""Looks up the `computation` and it's compilation time in the persistent
|
||||
compilation cache repository.
|
||||
"""
|
||||
try:
|
||||
return compilation_cache.get_executable_and_time(
|
||||
cache_key, compile_options, backend)
|
||||
except Exception as ex:
|
||||
if config.jax_raise_persistent_cache_errors:
|
||||
raise
|
||||
warnings.warn(
|
||||
f"Error reading persistent compilation cache entry for "
|
||||
f"'{module_name}': {type(ex).__name__}: {ex}")
|
||||
return None, None
|
||||
|
||||
|
||||
def _cache_write(cache_key: str,
|
||||
compile_time_secs: float,
|
||||
module_name: str,
|
||||
backend: xc.Client, executable: xc.LoadedExecutable,
|
||||
host_callbacks: Sequence[Any]) -> None:
|
||||
"""Writes the `serialized_computation` and its compilation time to the
|
||||
persistent compilation cache repository.
|
||||
"""
|
||||
if host_callbacks:
|
||||
logger.info(
|
||||
"Not writing persistent cache entry for '%s' because it uses host "
|
||||
"callbacks (e.g. from jax.debug.print or breakpoint)", module_name)
|
||||
return
|
||||
|
||||
min_compile_time = config.jax_persistent_cache_min_compile_time_secs
|
||||
if min_compile_time:
|
||||
if compile_time_secs < min_compile_time:
|
||||
logger.info(
|
||||
"Not writing persistent cache entry for '%s' because it took < %.2f "
|
||||
"seconds to compile (%.2fs)", module_name, min_compile_time,
|
||||
compile_time_secs)
|
||||
return
|
||||
else:
|
||||
logger.info(
|
||||
"'%s' took at least %.2f seconds to compile (%.2fs), writing "
|
||||
"persistent cache entry", module_name, min_compile_time,
|
||||
compile_time_secs)
|
||||
|
||||
try:
|
||||
compilation_cache.put_executable_and_time(
|
||||
cache_key, module_name, executable, backend, int(compile_time_secs))
|
||||
except Exception as ex:
|
||||
if config.jax_raise_persistent_cache_errors:
|
||||
raise
|
||||
warnings.warn(
|
||||
f"Error writing persistent compilation cache entry for "
|
||||
f"'{module_name}': {type(ex).__name__}: {ex}")
|
@ -24,22 +24,15 @@ import itertools
|
||||
import time
|
||||
from typing import Any, Callable, NamedTuple
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import basearray
|
||||
from jax._src import compilation_cache
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import api_util
|
||||
from jax._src import path
|
||||
from jax._src import profiler
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
@ -50,7 +43,6 @@ from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.monitoring import record_event_duration_secs
|
||||
from jax._src.partition_spec import PartitionSpec
|
||||
@ -64,13 +56,6 @@ JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
|
||||
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
|
||||
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
|
||||
|
||||
_DUMP_IR_TO = jax_config.DEFINE_string(
|
||||
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
|
||||
help="Path to which the IR that is emitted by JAX as input to the "
|
||||
"compiler should be dumped as text files. Optional. If omitted, JAX "
|
||||
"will not dump IR.")
|
||||
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
xe = xc._xla
|
||||
@ -404,162 +389,6 @@ def _check_special(name: str, dtype: np.dtype, buf: basearray.Array) -> None:
|
||||
if config.jax_debug_infs and np.any(np.isinf(np.asarray(buf))):
|
||||
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
||||
|
||||
|
||||
@profiler.annotate_function
|
||||
def backend_compile(
|
||||
backend: Backend,
|
||||
module: ir.Module,
|
||||
options: xc.CompileOptions,
|
||||
host_callbacks: Sequence[Any],
|
||||
) -> xc.LoadedExecutable:
|
||||
# Convert ir.Module to a string representation, unless the
|
||||
# back-end expliclity flags the ability to handle a module directly
|
||||
# (avoiding the overhead of back and forth conversions)
|
||||
if getattr(backend, "needs_str_ir", True):
|
||||
built_c = mlir.module_to_bytecode(module)
|
||||
else:
|
||||
built_c = module
|
||||
|
||||
# we use a separate function call to ensure that XLA compilation appears
|
||||
# separately in Python profiling results
|
||||
if host_callbacks:
|
||||
return backend.compile(built_c, compile_options=options,
|
||||
host_callbacks=host_callbacks)
|
||||
# Some backends don't have `host_callbacks` option yet
|
||||
# TODO(sharadmv): remove this fallback when all backends allow `compile`
|
||||
# to take in `host_callbacks`
|
||||
return backend.compile(built_c, compile_options=options)
|
||||
|
||||
|
||||
_ir_dump_counter = itertools.count()
|
||||
|
||||
def _make_string_safe_for_filename(s: str) -> str:
|
||||
return re.sub(r'[^\w.)( -]', '', s)
|
||||
|
||||
def _dump_ir_to_file(name: str, ir: str):
|
||||
id = next(_ir_dump_counter)
|
||||
name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir"
|
||||
name = path.Path(_DUMP_IR_TO.value) / name
|
||||
name.write_text(ir)
|
||||
|
||||
|
||||
def compile_or_get_cached(
|
||||
backend: Backend,
|
||||
computation: ir.Module,
|
||||
devices: np.ndarray,
|
||||
compile_options: xc.CompileOptions,
|
||||
host_callbacks: Sequence[Any],
|
||||
) -> xc.LoadedExecutable:
|
||||
sym_name = computation.operation.attributes['sym_name']
|
||||
module_name = ir.StringAttr(sym_name).value
|
||||
|
||||
if _DUMP_IR_TO.value:
|
||||
_dump_ir_to_file(module_name, mlir.module_to_string(computation))
|
||||
|
||||
# Persistent compilation cache only implemented on TPU and GPU.
|
||||
# TODO(skye): add warning when initializing cache on unsupported default platform
|
||||
supported_platforms = ["tpu", "gpu"]
|
||||
# (b/233850967) CPU caching can be enabled if XLA Runtime is enabled.
|
||||
if "--xla_cpu_use_xla_runtime=true" in os.environ.get("XLA_FLAGS", ""):
|
||||
supported_platforms.append("cpu")
|
||||
use_compilation_cache = (compilation_cache.is_initialized() and
|
||||
backend.platform in supported_platforms)
|
||||
|
||||
if not use_compilation_cache:
|
||||
return backend_compile(backend, computation, compile_options,
|
||||
host_callbacks)
|
||||
|
||||
cache_key = compilation_cache.get_cache_key(
|
||||
computation, devices, compile_options, backend,
|
||||
jax_config.config.jax_use_original_compilation_cache_key_generation,
|
||||
)
|
||||
|
||||
cache_retrieval_start = time.monotonic()
|
||||
retrieved_executable, retrieved_compile_time = _cache_read(
|
||||
module_name, cache_key, compile_options, backend)
|
||||
cache_retrieval_time = time.monotonic() - cache_retrieval_start
|
||||
|
||||
if retrieved_executable is not None:
|
||||
assert retrieved_compile_time is not None
|
||||
logger.info("Persistent compilation cache hit for '%s'", module_name)
|
||||
record_event_duration_secs(
|
||||
"/jax/compilation_cache/cache_retrieval_time_sec", cache_retrieval_time)
|
||||
# TODO(b/293308239) Instrument a metric for new cache savings once the
|
||||
# enabling flag is added.
|
||||
# TODO(b/293308239) Remove the metric for original cache savings after the
|
||||
# new compilation cache key implementation is fully rolled out.
|
||||
record_event_duration_secs(
|
||||
"/jax/compilation_cache/original_compile_time_saved_sec",
|
||||
retrieved_compile_time - cache_retrieval_time)
|
||||
return retrieved_executable
|
||||
else:
|
||||
start_time = time.monotonic()
|
||||
executable = backend_compile(backend, computation,
|
||||
compile_options, host_callbacks)
|
||||
compile_time = time.monotonic() - start_time
|
||||
_cache_write(cache_key, compile_time, module_name, backend, executable,
|
||||
host_callbacks)
|
||||
return executable
|
||||
|
||||
|
||||
def _cache_read(
|
||||
module_name: str, cache_key: str, compile_options: xc.CompileOptions,
|
||||
backend: Backend
|
||||
) -> tuple[xc.LoadedExecutable | None, int | None]:
|
||||
"""Looks up the `computation` and it's compilation time in the persistent
|
||||
compilation cache repository.
|
||||
"""
|
||||
try:
|
||||
return compilation_cache.get_executable_and_time(
|
||||
cache_key, compile_options, backend)
|
||||
except Exception as ex:
|
||||
if config.jax_raise_persistent_cache_errors:
|
||||
raise
|
||||
warnings.warn(
|
||||
f"Error reading persistent compilation cache entry for "
|
||||
f"'{module_name}': {type(ex).__name__}: {ex}")
|
||||
return None, None
|
||||
|
||||
|
||||
def _cache_write(cache_key: str,
|
||||
compile_time_secs: float,
|
||||
module_name: str,
|
||||
backend: Backend, executable: xc.LoadedExecutable,
|
||||
host_callbacks: Sequence[Any]) -> None:
|
||||
"""Writes the `serialized_computation` and its compilation time to the
|
||||
persistent compilation cache repository.
|
||||
"""
|
||||
if host_callbacks:
|
||||
logger.info(
|
||||
"Not writing persistent cache entry for '%s' because it uses host "
|
||||
"callbacks (e.g. from jax.debug.print or breakpoint)", module_name)
|
||||
return
|
||||
|
||||
min_compile_time = config.jax_persistent_cache_min_compile_time_secs
|
||||
if min_compile_time:
|
||||
if compile_time_secs < min_compile_time:
|
||||
logger.info(
|
||||
"Not writing persistent cache entry for '%s' because it took < %.2f "
|
||||
"seconds to compile (%.2fs)", module_name, min_compile_time,
|
||||
compile_time_secs)
|
||||
return
|
||||
else:
|
||||
logger.info(
|
||||
"'%s' took at least %.2f seconds to compile (%.2fs), writing "
|
||||
"persistent cache entry", module_name, min_compile_time,
|
||||
compile_time_secs)
|
||||
|
||||
try:
|
||||
compilation_cache.put_executable_and_time(
|
||||
cache_key, module_name, executable, backend, int(compile_time_secs))
|
||||
except Exception as ex:
|
||||
if config.jax_raise_persistent_cache_errors:
|
||||
raise
|
||||
warnings.warn(
|
||||
f"Error writing persistent compilation cache entry for "
|
||||
f"'{module_name}': {type(ex).__name__}: {ex}")
|
||||
|
||||
|
||||
# TODO(yashkatariya): Generalize is_compatible_aval (maybe renamed) and use that
|
||||
# to check if shardings are compatible with the input.
|
||||
def _check_sharding(aval: core.AbstractValue, s: Sharding):
|
||||
|
@ -34,6 +34,7 @@ from jax.errors import JAXTypeError
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import compiler
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
@ -903,7 +904,7 @@ class UnloadedPmapExecutable:
|
||||
num_partitions = 1
|
||||
device_assignment: np.ndarray = np.array(devices).reshape(
|
||||
(replicas.num_global_replicas, num_partitions))
|
||||
compile_options = xb.get_compile_options(
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=replicas.num_global_replicas,
|
||||
num_partitions=num_partitions,
|
||||
device_assignment=device_assignment,
|
||||
@ -953,7 +954,7 @@ class UnloadedPmapExecutable:
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
|
||||
fun_name=pci.name, event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
compiled = dispatch.compile_or_get_cached(
|
||||
compiled = compiler.compile_or_get_cached(
|
||||
pci.backend, hlo, device_assignment, compile_options,
|
||||
host_callbacks)
|
||||
|
||||
@ -2482,7 +2483,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
fdo_profile = (None if compiler_options is None else
|
||||
compiler_options.pop("fdo_profile", None))
|
||||
|
||||
compile_options = xb.get_compile_options(
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=num_replicas,
|
||||
num_partitions=num_partitions,
|
||||
device_assignment=xla_device_assignment,
|
||||
@ -2508,7 +2509,7 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
|
||||
with dispatch.log_elapsed_time(
|
||||
"Finished XLA compilation of {fun_name} in {elapsed_time} sec",
|
||||
fun_name=name, event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
xla_executable = dispatch.compile_or_get_cached(
|
||||
xla_executable = compiler.compile_or_get_cached(
|
||||
backend, computation, dev, compile_options, host_callbacks)
|
||||
return xla_executable, compile_options
|
||||
|
||||
|
@ -33,12 +33,9 @@ import threading
|
||||
from typing import Any, Callable, Optional, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src import lib
|
||||
from jax._src import distributed
|
||||
from jax._src import config as jax_config
|
||||
from jax._src.config import bool_env, config
|
||||
from jax._src.config import config
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
@ -73,11 +70,6 @@ _PLATFORM_NAME = jax_config.DEFINE_string(
|
||||
'jax_platform_name',
|
||||
os.getenv('JAX_PLATFORM_NAME', '').lower(),
|
||||
'Deprecated, please use --jax_platforms instead.')
|
||||
_DISABLE_MOST_OPTIMIZATIONS = jax_config.DEFINE_bool(
|
||||
'jax_disable_most_optimizations',
|
||||
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
|
||||
'Try not to do much optimization work. This can be useful if the cost of '
|
||||
'optimization is greater than that of running a less-optimized program.')
|
||||
CUDA_VISIBLE_DEVICES = jax_config.DEFINE_string(
|
||||
'jax_cuda_visible_devices', 'all',
|
||||
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
|
||||
@ -88,122 +80,6 @@ _ROCM_VISIBLE_DEVICES = jax_config.DEFINE_string(
|
||||
'comma-separate list of integer device IDs.')
|
||||
|
||||
|
||||
# Will be monkeypatched with the function that gets the XLA-AutoFDO profile
|
||||
# version. The default (-1) takes care of errors.
|
||||
def get_latest_profile_version() -> int:
|
||||
return -1
|
||||
|
||||
|
||||
def get_compile_options(
|
||||
num_replicas: int,
|
||||
num_partitions: int,
|
||||
device_assignment=None,
|
||||
use_spmd_partitioning: bool = True,
|
||||
use_auto_spmd_partitioning: bool = False,
|
||||
auto_spmd_partitioning_mesh_shape: Optional[list[int]] = None,
|
||||
auto_spmd_partitioning_mesh_ids: Optional[list[int]] = None,
|
||||
env_options_overrides: Optional[dict[str, str]] = None,
|
||||
fdo_profile: Optional[bytes] = None,
|
||||
) -> xla_client.CompileOptions:
|
||||
"""Returns the compile options to use, as derived from flag values.
|
||||
|
||||
Args:
|
||||
num_replicas: Number of replicas for which to compile.
|
||||
num_partitions: Number of partitions for which to compile.
|
||||
device_assignment: Optional ndarray of jax devices indicating the assignment
|
||||
of logical replicas to physical devices (default inherited from
|
||||
xla_client.CompileOptions). Must be consistent with `num_replicas` and
|
||||
`num_partitions`.
|
||||
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
|
||||
partitioning in XLA.
|
||||
use_auto_spmd_partitioning: boolean indicating whether to automatically
|
||||
generate XLA shardings for SPMD partitioner.
|
||||
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
|
||||
auto_spmd_partitioning search space.
|
||||
auto_spmd_partitioning_mesh_ids: device ids used to create
|
||||
auto_spmd_partitioning search space.
|
||||
env_options_overrides: dict of additional options parsed by the compiler
|
||||
fdo_profile: Optional profile for feedback-directed optimization passed to
|
||||
XLA.
|
||||
"""
|
||||
compile_options = xla_client.CompileOptions()
|
||||
compile_options.num_replicas = num_replicas
|
||||
compile_options.num_partitions = num_partitions
|
||||
build_options = compile_options.executable_build_options
|
||||
build_options.use_spmd_partitioning = use_spmd_partitioning
|
||||
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
|
||||
if fdo_profile is not None:
|
||||
build_options.fdo_profile = fdo_profile
|
||||
if use_auto_spmd_partitioning:
|
||||
build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape or []
|
||||
build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids or []
|
||||
if device_assignment is not None:
|
||||
logger.debug(
|
||||
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
|
||||
num_replicas, num_partitions, device_assignment)
|
||||
device_assignment = np.array(device_assignment)
|
||||
|
||||
# Allow 1D device assignment if num_partitions is 1.
|
||||
if (device_assignment.ndim == 1) and (num_partitions == 1):
|
||||
device_assignment = device_assignment[:, None]
|
||||
|
||||
if num_replicas != device_assignment.shape[0]:
|
||||
msg = 'device_assignment does not match num_replicas: {} vs {}.'
|
||||
raise ValueError(msg.format(device_assignment, num_replicas))
|
||||
|
||||
if num_partitions != device_assignment.shape[1]:
|
||||
msg = 'device_assignment does not match num_partitions: {} vs {}.'
|
||||
raise ValueError(msg.format(device_assignment, num_partitions))
|
||||
|
||||
if device_assignment.dtype == object:
|
||||
device_assignment = np.vectorize(lambda d: d.id, otypes=[int])(
|
||||
device_assignment)
|
||||
device_assignment = xla_client.DeviceAssignment.create(device_assignment)
|
||||
assert device_assignment.replica_count() == num_replicas
|
||||
assert device_assignment.computation_count() == num_partitions
|
||||
compile_options.device_assignment = device_assignment
|
||||
|
||||
if env_options_overrides is not None:
|
||||
compile_options.env_option_overrides = list(env_options_overrides.items())
|
||||
|
||||
debug_options = compile_options.executable_build_options.debug_options
|
||||
if lib.cuda_path is not None:
|
||||
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path
|
||||
|
||||
if _DISABLE_MOST_OPTIMIZATIONS.value:
|
||||
debug_options.xla_backend_optimization_level = 0
|
||||
debug_options.xla_llvm_disable_expensive_passes = True
|
||||
debug_options.xla_test_all_input_layouts = False
|
||||
|
||||
# XLA-AutoFDO profile version: precedence order is:
|
||||
# 1. Whatever --jax_xla_profile_version is set to.
|
||||
# 2. If --jax_xla_profile_version is not set (i.e., 0), call the function
|
||||
# set in get_latest_profile_version and use the return value if non-zero.
|
||||
# If the function returns 0, set -1; this is an error.
|
||||
# -1 indicates that no attempt should be made to retrieve the latest profile
|
||||
# later on.
|
||||
jax_xla_profile_version = config.jax_xla_profile_version
|
||||
if jax_xla_profile_version > 0:
|
||||
compile_options.profile_version = jax_xla_profile_version
|
||||
logger.debug("get_compile_options XLA-AutoFDO profile: " +
|
||||
"using JAX XLA profile version %d from flag",
|
||||
jax_xla_profile_version)
|
||||
else:
|
||||
fdo_profile_version = get_latest_profile_version()
|
||||
if fdo_profile_version != 0:
|
||||
compile_options.profile_version = fdo_profile_version
|
||||
logger.debug("get_compile_options XLA-AutoFDO profile: " +
|
||||
"using XLA-AutoFDO profile version %d",
|
||||
fdo_profile_version)
|
||||
else:
|
||||
no_profile_dont_retrieve = -1
|
||||
compile_options.profile_version = no_profile_dont_retrieve
|
||||
logger.error("get_compile_options XLA-AutoFDO profile: " +
|
||||
"XLA-AutoFDO profile version is 0; this should not happen")
|
||||
|
||||
return compile_options
|
||||
|
||||
|
||||
# Backends
|
||||
|
||||
def tpu_client_timer_callback(timer_secs: float) -> Optional[xla_client.Client]:
|
||||
|
@ -517,6 +517,7 @@ from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import compiler
|
||||
from jax._src import dispatch
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import sharding_impls
|
||||
@ -2014,7 +2015,7 @@ def _initialize_outfeed_receiver(
|
||||
_callback_handler_data.receiver = outfeed_receiver_module.start(
|
||||
_callback_input_received, tuple(clients_with_outfeed),
|
||||
max_callback_queue_size_bytes,
|
||||
xb.get_compile_options(1, 1).executable_build_options) # type:ignore
|
||||
compiler.get_compile_options(1, 1).executable_build_options) # type:ignore
|
||||
|
||||
def exit_handler():
|
||||
# Prevent logging usage during compilation, gives errors under pytest
|
||||
|
@ -40,6 +40,7 @@ from jax.experimental.shard_map import shard_map
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec as P
|
||||
import jax.numpy as jnp
|
||||
from jax._src import compiler
|
||||
from jax._src import xla_bridge
|
||||
|
||||
import numpy as np
|
||||
@ -108,7 +109,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
device_assignment = np.arange(num_partitions * num_replicas)
|
||||
device_assignment = np.reshape(device_assignment, (-1, num_partitions))
|
||||
use_spmd_partitioning = num_partitions > 1
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=num_replicas,
|
||||
num_partitions=num_partitions,
|
||||
device_assignment=device_assignment,
|
||||
|
@ -35,9 +35,12 @@ from jax._src.core import (
|
||||
)
|
||||
|
||||
# TODO(phawkins): update users.
|
||||
from jax._src.compiler import (
|
||||
backend_compile as backend_compile,
|
||||
)
|
||||
|
||||
from jax._src.dispatch import (
|
||||
apply_primitive as apply_primitive,
|
||||
backend_compile as backend_compile,
|
||||
)
|
||||
|
||||
from jax._src.sharding_impls import (
|
||||
|
@ -16,7 +16,10 @@
|
||||
from jax._src.xla_bridge import (
|
||||
default_backend as default_backend,
|
||||
get_backend as get_backend,
|
||||
get_compile_options as get_compile_options,
|
||||
xla_client as xla_client,
|
||||
_backends as _backends,
|
||||
)
|
||||
|
||||
from jax._src.compiler import (
|
||||
get_compile_options as get_compile_options,
|
||||
)
|
||||
|
11
tests/BUILD
11
tests/BUILD
@ -911,6 +911,7 @@ py_test(
|
||||
data = ["testdata/example_pjrt_plugin_config.json"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:compiler",
|
||||
"//jax:test_util",
|
||||
] + py_deps("absl/logging"),
|
||||
)
|
||||
@ -928,13 +929,19 @@ py_test(
|
||||
jax_test(
|
||||
name = "compilation_cache_test",
|
||||
srcs = ["compilation_cache_test.py"],
|
||||
deps = ["//jax:compilation_cache_internal"],
|
||||
deps = [
|
||||
"//jax:compilation_cache_internal",
|
||||
"//jax:compiler",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "cache_key_test",
|
||||
srcs = ["cache_key_test.py"],
|
||||
deps = ["//jax:cache_key"],
|
||||
deps = [
|
||||
"//jax:cache_key",
|
||||
"//jax:compiler",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
|
@ -24,6 +24,7 @@ import jax
|
||||
from jax import config
|
||||
from jax import lax
|
||||
from jax._src import cache_key
|
||||
from jax._src import compiler
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.config import compilation_cache_include_metadata_in_key
|
||||
@ -38,7 +39,7 @@ FLAGS = config.FLAGS
|
||||
class CacheKeyTest(jtu.JaxTestCase):
|
||||
|
||||
def test_compile_options(self):
|
||||
compile_options_not_filled = xla_bridge.get_compile_options(
|
||||
compile_options_not_filled = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
compile_options_filled = self.filled_compile_options()
|
||||
@ -55,7 +56,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
self.assertNotEqual(filled_hash1, not_filled_hash3)
|
||||
|
||||
def test_executable_build_options(self):
|
||||
compile_options_not_filled = xla_bridge.get_compile_options(
|
||||
compile_options_not_filled = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
compile_options_filled = self.filled_compile_options()
|
||||
@ -75,7 +76,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
self.assertNotEqual(filled_hash1, not_filled_hash3)
|
||||
|
||||
def test_debug_options(self):
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
hash1 = self.get_hashed_value(
|
||||
@ -141,7 +142,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
def test_same_key(self):
|
||||
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
@ -153,7 +154,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
def test_different_key(self):
|
||||
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options_not_filled = xla_bridge.get_compile_options(
|
||||
compile_options_not_filled = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
compile_options_filled = self.filled_compile_options()
|
||||
@ -169,7 +170,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
computation1 = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
computation2 = jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
@ -186,7 +187,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
computation1 = jax.jit(f).lower(1, 1).compiler_ir()
|
||||
computation2 = jax.jit(g).lower(2, 3).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
@ -201,7 +202,7 @@ class CacheKeyTest(jtu.JaxTestCase):
|
||||
|
||||
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
|
@ -28,6 +28,7 @@ from jax import jit
|
||||
from jax import lax
|
||||
from jax import pmap
|
||||
from jax._src import compilation_cache as cc
|
||||
from jax._src import compiler
|
||||
from jax._src import monitoring
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
@ -78,7 +79,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
cc.initialize_cache(tmpdir)
|
||||
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
@ -93,7 +94,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
cc.initialize_cache(tmpdir)
|
||||
computation1 = str(jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir())
|
||||
computation2 = str(jax.jit(lambda x, y: x * y).lower(2, 2).compiler_ir())
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
@ -117,7 +118,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
.compiler_ir()
|
||||
)
|
||||
devices = np.array([[jax.local_devices()[0]]])
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1
|
||||
)
|
||||
backend = xla_bridge.get_backend()
|
||||
|
@ -19,6 +19,7 @@ import warnings
|
||||
|
||||
from absl import logging
|
||||
from absl.testing import absltest
|
||||
from jax._src import compiler
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -35,7 +36,7 @@ mock = absltest.mock
|
||||
class XlaBridgeTest(jtu.JaxTestCase):
|
||||
|
||||
def test_set_device_assignment_no_partition(self):
|
||||
compile_options = xb.get_compile_options(
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=4, num_partitions=1, device_assignment=[0, 1, 2, 3])
|
||||
expected_device_assignment = ("Computations: 1 Replicas: 4\nComputation 0: "
|
||||
"0 1 2 3 \n")
|
||||
@ -43,7 +44,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
expected_device_assignment)
|
||||
|
||||
def test_set_device_assignment_with_partition(self):
|
||||
compile_options = xb.get_compile_options(
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=2, num_partitions=2, device_assignment=[[0, 1], [2, 3]])
|
||||
expected_device_assignment = ("Computations: 2 Replicas: 2\nComputation 0: "
|
||||
"0 2 \nComputation 1: 1 3 \n")
|
||||
@ -51,7 +52,7 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
expected_device_assignment)
|
||||
|
||||
def test_set_fdo_profile(self):
|
||||
compile_options = xb.get_compile_options(
|
||||
compile_options = compiler.get_compile_options(
|
||||
num_replicas=1, num_partitions=1, fdo_profile=b"test_profile"
|
||||
)
|
||||
self.assertEqual(
|
||||
@ -63,10 +64,10 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
jax_flag_profile = 1
|
||||
another_profile = 2
|
||||
with jax_config.jax_xla_profile_version(jax_flag_profile):
|
||||
with mock.patch.object(xb, "get_latest_profile_version",
|
||||
with mock.patch.object(compiler, "get_latest_profile_version",
|
||||
side_effect=lambda: another_profile):
|
||||
self.assertEqual(
|
||||
xb.get_compile_options(
|
||||
compiler.get_compile_options(
|
||||
num_replicas=3, num_partitions=4
|
||||
).profile_version,
|
||||
jax_flag_profile,
|
||||
@ -75,10 +76,10 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
# Use whatever non-zero value the function get_latest_profile_version
|
||||
# returns if --jax_xla_profile_version is not set.
|
||||
profile_version = 1
|
||||
with mock.patch.object(xb, "get_latest_profile_version",
|
||||
with mock.patch.object(compiler, "get_latest_profile_version",
|
||||
side_effect=lambda: profile_version):
|
||||
self.assertEqual(
|
||||
xb.get_compile_options(
|
||||
compiler.get_compile_options(
|
||||
num_replicas=3, num_partitions=4
|
||||
).profile_version,
|
||||
profile_version,
|
||||
@ -89,10 +90,10 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
# retrieve the latest profile later.
|
||||
error_return = 0
|
||||
no_profile_dont_retrieve = -1
|
||||
with mock.patch.object(xb, "get_latest_profile_version",
|
||||
with mock.patch.object(compiler, "get_latest_profile_version",
|
||||
side_effect=lambda: error_return):
|
||||
self.assertEqual(
|
||||
xb.get_compile_options(
|
||||
compiler.get_compile_options(
|
||||
num_replicas=3, num_partitions=4
|
||||
).profile_version,
|
||||
no_profile_dont_retrieve,
|
||||
|
Loading…
x
Reference in New Issue
Block a user