1052 lines
40 KiB
Python
Raw Normal View History

2018-11-17 18:03:33 -08:00
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lowering of jaxprs into XLA (HLO) computations.
2018-11-17 18:03:33 -08:00
from collections import defaultdict, deque
import collections.abc
import dataclasses
import functools
from functools import partial
2018-11-17 18:03:33 -08:00
import itertools as it
import operator
import re
from typing import (Any, Callable, Deque, Dict, List, NamedTuple, Optional,
Sequence, Set, Type, Tuple, Union)
from typing_extensions import Protocol
import numpy as np
2018-11-17 18:03:33 -08:00
from jax.config import config
from jax import core
from jax._src import ad_util
from jax._src import device_array
from jax._src import dtypes
from jax import linear_util as lu
from jax._src import source_info_util
from jax._src.abstract_arrays import (make_shaped_array, array_types)
from jax.core import (ConcreteArray, ShapedArray,
Literal, str_eqn_compact, abstract_token)
import jax._src.pretty_printer as pp
from jax._src import util
from jax._src.util import (prod, extend_name_stack, new_name_stack, wrap_name,
safe_zip, safe_map, partition_list)
from jax._src.lib import xla_client as xc
from jax.interpreters import partial_eval as pe
from jax.interpreters import ad
2018-11-17 18:03:33 -08:00
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
xe = xc._xla
xops = xc._xla.ops
# Types
Backend = xe.Client
Device = xc.Device
Buffer = xe.Buffer
XlaOp = xc.XlaOp
XlaShape = xc.Shape
XlaBuilder = xc.XlaBuilder
XlaExecutable = xc.Executable
# apply_primitive is defined in jax._src.dispatch.
apply_primitive: Callable
backend_compile: Callable
device_put: Callable
# TODO(phawkins): update code to point to new locations.
DeviceArray = device_array.DeviceArray
_DeviceArray = device_array._DeviceArray
_CppDeviceArray = xe.Buffer
make_device_array = device_array.make_device_array
def identity(x): return x
_scalar_types = dtypes.python_scalar_dtypes.keys()
change the xla representation of JAX's unit (#2416) * change the xla representation of JAX's unit Previously the representation of JAX's unit value (a sentinel / placeholder) was an empty tuple, but by changing the representation to something else we can further reduce our dependence on runtime tuples. This commit makes the representation fairly easy to change. There are three functions in xla.py that define the representation. Here are versions that would keep the old XLA representation as an empty tuple: ``` def _make_unit(c): return c.Tuple() def _make_abstract_unit(_): return xc.Shape.tuple_shape(()) def _device_put_unit(_, device): return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device)) ``` The new representation is as a trivial array. An alternative representation would be nothing at all: we don't need to generate XLA computations that have representations of JAX units. While that alterntaive is probably the best choice, it seemed like it would require a bit more refactoring/bookkeeping (e.g. to allow XLA computations to have a smaller number of outputs than the corresponding JAX function), and would also mean the XLA representation would be a step further removed from the jaxpr representation. So I stuck with a trivial array for now. The mapping from JAX types to XLA types need not be invertible. However, XLA translation rules currently don't take as arguments the corresponding JAX types (abstract values), and there were a few cases where we relied on checking whether an argument's XLA type was that of an empty tuple so as to determine if we were effectively operating on a JAX unit. In particular, the AD-related primitive add_jaxvals_p could in principle add two units, and get lowered to an XLA addition on the unit representation. Previously, the translation rule for add_jaxvals_p checked the XLA type so that adding two empty tuples didn't produce any XLA operation; now it adds its inputs, and so if unit is represented as a trivial array we could be inserting trivial scalar adds where we had none before. However, if that case is ever possible, it doesn't come up in our tests (which I checked by keeping the representation as an empty tuple and then asserting an XLA tuple type is never seen by that translation rule). * add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
# unit representation
def _make_unit_constant(c): return [
xops.Constant(c, np.zeros((), dtype=np.dtype('bool')))]
def _make_unit_shape(_): return (xc.Shape.array_shape(np.dtype('bool'), ()),)
def _make_array_shape(a: ShapedArray) -> Sequence[XlaShape]:
if a.dtype is dtypes.float0:
return (xc.Shape.array_shape(np.dtype('bool'), a.shape),)
else:
return (xc.Shape.array_shape(a.dtype, a.shape),)
change the xla representation of JAX's unit (#2416) * change the xla representation of JAX's unit Previously the representation of JAX's unit value (a sentinel / placeholder) was an empty tuple, but by changing the representation to something else we can further reduce our dependence on runtime tuples. This commit makes the representation fairly easy to change. There are three functions in xla.py that define the representation. Here are versions that would keep the old XLA representation as an empty tuple: ``` def _make_unit(c): return c.Tuple() def _make_abstract_unit(_): return xc.Shape.tuple_shape(()) def _device_put_unit(_, device): return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device)) ``` The new representation is as a trivial array. An alternative representation would be nothing at all: we don't need to generate XLA computations that have representations of JAX units. While that alterntaive is probably the best choice, it seemed like it would require a bit more refactoring/bookkeeping (e.g. to allow XLA computations to have a smaller number of outputs than the corresponding JAX function), and would also mean the XLA representation would be a step further removed from the jaxpr representation. So I stuck with a trivial array for now. The mapping from JAX types to XLA types need not be invertible. However, XLA translation rules currently don't take as arguments the corresponding JAX types (abstract values), and there were a few cases where we relied on checking whether an argument's XLA type was that of an empty tuple so as to determine if we were effectively operating on a JAX unit. In particular, the AD-related primitive add_jaxvals_p could in principle add two units, and get lowered to an XLA addition on the unit representation. Previously, the translation rule for add_jaxvals_p checked the XLA type so that adding two empty tuples didn't produce any XLA operation; now it adds its inputs, and so if unit is represented as a trivial array we could be inserting trivial scalar adds where we had none before. However, if that case is ever possible, it doesn't come up in our tests (which I checked by keeping the representation as an empty tuple and then asserting an XLA tuple type is never seen by that translation rule). * add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
def _get_canonical_source_file(frame: source_info_util.Frame):
source_file = frame.file_name
if config.jax_hlo_source_file_canonicalization_regex:
source_file = re.sub(config.jax_hlo_source_file_canonicalization_regex,
'', source_file)
return source_file
[jax2tf] Add support for generating HLO OpMetadata in the TF graph The goal is to ensure that the HLO that jax2tf->TF/XLA generates has the same metadata as what JAX generates. This includes `op_type`, `op_name`, and source information, which are used for debugging and profiling. In order to ensure that this metadata is carried from the JAX tracing time to TF/XLA, we save the metadata in custom TF op attributes. These attributes are automatically preserved through SavedModel. This relies on a separate change in TF/XLA to look for these custom attributes and override its default. For the source information, we use pretty much the same code that xla.py uses. HLO OpMetadata has room for only one source location. JAX (xla.py) picks the top-most user frame, which is obtained by filtering out the stack frames in the JAX source tree. When used with jax2tf we also need to filter out stack frames in the TensorFlow source tree. The hardest part is to generate the `op_name`, which is a hierarchical name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`. We carry the current `name_stack` in thread-local state. Unfortunately, there is no easy way to share the exact code that achieves this in xla.py. At the same time it is not crucial that we have exactly identical name stacks as in JAX. I attempted to also carry this state in the JAX `MainTrace`, but could not fully control the name stack. E.g., when calling a jitted-function we have to reuse the current `MainTrace` although we want to push an element on the name stack. For now this option is not yet enabled until we make the necessary changes in TensorFlow.
2021-05-25 13:33:35 +02:00
tracebacks = {}
def make_op_metadata(primitive: core.Primitive,
params: Dict, *,
source_info: source_info_util.SourceInfo,
name_stack: Union[str, source_info_util.NameStack] = "",
[jax2tf] Add support for generating HLO OpMetadata in the TF graph The goal is to ensure that the HLO that jax2tf->TF/XLA generates has the same metadata as what JAX generates. This includes `op_type`, `op_name`, and source information, which are used for debugging and profiling. In order to ensure that this metadata is carried from the JAX tracing time to TF/XLA, we save the metadata in custom TF op attributes. These attributes are automatically preserved through SavedModel. This relies on a separate change in TF/XLA to look for these custom attributes and override its default. For the source information, we use pretty much the same code that xla.py uses. HLO OpMetadata has room for only one source location. JAX (xla.py) picks the top-most user frame, which is obtained by filtering out the stack frames in the JAX source tree. When used with jax2tf we also need to filter out stack frames in the TensorFlow source tree. The hardest part is to generate the `op_name`, which is a hierarchical name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`. We carry the current `name_stack` in thread-local state. Unfortunately, there is no easy way to share the exact code that achieves this in xla.py. At the same time it is not crucial that we have exactly identical name stacks as in JAX. I attempted to also carry this state in the JAX `MainTrace`, but could not fully control the name stack. E.g., when calling a jitted-function we have to reuse the current `MainTrace` although we want to push an element on the name stack. For now this option is not yet enabled until we make the necessary changes in TensorFlow.
2021-05-25 13:33:35 +02:00
) -> xc.OpMetadata:
if config.jax_experimental_name_stack:
eqn_str = str(source_info.name_stack) + '/' + str_eqn_compact(primitive.name, params)
else:
assert isinstance(name_stack, str)
eqn_str = name_stack + str_eqn_compact(primitive.name, params)
tracebacks[eqn_str] = source_info.traceback
frame = source_info_util.user_frame(source_info)
[jax2tf] Add support for generating HLO OpMetadata in the TF graph The goal is to ensure that the HLO that jax2tf->TF/XLA generates has the same metadata as what JAX generates. This includes `op_type`, `op_name`, and source information, which are used for debugging and profiling. In order to ensure that this metadata is carried from the JAX tracing time to TF/XLA, we save the metadata in custom TF op attributes. These attributes are automatically preserved through SavedModel. This relies on a separate change in TF/XLA to look for these custom attributes and override its default. For the source information, we use pretty much the same code that xla.py uses. HLO OpMetadata has room for only one source location. JAX (xla.py) picks the top-most user frame, which is obtained by filtering out the stack frames in the JAX source tree. When used with jax2tf we also need to filter out stack frames in the TensorFlow source tree. The hardest part is to generate the `op_name`, which is a hierarchical name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`. We carry the current `name_stack` in thread-local state. Unfortunately, there is no easy way to share the exact code that achieves this in xla.py. At the same time it is not crucial that we have exactly identical name stacks as in JAX. I attempted to also carry this state in the JAX `MainTrace`, but could not fully control the name stack. E.g., when calling a jitted-function we have to reuse the current `MainTrace` although we want to push an element on the name stack. For now this option is not yet enabled until we make the necessary changes in TensorFlow.
2021-05-25 13:33:35 +02:00
return xc.OpMetadata(
op_type=primitive.name,
op_name=eqn_str,
source_file=_get_canonical_source_file(frame) if frame else None,
[jax2tf] Add support for generating HLO OpMetadata in the TF graph The goal is to ensure that the HLO that jax2tf->TF/XLA generates has the same metadata as what JAX generates. This includes `op_type`, `op_name`, and source information, which are used for debugging and profiling. In order to ensure that this metadata is carried from the JAX tracing time to TF/XLA, we save the metadata in custom TF op attributes. These attributes are automatically preserved through SavedModel. This relies on a separate change in TF/XLA to look for these custom attributes and override its default. For the source information, we use pretty much the same code that xla.py uses. HLO OpMetadata has room for only one source location. JAX (xla.py) picks the top-most user frame, which is obtained by filtering out the stack frames in the JAX source tree. When used with jax2tf we also need to filter out stack frames in the TensorFlow source tree. The hardest part is to generate the `op_name`, which is a hierarchical name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`. We carry the current `name_stack` in thread-local state. Unfortunately, there is no easy way to share the exact code that achieves this in xla.py. At the same time it is not crucial that we have exactly identical name stacks as in JAX. I attempted to also carry this state in the JAX `MainTrace`, but could not fully control the name stack. E.g., when calling a jitted-function we have to reuse the current `MainTrace` although we want to push an element on the name stack. For now this option is not yet enabled until we make the necessary changes in TensorFlow.
2021-05-25 13:33:35 +02:00
source_line=frame.line_num if frame else None)
# Utilities
def parameter(builder, num, shape, name=None, replicated=None):
if name is None:
name = ''
if replicated is None:
replicated = []
elif isinstance(replicated, bool):
replicated = [replicated] * shape.leaf_count()
return xops.Parameter(builder, num,
shape.with_major_to_minor_layout_if_absent(), name,
replicated)
# HLO instructions optionally can be annotated to say how the output should be
# spatially partitioned (represented in XLA as OpSharding protos, see
# sharding_to_proto). For array outputs, the annotation is either an int per
# dimension specifying the number of ways that dimension divided (i.e. the total
# number of shards is the product), or None to indicate the array should be
# replicated. Tuple outputs are represented as tuples thereof. XLA supports
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
# checkers don't support recursive types), so we only represent one level of
# nesting in this type definition.
SpatialSharding = Union[Tuple[int, ...],
None,
Tuple[Optional[Tuple[int, ...]], ...]]
def sharding_to_proto(sharding: SpatialSharding):
"""Converts a SpatialSharding to an OpSharding.
See
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
for details on the OpSharding proto.
"""
proto = xc.OpSharding()
if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
assert all(s is None or isinstance(s, tuple) for s in sharding)
return tuple_sharding_proto(list(map(sharding_to_proto, sharding))) # type: ignore
if sharding is None:
proto.type = xc.OpSharding.Type.REPLICATED
else:
proto.type = xc.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = list(sharding)
proto.tile_assignment_devices = list(range(np.product(sharding))) # type: ignore
return proto
def tuple_sharding_proto(elems):
proto = xc.OpSharding()
assert all(isinstance(e, type(proto)) for e in elems)
proto.type = xc.OpSharding.Type.TUPLE
proto.tuple_shardings = elems
return proto
def set_sharding_proto(builder, op, sharding_proto, unspecified_dims=None):
"""Uses CustomCall to annotate a value as sharded."""
# "Sharding" is a built-in custom call target that acts like an identity
# function, and is used to attach an OpSharding to.
def _create_custom_call(x):
# unspecified_dims indicate dimensions whose shardings are not specified and
# XLA sharding propagation can change them.
if unspecified_dims:
opaque = 'unspecified_dims=[' + ','.join(
[str(i) for i in unspecified_dims]) + ']'
opaque = bytes(opaque, 'utf-8')
return xops.CustomCall(
builder, b'Sharding', [x], builder.get_shape(x), opaque=opaque)
else:
return xops.CustomCall(builder, b'Sharding', [x], builder.get_shape(x))
return with_sharding_proto(builder, sharding_proto, _create_custom_call, op)
def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
builder.set_sharding(sharding_proto)
try:
return op_fn(*args, **kwargs)
finally:
builder.clear_sharding()
def set_sharding(builder, op, sharding: SpatialSharding, unspecified_dims=None):
"""Uses CustomCall to annotate a value as sharded."""
return set_sharding_proto(builder, op, sharding_to_proto(sharding),
unspecified_dims)
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
return with_sharding_proto(builder, sharding_to_proto(sharding), op_fn, *args,
**kwargs)
### handlers
# Numpy dtypes -> XLA primitive types
_dtype_to_primitive_type: Dict[np.dtype, xc.PrimitiveType] = {
np.dtype('bool'): xc.PrimitiveType.PRED,
np.dtype('int8'): xc.PrimitiveType.S8,
np.dtype('int16'): xc.PrimitiveType.S16,
np.dtype('int32'): xc.PrimitiveType.S32,
np.dtype('int64'): xc.PrimitiveType.S64,
np.dtype('uint8'): xc.PrimitiveType.U8,
np.dtype('uint16'): xc.PrimitiveType.U16,
np.dtype('uint32'): xc.PrimitiveType.U32,
np.dtype('uint64'): xc.PrimitiveType.U64,
np.dtype(dtypes.bfloat16): xc.PrimitiveType.BF16,
np.dtype('float16'): xc.PrimitiveType.F16,
np.dtype('float32'): xc.PrimitiveType.F32,
np.dtype('float64'): xc.PrimitiveType.F64,
np.dtype('complex64'): xc.PrimitiveType.C64,
np.dtype('complex128'): xc.PrimitiveType.C128,
}
def dtype_to_primitive_type(dtype: np.dtype) -> xc.PrimitiveType:
"""Converts a NumPy dtype into an XLA PrimitiveType."""
# Many things (e.g., strings, scalar types) can be compared with NumPy dtypes,
# but may not hash correctly. Make sure we have a true np.dtype.
assert isinstance(dtype, np.dtype), type(dtype)
try:
return _dtype_to_primitive_type[dtype]
except KeyError as err:
raise TypeError(f"No XLA lowering for NumPy dtype: {dtype}") from err
# JAX abstract values -> XLA shapes
def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[XlaShape]:
try:
return xla_shape_handlers[type(aval)](aval)
except KeyError as err:
raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err
xla_shape_handlers: Dict[Type[core.AbstractValue],
Callable[[Any], Sequence[XlaShape]]] = {
core.AbstractUnit: _make_unit_shape,
ShapedArray: _make_array_shape,
ConcreteArray: _make_array_shape,
}
xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
# IR constants
_constant_handlers: Dict[type, Callable] = {}
def pyval_to_ir_constants(builder, py_val, canonicalize_types=True):
"""Translate a general constant `py_val` to a constant, canonicalizing its dtype.
Args:
py_val: a Python value to be translated to a constant.
Returns:
A representation of the constant as a list of xla ops.
"""
for t in type(py_val).__mro__:
handler = _constant_handlers.get(t)
if handler: return handler(builder, py_val, canonicalize_types)
if hasattr(py_val, '__jax_array__'):
return pyval_to_ir_constants(builder, py_val.__jax_array__(),
canonicalize_types)
raise TypeError("No constant handler for type: {}".format(type(py_val)))
def pyval_to_ir_constant(builder, py_val, canonicalize_types=True):
"""Translate constant `py_val` to a constant, canonicalizing its dtype.
Args:
py_val: a Python value to be translated to a constant.
Returns:
A representation of the constant, either a ComputationDataHandle or None
"""
const = pyval_to_ir_constants(builder, py_val, canonicalize_types=canonicalize_types)
assert len(const) == 1, f"Internal error: cannot create constant from object of type {type(py_val)}"
return const[0]
def register_constant_handler(type_, handler_fun):
_constant_handlers[type_] = handler_fun
register_constant_handler(core.Unit, lambda c, *_: _make_unit_constant(c))
# TODO(mattjj,frostig): try to remove this function
def _normalize_to_xla_dtypes(val):
"""Normalize dtypes in a value."""
if hasattr(val, '__array__') or np.isscalar(val):
return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
elif isinstance(val, (tuple, list)):
return tuple(_normalize_to_xla_dtypes(x) for x in val)
raise TypeError('Can\'t convert to XLA: {}'.format(val))
def _numpy_array_constant(builder, value, canonicalize_types=True):
if canonicalize_types:
value = _normalize_to_xla_dtypes(value)
return [xops.Constant(builder, value)]
def _ndarray_constant_handler(c, val, canonicalize_types=True):
"""Constant handler for ndarray literals, handling zero-size strides.
This function essentially calls _numpy_array_constant(val) except it has
special handling of arrays with any strides of size zero: for those, it
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
to avoid staging in large literals that might arise from np.zeros or np.ones
or the output of lax.broadcast (which uses np.broadcast_to which in turn
uses size-zero strides).
Args:
c: an XlaBuilder
val: an ndarray.
Returns:
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
staged into the XLA Computation.
"""
# TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose
if dtypes.result_type(val) == dtypes.float0:
return _numpy_array_constant(c, np.zeros(val.shape, dtype=np.bool_))
elif np.any(np.equal(0, val.strides)) and val.size > 0:
zero_stride_axes, = np.where(np.equal(0, val.strides))
other_axes, = np.where(np.not_equal(0, val.strides))
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
for ax in range(val.ndim))]
xla_val = xops.Broadcast(
_numpy_array_constant(c, collapsed_val, canonicalize_types)[0],
np.take(val.shape, zero_stride_axes))
permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes))
return [xops.Transpose(xla_val, permutation)]
else:
return _numpy_array_constant(c, val, canonicalize_types)
register_constant_handler(np.ndarray, _ndarray_constant_handler)
def _scalar_constant_handler(c, val, canonicalize_types=True):
return _numpy_array_constant(c, val, canonicalize_types)
for scalar_type in [np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64,
np.float16, np.float32, np.float64,
np.bool_, np.longlong,
dtypes.bfloat16]:
register_constant_handler(scalar_type, _scalar_constant_handler)
# https://github.com/winpython/winpython/issues/613#issuecomment-380121523
if hasattr(np, "float128"):
register_constant_handler(np.float128, _scalar_constant_handler)
def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True):
return _numpy_array_constant(c, dtype.type(val))
for ptype, dtype in dtypes.python_scalar_dtypes.items():
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
def _device_array_constant_handler(c, val, canonicalize_types=True):
return pyval_to_ir_constants(c, val.device_buffer.to_py())
for t in device_array.device_array_types:
register_constant_handler(t, _device_array_constant_handler)
register_constant_handler(core.Token, lambda c, _, __: [xops.CreateToken(c)])
# TODO(mattjj): try to remove this canonicalize_dtype stuff
def canonicalize_dtype(x):
typ = type(x)
handler = canonicalize_dtype_handlers.get(typ)
if handler: return handler(x)
for typ in typ.__mro__:
handler = canonicalize_dtype_handlers.get(typ)
if handler: return handler(x)
if hasattr(x, '__jax_array__'):
return canonicalize_dtype(x.__jax_array__())
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")
def _canonicalize_ndarray_dtype(x):
return np.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x)))
def _canonicalize_python_scalar_dtype(typ, x):
return np.asarray(
x, dtypes.canonicalize_dtype(dtypes._scalar_type_to_dtype(typ, x)))
canonicalize_dtype_handlers: Dict[Any, Callable] = {core.Unit: identity}
for t in device_array.device_array_types:
canonicalize_dtype_handlers[t] = lambda x: x
canonicalize_dtype_handlers.update(
(t, _canonicalize_ndarray_dtype) for t in array_types)
canonicalize_dtype_handlers.update(
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
canonicalize_dtype_handlers[core.Token] = lambda x: x
def abstractify(x) -> core.AbstractValue:
typ = type(x)
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
for typ in typ.__mro__:
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
if hasattr(x, '__jax_array__'):
return abstractify(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
def _make_abstract_python_scalar(typ, val):
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val), weak_type=True)
pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = {
core.Unit: lambda _: core.abstract_unit,
}
for t in device_array.device_array_types:
pytype_aval_mappings[t] = operator.attrgetter('aval')
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
pytype_aval_mappings.update((t, make_shaped_array) for t in array_types)
pytype_aval_mappings.update(
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
def primitive_subcomputation(platform: str, axis_env: 'AxisEnv',
prim: core.Primitive,
*avals: core.AbstractValue, **params):
c = xc.XlaBuilder(f"primitive_computation_{prim.name}")
f = lower_fun(prim.bind, multiple_results=prim.multiple_results,
new_style=True)
xla_args, _ = _xla_callable_args(c, avals, tuple_args=False,
filter_tokens=False)
ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env,
name_stack=new_name_stack())
ans = f(ctx.replace(builder=c), avals, None, *xla_args, **params)
if prim.multiple_results:
ans = xops.Tuple(c, ans)
else:
ans, = ans
return c.build(ans)
Create a separate internal helper function for XLA compilation (#3852) XLA backends are written in C++, so method calls don't show up in Python profiling results from cProfile. Adding an explicit function call fixes that. This is helpful for interpretting profiling results, e.g., on the example from https://github.com/google/jax/issues/3847. Before: 70814996 function calls (69915267 primitive calls) in 112.804 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 1193 24.936 0.021 30.336 0.025 xla.py:227(xla_primitive_callable) 10524/1 16.342 0.002 112.991 112.991 xla.py:595(_xla_callable) 2014622/1843062 8.745 0.000 16.618 0.000 util.py:29(safe_map) 18145 3.662 0.000 4.218 0.000 source_info_util.py:27(user_frame) 196061/183909 1.604 0.000 24.647 0.000 partial_eval.py:150(default_process_primitive) 423499 1.569 0.000 1.569 0.000 {method 'reduce' of 'numpy.ufunc' objects} After: 71147652 function calls (70235594 primitive calls) in 101.718 seconds Ordered by: internal time ncalls tottime percall cumtime percall filename:lineno(function) 1294 38.894 0.030 38.894 0.030 xla.py:325(_backend_compile) 2017790/1844559 6.965 0.000 14.139 0.000 util.py:29(safe_map) 18146 3.317 0.000 3.839 0.000 source_info_util.py:27(user_frame) 196226/184073 1.467 0.000 21.889 0.000 partial_eval.py:150(default_process_primitive) 423771 1.419 0.000 1.419 0.000 {method 'reduce' of 'numpy.ufunc' objects} We now clearly see that both `xla_primitive_callable` and `_xla_callable` are slow for the same reason and ~40 seconds is spent inside XLA compilation.
2020-07-24 11:05:40 -07:00
# Used within _xla_callable_args and _xla_param to distinguish between None (no
# sharding annotation set) and replicated.
_replicated_param = object()
def _token_param_shape():
"""Shape used in place of tokens as top-level computation arguments."""
return xc.Shape.array_shape(np.dtype(np.bool_), [])
def _make_token_return_value(c):
"""Value used in place of tokens as a top-level computation return value."""
return xops.Constant(c, np.zeros((), dtype=np.dtype(np.bool_)))
def _xla_callable_args(
c, avals, tuple_args, *,
replicated=None,
partitions=None,
partitions_proto: bool = False,
donated_invars=None,
filter_tokens=True):
assert partitions is None or len(partitions) == len(avals)
if not tuple_args:
if replicated is None:
replicated = [None] * len(avals)
if partitions is None:
parts: List[object] = [None] * len(avals)
elif partitions_proto:
parts = partitions
else:
parts = [_replicated_param if part is None else part
for part in partitions]
counts = it.count()
xla_args = [_xla_param(c, next(counts), xla_shape, r, p, partitions_proto,
filter_tokens)
for (a, r, p) in safe_zip(avals, replicated, parts)
for xla_shape in aval_to_xla_shapes(a)]
if donated_invars is not None:
donated_invars = [
d for (a, _, _, d) in zip(avals, replicated, parts, donated_invars)
for xla_shape in aval_to_xla_shapes(a)]
return xla_args, donated_invars
else:
if replicated is not None:
replicated = [r for a, r in zip(avals, replicated)
if a is not abstract_token]
if partitions is None:
tuple_parts = None
elif partitions_proto:
tuple_parts = tuple_sharding_proto(partitions)
else:
tuple_parts = tuple(partitions)
tuple_shape = xc.Shape.tuple_shape(
[shape if not (filter_tokens and a is abstract_token)
else _token_param_shape()
for a in avals for shape in aval_to_xla_shapes(a)])
tuple_param = _xla_param(c, 0, tuple_shape, replicated, tuple_parts,
partitions_proto, filter_tokens)
xla_args = [v if not (filter_tokens and a is abstract_token)
else xops.CreateToken(c)
for a, v in zip(avals, xla_destructure(c, tuple_param))]
return xla_args, donated_invars
def _xla_param(builder, param_num, xla_shape, replicated, partitions,
parts_proto, filter_tokens):
is_token = xla_shape.is_token()
if filter_tokens and is_token:
xla_shape = _token_param_shape()
make_param = partial(parameter, builder, param_num, xla_shape,
replicated=replicated)
with_sharding_fn = with_sharding_proto if parts_proto else with_sharding
if partitions is None:
out = make_param()
elif partitions is _replicated_param:
out = with_sharding_fn(builder, None, make_param)
else:
out = with_sharding_fn(builder, partitions, make_param)
if filter_tokens and is_token:
out = xops.CreateToken(builder)
return out
2020-12-18 16:26:31 +00:00
### compiling jaxprs
2018-11-17 18:03:33 -08:00
def _flatmap(func: Callable, vars: Sequence):
return list(it.chain.from_iterable(map(func, vars)))
def _partitionmap(func: Callable, vars: Sequence, nodes: Sequence):
return map(func, vars,
util.unflatten(nodes,
[len(aval_to_xla_shapes(v.aval)) for v in vars]))
class AxisEnv(NamedTuple):
"""Represents a pmap mesh (only along the replica axes)."""
nreps: int
names: Tuple[Any, ...]
sizes: Tuple[int, ...]
@dataclasses.dataclass
class TranslationContext:
builder: xc.XlaBuilder
# TODO(phawkins): make platform non-optional. We should always be translating
# with a specific platform in mind.
platform: Optional[str]
axis_env: AxisEnv
name_stack: Union[str, source_info_util.NameStack]
def replace(self, **kw): return dataclasses.replace(self, **kw)
2018-11-17 18:03:33 -08:00
def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr,
consts: Sequence[XlaOp], *args: XlaOp) -> Sequence[XlaOp]:
assert ctx.platform is not None
2018-11-17 18:03:33 -08:00
def read(v):
if type(v) is Literal:
return pyval_to_ir_constants(ctx.builder, canonicalize_dtype(v.val))
else:
return env[v]
2018-11-17 18:03:33 -08:00
change the xla representation of JAX's unit (#2416) * change the xla representation of JAX's unit Previously the representation of JAX's unit value (a sentinel / placeholder) was an empty tuple, but by changing the representation to something else we can further reduce our dependence on runtime tuples. This commit makes the representation fairly easy to change. There are three functions in xla.py that define the representation. Here are versions that would keep the old XLA representation as an empty tuple: ``` def _make_unit(c): return c.Tuple() def _make_abstract_unit(_): return xc.Shape.tuple_shape(()) def _device_put_unit(_, device): return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device)) ``` The new representation is as a trivial array. An alternative representation would be nothing at all: we don't need to generate XLA computations that have representations of JAX units. While that alterntaive is probably the best choice, it seemed like it would require a bit more refactoring/bookkeeping (e.g. to allow XLA computations to have a smaller number of outputs than the corresponding JAX function), and would also mean the XLA representation would be a step further removed from the jaxpr representation. So I stuck with a trivial array for now. The mapping from JAX types to XLA types need not be invertible. However, XLA translation rules currently don't take as arguments the corresponding JAX types (abstract values), and there were a few cases where we relied on checking whether an argument's XLA type was that of an empty tuple so as to determine if we were effectively operating on a JAX unit. In particular, the AD-related primitive add_jaxvals_p could in principle add two units, and get lowered to an XLA addition on the unit representation. Previously, the translation rule for add_jaxvals_p checked the XLA type so that adding two empty tuples didn't produce any XLA operation; now it adds its inputs, and so if unit is represented as a trivial array we could be inserting trivial scalar adds where we had none before. However, if that case is ever possible, it doesn't come up in our tests (which I checked by keeping the representation as an empty tuple and then asserting an XLA tuple type is never seen by that translation rule). * add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
def aval(v):
if type(v) is Literal:
return abstractify(v.val)
else:
return v.aval
2018-11-17 18:03:33 -08:00
def write(v, node):
assert node is not None
env[v] = node
env: Dict[core.Var, Sequence[XlaOp]] = {}
_partitionmap(write, [core.unitvar],
pyval_to_ir_constants(ctx.builder, core.unit))
_partitionmap(write, jaxpr.constvars, consts)
_partitionmap(write, jaxpr.invars, args)
2018-11-17 18:03:33 -08:00
for eqn in jaxpr.eqns:
if config.jax_experimental_name_stack:
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
source_info = eqn.source_info.replace(
name_stack=ctx.name_stack + eqn.source_info.name_stack)
else:
source_info = eqn.source_info
[jax2tf] Add support for generating HLO OpMetadata in the TF graph The goal is to ensure that the HLO that jax2tf->TF/XLA generates has the same metadata as what JAX generates. This includes `op_type`, `op_name`, and source information, which are used for debugging and profiling. In order to ensure that this metadata is carried from the JAX tracing time to TF/XLA, we save the metadata in custom TF op attributes. These attributes are automatically preserved through SavedModel. This relies on a separate change in TF/XLA to look for these custom attributes and override its default. For the source information, we use pretty much the same code that xla.py uses. HLO OpMetadata has room for only one source location. JAX (xla.py) picks the top-most user frame, which is obtained by filtering out the stack frames in the JAX source tree. When used with jax2tf we also need to filter out stack frames in the TensorFlow source tree. The hardest part is to generate the `op_name`, which is a hierarchical name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`. We carry the current `name_stack` in thread-local state. Unfortunately, there is no easy way to share the exact code that achieves this in xla.py. At the same time it is not crucial that we have exactly identical name stacks as in JAX. I attempted to also carry this state in the JAX `MainTrace`, but could not fully control the name stack. E.g., when calling a jitted-function we have to reuse the current `MainTrace` although we want to push an element on the name stack. For now this option is not yet enabled until we make the necessary changes in TensorFlow.
2021-05-25 13:33:35 +02:00
op_metadata = make_op_metadata(
eqn.primitive, eqn.params, name_stack=ctx.name_stack,
source_info=source_info)
ctx.builder.set_op_metadata(op_metadata)
in_nodes = _flatmap(read, eqn.invars)
if (ctx.platform is not None and
eqn.primitive in _backend_specific_translations[ctx.platform]):
rule = _backend_specific_translations[ctx.platform][eqn.primitive]
elif eqn.primitive in _translations:
rule = _translations[eqn.primitive]
else:
raise NotImplementedError(
f"XLA translation rule for primitive '{eqn.primitive.name}' not found")
with source_info_util.user_context(eqn.source_info.traceback):
eqn_ctx = (ctx.replace(name_stack=source_info.name_stack) if
config.jax_experimental_name_stack else ctx)
ans = rule(eqn_ctx, map(aval, eqn.invars), map(aval, eqn.outvars),
*in_nodes, **eqn.params)
assert isinstance(ans, collections.abc.Sequence), (ans, eqn)
assert all(isinstance(x, xe.XlaOp) for x in ans), (ans, eqn)
map(ctx.builder.get_shape, ans) # force xla to do shape error checking
ctx.builder.clear_op_metadata()
_partitionmap(write, eqn.outvars, ans)
return _flatmap(read, jaxpr.outvars)
2019-04-15 07:45:10 -07:00
2018-11-17 18:03:33 -08:00
def xla_destructure(c, ans):
num_elements = len(c.get_shape(ans).tuple_shapes())
return [xops.GetTupleElement(ans, i) for i in range(num_elements)]
2018-11-17 18:03:33 -08:00
def check_backend_matches(inner_backend, outer_backend):
# For nested calls, the outermost call sets the backend for all inner calls;
# it's an error if the inner call has a conflicting explicit backend spec.
if inner_backend and inner_backend != outer_backend:
raise ValueError(
f"Outer-jit backend specification {outer_backend} must match explicit "
f"inner-jit backend specification {inner_backend}.")
2018-11-17 18:03:33 -08:00
def extend_axis_env(env: AxisEnv, name, size: int):
return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,))
2018-11-17 18:03:33 -08:00
enable jit+pmap by merging pxla.py and xla.py This change is essentially de-duplicating the XLA lowering logic between xla.py and pxla.py. Only the latter was capable of handling collectives (aka pmap primitives), which meant that these didn't work: 1. some compositions of jit and pmap, like jit-of-pmap 2. collectives inside initial-style control flow like scan 3. jax.xla_computation on a function involving collectives By merging the logic into xla.py, now all the lowering machinery works with everything. Woo! The pxla.py file still exists and contains mostly dynamic/runtime components for pmap and functions used only by pmap and collectives translations. In particular, pxla.py has * the pmap impl, particularly the dispatching logic for top-level pmaps, including argument sharding and lazy sharded result persistence * the ShardedDeviceArray / ShardedDeviceTuple classes * the dynamic (trace-time) axis environment data structures and logic and the special axis_index primitive * the split-axis transformation for soft_pmap * the PmapPrimitive (just a tagged version of Primitive) * the static sharding/unsharding logic for pmap-inside-jit/pmap These things moved over to xla.py * the logic for lowering pmap primitives, especially the static axis environment used during xla lowering This change refactors the translation rule tables a bit. Instead of just having one table, there are now four, and they contain rules with slightly different type signatures: * the `translations` table has rules with the same signatures as always, i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut` * the `backend_specific_translations` table is keyed by platform name strings and has dict values that each have the same type as `translations` * the `parallel_translations` table is used for primitives modeling parallel collectives, and so it has rules with signature `CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut` * the `initial_style_translations` table is for the initial-style control flow primitives (like `scan`), for which the translation rules themselves lower jaxprs to XLA computations and thus require the static axis env to be passed in; the rules there have signature `CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut` * the `call_translations` table is sued for `xla_call` and `xla_pmap`, i.e. the primitives underlying `jit` and `pmap` respectively, and has rules with signature `CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp` Having these as separate tables is an uninteresting implementation detail. The lowering function `_jaxpr_computation` just does a case analysis on whether the primitive being translated has an entry in any table (where the `backend_specific_translations` table must be checked before the `translations` table, since some primitives may be entered in both). This change fixes #804 also addresses #852, in that the lax control flow impls for those primitives are now based on Python-level jaxpr interpreters rather than XLA compilation, but we should probably wait to close the latter issue until we benchmark and improve things more. This change at least seems not to be a performance regression: on my machine the lax control flow tests go from running in ~20s to running in ~14s. This change also adds a docstring for `jax.xla_computation` and some basic tests.
2019-07-02 13:17:31 -07:00
def axis_read(axis_env, axis_name):
try:
return max(i for i, name in enumerate(axis_env.names) if name == axis_name)
except ValueError:
raise NameError("unbound axis name: {}".format(axis_name)) from None
enable jit+pmap by merging pxla.py and xla.py This change is essentially de-duplicating the XLA lowering logic between xla.py and pxla.py. Only the latter was capable of handling collectives (aka pmap primitives), which meant that these didn't work: 1. some compositions of jit and pmap, like jit-of-pmap 2. collectives inside initial-style control flow like scan 3. jax.xla_computation on a function involving collectives By merging the logic into xla.py, now all the lowering machinery works with everything. Woo! The pxla.py file still exists and contains mostly dynamic/runtime components for pmap and functions used only by pmap and collectives translations. In particular, pxla.py has * the pmap impl, particularly the dispatching logic for top-level pmaps, including argument sharding and lazy sharded result persistence * the ShardedDeviceArray / ShardedDeviceTuple classes * the dynamic (trace-time) axis environment data structures and logic and the special axis_index primitive * the split-axis transformation for soft_pmap * the PmapPrimitive (just a tagged version of Primitive) * the static sharding/unsharding logic for pmap-inside-jit/pmap These things moved over to xla.py * the logic for lowering pmap primitives, especially the static axis environment used during xla lowering This change refactors the translation rule tables a bit. Instead of just having one table, there are now four, and they contain rules with slightly different type signatures: * the `translations` table has rules with the same signatures as always, i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut` * the `backend_specific_translations` table is keyed by platform name strings and has dict values that each have the same type as `translations` * the `parallel_translations` table is used for primitives modeling parallel collectives, and so it has rules with signature `CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut` * the `initial_style_translations` table is for the initial-style control flow primitives (like `scan`), for which the translation rules themselves lower jaxprs to XLA computations and thus require the static axis env to be passed in; the rules there have signature `CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut` * the `call_translations` table is sued for `xla_call` and `xla_pmap`, i.e. the primitives underlying `jit` and `pmap` respectively, and has rules with signature `CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp` Having these as separate tables is an uninteresting implementation detail. The lowering function `_jaxpr_computation` just does a case analysis on whether the primitive being translated has an entry in any table (where the `backend_specific_translations` table must be checked before the `translations` table, since some primitives may be entered in both). This change fixes #804 also addresses #852, in that the lax control flow impls for those primitives are now based on Python-level jaxpr interpreters rather than XLA compilation, but we should probably wait to close the latter issue until we benchmark and improve things more. This change at least seems not to be a performance regression: on my machine the lax control flow tests go from running in ~20s to running in ~14s. This change also adds a docstring for `jax.xla_computation` and some basic tests.
2019-07-02 13:17:31 -07:00
def axis_groups(axis_env: AxisEnv, name) -> Tuple[Tuple[int, ...]]:
if not isinstance(name, (list, tuple)):
name = (name,)
mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
trailing_size, ragged = divmod(axis_env.nreps, prod(axis_env.sizes))
enable jit+pmap by merging pxla.py and xla.py This change is essentially de-duplicating the XLA lowering logic between xla.py and pxla.py. Only the latter was capable of handling collectives (aka pmap primitives), which meant that these didn't work: 1. some compositions of jit and pmap, like jit-of-pmap 2. collectives inside initial-style control flow like scan 3. jax.xla_computation on a function involving collectives By merging the logic into xla.py, now all the lowering machinery works with everything. Woo! The pxla.py file still exists and contains mostly dynamic/runtime components for pmap and functions used only by pmap and collectives translations. In particular, pxla.py has * the pmap impl, particularly the dispatching logic for top-level pmaps, including argument sharding and lazy sharded result persistence * the ShardedDeviceArray / ShardedDeviceTuple classes * the dynamic (trace-time) axis environment data structures and logic and the special axis_index primitive * the split-axis transformation for soft_pmap * the PmapPrimitive (just a tagged version of Primitive) * the static sharding/unsharding logic for pmap-inside-jit/pmap These things moved over to xla.py * the logic for lowering pmap primitives, especially the static axis environment used during xla lowering This change refactors the translation rule tables a bit. Instead of just having one table, there are now four, and they contain rules with slightly different type signatures: * the `translations` table has rules with the same signatures as always, i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut` * the `backend_specific_translations` table is keyed by platform name strings and has dict values that each have the same type as `translations` * the `parallel_translations` table is used for primitives modeling parallel collectives, and so it has rules with signature `CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut` * the `initial_style_translations` table is for the initial-style control flow primitives (like `scan`), for which the translation rules themselves lower jaxprs to XLA computations and thus require the static axis env to be passed in; the rules there have signature `CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut` * the `call_translations` table is sued for `xla_call` and `xla_pmap`, i.e. the primitives underlying `jit` and `pmap` respectively, and has rules with signature `CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp` Having these as separate tables is an uninteresting implementation detail. The lowering function `_jaxpr_computation` just does a case analysis on whether the primitive being translated has an entry in any table (where the `backend_specific_translations` table must be checked before the `translations` table, since some primitives may be entered in both). This change fixes #804 also addresses #852, in that the lax control flow impls for those primitives are now based on Python-level jaxpr interpreters rather than XLA compilation, but we should probably wait to close the latter issue until we benchmark and improve things more. This change at least seems not to be a performance regression: on my machine the lax control flow tests go from running in ~20s to running in ~14s. This change also adds a docstring for `jax.xla_computation` and some basic tests.
2019-07-02 13:17:31 -07:00
assert not ragged
mesh_spec = axis_env.sizes + (trailing_size,)
return _axis_groups(mesh_spec, mesh_axes)
def _axis_groups(mesh_spec, mesh_axes):
"""Computes replica group ids for a collective performed over a subset of the mesh.
2021-01-15 11:49:19 +11:00
Args:
mesh_spec: A sequence of integers representing the mesh shape.
mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive)
indicating over which axes the collective is performed.
Returns:
A tuple of replica groups (i.e. tuples containing replica ids).
"""
iota = np.arange(prod(mesh_spec)).reshape(mesh_spec)
groups = np.reshape(
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
(prod(np.take(mesh_spec, mesh_axes)), -1))
return tuple(unsafe_map(tuple, groups.T))
enable jit+pmap by merging pxla.py and xla.py This change is essentially de-duplicating the XLA lowering logic between xla.py and pxla.py. Only the latter was capable of handling collectives (aka pmap primitives), which meant that these didn't work: 1. some compositions of jit and pmap, like jit-of-pmap 2. collectives inside initial-style control flow like scan 3. jax.xla_computation on a function involving collectives By merging the logic into xla.py, now all the lowering machinery works with everything. Woo! The pxla.py file still exists and contains mostly dynamic/runtime components for pmap and functions used only by pmap and collectives translations. In particular, pxla.py has * the pmap impl, particularly the dispatching logic for top-level pmaps, including argument sharding and lazy sharded result persistence * the ShardedDeviceArray / ShardedDeviceTuple classes * the dynamic (trace-time) axis environment data structures and logic and the special axis_index primitive * the split-axis transformation for soft_pmap * the PmapPrimitive (just a tagged version of Primitive) * the static sharding/unsharding logic for pmap-inside-jit/pmap These things moved over to xla.py * the logic for lowering pmap primitives, especially the static axis environment used during xla lowering This change refactors the translation rule tables a bit. Instead of just having one table, there are now four, and they contain rules with slightly different type signatures: * the `translations` table has rules with the same signatures as always, i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut` * the `backend_specific_translations` table is keyed by platform name strings and has dict values that each have the same type as `translations` * the `parallel_translations` table is used for primitives modeling parallel collectives, and so it has rules with signature `CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut` * the `initial_style_translations` table is for the initial-style control flow primitives (like `scan`), for which the translation rules themselves lower jaxprs to XLA computations and thus require the static axis env to be passed in; the rules there have signature `CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut` * the `call_translations` table is sued for `xla_call` and `xla_pmap`, i.e. the primitives underlying `jit` and `pmap` respectively, and has rules with signature `CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp` Having these as separate tables is an uninteresting implementation detail. The lowering function `_jaxpr_computation` just does a case analysis on whether the primitive being translated has an entry in any table (where the `backend_specific_translations` table must be checked before the `translations` table, since some primitives may be entered in both). This change fixes #804 also addresses #852, in that the lax control flow impls for those primitives are now based on Python-level jaxpr interpreters rather than XLA compilation, but we should probably wait to close the latter issue until we benchmark and improve things more. This change at least seems not to be a performance regression: on my machine the lax control flow tests go from running in ~20s to running in ~14s. This change also adds a docstring for `jax.xla_computation` and some basic tests.
2019-07-02 13:17:31 -07:00
# TODO(mattjj,skyewm): the functions here are utilities for checking if
# not-yet-supported features are used with multi-host programming
enable jit+pmap by merging pxla.py and xla.py This change is essentially de-duplicating the XLA lowering logic between xla.py and pxla.py. Only the latter was capable of handling collectives (aka pmap primitives), which meant that these didn't work: 1. some compositions of jit and pmap, like jit-of-pmap 2. collectives inside initial-style control flow like scan 3. jax.xla_computation on a function involving collectives By merging the logic into xla.py, now all the lowering machinery works with everything. Woo! The pxla.py file still exists and contains mostly dynamic/runtime components for pmap and functions used only by pmap and collectives translations. In particular, pxla.py has * the pmap impl, particularly the dispatching logic for top-level pmaps, including argument sharding and lazy sharded result persistence * the ShardedDeviceArray / ShardedDeviceTuple classes * the dynamic (trace-time) axis environment data structures and logic and the special axis_index primitive * the split-axis transformation for soft_pmap * the PmapPrimitive (just a tagged version of Primitive) * the static sharding/unsharding logic for pmap-inside-jit/pmap These things moved over to xla.py * the logic for lowering pmap primitives, especially the static axis environment used during xla lowering This change refactors the translation rule tables a bit. Instead of just having one table, there are now four, and they contain rules with slightly different type signatures: * the `translations` table has rules with the same signatures as always, i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut` * the `backend_specific_translations` table is keyed by platform name strings and has dict values that each have the same type as `translations` * the `parallel_translations` table is used for primitives modeling parallel collectives, and so it has rules with signature `CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut` * the `initial_style_translations` table is for the initial-style control flow primitives (like `scan`), for which the translation rules themselves lower jaxprs to XLA computations and thus require the static axis env to be passed in; the rules there have signature `CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut` * the `call_translations` table is sued for `xla_call` and `xla_pmap`, i.e. the primitives underlying `jit` and `pmap` respectively, and has rules with signature `CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp` Having these as separate tables is an uninteresting implementation detail. The lowering function `_jaxpr_computation` just does a case analysis on whether the primitive being translated has an entry in any table (where the `backend_specific_translations` table must be checked before the `translations` table, since some primitives may be entered in both). This change fixes #804 also addresses #852, in that the lax control flow impls for those primitives are now based on Python-level jaxpr interpreters rather than XLA compilation, but we should probably wait to close the latter issue until we benchmark and improve things more. This change at least seems not to be a performance regression: on my machine the lax control flow tests go from running in ~20s to running in ~14s. This change also adds a docstring for `jax.xla_computation` and some basic tests.
2019-07-02 13:17:31 -07:00
def jaxpr_collectives(jaxpr):
"""Generates all the collective primitives anywhere inside a Jaxpr."""
for eqn in jaxpr.eqns:
if eqn.primitive in _collective_primitives:
yield eqn.primitive
for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_collectives(subjaxpr)
### xla_call underlying jit
2018-11-17 18:03:33 -08:00
Add support for buffer donation in `jit` and `pmap`. (#2936) For a computation of the form: >>> f = lambda x: x ** 2 >>> f = jax.jit(f) >>> while run: ... x = f(x) JAX must currently always have two copies of `x` in device memory since there is no reliable way in Python to determine whether there will be future uses of `x`. This causes two classes of problem: 1. Users at the limit of available device are constrained by the additional copy of their parameters and other state while they typically only require one copy. This typically frees 100M+ of device memory and is a critical optimization for larger models to match state of the art performance in other frameworks. 2. This constant alloc/free of the input/output buffers can cause memory fragmentation on some platforms (although having a reusing allocator and limiting run-ahead may be a better solution for this problem). We propose fixing this by using input/output aliasing as supported by XLA. We will support this in JAX by allowing certain arguments of jit/pmap decorated functions to be donated and reused as outputs: >>> f = lambda x: x ** 2 >>> f = jit(f, donate_argnums=0) >>> while run: ... x = f(x) JAX will determine that the donated input `x` can alias with the output of the function and it will instruct XLA it _must_ write the result to this buffer. If a user tries to reuse a buffer after it has been donated they get an error that the buffer is invalid: >>> y = f(x) >>> jax.device_get(x) ... RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. The semantics of `donate_argnums` follows that of `static_argnums`, namely that it identifies positional arguments to the computation that are to be donated to the computation and used as part of the output. One feature that is also enabled by this is invalidating buffers that should only be used once, for example PRNGKeys: >>> @partial(jit, donate_argnums=0) ... def move(x): ... # Do something complex enough for JAX to just optimize it away. ... return tree_map(lambda x: x + x - x, x) >>> def safe_eager_uniform(key, *a, **k): ... assert hasattr(key, 'device_buffer'), "random must run eagerly" ... key = move(key) ... return jax.random.uniform(key, *a, **k) This is not a complete answer to random safety since it is still possible to reuse a key as part of a traced computation, however it can be used to support this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
def flatten_shape(s: XlaShape) -> Sequence[Tuple[Sequence[int], XlaShape]]:
"""Expands a given shape tree into a flat list of indices to arrays.
Given the following computation:
>>> c = xc.XlaBuilder("example")
>>> p0 = parameter(c, 1, xc.shape_from_pyval(jnp.ones([1])))
>>> p1 = parameter(c, 2, xc.shape_from_pyval(jnp.ones([2])))
>>> p2 = parameter(c, 3, xc.shape_from_pyval(jnp.ones([3])))
Add support for buffer donation in `jit` and `pmap`. (#2936) For a computation of the form: >>> f = lambda x: x ** 2 >>> f = jax.jit(f) >>> while run: ... x = f(x) JAX must currently always have two copies of `x` in device memory since there is no reliable way in Python to determine whether there will be future uses of `x`. This causes two classes of problem: 1. Users at the limit of available device are constrained by the additional copy of their parameters and other state while they typically only require one copy. This typically frees 100M+ of device memory and is a critical optimization for larger models to match state of the art performance in other frameworks. 2. This constant alloc/free of the input/output buffers can cause memory fragmentation on some platforms (although having a reusing allocator and limiting run-ahead may be a better solution for this problem). We propose fixing this by using input/output aliasing as supported by XLA. We will support this in JAX by allowing certain arguments of jit/pmap decorated functions to be donated and reused as outputs: >>> f = lambda x: x ** 2 >>> f = jit(f, donate_argnums=0) >>> while run: ... x = f(x) JAX will determine that the donated input `x` can alias with the output of the function and it will instruct XLA it _must_ write the result to this buffer. If a user tries to reuse a buffer after it has been donated they get an error that the buffer is invalid: >>> y = f(x) >>> jax.device_get(x) ... RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. The semantics of `donate_argnums` follows that of `static_argnums`, namely that it identifies positional arguments to the computation that are to be donated to the computation and used as part of the output. One feature that is also enabled by this is invalidating buffers that should only be used once, for example PRNGKeys: >>> @partial(jit, donate_argnums=0) ... def move(x): ... # Do something complex enough for JAX to just optimize it away. ... return tree_map(lambda x: x + x - x, x) >>> def safe_eager_uniform(key, *a, **k): ... assert hasattr(key, 'device_buffer'), "random must run eagerly" ... key = move(key) ... return jax.random.uniform(key, *a, **k) This is not a complete answer to random safety since it is still possible to reuse a key as part of a traced computation, however it can be used to support this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
>>> o = xops.Tuple(c, [p0, p1, p2])
We can query the arrays in the output tuple:
>>> flatten_shape(c.GetShape(o))
[((0,), f32[1]{0}), ((1,), f32[2]{0}), ((2,), f32[3]{0})]
Add support for buffer donation in `jit` and `pmap`. (#2936) For a computation of the form: >>> f = lambda x: x ** 2 >>> f = jax.jit(f) >>> while run: ... x = f(x) JAX must currently always have two copies of `x` in device memory since there is no reliable way in Python to determine whether there will be future uses of `x`. This causes two classes of problem: 1. Users at the limit of available device are constrained by the additional copy of their parameters and other state while they typically only require one copy. This typically frees 100M+ of device memory and is a critical optimization for larger models to match state of the art performance in other frameworks. 2. This constant alloc/free of the input/output buffers can cause memory fragmentation on some platforms (although having a reusing allocator and limiting run-ahead may be a better solution for this problem). We propose fixing this by using input/output aliasing as supported by XLA. We will support this in JAX by allowing certain arguments of jit/pmap decorated functions to be donated and reused as outputs: >>> f = lambda x: x ** 2 >>> f = jit(f, donate_argnums=0) >>> while run: ... x = f(x) JAX will determine that the donated input `x` can alias with the output of the function and it will instruct XLA it _must_ write the result to this buffer. If a user tries to reuse a buffer after it has been donated they get an error that the buffer is invalid: >>> y = f(x) >>> jax.device_get(x) ... RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. The semantics of `donate_argnums` follows that of `static_argnums`, namely that it identifies positional arguments to the computation that are to be donated to the computation and used as part of the output. One feature that is also enabled by this is invalidating buffers that should only be used once, for example PRNGKeys: >>> @partial(jit, donate_argnums=0) ... def move(x): ... # Do something complex enough for JAX to just optimize it away. ... return tree_map(lambda x: x + x - x, x) >>> def safe_eager_uniform(key, *a, **k): ... assert hasattr(key, 'device_buffer'), "random must run eagerly" ... key = move(key) ... return jax.random.uniform(key, *a, **k) This is not a complete answer to random safety since it is still possible to reuse a key as part of a traced computation, however it can be used to support this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
Or the arrays in one of the parameters (which is itself an array):
>>> flatten_shape(c.GetShape(p0))
[((), f32[1]{0})]
Add support for buffer donation in `jit` and `pmap`. (#2936) For a computation of the form: >>> f = lambda x: x ** 2 >>> f = jax.jit(f) >>> while run: ... x = f(x) JAX must currently always have two copies of `x` in device memory since there is no reliable way in Python to determine whether there will be future uses of `x`. This causes two classes of problem: 1. Users at the limit of available device are constrained by the additional copy of their parameters and other state while they typically only require one copy. This typically frees 100M+ of device memory and is a critical optimization for larger models to match state of the art performance in other frameworks. 2. This constant alloc/free of the input/output buffers can cause memory fragmentation on some platforms (although having a reusing allocator and limiting run-ahead may be a better solution for this problem). We propose fixing this by using input/output aliasing as supported by XLA. We will support this in JAX by allowing certain arguments of jit/pmap decorated functions to be donated and reused as outputs: >>> f = lambda x: x ** 2 >>> f = jit(f, donate_argnums=0) >>> while run: ... x = f(x) JAX will determine that the donated input `x` can alias with the output of the function and it will instruct XLA it _must_ write the result to this buffer. If a user tries to reuse a buffer after it has been donated they get an error that the buffer is invalid: >>> y = f(x) >>> jax.device_get(x) ... RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. The semantics of `donate_argnums` follows that of `static_argnums`, namely that it identifies positional arguments to the computation that are to be donated to the computation and used as part of the output. One feature that is also enabled by this is invalidating buffers that should only be used once, for example PRNGKeys: >>> @partial(jit, donate_argnums=0) ... def move(x): ... # Do something complex enough for JAX to just optimize it away. ... return tree_map(lambda x: x + x - x, x) >>> def safe_eager_uniform(key, *a, **k): ... assert hasattr(key, 'device_buffer'), "random must run eagerly" ... key = move(key) ... return jax.random.uniform(key, *a, **k) This is not a complete answer to random safety since it is still possible to reuse a key as part of a traced computation, however it can be used to support this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
Args
s: The input shape.
Returns:
An iterable of pairs of indices and shapes for each array within the shape
tree.
"""
results: List[Tuple[Tuple[int, ...], XlaShape]] = []
_flatten_shape(s, (), results)
return results
def _flatten_shape(s: XlaShape, index: Tuple[int, ...],
results: List[Tuple[Tuple[int, ...], XlaShape]]) -> None:
if s.is_array() or s.is_token():
results.append((index, s))
else:
assert s.is_tuple()
for i, sub in enumerate(s.tuple_shapes()):
_flatten_shape(sub, index + (i,), results)
Add support for buffer donation in `jit` and `pmap`. (#2936) For a computation of the form: >>> f = lambda x: x ** 2 >>> f = jax.jit(f) >>> while run: ... x = f(x) JAX must currently always have two copies of `x` in device memory since there is no reliable way in Python to determine whether there will be future uses of `x`. This causes two classes of problem: 1. Users at the limit of available device are constrained by the additional copy of their parameters and other state while they typically only require one copy. This typically frees 100M+ of device memory and is a critical optimization for larger models to match state of the art performance in other frameworks. 2. This constant alloc/free of the input/output buffers can cause memory fragmentation on some platforms (although having a reusing allocator and limiting run-ahead may be a better solution for this problem). We propose fixing this by using input/output aliasing as supported by XLA. We will support this in JAX by allowing certain arguments of jit/pmap decorated functions to be donated and reused as outputs: >>> f = lambda x: x ** 2 >>> f = jit(f, donate_argnums=0) >>> while run: ... x = f(x) JAX will determine that the donated input `x` can alias with the output of the function and it will instruct XLA it _must_ write the result to this buffer. If a user tries to reuse a buffer after it has been donated they get an error that the buffer is invalid: >>> y = f(x) >>> jax.device_get(x) ... RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. The semantics of `donate_argnums` follows that of `static_argnums`, namely that it identifies positional arguments to the computation that are to be donated to the computation and used as part of the output. One feature that is also enabled by this is invalidating buffers that should only be used once, for example PRNGKeys: >>> @partial(jit, donate_argnums=0) ... def move(x): ... # Do something complex enough for JAX to just optimize it away. ... return tree_map(lambda x: x + x - x, x) >>> def safe_eager_uniform(key, *a, **k): ... assert hasattr(key, 'device_buffer'), "random must run eagerly" ... key = move(key) ... return jax.random.uniform(key, *a, **k) This is not a complete answer to random safety since it is still possible to reuse a key as part of a traced computation, however it can be used to support this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
def _xla_consts(c, consts):
unique_consts = {id(const): const for const in consts}
xla_consts = {
id_: pyval_to_ir_constants(c, const) for id_, const in unique_consts.items()}
return [c for const in consts for c in xla_consts[id(const)]]
def set_up_aliases(c, xla_args, out_shape: XlaShape, donated_args, tuple_args):
Add support for buffer donation in `jit` and `pmap`. (#2936) For a computation of the form: >>> f = lambda x: x ** 2 >>> f = jax.jit(f) >>> while run: ... x = f(x) JAX must currently always have two copies of `x` in device memory since there is no reliable way in Python to determine whether there will be future uses of `x`. This causes two classes of problem: 1. Users at the limit of available device are constrained by the additional copy of their parameters and other state while they typically only require one copy. This typically frees 100M+ of device memory and is a critical optimization for larger models to match state of the art performance in other frameworks. 2. This constant alloc/free of the input/output buffers can cause memory fragmentation on some platforms (although having a reusing allocator and limiting run-ahead may be a better solution for this problem). We propose fixing this by using input/output aliasing as supported by XLA. We will support this in JAX by allowing certain arguments of jit/pmap decorated functions to be donated and reused as outputs: >>> f = lambda x: x ** 2 >>> f = jit(f, donate_argnums=0) >>> while run: ... x = f(x) JAX will determine that the donated input `x` can alias with the output of the function and it will instruct XLA it _must_ write the result to this buffer. If a user tries to reuse a buffer after it has been donated they get an error that the buffer is invalid: >>> y = f(x) >>> jax.device_get(x) ... RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. The semantics of `donate_argnums` follows that of `static_argnums`, namely that it identifies positional arguments to the computation that are to be donated to the computation and used as part of the output. One feature that is also enabled by this is invalidating buffers that should only be used once, for example PRNGKeys: >>> @partial(jit, donate_argnums=0) ... def move(x): ... # Do something complex enough for JAX to just optimize it away. ... return tree_map(lambda x: x + x - x, x) >>> def safe_eager_uniform(key, *a, **k): ... assert hasattr(key, 'device_buffer'), "random must run eagerly" ... key = move(key) ... return jax.random.uniform(key, *a, **k) This is not a complete answer to random safety since it is still possible to reuse a key as part of a traced computation, however it can be used to support this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
"""Configures input/output "must" aliasing based on `donated_args`."""
# First for every input array add it to `donations` iff it is a member of
# `donated_args`.
donations: Dict[Tuple[Tuple[int, ...], Any], Deque]
Add support for buffer donation in `jit` and `pmap`. (#2936) For a computation of the form: >>> f = lambda x: x ** 2 >>> f = jax.jit(f) >>> while run: ... x = f(x) JAX must currently always have two copies of `x` in device memory since there is no reliable way in Python to determine whether there will be future uses of `x`. This causes two classes of problem: 1. Users at the limit of available device are constrained by the additional copy of their parameters and other state while they typically only require one copy. This typically frees 100M+ of device memory and is a critical optimization for larger models to match state of the art performance in other frameworks. 2. This constant alloc/free of the input/output buffers can cause memory fragmentation on some platforms (although having a reusing allocator and limiting run-ahead may be a better solution for this problem). We propose fixing this by using input/output aliasing as supported by XLA. We will support this in JAX by allowing certain arguments of jit/pmap decorated functions to be donated and reused as outputs: >>> f = lambda x: x ** 2 >>> f = jit(f, donate_argnums=0) >>> while run: ... x = f(x) JAX will determine that the donated input `x` can alias with the output of the function and it will instruct XLA it _must_ write the result to this buffer. If a user tries to reuse a buffer after it has been donated they get an error that the buffer is invalid: >>> y = f(x) >>> jax.device_get(x) ... RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. The semantics of `donate_argnums` follows that of `static_argnums`, namely that it identifies positional arguments to the computation that are to be donated to the computation and used as part of the output. One feature that is also enabled by this is invalidating buffers that should only be used once, for example PRNGKeys: >>> @partial(jit, donate_argnums=0) ... def move(x): ... # Do something complex enough for JAX to just optimize it away. ... return tree_map(lambda x: x + x - x, x) >>> def safe_eager_uniform(key, *a, **k): ... assert hasattr(key, 'device_buffer'), "random must run eagerly" ... key = move(key) ... return jax.random.uniform(key, *a, **k) This is not a complete answer to random safety since it is still possible to reuse a key as part of a traced computation, however it can be used to support this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
donations = defaultdict(deque)
for arg_index, arg in enumerate(xla_args):
if donated_args[arg_index]:
for param_index, element in flatten_shape(c.GetShape(arg)):
key = (element.dimensions(), element.xla_element_type())
Add support for buffer donation in `jit` and `pmap`. (#2936) For a computation of the form: >>> f = lambda x: x ** 2 >>> f = jax.jit(f) >>> while run: ... x = f(x) JAX must currently always have two copies of `x` in device memory since there is no reliable way in Python to determine whether there will be future uses of `x`. This causes two classes of problem: 1. Users at the limit of available device are constrained by the additional copy of their parameters and other state while they typically only require one copy. This typically frees 100M+ of device memory and is a critical optimization for larger models to match state of the art performance in other frameworks. 2. This constant alloc/free of the input/output buffers can cause memory fragmentation on some platforms (although having a reusing allocator and limiting run-ahead may be a better solution for this problem). We propose fixing this by using input/output aliasing as supported by XLA. We will support this in JAX by allowing certain arguments of jit/pmap decorated functions to be donated and reused as outputs: >>> f = lambda x: x ** 2 >>> f = jit(f, donate_argnums=0) >>> while run: ... x = f(x) JAX will determine that the donated input `x` can alias with the output of the function and it will instruct XLA it _must_ write the result to this buffer. If a user tries to reuse a buffer after it has been donated they get an error that the buffer is invalid: >>> y = f(x) >>> jax.device_get(x) ... RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. The semantics of `donate_argnums` follows that of `static_argnums`, namely that it identifies positional arguments to the computation that are to be donated to the computation and used as part of the output. One feature that is also enabled by this is invalidating buffers that should only be used once, for example PRNGKeys: >>> @partial(jit, donate_argnums=0) ... def move(x): ... # Do something complex enough for JAX to just optimize it away. ... return tree_map(lambda x: x + x - x, x) >>> def safe_eager_uniform(key, *a, **k): ... assert hasattr(key, 'device_buffer'), "random must run eagerly" ... key = move(key) ... return jax.random.uniform(key, *a, **k) This is not a complete answer to random safety since it is still possible to reuse a key as part of a traced computation, however it can be used to support this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
if tuple_args:
param_number = 0
param_index = (arg_index,) + tuple(param_index)
donations[key].append((param_number, param_index, arg_index))
else:
param_number = arg_index
donations[key].append((param_number, param_index, arg_index))
# Consume donations for outputs.
out_donated_args = list(donated_args)
for output_index, element in flatten_shape(out_shape):
key = (element.dimensions(), element.xla_element_type())
Add support for buffer donation in `jit` and `pmap`. (#2936) For a computation of the form: >>> f = lambda x: x ** 2 >>> f = jax.jit(f) >>> while run: ... x = f(x) JAX must currently always have two copies of `x` in device memory since there is no reliable way in Python to determine whether there will be future uses of `x`. This causes two classes of problem: 1. Users at the limit of available device are constrained by the additional copy of their parameters and other state while they typically only require one copy. This typically frees 100M+ of device memory and is a critical optimization for larger models to match state of the art performance in other frameworks. 2. This constant alloc/free of the input/output buffers can cause memory fragmentation on some platforms (although having a reusing allocator and limiting run-ahead may be a better solution for this problem). We propose fixing this by using input/output aliasing as supported by XLA. We will support this in JAX by allowing certain arguments of jit/pmap decorated functions to be donated and reused as outputs: >>> f = lambda x: x ** 2 >>> f = jit(f, donate_argnums=0) >>> while run: ... x = f(x) JAX will determine that the donated input `x` can alias with the output of the function and it will instruct XLA it _must_ write the result to this buffer. If a user tries to reuse a buffer after it has been donated they get an error that the buffer is invalid: >>> y = f(x) >>> jax.device_get(x) ... RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer. The semantics of `donate_argnums` follows that of `static_argnums`, namely that it identifies positional arguments to the computation that are to be donated to the computation and used as part of the output. One feature that is also enabled by this is invalidating buffers that should only be used once, for example PRNGKeys: >>> @partial(jit, donate_argnums=0) ... def move(x): ... # Do something complex enough for JAX to just optimize it away. ... return tree_map(lambda x: x + x - x, x) >>> def safe_eager_uniform(key, *a, **k): ... assert hasattr(key, 'device_buffer'), "random must run eagerly" ... key = move(key) ... return jax.random.uniform(key, *a, **k) This is not a complete answer to random safety since it is still possible to reuse a key as part of a traced computation, however it can be used to support this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
if donations.get(key, ()):
param_number, param_index, arg_index = donations[key].popleft()
out_donated_args[arg_index] = False
c.setup_alias(output_index, param_number, param_index)
return tuple(out_donated_args)
2021-03-02 09:42:12 -08:00
xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
xla_call = xla_call_p.bind
2018-11-17 18:03:33 -08:00
def _xla_call_partial_eval_update_params(params, kept_inputs, num_new_inputs):
donated_invars = params['donated_invars']
if not kept_inputs and donated_invars:
# JaxprTrace.post_process_call creates a call with no input tracers
donated_invars = (False,) * num_new_inputs
else:
assert len(kept_inputs) == len(donated_invars)
# JaxprTrace.process_call drops known input tracers
donated_invars = [d for d, kept in zip(donated_invars, kept_inputs) if kept]
# Any new inputs are prepended to the left, so mark those as not donated.
donated_invars = [False] * num_new_inputs + donated_invars
return dict(params, donated_invars=tuple(donated_invars))
pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params
def _xla_call_jvp_update_params(params, nz_tangents):
donated_invars = params['donated_invars']
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
new_donated_invars = (*donated_invars, *donated_tangents)
return dict(params, donated_invars=new_donated_invars)
ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params
def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
donated_invars = params['donated_invars']
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
donated_cotangents = [False for nz in nonzero_cts if nz]
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
def _xla_call_translation_rule(ctx, avals_in, avals_out, *in_nodes, name,
backend=None, call_jaxpr, donated_invars,
inline=None, device=None):
del device, donated_invars, inline # Ignored.
c = ctx.builder
check_backend_matches(backend, ctx.platform)
subc = xc.XlaBuilder(f"jit_{name}")
args = [parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
sub_ctx = ctx.replace(
builder=subc,
name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'jit')))
out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
if len(out_nodes) == 1:
subc = subc.Build(out_nodes[0])
return [xops.Call(c, subc, list(in_nodes))]
else:
subc = subc.Build(xops.Tuple(subc, out_nodes))
return xla_destructure(c, xops.Call(c, subc, list(in_nodes)))
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
2018-11-17 18:03:33 -08:00
2021-08-06 11:09:29 -07:00
def _xla_call_partial_eval_custom_params_updater(
unks_in: Sequence[bool],
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
num_res: int, params_known: dict, params_staged: dict
2021-08-06 11:09:29 -07:00
) -> Tuple[dict, dict]:
2021-08-25 20:46:11 -07:00
# pruned inputs to jaxpr_known according to unks_in, so prune donated_invars
donated_invars_known, _ = partition_list(unks_in, params_known['donated_invars'])
new_params_known = dict(params_known, donated_invars=tuple(donated_invars_known))
# added num_res new inputs to jaxpr_staged, so extend donated_invars
donated_invars_staged = [*([False] * num_res), *params_staged['donated_invars']]
new_params_staged = dict(params_staged, donated_invars=tuple(donated_invars_staged))
return new_params_known, new_params_staged
2021-08-06 11:09:29 -07:00
pe.partial_eval_jaxpr_custom_rules[xla_call_p] = \
partial(pe.call_partial_eval_custom_rule, 'call_jaxpr',
2021-08-06 11:09:29 -07:00
_xla_call_partial_eval_custom_params_updater)
pe.dce_rules[xla_call_p] = pe.dce_jaxpr_call_rule
pe.padding_rules[xla_call_p] = partial(pe.call_padding_rule, xla_call_p)
2021-08-06 11:09:29 -07:00
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext,
settings: core.JaxprPpSettings,
) -> List[pp.Doc]:
printed_params = {k:v for k, v in eqn.params.items() if
k == 'call_jaxpr' or k == 'name' or
k == 'backend' and v is not None or
k == 'device' and v is not None or
k == 'donated_invars' and any(v)}
return [pp.text(eqn.primitive.name),
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
pp.text(" ") + core.pp_vars(eqn.invars, context)]
core.pp_eqn_rules[xla_call_p] = _pp_xla_call
### translation tables
2018-11-17 18:03:33 -08:00
MYPY = False
if not MYPY:
class TranslationRule(Protocol):
def __call__(self, ctx: TranslationContext,
avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
*args: XlaOp, **kw
) -> Sequence[XlaOp]:
"""A translation rule lowers a primitive invocation into an XLA HLO."""
else:
TranslationRule = Any
_translations: Dict[core.Primitive, TranslationRule] = {}
_backend_specific_translations: Dict[str, Dict[core.Primitive, TranslationRule]]
_backend_specific_translations = defaultdict(dict)
_collective_primitives: Set[core.Primitive] = set()
_initial_style_primitives: Set[core.Primitive] = set()
def register_initial_style_primitive(prim: core.Primitive):
_initial_style_primitives.add(prim)
def register_collective_primitive(prim: core.Primitive):
_collective_primitives.add(prim)
def register_translation(prim: core.Primitive, rule: TranslationRule, *,
platform: Optional[str] = None) -> None:
ts = (_translations if platform is None
else _backend_specific_translations[platform])
ts[prim] = rule
# As a temporary backward compatibility measure, we use an adapter class to
# convert from the old styles of translation rules to the newer ones.
# TODO(phawkins): update users of the older translation rule styles and remove
# the adapters.
class _TranslationRuleAdapter:
def __init__(self, translations,
wrap_fn: Callable[[core.Primitive, Callable], TranslationRule]):
self._translations = translations
self._wrap_fn = wrap_fn
def __setitem__(self, key: core.Primitive, value: Callable):
self._translations[key] = self._wrap_fn(key, value)
def _wrap_old_translation(prim: core.Primitive, f: Callable) -> TranslationRule:
@functools.wraps(f)
def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
*args: XlaOp, **kw) -> Sequence[XlaOp]:
ans = f(ctx.builder, *args, **kw)
if (prim.multiple_results or
any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)):
return xla_destructure(ctx.builder, ans)
else:
return [ans]
return wrapped
def _wrap_old_call_translation(prim: core.Primitive,
f: Callable) -> TranslationRule:
@functools.wraps(f)
def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
*args: XlaOp, **kw) -> Sequence[XlaOp]:
platform = kw.pop("backend", None)
check_backend_matches(platform, ctx.platform)
ans = f(ctx.builder, ctx.axis_env, args, ctx.name_stack,
backend=ctx.platform, **kw)
if (prim.multiple_results or
any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)):
return xla_destructure(ctx.builder, ans)
else:
return [ans]
return wrapped
translations : _TranslationRuleAdapter
translations = _TranslationRuleAdapter(_translations, _wrap_old_translation)
class _BackendSpecificTranslationsAdapter(defaultdict):
def __missing__(self, key):
ret = self[key] = _TranslationRuleAdapter(
_backend_specific_translations[key], _wrap_old_translation)
return ret
2018-11-17 18:03:33 -08:00
backend_specific_translations: Dict[str, _TranslationRuleAdapter]
backend_specific_translations = _BackendSpecificTranslationsAdapter()
call_translations : _TranslationRuleAdapter
call_translations = _TranslationRuleAdapter(
_translations, _wrap_old_call_translation)
register_translation(xla_call_p, _xla_call_translation_rule)
2018-11-17 18:03:33 -08:00
def zeros_like_translation_rule(c, x):
shape = c.get_shape(x)
change the xla representation of JAX's unit (#2416) * change the xla representation of JAX's unit Previously the representation of JAX's unit value (a sentinel / placeholder) was an empty tuple, but by changing the representation to something else we can further reduce our dependence on runtime tuples. This commit makes the representation fairly easy to change. There are three functions in xla.py that define the representation. Here are versions that would keep the old XLA representation as an empty tuple: ``` def _make_unit(c): return c.Tuple() def _make_abstract_unit(_): return xc.Shape.tuple_shape(()) def _device_put_unit(_, device): return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device)) ``` The new representation is as a trivial array. An alternative representation would be nothing at all: we don't need to generate XLA computations that have representations of JAX units. While that alterntaive is probably the best choice, it seemed like it would require a bit more refactoring/bookkeeping (e.g. to allow XLA computations to have a smaller number of outputs than the corresponding JAX function), and would also mean the XLA representation would be a step further removed from the jaxpr representation. So I stuck with a trivial array for now. The mapping from JAX types to XLA types need not be invertible. However, XLA translation rules currently don't take as arguments the corresponding JAX types (abstract values), and there were a few cases where we relied on checking whether an argument's XLA type was that of an empty tuple so as to determine if we were effectively operating on a JAX unit. In particular, the AD-related primitive add_jaxvals_p could in principle add two units, and get lowered to an XLA addition on the unit representation. Previously, the translation rule for add_jaxvals_p checked the XLA type so that adding two empty tuples didn't produce any XLA operation; now it adds its inputs, and so if unit is represented as a trivial array we could be inserting trivial scalar adds where we had none before. However, if that case is ever possible, it doesn't come up in our tests (which I checked by keeping the representation as an empty tuple and then asserting an XLA tuple type is never seen by that translation rule). * add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
assert not shape.is_tuple()
zero = xops.Constant(c, np.array(0, shape.element_type()))
return xops.Broadcast(zero, shape.dimensions())
translations[ad_util.zeros_like_p] = zeros_like_translation_rule
2018-11-17 18:03:33 -08:00
def add_jaxvals_translation_rule(c, x, y):
shape = c.get_shape(x)
change the xla representation of JAX's unit (#2416) * change the xla representation of JAX's unit Previously the representation of JAX's unit value (a sentinel / placeholder) was an empty tuple, but by changing the representation to something else we can further reduce our dependence on runtime tuples. This commit makes the representation fairly easy to change. There are three functions in xla.py that define the representation. Here are versions that would keep the old XLA representation as an empty tuple: ``` def _make_unit(c): return c.Tuple() def _make_abstract_unit(_): return xc.Shape.tuple_shape(()) def _device_put_unit(_, device): return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device)) ``` The new representation is as a trivial array. An alternative representation would be nothing at all: we don't need to generate XLA computations that have representations of JAX units. While that alterntaive is probably the best choice, it seemed like it would require a bit more refactoring/bookkeeping (e.g. to allow XLA computations to have a smaller number of outputs than the corresponding JAX function), and would also mean the XLA representation would be a step further removed from the jaxpr representation. So I stuck with a trivial array for now. The mapping from JAX types to XLA types need not be invertible. However, XLA translation rules currently don't take as arguments the corresponding JAX types (abstract values), and there were a few cases where we relied on checking whether an argument's XLA type was that of an empty tuple so as to determine if we were effectively operating on a JAX unit. In particular, the AD-related primitive add_jaxvals_p could in principle add two units, and get lowered to an XLA addition on the unit representation. Previously, the translation rule for add_jaxvals_p checked the XLA type so that adding two empty tuples didn't produce any XLA operation; now it adds its inputs, and so if unit is represented as a trivial array we could be inserting trivial scalar adds where we had none before. However, if that case is ever possible, it doesn't come up in our tests (which I checked by keeping the representation as an empty tuple and then asserting an XLA tuple type is never seen by that translation rule). * add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
assert not shape.is_tuple()
return xops.Add(x, y)
translations[ad_util.add_jaxvals_p] = add_jaxvals_translation_rule
2018-11-17 18:03:33 -08:00
translations[ad_util.stop_gradient_p] = lambda c, x: x
@lu.transformation
def _tuple_output(*args, **kwargs):
ans = yield args, kwargs
yield (ans,)
def lower_fun(fun: Callable, *, multiple_results: bool, backend=None,
new_style: bool = False) -> Callable:
if new_style:
def f_new(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
avals_out: Optional[Sequence[core.AbstractValue]],
*xla_args: xc.XlaOp,
**params) -> Sequence[xc.XlaOp]:
wrapped_fun = lu.wrap_init(fun, params)
if not multiple_results:
wrapped_fun = _tuple_output(wrapped_fun)
with core.extend_axis_env_nd(zip(ctx.axis_env.names, ctx.axis_env.sizes)):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals_in)
return jaxpr_subcomp(ctx, jaxpr, _xla_consts(ctx.builder, consts),
*xla_args)
return f_new
# TODO(phawkins): migrate dependent code & always use new_style=True.
if backend is None:
# The user didn't specify a backend. This isn't possible with the new style
# API.
backend = "backend_not_specified"
2020-03-16 12:13:25 -07:00
def f(c, *xla_args, **params):
avals = [_array_aval_from_xla_shape(c.get_shape(x)) for x in xla_args]
return f_with_avals(c, avals, xla_args, params)
def f_with_avals(c, avals, xla_args, params):
# parallelism is only supported via the new-style API.
axis_env = AxisEnv(1, (), ())
wrapped_fun = lu.wrap_init(fun, params)
if not multiple_results:
wrapped_fun = _tuple_output(wrapped_fun)
2021-07-21 21:14:40 -07:00
with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)):
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
ctx = TranslationContext(c, backend, axis_env, new_name_stack())
outs = jaxpr_subcomp(ctx, jaxpr, _xla_consts(c, consts), *xla_args)
if (multiple_results or
any(len(aval_to_xla_shapes(v.aval)) > 1 for v in jaxpr.outvars)):
return xops.Tuple(c, outs)
else:
assert len(outs) == 1, outs
return outs[0]
return f
2018-11-17 18:03:33 -08:00
change the xla representation of JAX's unit (#2416) * change the xla representation of JAX's unit Previously the representation of JAX's unit value (a sentinel / placeholder) was an empty tuple, but by changing the representation to something else we can further reduce our dependence on runtime tuples. This commit makes the representation fairly easy to change. There are three functions in xla.py that define the representation. Here are versions that would keep the old XLA representation as an empty tuple: ``` def _make_unit(c): return c.Tuple() def _make_abstract_unit(_): return xc.Shape.tuple_shape(()) def _device_put_unit(_, device): return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device)) ``` The new representation is as a trivial array. An alternative representation would be nothing at all: we don't need to generate XLA computations that have representations of JAX units. While that alterntaive is probably the best choice, it seemed like it would require a bit more refactoring/bookkeeping (e.g. to allow XLA computations to have a smaller number of outputs than the corresponding JAX function), and would also mean the XLA representation would be a step further removed from the jaxpr representation. So I stuck with a trivial array for now. The mapping from JAX types to XLA types need not be invertible. However, XLA translation rules currently don't take as arguments the corresponding JAX types (abstract values), and there were a few cases where we relied on checking whether an argument's XLA type was that of an empty tuple so as to determine if we were effectively operating on a JAX unit. In particular, the AD-related primitive add_jaxvals_p could in principle add two units, and get lowered to an XLA addition on the unit representation. Previously, the translation rule for add_jaxvals_p checked the XLA type so that adding two empty tuples didn't produce any XLA operation; now it adds its inputs, and so if unit is represented as a trivial array we could be inserting trivial scalar adds where we had none before. However, if that case is ever possible, it doesn't come up in our tests (which I checked by keeping the representation as an empty tuple and then asserting an XLA tuple type is never seen by that translation rule). * add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
def _array_aval_from_xla_shape(xla_shape):
# This function instantiates the assumption that we can map fro XLA array
# types to JAX array types.
# TODO(mattjj): remove assumption can map XLA array types to JAX array types
assert not xla_shape.is_tuple()
return ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
ad.primitive_transposes[core.named_call_p] = partial(ad.call_transpose,
core.named_call_p)
def _named_call_translation_rule(ctx, avals_in, avals_out, *in_nodes,
name="core_call", backend=None, call_jaxpr):
check_backend_matches(backend, ctx.platform)
c = ctx.builder
subc = xc.XlaBuilder(name)
args = [parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
sub_ctx = ctx.replace(builder=subc,
name_stack=extend_name_stack(ctx.name_stack, name))
out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
if len(out_nodes) == 1:
subc = subc.Build(out_nodes[0])
return [xops.Call(c, subc, list(in_nodes))]
else:
subc = subc.Build(xops.Tuple(subc, out_nodes))
return xla_destructure(c, xops.Call(c, subc, list(in_nodes)))
register_translation(core.named_call_p, _named_call_translation_rule)
def _call_translation_rule(ctx, avals_in, avals_out, *in_nodes, backend=None,
call_jaxpr):
return _named_call_translation_rule(
ctx, avals_in, avals_out, *in_nodes, name="core_call", backend=backend,
call_jaxpr=call_jaxpr)
register_translation(core.call_p, _call_translation_rule)