mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00

Users of this mechanism should migrate to the newer PJRT plugin registration mechanism (see the comments on discover_plugins() in this file).
811 lines
29 KiB
Python
811 lines
29 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Interface and utility functions to XLA.
|
|
|
|
This module wraps the XLA client(s) and builders to standardize their interfaces
|
|
and provide some automatic type mapping logic for converting between Numpy and
|
|
XLA. There are also a handful of related casting utilities.
|
|
"""
|
|
|
|
from functools import partial, lru_cache
|
|
import importlib
|
|
import io
|
|
import json
|
|
import logging
|
|
import os
|
|
import platform as py_platform
|
|
import pkgutil
|
|
import sys
|
|
import threading
|
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
|
import warnings
|
|
import numpy as np
|
|
|
|
from jax._src import lib
|
|
from jax._src import distributed
|
|
from jax._src.config import flags, bool_env, config, int_env
|
|
from jax._src.lib import xla_client
|
|
from jax._src.lib import xla_extension_version
|
|
from jax._src import traceback_util
|
|
from jax._src import util
|
|
|
|
iree: Optional[Any]
|
|
|
|
try:
|
|
import jax._src.iree as iree # type: ignore
|
|
except (ModuleNotFoundError, ImportError):
|
|
iree = None
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
jax_plugins: Optional[Any]
|
|
try:
|
|
import jax_plugins # type: ignore
|
|
except ModuleNotFoundError:
|
|
jax_plugins = None
|
|
except ImportError as e:
|
|
logger.error("Failed to import jax_plugins: %s", e)
|
|
jax_plugins = None
|
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
|
|
XlaBackend = xla_client.Client
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
# TODO(phawkins): Remove jax_xla_backend.
|
|
flags.DEFINE_string(
|
|
'jax_xla_backend', '',
|
|
'Deprecated, please use --jax_platforms instead.')
|
|
flags.DEFINE_string(
|
|
'jax_backend_target',
|
|
os.getenv('JAX_BACKEND_TARGET', '').lower(),
|
|
'Either "local" or "rpc:address" to connect to a remote service target.')
|
|
# TODO(skye): warn when this is used once we test out --jax_platforms a bit
|
|
flags.DEFINE_string(
|
|
'jax_platform_name',
|
|
os.getenv('JAX_PLATFORM_NAME', '').lower(),
|
|
'Deprecated, please use --jax_platforms instead.')
|
|
flags.DEFINE_bool(
|
|
'jax_disable_most_optimizations',
|
|
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
|
|
'Try not to do much optimization work. This can be useful if the cost of '
|
|
'optimization is greater than that of running a less-optimized program.')
|
|
flags.DEFINE_integer(
|
|
'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0),
|
|
'Optional profile version for XLA compilation. '
|
|
'This is meaningful only when XLA is configured to '
|
|
'support the remote compilation profile feature.')
|
|
flags.DEFINE_string(
|
|
'jax_cuda_visible_devices', 'all',
|
|
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
|
|
'comma-separate list of integer device IDs.')
|
|
flags.DEFINE_string(
|
|
'jax_rocm_visible_devices', 'all',
|
|
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
|
|
'comma-separate list of integer device IDs.')
|
|
|
|
def get_compile_options(
|
|
num_replicas: int,
|
|
num_partitions: int,
|
|
device_assignment=None,
|
|
use_spmd_partitioning: bool = True,
|
|
use_auto_spmd_partitioning: bool = False,
|
|
auto_spmd_partitioning_mesh_shape=[],
|
|
auto_spmd_partitioning_mesh_ids=[],
|
|
env_options_overrides: Optional[Dict[str, str]] = None,
|
|
) -> xla_client.CompileOptions:
|
|
"""Returns the compile options to use, as derived from flag values.
|
|
|
|
Args:
|
|
num_replicas: Number of replicas for which to compile.
|
|
num_partitions: Number of partitions for which to compile.
|
|
device_assignment: Optional ndarray of jax devices indicating the assignment
|
|
of logical replicas to physical devices (default inherited from
|
|
xla_client.CompileOptions). Must be consistent with `num_replicas` and
|
|
`num_partitions`.
|
|
use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
|
|
partitioning in XLA.
|
|
use_auto_spmd_partitioning: boolean indicating whether to automatically
|
|
generate XLA shardings for SPMD partitioner.
|
|
auto_spmd_partitioning_mesh_shape: device mesh shape used to create
|
|
auto_spmd_partitioning search space.
|
|
auto_spmd_partitioning_mesh_ids: device ids used to create
|
|
auto_spmd_partitioning search space.
|
|
env_options_overrides: dict of additional options parsed by the compiler
|
|
"""
|
|
compile_options = xla_client.CompileOptions()
|
|
compile_options.num_replicas = num_replicas
|
|
compile_options.num_partitions = num_partitions
|
|
build_options = compile_options.executable_build_options
|
|
build_options.use_spmd_partitioning = use_spmd_partitioning
|
|
build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
|
|
if use_auto_spmd_partitioning:
|
|
build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape
|
|
build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids
|
|
if device_assignment is not None:
|
|
logger.debug(
|
|
'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
|
|
num_replicas, num_partitions, device_assignment)
|
|
device_assignment = np.array(device_assignment)
|
|
|
|
# Allow 1D device assignment if num_partitions is 1.
|
|
if (device_assignment.ndim == 1) and (num_partitions == 1):
|
|
device_assignment = device_assignment[:, None]
|
|
|
|
if num_replicas != device_assignment.shape[0]:
|
|
msg = 'device_assignment does not match num_replicas: {} vs {}.'
|
|
raise ValueError(msg.format(device_assignment, num_replicas))
|
|
|
|
if num_partitions != device_assignment.shape[1]:
|
|
msg = 'device_assignment does not match num_partitions: {} vs {}.'
|
|
raise ValueError(msg.format(device_assignment, num_partitions))
|
|
|
|
if device_assignment.dtype == object:
|
|
device_assignment = np.vectorize(lambda d: d.id, otypes=[int])(
|
|
device_assignment)
|
|
device_assignment = xla_client.DeviceAssignment.create(device_assignment)
|
|
assert device_assignment.replica_count() == num_replicas
|
|
assert device_assignment.computation_count() == num_partitions
|
|
compile_options.device_assignment = device_assignment
|
|
|
|
if env_options_overrides is not None:
|
|
compile_options.env_option_overrides = list(env_options_overrides.items())
|
|
|
|
debug_options = compile_options.executable_build_options.debug_options
|
|
if lib.cuda_path is not None:
|
|
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path
|
|
|
|
if FLAGS.jax_disable_most_optimizations:
|
|
|
|
debug_options.xla_backend_optimization_level = 0
|
|
debug_options.xla_llvm_disable_expensive_passes = True
|
|
debug_options.xla_test_all_input_layouts = False
|
|
|
|
compile_options.profile_version = FLAGS.jax_xla_profile_version
|
|
return compile_options
|
|
|
|
|
|
# Backends
|
|
|
|
def tpu_client_timer_callback(timer_secs: float) -> Optional[xla_client.Client]:
|
|
def _log_warning():
|
|
warnings.warn(
|
|
f'TPU backend initialization is taking more than {timer_secs} seconds. '
|
|
'Did you run your code on all TPU hosts? '
|
|
'See https://jax.readthedocs.io/en/latest/multi_process.html '
|
|
'for more information.')
|
|
|
|
# Will log a warning after `timer_secs`.
|
|
t = threading.Timer(timer_secs, _log_warning)
|
|
t.start()
|
|
|
|
try:
|
|
client = xla_client.make_tpu_client()
|
|
finally:
|
|
t.cancel()
|
|
|
|
return client
|
|
|
|
|
|
# Backends, in increasing order of preference.
|
|
# We have no particular opinion about how "backends" relate to "devices". For
|
|
# example, there could be multiple backends that provide the same kind of
|
|
# device.
|
|
BackendFactory = Callable[[], Optional[xla_client.Client]]
|
|
_backend_factories: Dict[str, Tuple[BackendFactory, int]] = {}
|
|
_default_backend: Optional[xla_client.Client] = None
|
|
_backends : Dict[str, xla_client.Client] = {}
|
|
_backends_errors : Dict[str, str] = {}
|
|
_backend_lock = threading.Lock()
|
|
|
|
def register_backend_factory(name: str, factory: BackendFactory, *,
|
|
priority: int = 0) -> None:
|
|
with _backend_lock:
|
|
if name in _backends:
|
|
raise RuntimeError(f"Backend {name} already initialized")
|
|
_backend_factories[name] = (factory, priority)
|
|
|
|
|
|
register_backend_factory('cpu',
|
|
partial(xla_client.make_cpu_client, use_tfrt=True),
|
|
priority=0)
|
|
|
|
|
|
def make_gpu_client(
|
|
*, platform_name: str, visible_devices_flag: str
|
|
) -> xla_client.Client:
|
|
visible_devices = getattr(FLAGS, visible_devices_flag, "all")
|
|
allowed_devices = None
|
|
if visible_devices != "all":
|
|
allowed_devices = {int(x) for x in visible_devices.split(",")}
|
|
|
|
if xla_extension_version < 160:
|
|
return xla_client.make_gpu_client(
|
|
distributed_client=distributed.global_state.client,
|
|
node_id=distributed.global_state.process_id,
|
|
platform_name=platform_name,
|
|
allowed_devices=allowed_devices,
|
|
)
|
|
else:
|
|
# Remove `type: ignore` when the min jaxlib version (xla_extension_version)
|
|
# >= 160.
|
|
return xla_client.make_gpu_client(
|
|
distributed_client=distributed.global_state.client,
|
|
node_id=distributed.global_state.process_id,
|
|
num_nodes=distributed.global_state.num_processes,
|
|
platform_name=platform_name,
|
|
allowed_devices=allowed_devices,
|
|
) # type: ignore
|
|
|
|
|
|
if hasattr(xla_client, "make_gpu_client"):
|
|
register_backend_factory(
|
|
'cuda', partial(make_gpu_client, platform_name='cuda',
|
|
visible_devices_flag='jax_cuda_visible_devices'),
|
|
priority=200)
|
|
register_backend_factory(
|
|
'rocm', partial(make_gpu_client, platform_name='rocm',
|
|
visible_devices_flag='jax_rocm_visible_devices'),
|
|
priority=200)
|
|
|
|
|
|
if hasattr(xla_client, "make_tpu_client"):
|
|
register_backend_factory(
|
|
'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300)
|
|
|
|
|
|
def _get_pjrt_plugin_names_and_library_paths(
|
|
plugins_from_env: str,
|
|
) -> Dict[str, str]:
|
|
"""Gets the names and library paths of PJRT plugins to load from env var.
|
|
|
|
Args:
|
|
plugins_from_env: plugin name and pathes from env var. It is in the format
|
|
of 'name1:path1,name2:path2' ('name1;path1,name2;path2' for windows).
|
|
|
|
Returns:
|
|
A dict of {plugin_name: library path} for the PJRT plugins to load.
|
|
"""
|
|
if not plugins_from_env:
|
|
return {}
|
|
|
|
pjrt_plugins = {}
|
|
for plugin in plugins_from_env.split(','):
|
|
try:
|
|
name, library_path = plugin.split(os.path.pathsep)
|
|
pjrt_plugins[name] = library_path
|
|
except ValueError:
|
|
logger.warning(
|
|
'invalid value %s in env var PJRT_NAMES_AND_LIBRARY_PATHS %s',
|
|
plugin,
|
|
plugins_from_env,
|
|
)
|
|
return pjrt_plugins
|
|
|
|
|
|
def _get_pjrt_plugin_config(
|
|
json_path: str,
|
|
) -> Tuple[str, Optional[Mapping[str, Union[str, int, List[int], float]]]]:
|
|
"""Gets PJRT plugin configuration from a json file.
|
|
|
|
The json file needs to have a "library_path" field for the plugin library
|
|
path. It can have an optional "create_option" field for the options used when
|
|
creating a PJRT plugin client. The value of "create_option" is key-value
|
|
pairs. Please see xla_client._NameValueMapping for the supported types of
|
|
values.
|
|
"""
|
|
with io.open(json_path, 'r') as f:
|
|
config = json.load(f)
|
|
if 'library_path' not in config.keys():
|
|
raise ValueError(
|
|
'PJRT plugin config file should contain "library_path" field.'
|
|
)
|
|
return (config['library_path'], config.get('create_options'))
|
|
|
|
def discover_pjrt_plugins() -> None:
|
|
"""Discovers plugins in the namespace package `jax_plugins` and import them.
|
|
|
|
There are two methods used to discover plugin modules. They are intended
|
|
to be used together by implementors in order to cover all packaging and
|
|
development cases:
|
|
|
|
1. Define a globally unique module under the `jax_plugins` namespace
|
|
package (i.e. just create a `jax_plugins` directory and define your
|
|
module below it).
|
|
2. If building a package via pyproject.toml or setup.py, advertise your
|
|
plugin module name by including an entry-point under the `jax_plugins`
|
|
group which points to your full module name.
|
|
|
|
During Jax startup, Jax will load each module discovered in such a way and
|
|
call its `initialize()` function. It is expected that this function should
|
|
register its concrete plugin name/implementations via call(s) to
|
|
`jax._src.xla_bridge.register_plugin(name, priority=, library_paty=,
|
|
options=)`. Since `initialize()` functions are called for all installed
|
|
plugins, they should avoid doing expensive, non-registration related work.
|
|
|
|
TODO: We should provide a variant of `register_plugin` which allows the
|
|
library_path and options to be resolved via a callback. This would enable
|
|
light-weight plugin registration in cases where options need to be derived
|
|
from heavy-weight system initialization.
|
|
"""
|
|
plugin_modules = set()
|
|
# Scan installed modules under |jax_plugins|. Note that not all packaging
|
|
# scenarios are amenable to such scanning, so we also use the entry-point
|
|
# method to seed the list.
|
|
if jax_plugins:
|
|
for _, name, _ in pkgutil.iter_modules(
|
|
jax_plugins.__path__, jax_plugins.__name__ + '.'
|
|
):
|
|
logger.debug("Discovered path based JAX plugin: %s", name)
|
|
plugin_modules.add(name)
|
|
else:
|
|
logger.debug("No jax_plugins namespace packages available")
|
|
|
|
# Augment with advertised entrypoints.
|
|
if sys.version_info < (3, 10):
|
|
# Use the backport library because it provides a forward-compatible
|
|
# implementation.
|
|
from importlib_metadata import entry_points
|
|
else:
|
|
from importlib.metadata import entry_points
|
|
|
|
for entry_point in entry_points(group="jax_plugins"):
|
|
logger.debug("Discovered entry-point based JAX plugin: %s",
|
|
entry_point.value)
|
|
plugin_modules.add(entry_point.value)
|
|
|
|
# Now load and initialize them all.
|
|
for plugin_module_name in plugin_modules:
|
|
logger.debug("Loading plugin module %s", plugin_module_name)
|
|
plugin_module = None
|
|
try:
|
|
plugin_module = importlib.import_module(plugin_module_name)
|
|
except ModuleNotFoundError:
|
|
logger.warning("Jax plugin configuration error: Plugin module %s "
|
|
"does not exist", plugin_module_name)
|
|
except ImportError:
|
|
logger.exception("Jax plugin configuration error: Plugin module %s "
|
|
"could not be loaded")
|
|
|
|
if plugin_module:
|
|
try:
|
|
plugin_module.initialize()
|
|
except:
|
|
logger.exception("Jax plugin configuration error: Exception when "
|
|
"calling %s.initialize()", plugin_module_name)
|
|
|
|
|
|
# TODO(b/261345120): decide on a public name and expose a public method which is
|
|
# an alias of this method.
|
|
def register_plugin(
|
|
plugin_name: str,
|
|
*,
|
|
priority: int = 400,
|
|
library_path: Optional[str] = None,
|
|
options: Optional[Mapping[str, Union[str, int, List[int], float]]] = None,
|
|
) -> None:
|
|
"""Registers a backend factory for the PJRT plugin.
|
|
|
|
Args:
|
|
plugin_name: the name of the plugin.
|
|
priority: the priority this plugin should be registered in jax backends.
|
|
Default to be 400.
|
|
library_path: Optional. The full path to the .so file of the plugin.
|
|
Required when the plugin is dynamically linked.
|
|
options: Optional. It is used when creating a PJRT plugin client.
|
|
"""
|
|
def factory():
|
|
# Plugin may already be statically linked in some configurations.
|
|
if not xla_client.pjrt_plugin_loaded(plugin_name):
|
|
if library_path is None:
|
|
raise ValueError(
|
|
'The library path is None when trying to dynamically load the'
|
|
' plugin.'
|
|
)
|
|
xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
|
|
return xla_client.make_c_api_client(plugin_name, options)
|
|
|
|
logger.debug(
|
|
'registering PJRT plugin %s from %s', plugin_name, library_path
|
|
)
|
|
register_backend_factory(plugin_name, factory, priority=priority)
|
|
|
|
|
|
def register_pjrt_plugin_factories_from_env() -> None:
|
|
"""Registers backend factories for PJRT plugins.
|
|
|
|
A backend factory will be registered for every PJRT plugin in the input
|
|
string, in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2'
|
|
for windows). The path can be a path to the plugin library or a path to the
|
|
plugin configuration json file. The json file needs to have a "library_path"
|
|
field for the plugin library path. It can have an optional "create_option"
|
|
field for the options used when creating a PJRT plugin client. The value of
|
|
"create_option" is key-value pairs. Please see xla_client._NameValueMapping
|
|
for the supported types of values.
|
|
|
|
TPU PJRT plugin will be loaded and registered separately in make_tpu_client.
|
|
"""
|
|
pjrt_plugins = _get_pjrt_plugin_names_and_library_paths(
|
|
os.getenv('PJRT_NAMES_AND_LIBRARY_PATHS', '')
|
|
)
|
|
for plugin_name, path in pjrt_plugins.items():
|
|
if path.endswith('.json'):
|
|
library_path, options = _get_pjrt_plugin_config(path)
|
|
else:
|
|
library_path = path
|
|
options = None
|
|
logger.debug(
|
|
'registering PJRT plugin %s from %s', plugin_name, library_path
|
|
)
|
|
register_plugin(plugin_name, library_path=library_path, options=options)
|
|
|
|
|
|
# Plugins in the namespace package `jax_plugins` will be imported.
|
|
discover_pjrt_plugins()
|
|
# Registers plugins names and paths set in env var PJRT_NAMES_AND_LIBRARY_PATHS,
|
|
# in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2' for
|
|
# windows).
|
|
register_pjrt_plugin_factories_from_env()
|
|
|
|
if iree is not None:
|
|
register_backend_factory("iree", iree.iree_client_factory, priority=-100)
|
|
|
|
_platform_aliases = {
|
|
"cuda": "gpu",
|
|
"rocm": "gpu",
|
|
}
|
|
|
|
_alias_to_platforms: Dict[str, List[str]] = {}
|
|
for _platform, _alias in _platform_aliases.items():
|
|
_alias_to_platforms.setdefault(_alias, []).append(_platform)
|
|
|
|
|
|
def is_known_platform(platform: str) -> bool:
|
|
# A platform is valid if there is a registered factory for it. It does not
|
|
# matter if we were unable to initialize that platform; we only care that
|
|
# we've heard of it and it isn't, e.g., a typo.
|
|
return (platform in _backend_factories.keys() or
|
|
platform in _platform_aliases.keys())
|
|
|
|
|
|
def canonicalize_platform(platform: str) -> str:
|
|
"""Replaces platform aliases with their concrete equivalent.
|
|
|
|
In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
|
|
hardware is actually present. We want to distinguish "cuda" and "rocm" for
|
|
purposes such as MLIR lowering rules, but in many cases we don't want to
|
|
force users to care.
|
|
"""
|
|
platforms = _alias_to_platforms.get(platform, None)
|
|
if platforms is None:
|
|
return platform
|
|
|
|
b = backends()
|
|
for p in platforms:
|
|
if p in b.keys():
|
|
return p
|
|
raise RuntimeError(f"Unknown backend: '{platform}' requested, but no "
|
|
f"platforms that are instances of {platform} are present. "
|
|
"Platforms present are: " + ",".join(b.keys()))
|
|
|
|
|
|
def expand_platform_alias(platform: str) -> List[str]:
|
|
"""Expands, e.g., "gpu" to ["cuda", "rocm"].
|
|
|
|
This is used for convenience reasons: we expect cuda and rocm to act similarly
|
|
in many respects since they share most of the same code.
|
|
"""
|
|
return _alias_to_platforms.get(platform, [platform])
|
|
|
|
def is_gpu(platform):
|
|
return platform in ("cuda", "rocm")
|
|
|
|
def backends() -> Dict[str, xla_client.Client]:
|
|
global _backends
|
|
global _backends_errors
|
|
global _default_backend
|
|
|
|
with _backend_lock:
|
|
if _backends:
|
|
return _backends
|
|
if config.jax_platforms:
|
|
jax_platforms = config.jax_platforms.split(",")
|
|
platforms = []
|
|
# Allow platform aliases in the list of platforms.
|
|
for platform in jax_platforms:
|
|
platforms.extend(expand_platform_alias(platform))
|
|
priorities = range(len(platforms), 0, -1)
|
|
platforms_and_priorities = list(zip(platforms, priorities))
|
|
else:
|
|
platforms_and_priorities = list(
|
|
(platform, priority) for platform, (_, priority)
|
|
in _backend_factories.items())
|
|
default_priority = -1000
|
|
for platform, priority in platforms_and_priorities:
|
|
try:
|
|
backend = _init_backend(platform)
|
|
_backends[platform] = backend
|
|
|
|
if priority > default_priority:
|
|
_default_backend = backend
|
|
default_priority = priority
|
|
except Exception as err:
|
|
if platform in ('cpu', 'interpreter'):
|
|
# We always expect the CPU and interpreter backends to initialize
|
|
# successfully.
|
|
raise
|
|
else:
|
|
# If the backend isn't built into the binary, or if it has no devices,
|
|
# we expect a RuntimeError.
|
|
err_msg = f"Unable to initialize backend '{platform}': {err}"
|
|
if config.jax_platforms:
|
|
err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)"
|
|
raise RuntimeError(err_msg)
|
|
else:
|
|
_backends_errors[platform] = str(err)
|
|
logger.info(err_msg)
|
|
continue
|
|
assert _default_backend is not None
|
|
# We don't warn about falling back to CPU on Mac OS, because we don't
|
|
# support anything else there at the moment and warning would be pointless.
|
|
if (py_platform.system() != "Darwin" and
|
|
_default_backend.platform == "cpu" and
|
|
FLAGS.jax_platform_name != 'cpu'):
|
|
logger.warning('No GPU/TPU found, falling back to CPU. '
|
|
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
|
|
return _backends
|
|
|
|
|
|
def _clear_backends() -> None:
|
|
global _backends
|
|
global _backends_errors
|
|
global _default_backend
|
|
|
|
logger.info("Clearing JAX backend caches.")
|
|
with _backend_lock:
|
|
_backends = {}
|
|
_backends_errors = {}
|
|
_default_backend = None
|
|
|
|
get_backend.cache_clear()
|
|
|
|
|
|
def _init_backend(platform: str) -> xla_client.Client:
|
|
factory, unused_priority = _backend_factories.get(platform, (None, None))
|
|
if factory is None:
|
|
raise RuntimeError(
|
|
f"Backend '{platform}' is not in the list of known backends: "
|
|
f"{list(_backend_factories.keys())}.")
|
|
|
|
logger.debug("Initializing backend '%s'", platform)
|
|
backend = factory()
|
|
# TODO(skye): consider raising more descriptive errors directly from backend
|
|
# factories instead of returning None.
|
|
if backend is None:
|
|
raise RuntimeError(f"Could not initialize backend '{platform}'")
|
|
if backend.device_count() == 0:
|
|
raise RuntimeError(f"Backend '{platform}' provides no devices.")
|
|
util.distributed_debug_log(("Initialized backend", backend.platform),
|
|
("process_index", backend.process_index()),
|
|
("device_count", backend.device_count()),
|
|
("local_devices", backend.local_devices()))
|
|
logger.debug("Backend '%s' initialized", platform)
|
|
return backend
|
|
|
|
|
|
def _get_backend_uncached(
|
|
platform: Union[None, str, xla_client.Client] = None
|
|
) -> xla_client.Client:
|
|
# TODO(mattjj,skyewm): remove this input polymorphism after we clean up how
|
|
# 'backend' values are handled
|
|
if platform is not None and not isinstance(platform, str):
|
|
return platform
|
|
|
|
platform = (platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name
|
|
or None)
|
|
|
|
bs = backends()
|
|
if platform is not None:
|
|
platform = canonicalize_platform(platform)
|
|
backend = bs.get(platform, None)
|
|
if backend is None:
|
|
if platform in _backends_errors:
|
|
raise RuntimeError(f"Backend '{platform}' failed to initialize: "
|
|
f"{_backends_errors[platform]}")
|
|
raise RuntimeError(f"Unknown backend {platform}")
|
|
return backend
|
|
else:
|
|
assert _default_backend is not None
|
|
return _default_backend
|
|
|
|
|
|
@lru_cache(maxsize=None) # don't use util.memoize because there is no X64 dependence.
|
|
def get_backend(
|
|
platform: Union[None, str, xla_client.Client] = None
|
|
) -> xla_client.Client:
|
|
return _get_backend_uncached(platform)
|
|
|
|
|
|
def get_device_backend(
|
|
device: Optional[xla_client.Device] = None,
|
|
) -> xla_client.Client:
|
|
"""Returns the Backend associated with `device`, or the default Backend."""
|
|
if device is not None:
|
|
return device.client
|
|
return get_backend()
|
|
|
|
|
|
def device_count(
|
|
backend: Optional[Union[str, xla_client.Client]] = None
|
|
) -> int:
|
|
"""Returns the total number of devices.
|
|
|
|
On most platforms, this is the same as :py:func:`jax.local_device_count`.
|
|
However, on multi-process platforms where different devices are associated
|
|
with different processes, this will return the total number of devices across
|
|
all processes.
|
|
|
|
Args:
|
|
backend: This is an experimental feature and the API is likely to change.
|
|
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
|
``'tpu'``.
|
|
|
|
Returns:
|
|
Number of devices.
|
|
|
|
"""
|
|
return int(get_backend(backend).device_count())
|
|
|
|
|
|
def local_device_count(
|
|
backend: Optional[Union[str, xla_client.Client]] = None
|
|
) -> int:
|
|
"""Returns the number of devices addressable by this process."""
|
|
return int(get_backend(backend).local_device_count())
|
|
|
|
|
|
def devices(
|
|
backend: Optional[Union[str, xla_client.Client]] = None
|
|
) -> List[xla_client.Device]:
|
|
"""Returns a list of all devices for a given backend.
|
|
|
|
.. currentmodule:: jaxlib.xla_extension
|
|
|
|
Each device is represented by a subclass of :class:`Device` (e.g.
|
|
:class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
|
|
equal to ``device_count(backend)``. Local devices can be identified by
|
|
comparing :attr:`Device.process_index` to the value returned by
|
|
:py:func:`jax.process_index`.
|
|
|
|
If ``backend`` is ``None``, returns all the devices from the default backend.
|
|
The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
|
|
otherwise ``'cpu'``.
|
|
|
|
Args:
|
|
backend: This is an experimental feature and the API is likely to change.
|
|
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
|
``'tpu'``.
|
|
|
|
Returns:
|
|
List of Device subclasses.
|
|
"""
|
|
return get_backend(backend).devices()
|
|
|
|
|
|
def default_backend() -> str:
|
|
"""Returns the platform name of the default XLA backend."""
|
|
return get_backend(None).platform
|
|
|
|
@lru_cache
|
|
def local_devices(process_index: Optional[int] = None,
|
|
backend: Optional[Union[str, xla_client.Client]] = None,
|
|
host_id: Optional[int] = None) -> List[xla_client.Device]:
|
|
"""Like :py:func:`jax.devices`, but only returns devices local to a given process.
|
|
|
|
If ``process_index`` is ``None``, returns devices local to this process.
|
|
|
|
Args:
|
|
process_index: the integer index of the process. Process indices can be
|
|
retrieved via ``len(jax.process_count())``.
|
|
backend: This is an experimental feature and the API is likely to change.
|
|
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
|
``'tpu'``.
|
|
|
|
Returns:
|
|
List of Device subclasses.
|
|
"""
|
|
if host_id is not None:
|
|
warnings.warn(
|
|
"The argument to jax.local_devices has been renamed from `host_id` to "
|
|
"`process_index`. This alias will eventually be removed; please update "
|
|
"your code.")
|
|
process_index = host_id
|
|
if process_index is None:
|
|
process_index = get_backend(backend).process_index()
|
|
if not (0 <= process_index < process_count()):
|
|
raise ValueError(f"Unknown process_index {process_index}")
|
|
return [d for d in devices(backend) if d.process_index == process_index]
|
|
|
|
|
|
def process_index(
|
|
backend: Optional[Union[str, xla_client.Client]] = None
|
|
) -> int:
|
|
"""Returns the integer process index of this process.
|
|
|
|
On most platforms, this will always be 0. This will vary on multi-process
|
|
platforms though.
|
|
|
|
Args:
|
|
backend: This is an experimental feature and the API is likely to change.
|
|
Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
|
|
``'tpu'``.
|
|
|
|
Returns:
|
|
Integer process index.
|
|
"""
|
|
return get_backend(backend).process_index()
|
|
|
|
|
|
# TODO: remove this sometime after jax 0.2.13 is released
|
|
def host_id(backend: Optional[Union[str, xla_client.Client]] = None) -> int:
|
|
warnings.warn(
|
|
"jax.host_id has been renamed to jax.process_index. This alias "
|
|
"will eventually be removed; please update your code.")
|
|
return process_index(backend)
|
|
|
|
|
|
@lru_cache
|
|
def process_count(
|
|
backend: Optional[Union[str, xla_client.Client]] = None
|
|
) -> int:
|
|
"""Returns the number of JAX processes associated with the backend."""
|
|
return max(d.process_index for d in devices(backend)) + 1
|
|
|
|
|
|
# TODO: remove this sometime after jax 0.2.13 is released
|
|
def host_count(backend: Optional[Union[str, xla_client.Client]] = None) -> int:
|
|
warnings.warn(
|
|
"jax.host_count has been renamed to jax.process_count. This alias "
|
|
"will eventually be removed; please update your code.")
|
|
return process_count(backend)
|
|
|
|
|
|
# TODO: remove this sometime after jax 0.2.13 is released
|
|
def host_ids(
|
|
backend: Optional[Union[str, xla_client.Client]] = None
|
|
) -> List[int]:
|
|
warnings.warn(
|
|
"jax.host_ids has been deprecated; please use range(jax.process_count()) "
|
|
"instead. jax.host_ids will eventually be removed; please update your "
|
|
"code.")
|
|
return list(range(process_count(backend)))
|
|
|
|
|
|
def using_pjrt_c_api(backend=None):
|
|
return "PJRT C API" in get_backend(backend).platform_version
|
|
|
|
|
|
# TODO(parkers): Get rid of this in favor of a generic way to get topologies.
|
|
def make_pjrt_tpu_topology(topology_name=None, **kwargs):
|
|
# TODO(b/261484192): Make a system for lazily loading libtpu.so and call
|
|
# that inside make_tfrt_tpu_c_api_device_topology.
|
|
get_backend() # Properly initialize libtpu.so.
|
|
return xla_client.make_tfrt_tpu_c_api_device_topology(
|
|
topology_name, **kwargs
|
|
)
|