Refactor handling of XLA backends.

Use a new xla_client.get_local_backend() method if available, which will be available in a future Jaxlib release.
Use 'cpu', 'gpu' to name platforms instead of 'Host', and 'CUDA'.

Move logic to initialize backends into get_backend() instead of get_xla_client().
Remove xla_bridge.get_xla_client(). Just import xla_client.xla_bridge instead.

Remove _platform_name. Instead, ask the backend for its platform name.
This commit is contained in:
Peter Hawkins 2019-03-29 11:09:56 -04:00
parent 77615e2727
commit 42c62d0dad
5 changed files with 69 additions and 91 deletions

View File

@ -59,7 +59,8 @@ def xla_primitive_callable(prim, *abstract_args, **kwargs):
built_c = primitive_computation(prim, *shapes, **kwargs)
result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape())
handle_result = result_handler(result_shape)
compiled = built_c.Compile(shapes, xb.get_compile_options())
compiled = built_c.Compile(shapes, xb.get_compile_options(),
backend=xb.get_backend())
return partial(execute_compiled_primitive, compiled, handle_result)
@memoize
@ -188,7 +189,8 @@ def tuple_constant(c, val, canonicalize_types=True):
xb.register_constant_handler(JaxTuple, tuple_constant)
def translation_rule(p):
backend_specific_rule = backend_specific_translations[xb._platform_name].get(p)
backend = xb.get_backend()
backend_specific_rule = backend_specific_translations[backend.platform].get(p)
try:
return backend_specific_rule or translations[p]
except KeyError:

View File

@ -56,6 +56,7 @@ from .interpreters import parallel
from .util import curry, memoize, safe_zip, unzip2, prod
from .tree_util import build_tree, tree_unflatten
from .lib import xla_bridge
from .lib.xla_bridge import xla_client
FLAGS = flags.FLAGS
@ -3342,7 +3343,7 @@ def _reduce_window_sum_transpose_rule(cotangent, window_dimensions,
for (lo, hi), stride in zip(pads, window_strides)]
pad_cotangent = pad(cotangent, _zero(cotangent), padding_config)
result = _reduce_window_sum(pad_cotangent, window_dimensions, ones,
xla_bridge.get_xla_client().PaddingType.VALID)
xla_client.PaddingType.VALID)
assert result.shape == input_shape
return [result]
@ -4121,7 +4122,7 @@ def _dilate_shape(shape, dilation):
def padtype_to_pads(in_shape, window_shape, window_strides, padding):
"""Convert padding string to list of pairs of pad values."""
PaddingType = xla_bridge.get_xla_client().PaddingType
PaddingType = xla_client.PaddingType
if isinstance(padding, str):
mapping = {'VALID': PaddingType.VALID, 'SAME': PaddingType.SAME}

View File

