mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Refactor handling of XLA backends.
Use a new xla_client.get_local_backend() method if available, which will be available in a future Jaxlib release. Use 'cpu', 'gpu' to name platforms instead of 'Host', and 'CUDA'. Move logic to initialize backends into get_backend() instead of get_xla_client(). Remove xla_bridge.get_xla_client(). Just import xla_client.xla_bridge instead. Remove _platform_name. Instead, ask the backend for its platform name.
This commit is contained in:
parent
77615e2727
commit
42c62d0dad
@ -59,7 +59,8 @@ def xla_primitive_callable(prim, *abstract_args, **kwargs):
|
||||
built_c = primitive_computation(prim, *shapes, **kwargs)
|
||||
result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape())
|
||||
handle_result = result_handler(result_shape)
|
||||
compiled = built_c.Compile(shapes, xb.get_compile_options())
|
||||
compiled = built_c.Compile(shapes, xb.get_compile_options(),
|
||||
backend=xb.get_backend())
|
||||
return partial(execute_compiled_primitive, compiled, handle_result)
|
||||
|
||||
@memoize
|
||||
@ -188,7 +189,8 @@ def tuple_constant(c, val, canonicalize_types=True):
|
||||
xb.register_constant_handler(JaxTuple, tuple_constant)
|
||||
|
||||
def translation_rule(p):
|
||||
backend_specific_rule = backend_specific_translations[xb._platform_name].get(p)
|
||||
backend = xb.get_backend()
|
||||
backend_specific_rule = backend_specific_translations[backend.platform].get(p)
|
||||
try:
|
||||
return backend_specific_rule or translations[p]
|
||||
except KeyError:
|
||||
|
@ -56,6 +56,7 @@ from .interpreters import parallel
|
||||
from .util import curry, memoize, safe_zip, unzip2, prod
|
||||
from .tree_util import build_tree, tree_unflatten
|
||||
from .lib import xla_bridge
|
||||
from .lib.xla_bridge import xla_client
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
@ -3342,7 +3343,7 @@ def _reduce_window_sum_transpose_rule(cotangent, window_dimensions,
|
||||
for (lo, hi), stride in zip(pads, window_strides)]
|
||||
pad_cotangent = pad(cotangent, _zero(cotangent), padding_config)
|
||||
result = _reduce_window_sum(pad_cotangent, window_dimensions, ones,
|
||||
xla_bridge.get_xla_client().PaddingType.VALID)
|
||||
xla_client.PaddingType.VALID)
|
||||
assert result.shape == input_shape
|
||||
return [result]
|
||||
|
||||
@ -4121,7 +4122,7 @@ def _dilate_shape(shape, dilation):
|
||||
|
||||
def padtype_to_pads(in_shape, window_shape, window_strides, padding):
|
||||
"""Convert padding string to list of pairs of pad values."""
|
||||
PaddingType = xla_bridge.get_xla_client().PaddingType
|
||||
PaddingType = xla_client.PaddingType
|
||||
|
||||
if isinstance(padding, str):
|
||||
mapping = {'VALID': PaddingType.VALID, 'SAME': PaddingType.SAME}
|
||||
|
@ -112,7 +112,7 @@ def cholesky_cpu_translation_rule(c, operand):
|
||||
# TODO(phawkins): support LAPACK primitives in batched mode.
|
||||
return c.Cholesky(operand)
|
||||
|
||||
xla.backend_specific_translations['Host'][cholesky_p] = cholesky_cpu_translation_rule
|
||||
xla.backend_specific_translations['cpu'][cholesky_p] = cholesky_cpu_translation_rule
|
||||
|
||||
|
||||
# Symmetric/Hermitian eigendecomposition
|
||||
@ -178,7 +178,7 @@ eigh_p.def_impl(eigh_impl)
|
||||
eigh_p.def_abstract_eval(eigh_abstract_eval)
|
||||
xla.translations[eigh_p] = eigh_translation_rule
|
||||
ad.primitive_jvps[eigh_p] = eigh_jvp_rule
|
||||
xla.backend_specific_translations['Host'][eigh_p] = eigh_cpu_translation_rule
|
||||
xla.backend_specific_translations['cpu'][eigh_p] = eigh_cpu_translation_rule
|
||||
|
||||
|
||||
|
||||
@ -261,7 +261,7 @@ def triangular_solve_cpu_translation_rule(
|
||||
# TODO(phawkins): support BLAS primitives in batched mode.
|
||||
return c.TriangularSolve(a, b, left_side, lower, transpose_a, conjugate_a)
|
||||
|
||||
xla.backend_specific_translations['Host'][triangular_solve_p] = triangular_solve_cpu_translation_rule
|
||||
xla.backend_specific_translations['cpu'][triangular_solve_p] = triangular_solve_cpu_translation_rule
|
||||
|
||||
|
||||
# LU decomposition
|
||||
@ -360,7 +360,7 @@ def lu_cpu_translation_rule(c, operand):
|
||||
# TODO(phawkins): The hasattr() test here is to avoid incompatibilities between
|
||||
# jax and an older jaxlib. Remove after a jaxlib release includes jax_getrf.
|
||||
if hasattr(lapack, "jax_getrf"):
|
||||
xla.backend_specific_translations['Host'][lu_p] = lu_cpu_translation_rule
|
||||
xla.backend_specific_translations['cpu'][lu_p] = lu_cpu_translation_rule
|
||||
|
||||
|
||||
def lu_pivots_to_permutation(swaps, k):
|
||||
@ -471,4 +471,4 @@ svd_p = Primitive('svd')
|
||||
svd_p.def_impl(svd_impl)
|
||||
svd_p.def_abstract_eval(svd_abstract_eval)
|
||||
xla.translations[svd_p] = svd_translation_rule
|
||||
xla.backend_specific_translations['Host'][svd_p] = svd_cpu_translation_rule
|
||||
xla.backend_specific_translations['cpu'][svd_p] = svd_cpu_translation_rule
|
||||
|
@ -34,7 +34,7 @@ import jaxlib
|
||||
|
||||
# Check the jaxlib version before importing anything else from jaxlib.
|
||||
def _check_jaxlib_version():
|
||||
minimum_version = (0, 1, 9)
|
||||
minimum_version = (0, 1, 11)
|
||||
if hasattr(jaxlib, '__version__'):
|
||||
version = tuple(int(x) for x in jaxlib.__version__.split('.'))
|
||||
else:
|
||||
@ -65,18 +65,16 @@ flags.DEFINE_string(
|
||||
'jax_platform_name', '',
|
||||
'Platform name for XLA. The default is to attempt to use a '
|
||||
'GPU if available, but fall back to CPU otherwise. To set '
|
||||
'the platform manually, pass "Host" for CPU or "CUDA" for '
|
||||
'the platform manually, pass "cpu" for CPU or "gpu" for '
|
||||
'GPU.')
|
||||
|
||||
|
||||
_platform_name = None # set to the active platform name
|
||||
|
||||
|
||||
def get_compile_options(num_replicas=None):
|
||||
"""Returns the compile options to use, as derived from flag values."""
|
||||
compile_options = None
|
||||
if num_replicas is not None:
|
||||
compile_options = compile_options or get_xla_client().CompileOptions()
|
||||
compile_options = compile_options or xla_client.CompileOptions()
|
||||
compile_options.num_replicas = num_replicas
|
||||
return compile_options
|
||||
|
||||
@ -94,59 +92,55 @@ def memoize_thunk(func):
|
||||
return lambda: cached[0] if cached else (cached.append(func()) or cached[0])
|
||||
|
||||
|
||||
@memoize_thunk
|
||||
def get_xla_client():
|
||||
return _get_xla_client(FLAGS.jax_xla_backend, FLAGS.jax_platform_name)
|
||||
|
||||
|
||||
def _get_xla_client(backend_name, platform_name):
|
||||
"""Configures and returns a handle to the XLA client.
|
||||
|
||||
Args:
|
||||
backend_name: backend name, 'xla' or 'xrt'
|
||||
platform_name: platform name for XLA backend
|
||||
|
||||
Returns:
|
||||
A client library module, or an object that behaves identically to one.
|
||||
"""
|
||||
global _platform_name
|
||||
if backend_name == 'xla':
|
||||
if platform_name:
|
||||
xla_client.initialize_platform_name(platform_name)
|
||||
_platform_name = platform_name
|
||||
else:
|
||||
try:
|
||||
xla_client.initialize_platform_name('CUDA')
|
||||
_platform_name = 'CUDA'
|
||||
except RuntimeError:
|
||||
warnings.warn('No GPU found, falling back to CPU.')
|
||||
xla_client.initialize_platform_name('Host')
|
||||
_platform_name = 'Host'
|
||||
return xla_client
|
||||
|
||||
|
||||
_backends = {}
|
||||
|
||||
def register_backend(name, factory):
|
||||
_backends[name] = factory
|
||||
|
||||
def _get_local_backend():
|
||||
platform = FLAGS.jax_platform_name
|
||||
|
||||
if hasattr(xla_client, 'XlaLocalBackend'):
|
||||
register_backend('xla', lambda: xla_client.XlaLocalBackend())
|
||||
register_backend('xrt',
|
||||
lambda: xla_client.XrtBackend(FLAGS.jax_backend_target))
|
||||
else:
|
||||
# TODO(phawkins): this case is for cross-version compatibility. Delete this
|
||||
# case after a Jaxlib update.
|
||||
register_backend(
|
||||
'xla', lambda: xla_client.BackendSpec(xla_client.BackendType.XLA_LOCAL, ''))
|
||||
register_backend(
|
||||
'xrt', lambda: xla_client.BackendSpec(xla_client.BackendType.XRT,
|
||||
FLAGS.jax_backend_target))
|
||||
# Canonicalize platform names.
|
||||
cpu = 'cpu'
|
||||
gpu = 'gpu'
|
||||
if platform == 'Host':
|
||||
platform = cpu
|
||||
elif platform == 'CUDA':
|
||||
platform = gpu
|
||||
elif platform == '':
|
||||
platform = None
|
||||
|
||||
backend = None
|
||||
if hasattr(xla_client, 'get_local_backend'):
|
||||
backend = xla_client.get_local_backend(platform)
|
||||
else:
|
||||
# This case is for backward compatibility with Jaxlib versions that don't
|
||||
# have xla_client.get_local_backend().
|
||||
platforms = [gpu, cpu] if platform is None else [platform]
|
||||
for p in platforms:
|
||||
try:
|
||||
backend = xla_client.XlaLocalBackend(p)
|
||||
backend.platform = p
|
||||
break
|
||||
except RuntimeError:
|
||||
continue
|
||||
|
||||
if backend is None:
|
||||
raise RuntimeError("No local XLA backends found.")
|
||||
|
||||
if backend.platform == cpu and platform != cpu:
|
||||
warnings.warn('No GPU/TPU found, falling back to CPU.')
|
||||
|
||||
return backend
|
||||
|
||||
|
||||
|
||||
register_backend('xla', _get_local_backend)
|
||||
register_backend('xrt',
|
||||
lambda: xla_client.XrtBackend(FLAGS.jax_backend_target))
|
||||
|
||||
@memoize_thunk
|
||||
def _get_backend():
|
||||
def get_backend():
|
||||
backend = _backends.get(FLAGS.jax_xla_backend)
|
||||
if backend is None:
|
||||
msg = 'Unknown jax_xla_backend value "{}".'
|
||||
@ -155,13 +149,12 @@ def _get_backend():
|
||||
|
||||
|
||||
def device_count():
|
||||
_ = get_xla_client() # ensure initialize_platform_name is called
|
||||
return _get_backend().device_count()
|
||||
return get_backend().device_count()
|
||||
|
||||
|
||||
def device_put(pyval, device_num=0):
|
||||
client = get_xla_client()
|
||||
return client.LocalBuffer.from_pyval(pyval, device_num, backend=_get_backend())
|
||||
return xla_client.LocalBuffer.from_pyval(pyval, device_num,
|
||||
backend=get_backend())
|
||||
|
||||
|
||||
Shape = xla_client.Shape # pylint: disable=invalid-name
|
||||
@ -241,44 +234,32 @@ def shape_of(value):
|
||||
|
||||
def infeed_put(replica_id, pyval):
|
||||
pyval = normalize_to_xla_dtypes(pyval)
|
||||
return get_xla_client().transfer_to_infeed(
|
||||
return xla_client.transfer_to_infeed(
|
||||
pyval, replica_number=replica_id)
|
||||
|
||||
|
||||
class _JaxComputationBuilderBase(object):
|
||||
class _JaxComputationBuilder(xla_client.ComputationBuilder):
|
||||
"""Base class implementing all of JaxComputationBuilder.
|
||||
|
||||
This class is intended to override and augment the interface of an XLA
|
||||
ComputationBuilder to form JaxComputationBuilder, as made clear by
|
||||
`get_jax_computation_builder_class`, which relies on Python's
|
||||
method-resolution order to set up inheritance-like behavior. The class
|
||||
inheritance setup is deferred because the choice of the XLA ComputationBuilder
|
||||
class is based on the result of `get_xla_client()`. That is, the choice is
|
||||
based at least on the setting of flags, which are available only after module
|
||||
initialization time.
|
||||
ComputationBuilder to form JaxComputationBuilder
|
||||
"""
|
||||
# The JAXComputationBuilder is implemented using subclassing and inheritance
|
||||
# (via this base class), rather than a wrap-and-delegate style, simply to
|
||||
# avoid having to spell out all the methods to be forwarded to a wrapped
|
||||
# ComputationBuilder, especially since the underlying ComputationBuilders are
|
||||
# likely to be revised in the future. An alternative is to generate these
|
||||
# forwarding methods programmatically.
|
||||
|
||||
# Method name case follows that of the XLA ComputationBuilder
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
def Build(self, *args, **kwargs):
|
||||
return super(_JaxComputationBuilderBase, self).Build(
|
||||
*args, backend=_get_backend(), **kwargs)
|
||||
return super(_JaxComputationBuilder, self).Build(
|
||||
*args, **kwargs)
|
||||
|
||||
def Parameter(self, value, name=None, parameter_num=None):
|
||||
return super(_JaxComputationBuilderBase, self).ParameterWithShape(
|
||||
return super(_JaxComputationBuilder, self).ParameterWithShape(
|
||||
shape_of(value), name=name, parameter_num=parameter_num)
|
||||
|
||||
def NumpyArrayConstant(self, value, canonicalize_types=True):
|
||||
if canonicalize_types:
|
||||
value = normalize_to_xla_dtypes(value)
|
||||
return super(_JaxComputationBuilderBase, self).Constant(value)
|
||||
return super(_JaxComputationBuilder, self).Constant(value)
|
||||
|
||||
def ConstantLike(self, example_value, value, canonicalize_types=True):
|
||||
example_value = onp.asarray(example_value)
|
||||
@ -304,19 +285,12 @@ class _JaxComputationBuilderBase(object):
|
||||
if split_dimension == concat_dimension and len(replica_groups[0]) == 1:
|
||||
return operand
|
||||
else:
|
||||
return super(_JaxComputationBuilderBase, self).AllToAll(
|
||||
return super(_JaxComputationBuilder, self).AllToAll(
|
||||
operand, split_dimension, concat_dimension, replica_groups)
|
||||
|
||||
|
||||
@memoize_thunk
|
||||
def get_jax_computation_builder_class():
|
||||
xla_base = get_xla_client().ComputationBuilder
|
||||
jax_base = _JaxComputationBuilderBase
|
||||
return type('JaxComputationBuilder', (jax_base, xla_base), {})
|
||||
|
||||
|
||||
def make_computation_builder(name):
|
||||
return get_jax_computation_builder_class()(name)
|
||||
return _JaxComputationBuilder(name)
|
||||
|
||||
|
||||
def register_constant_handler(type_, handler_fun):
|
||||
|
@ -34,6 +34,7 @@ from ..interpreters.xla import DeviceArray
|
||||
from .. import lax
|
||||
from ..util import memoize, partial, get_module_functions, unzip2, prod as _prod
|
||||
from ..lib import xla_bridge
|
||||
from ..lib.xla_bridge import xla_client
|
||||
|
||||
if six.PY3:
|
||||
def removechars(s, chars):
|
||||
@ -1027,7 +1028,7 @@ def _make_cumulative_reduction(onp_reduction, window_reduce, init_val,
|
||||
window_dims = [1] * num_dims
|
||||
window_dims[axis] = a_shape[axis]
|
||||
return window_reduce(
|
||||
a, window_dims, strides, xla_bridge.get_xla_client().PaddingType.VALID)
|
||||
a, window_dims, strides, xla_client.PaddingType.VALID)
|
||||
|
||||
return cumulative_reduction
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user