@ -112,7 +112,7 @@ def cholesky_cpu_translation_rule(c, operand):
# TODO(phawkins): support LAPACK primitives in batched mode.
return c.Cholesky(operand)
xla.backend_specific_translations['Host'][cholesky_p] = cholesky_cpu_translation_rule
xla.backend_specific_translations['cpu'][cholesky_p] = cholesky_cpu_translation_rule
# Symmetric/Hermitian eigendecomposition
@ -178,7 +178,7 @@ eigh_p.def_impl(eigh_impl)
eigh_p.def_abstract_eval(eigh_abstract_eval)
xla.translations[eigh_p] = eigh_translation_rule
ad.primitive_jvps[eigh_p] = eigh_jvp_rule
xla.backend_specific_translations['Host'][eigh_p] = eigh_cpu_translation_rule
xla.backend_specific_translations['cpu'][eigh_p] = eigh_cpu_translation_rule
@ -261,7 +261,7 @@ def triangular_solve_cpu_translation_rule(
# TODO(phawkins): support BLAS primitives in batched mode.
return c.TriangularSolve(a, b, left_side, lower, transpose_a, conjugate_a)
xla.backend_specific_translations['Host'][triangular_solve_p] = triangular_solve_cpu_translation_rule
xla.backend_specific_translations['cpu'][triangular_solve_p] = triangular_solve_cpu_translation_rule
# LU decomposition
@ -360,7 +360,7 @@ def lu_cpu_translation_rule(c, operand):
# TODO(phawkins): The hasattr() test here is to avoid incompatibilities between
# jax and an older jaxlib. Remove after a jaxlib release includes jax_getrf.
if hasattr(lapack, "jax_getrf"):
xla.backend_specific_translations['Host'][lu_p] = lu_cpu_translation_rule
xla.backend_specific_translations['cpu'][lu_p] = lu_cpu_translation_rule
def lu_pivots_to_permutation(swaps, k):
@ -471,4 +471,4 @@ svd_p = Primitive('svd')
svd_p.def_impl(svd_impl)
svd_p.def_abstract_eval(svd_abstract_eval)
xla.translations[svd_p] = svd_translation_rule
xla.backend_specific_translations['Host'][svd_p] = svd_cpu_translation_rule
xla.backend_specific_translations['cpu'][svd_p] = svd_cpu_translation_rule

View File

@ -34,7 +34,7 @@ import jaxlib
# Check the jaxlib version before importing anything else from jaxlib.
def _check_jaxlib_version():
minimum_version = (0, 1, 9)
minimum_version = (0, 1, 11)
if hasattr(jaxlib, '__version__'):
version = tuple(int(x) for x in jaxlib.__version__.split('.'))
else:
@ -65,18 +65,16 @@ flags.DEFINE_string(
'jax_platform_name', '',
'Platform name for XLA. The default is to attempt to use a '
'GPU if available, but fall back to CPU otherwise. To set '
'the platform manually, pass "Host" for CPU or "CUDA" for '
'the platform manually, pass "cpu" for CPU or "gpu" for '
'GPU.')
_platform_name = None # set to the active platform name
def get_compile_options(num_replicas=None):
"""Returns the compile options to use, as derived from flag values."""
compile_options = None
if num_replicas is not None:
compile_options = compile_options or get_xla_client().CompileOptions()
compile_options = compile_options or xla_client.CompileOptions()
compile_options.num_replicas = num_replicas
return compile_options
@ -94,59 +92,55 @@ def memoize_thunk(func):
return lambda: cached[0] if cached else (cached.append(func()) or cached[0])
@memoize_thunk
def get_xla_client():
return _get_xla_client(FLAGS.jax_xla_backend, FLAGS.jax_platform_name)
def _get_xla_client(backend_name, platform_name):
"""Configures and returns a handle to the XLA client.
Args:
backend_name: backend name, 'xla' or 'xrt'
platform_name: platform name for XLA backend
Returns:
A client library module, or an object that behaves identically to one.
"""
global _platform_name
if backend_name == 'xla':
if platform_name:
xla_client.initialize_platform_name(platform_name)
_platform_name = platform_name
else:
try:
xla_client.initialize_platform_name('CUDA')
_platform_name = 'CUDA'
except RuntimeError:
warnings.warn('No GPU found, falling back to CPU.')
xla_client.initialize_platform_name('Host')
_platform_name = 'Host'
return xla_client
_backends = {}
def register_backend(name, factory):
_backends[name] = factory
def _get_local_backend():
platform = FLAGS.jax_platform_name
if hasattr(xla_client, 'XlaLocalBackend'):
register_backend('xla', lambda: xla_client.XlaLocalBackend())
register_backend('xrt',
lambda: xla_client.XrtBackend(FLAGS.jax_backend_target))
else:
# TODO(phawkins): this case is for cross-version compatibility. Delete this
# case after a Jaxlib update.
register_backend(
'xla', lambda: xla_client.BackendSpec(xla_client.BackendType.XLA_LOCAL, ''))
register_backend(
'xrt', lambda: xla_client.BackendSpec(xla_client.BackendType.XRT,
FLAGS.jax_backend_target))
# Canonicalize platform names.
cpu = 'cpu'
gpu = 'gpu'
if platform == 'Host':
platform = cpu
elif platform == 'CUDA':
platform = gpu
elif platform == '':
platform = None
backend = None
if hasattr(xla_client, 'get_local_backend'):
backend = xla_client.get_local_backend(platform)
else:
# This case is for backward compatibility with Jaxlib versions that don't
# have xla_client.get_local_backend().
platforms = [gpu, cpu] if platform is None else [platform]
for p in platforms:
try:
backend = xla_client.XlaLocalBackend(p)
backend.platform = p
break
except RuntimeError:
continue
if backend is None:
raise RuntimeError("No local XLA backends found.")
if backend.platform == cpu and platform != cpu:
warnings.warn('No GPU/TPU found, falling back to CPU.')
return backend
register_backend('xla', _get_local_backend)
register_backend('xrt',
lambda: xla_client.XrtBackend(FLAGS.jax_backend_target))
@memoize_thunk
def _get_backend():
def get_backend():
backend = _backends.get(FLAGS.jax_xla_backend)
if backend is None:
msg = 'Unknown jax_xla_backend value "{}".'
@ -155,13 +149,12 @@ def _get_backend():
def device_count():
_ = get_xla_client() # ensure initialize_platform_name is called
return _get_backend().device_count()
return get_backend().device_count()
def device_put(pyval, device_num=0):
client = get_xla_client()
return client.LocalBuffer.from_pyval(pyval, device_num, backend=_get_backend())
return xla_client.LocalBuffer.from_pyval(pyval, device_num,
backend=get_backend())
Shape = xla_client.Shape # pylint: disable=invalid-name
@ -241,44 +234,32 @@ def shape_of(value):
def infeed_put(replica_id, pyval):
pyval = normalize_to_xla_dtypes(pyval)
return get_xla_client().transfer_to_infeed(
return xla_client.transfer_to_infeed(
pyval, replica_number=replica_id)
class _JaxComputationBuilderBase(object):
class _JaxComputationBuilder(xla_client.ComputationBuilder):
"""Base class implementing all of JaxComputationBuilder.
This class is intended to override and augment the interface of an XLA
ComputationBuilder to form JaxComputationBuilder, as made clear by
`get_jax_computation_builder_class`, which relies on Python's
method-resolution order to set up inheritance-like behavior. The class
inheritance setup is deferred because the choice of the XLA ComputationBuilder
class is based on the result of `get_xla_client()`. That is, the choice is
based at least on the setting of flags, which are available only after module
initialization time.
ComputationBuilder to form JaxComputationBuilder
"""
# The JAXComputationBuilder is implemented using subclassing and inheritance
# (via this base class), rather than a wrap-and-delegate style, simply to
# avoid having to spell out all the methods to be forwarded to a wrapped
# ComputationBuilder, especially since the underlying ComputationBuilders are
# likely to be revised in the future. An alternative is to generate these
# forwarding methods programmatically.
# Method name case follows that of the XLA ComputationBuilder
# pylint: disable=invalid-name
def Build(self, *args, **kwargs):
return super(_JaxComputationBuilderBase, self).Build(
*args, backend=_get_backend(), **kwargs)
return super(_JaxComputationBuilder, self).Build(
*args, **kwargs)
def Parameter(self, value, name=None, parameter_num=None):
return super(_JaxComputationBuilderBase, self).ParameterWithShape(
return super(_JaxComputationBuilder, self).ParameterWithShape(
shape_of(value), name=name, parameter_num=parameter_num)
def NumpyArrayConstant(self, value, canonicalize_types=True):
if canonicalize_types:
value = normalize_to_xla_dtypes(value)
return super(_JaxComputationBuilderBase, self).Constant(value)
return super(_JaxComputationBuilder, self).Constant(value)
def ConstantLike(self, example_value, value, canonicalize_types=True):
example_value = onp.asarray(example_value)
@ -304,19 +285,12 @@ class _JaxComputationBuilderBase(object):
if split_dimension == concat_dimension and len(replica_groups[0]) == 1:
return operand
else:
return super(_JaxComputationBuilderBase, self).AllToAll(
return super(_JaxComputationBuilder, self).AllToAll(
operand, split_dimension, concat_dimension, replica_groups)
@memoize_thunk
def get_jax_computation_builder_class():
xla_base = get_xla_client().ComputationBuilder
jax_base = _JaxComputationBuilderBase
return type('JaxComputationBuilder', (jax_base, xla_base), {})
def make_computation_builder(name):
return get_jax_computation_builder_class()(name)
return _JaxComputationBuilder(name)
def register_constant_handler(type_, handler_fun):

View File

@ -34,6 +34,7 @@ from ..interpreters.xla import DeviceArray
from .. import lax
from ..util import memoize, partial, get_module_functions, unzip2, prod as _prod
from ..lib import xla_bridge
from ..lib.xla_bridge import xla_client
if six.PY3:
def removechars(s, chars):
@ -1027,7 +1028,7 @@ def _make_cumulative_reduction(onp_reduction, window_reduce, init_val,
window_dims = [1] * num_dims
window_dims[axis] = a_shape[axis]
return window_reduce(
a, window_dims, strides, xla_bridge.get_xla_client().PaddingType.VALID)
a, window_dims, strides, xla_client.PaddingType.VALID)
return cumulative_reduction