2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
# Lowering of jaxprs into XLA (HLO) computations.
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-11-19 11:36:35 +00:00
|
|
|
from collections import defaultdict, deque
|
2021-10-18 18:06:48 -07:00
|
|
|
import collections.abc
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
import dataclasses
|
|
|
|
import functools
|
2021-11-22 08:22:10 -08:00
|
|
|
from functools import partial
|
2018-11-17 18:03:33 -08:00
|
|
|
import itertools as it
|
2021-11-22 08:22:10 -08:00
|
|
|
import operator
|
2021-07-30 12:37:21 -07:00
|
|
|
import re
|
2021-11-30 14:24:02 -08:00
|
|
|
from typing import (Any, Callable, Deque, Dict, List, NamedTuple, Optional,
|
|
|
|
Sequence, Set, Type, Tuple, Union)
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
from typing_extensions import Protocol
|
2019-01-06 11:59:33 -08:00
|
|
|
|
2020-07-14 13:05:31 -07:00
|
|
|
import numpy as np
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax.config import config
|
|
|
|
from jax import core
|
2021-06-07 14:51:04 -07:00
|
|
|
from jax._src import ad_util
|
2021-11-22 08:22:10 -08:00
|
|
|
from jax._src import device_array
|
2021-04-07 19:35:17 -07:00
|
|
|
from jax._src import dtypes
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax import linear_util as lu
|
2020-11-04 11:54:01 -08:00
|
|
|
from jax._src import source_info_util
|
2021-06-02 06:24:42 -07:00
|
|
|
from jax._src.abstract_arrays import (make_shaped_array, array_types)
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax.core import (ConcreteArray, ShapedArray,
|
2022-02-09 21:57:06 +00:00
|
|
|
Literal, str_eqn_compact, abstract_token)
|
2021-09-24 22:08:42 -04:00
|
|
|
import jax._src.pretty_printer as pp
|
2021-11-24 12:52:08 -08:00
|
|
|
from jax._src import util
|
2021-10-28 11:06:58 -07:00
|
|
|
from jax._src.util import (prod, extend_name_stack, new_name_stack, wrap_name,
|
2021-11-24 07:47:48 -08:00
|
|
|
safe_zip, safe_map, partition_list)
|
2021-09-23 06:33:25 -07:00
|
|
|
from jax._src.lib import xla_client as xc
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax.interpreters import partial_eval as pe
|
|
|
|
from jax.interpreters import ad
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-06-23 09:39:45 -07:00
|
|
|
map, unsafe_map = safe_map, map
|
|
|
|
zip, unsafe_zip = safe_zip, zip
|
|
|
|
|
2020-04-23 18:30:47 -04:00
|
|
|
xe = xc._xla
|
|
|
|
xops = xc._xla.ops
|
|
|
|
|
2020-05-01 10:06:59 +03:00
|
|
|
# Types
|
2021-07-26 13:44:57 +01:00
|
|
|
Backend = xe.Client
|
|
|
|
Device = xc.Device
|
|
|
|
Buffer = xe.Buffer
|
2020-05-01 10:06:59 +03:00
|
|
|
|
2021-07-26 13:44:57 +01:00
|
|
|
XlaOp = xc.XlaOp
|
|
|
|
XlaShape = xc.Shape
|
2021-10-18 13:19:45 -04:00
|
|
|
XlaBuilder = xc.XlaBuilder
|
2021-07-26 13:44:57 +01:00
|
|
|
XlaExecutable = xc.Executable
|
2020-05-08 17:18:11 +03:00
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
# apply_primitive is defined in jax._src.dispatch.
|
|
|
|
apply_primitive: Callable
|
|
|
|
backend_compile: Callable
|
|
|
|
device_put: Callable
|
|
|
|
|
|
|
|
# TODO(phawkins): update code to point to new locations.
|
|
|
|
DeviceArray = device_array.DeviceArray
|
|
|
|
_DeviceArray = device_array._DeviceArray
|
|
|
|
_CppDeviceArray = xe.Buffer
|
|
|
|
make_device_array = device_array.make_device_array
|
2020-07-07 11:03:30 +03:00
|
|
|
|
2021-07-19 19:10:16 +00:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def identity(x): return x
|
|
|
|
|
2020-04-16 15:51:23 +01:00
|
|
|
_scalar_types = dtypes.python_scalar_dtypes.keys()
|
2019-07-27 15:46:14 -07:00
|
|
|
|
change the xla representation of JAX's unit (#2416)
* change the xla representation of JAX's unit
Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.
This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:
```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```
The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.
The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.
In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).
* add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
|
|
|
# unit representation
|
2021-10-19 08:40:15 -07:00
|
|
|
def _make_unit_constant(c): return [
|
|
|
|
xops.Constant(c, np.zeros((), dtype=np.dtype('bool')))]
|
2020-09-24 16:29:57 +01:00
|
|
|
def _make_unit_shape(_): return (xc.Shape.array_shape(np.dtype('bool'), ()),)
|
2021-10-13 13:57:42 -04:00
|
|
|
def _make_array_shape(a: ShapedArray) -> Sequence[XlaShape]:
|
2020-09-24 16:29:57 +01:00
|
|
|
if a.dtype is dtypes.float0:
|
2020-10-01 17:26:07 +00:00
|
|
|
return (xc.Shape.array_shape(np.dtype('bool'), a.shape),)
|
2020-09-24 16:29:57 +01:00
|
|
|
else:
|
|
|
|
return (xc.Shape.array_shape(a.dtype, a.shape),)
|
change the xla representation of JAX's unit (#2416)
* change the xla representation of JAX's unit
Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.
This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:
```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```
The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.
The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.
In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).
* add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
|
|
|
|
2021-07-30 12:37:21 -07:00
|
|
|
def _get_canonical_source_file(frame: source_info_util.Frame):
|
|
|
|
source_file = frame.file_name
|
|
|
|
if config.jax_hlo_source_file_canonicalization_regex:
|
|
|
|
source_file = re.sub(config.jax_hlo_source_file_canonicalization_regex,
|
|
|
|
'', source_file)
|
|
|
|
return source_file
|
|
|
|
|
[jax2tf] Add support for generating HLO OpMetadata in the TF graph
The goal is to ensure that the HLO that
jax2tf->TF/XLA generates has the same metadata as what JAX generates.
This includes `op_type`, `op_name`, and source information, which are
used for debugging and profiling.
In order to ensure that this metadata is carried from the JAX tracing
time to TF/XLA, we save the metadata in custom TF op attributes. These
attributes are automatically preserved through SavedModel. This relies
on a separate change in TF/XLA to look for these custom attributes
and override its default.
For the source information, we use pretty much the same code that
xla.py uses. HLO OpMetadata has room for only one source location.
JAX (xla.py) picks the top-most user frame, which is obtained by
filtering out the stack frames in the JAX source tree. When used
with jax2tf we also need to filter out stack frames in the
TensorFlow source tree.
The hardest part is to generate the `op_name`, which is a hierarchical
name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`.
We carry the current `name_stack` in thread-local state. Unfortunately, there
is no easy way to share the exact code that achieves this in xla.py. At the
same time it is not crucial that we have exactly identical name stacks as in
JAX.
I attempted to also carry this state in the JAX `MainTrace`, but could not
fully control the name stack. E.g., when calling a jitted-function we
have to reuse the current `MainTrace` although we want to push an element
on the name stack.
For now this option is not yet enabled until we make the necessary
changes in TensorFlow.
2021-05-25 13:33:35 +02:00
|
|
|
tracebacks = {}
|
|
|
|
def make_op_metadata(primitive: core.Primitive,
|
|
|
|
params: Dict, *,
|
2021-10-29 15:49:31 -07:00
|
|
|
source_info: source_info_util.SourceInfo,
|
2021-10-28 11:06:58 -07:00
|
|
|
name_stack: Union[str, source_info_util.NameStack] = "",
|
[jax2tf] Add support for generating HLO OpMetadata in the TF graph
The goal is to ensure that the HLO that
jax2tf->TF/XLA generates has the same metadata as what JAX generates.
This includes `op_type`, `op_name`, and source information, which are
used for debugging and profiling.
In order to ensure that this metadata is carried from the JAX tracing
time to TF/XLA, we save the metadata in custom TF op attributes. These
attributes are automatically preserved through SavedModel. This relies
on a separate change in TF/XLA to look for these custom attributes
and override its default.
For the source information, we use pretty much the same code that
xla.py uses. HLO OpMetadata has room for only one source location.
JAX (xla.py) picks the top-most user frame, which is obtained by
filtering out the stack frames in the JAX source tree. When used
with jax2tf we also need to filter out stack frames in the
TensorFlow source tree.
The hardest part is to generate the `op_name`, which is a hierarchical
name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`.
We carry the current `name_stack` in thread-local state. Unfortunately, there
is no easy way to share the exact code that achieves this in xla.py. At the
same time it is not crucial that we have exactly identical name stacks as in
JAX.
I attempted to also carry this state in the JAX `MainTrace`, but could not
fully control the name stack. E.g., when calling a jitted-function we
have to reuse the current `MainTrace` although we want to push an element
on the name stack.
For now this option is not yet enabled until we make the necessary
changes in TensorFlow.
2021-05-25 13:33:35 +02:00
|
|
|
) -> xc.OpMetadata:
|
2021-10-28 11:06:58 -07:00
|
|
|
if config.jax_experimental_name_stack:
|
|
|
|
eqn_str = str(source_info.name_stack) + '/' + str_eqn_compact(primitive.name, params)
|
|
|
|
else:
|
|
|
|
assert isinstance(name_stack, str)
|
|
|
|
eqn_str = name_stack + str_eqn_compact(primitive.name, params)
|
2021-10-29 15:49:31 -07:00
|
|
|
tracebacks[eqn_str] = source_info.traceback
|
2021-10-28 11:06:58 -07:00
|
|
|
frame = source_info_util.user_frame(source_info)
|
[jax2tf] Add support for generating HLO OpMetadata in the TF graph
The goal is to ensure that the HLO that
jax2tf->TF/XLA generates has the same metadata as what JAX generates.
This includes `op_type`, `op_name`, and source information, which are
used for debugging and profiling.
In order to ensure that this metadata is carried from the JAX tracing
time to TF/XLA, we save the metadata in custom TF op attributes. These
attributes are automatically preserved through SavedModel. This relies
on a separate change in TF/XLA to look for these custom attributes
and override its default.
For the source information, we use pretty much the same code that
xla.py uses. HLO OpMetadata has room for only one source location.
JAX (xla.py) picks the top-most user frame, which is obtained by
filtering out the stack frames in the JAX source tree. When used
with jax2tf we also need to filter out stack frames in the
TensorFlow source tree.
The hardest part is to generate the `op_name`, which is a hierarchical
name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`.
We carry the current `name_stack` in thread-local state. Unfortunately, there
is no easy way to share the exact code that achieves this in xla.py. At the
same time it is not crucial that we have exactly identical name stacks as in
JAX.
I attempted to also carry this state in the JAX `MainTrace`, but could not
fully control the name stack. E.g., when calling a jitted-function we
have to reuse the current `MainTrace` although we want to push an element
on the name stack.
For now this option is not yet enabled until we make the necessary
changes in TensorFlow.
2021-05-25 13:33:35 +02:00
|
|
|
return xc.OpMetadata(
|
|
|
|
op_type=primitive.name,
|
2021-09-24 22:08:42 -04:00
|
|
|
op_name=eqn_str,
|
2021-07-30 12:37:21 -07:00
|
|
|
source_file=_get_canonical_source_file(frame) if frame else None,
|
[jax2tf] Add support for generating HLO OpMetadata in the TF graph
The goal is to ensure that the HLO that
jax2tf->TF/XLA generates has the same metadata as what JAX generates.
This includes `op_type`, `op_name`, and source information, which are
used for debugging and profiling.
In order to ensure that this metadata is carried from the JAX tracing
time to TF/XLA, we save the metadata in custom TF op attributes. These
attributes are automatically preserved through SavedModel. This relies
on a separate change in TF/XLA to look for these custom attributes
and override its default.
For the source information, we use pretty much the same code that
xla.py uses. HLO OpMetadata has room for only one source location.
JAX (xla.py) picks the top-most user frame, which is obtained by
filtering out the stack frames in the JAX source tree. When used
with jax2tf we also need to filter out stack frames in the
TensorFlow source tree.
The hardest part is to generate the `op_name`, which is a hierarchical
name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`.
We carry the current `name_stack` in thread-local state. Unfortunately, there
is no easy way to share the exact code that achieves this in xla.py. At the
same time it is not crucial that we have exactly identical name stacks as in
JAX.
I attempted to also carry this state in the JAX `MainTrace`, but could not
fully control the name stack. E.g., when calling a jitted-function we
have to reuse the current `MainTrace` although we want to push an element
on the name stack.
For now this option is not yet enabled until we make the necessary
changes in TensorFlow.
2021-05-25 13:33:35 +02:00
|
|
|
source_line=frame.line_num if frame else None)
|
|
|
|
|
2021-11-30 14:24:02 -08:00
|
|
|
# Utilities
|
|
|
|
|
|
|
|
def parameter(builder, num, shape, name=None, replicated=None):
|
|
|
|
if name is None:
|
|
|
|
name = ''
|
|
|
|
if replicated is None:
|
|
|
|
replicated = []
|
|
|
|
elif isinstance(replicated, bool):
|
|
|
|
replicated = [replicated] * shape.leaf_count()
|
|
|
|
|
|
|
|
return xops.Parameter(builder, num,
|
|
|
|
shape.with_major_to_minor_layout_if_absent(), name,
|
|
|
|
replicated)
|
|
|
|
|
|
|
|
# HLO instructions optionally can be annotated to say how the output should be
|
|
|
|
# spatially partitioned (represented in XLA as OpSharding protos, see
|
2021-12-02 16:24:02 -08:00
|
|
|
# sharding_to_proto). For array outputs, the annotation is either an int per
|
2021-11-30 14:24:02 -08:00
|
|
|
# dimension specifying the number of ways that dimension divided (i.e. the total
|
|
|
|
# number of shards is the product), or None to indicate the array should be
|
|
|
|
# replicated. Tuple outputs are represented as tuples thereof. XLA supports
|
|
|
|
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
|
|
|
|
# checkers don't support recursive types), so we only represent one level of
|
|
|
|
# nesting in this type definition.
|
|
|
|
SpatialSharding = Union[Tuple[int, ...],
|
|
|
|
None,
|
2021-12-02 16:24:02 -08:00
|
|
|
Tuple[Optional[Tuple[int, ...]], ...]]
|
2021-11-30 14:24:02 -08:00
|
|
|
|
2021-12-02 16:24:02 -08:00
|
|
|
def sharding_to_proto(sharding: SpatialSharding):
|
2021-11-30 14:24:02 -08:00
|
|
|
"""Converts a SpatialSharding to an OpSharding.
|
|
|
|
|
|
|
|
See
|
|
|
|
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
|
|
|
|
for details on the OpSharding proto.
|
|
|
|
"""
|
|
|
|
proto = xc.OpSharding()
|
|
|
|
if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
|
|
|
|
assert all(s is None or isinstance(s, tuple) for s in sharding)
|
2021-12-02 16:24:02 -08:00
|
|
|
return tuple_sharding_proto(list(map(sharding_to_proto, sharding))) # type: ignore
|
2021-11-30 14:24:02 -08:00
|
|
|
|
|
|
|
if sharding is None:
|
|
|
|
proto.type = xc.OpSharding.Type.REPLICATED
|
|
|
|
else:
|
|
|
|
proto.type = xc.OpSharding.Type.OTHER
|
|
|
|
proto.tile_assignment_dimensions = list(sharding)
|
2022-01-31 13:39:11 -08:00
|
|
|
proto.tile_assignment_devices = list(range(np.product(sharding))) # type: ignore
|
2021-11-30 14:24:02 -08:00
|
|
|
return proto
|
|
|
|
|
|
|
|
def tuple_sharding_proto(elems):
|
|
|
|
proto = xc.OpSharding()
|
|
|
|
assert all(isinstance(e, type(proto)) for e in elems)
|
|
|
|
proto.type = xc.OpSharding.Type.TUPLE
|
|
|
|
proto.tuple_shardings = elems
|
|
|
|
return proto
|
|
|
|
|
2022-01-07 12:01:32 -08:00
|
|
|
|
|
|
|
def set_sharding_proto(builder, op, sharding_proto, unspecified_dims=None):
|
2021-11-30 14:24:02 -08:00
|
|
|
"""Uses CustomCall to annotate a value as sharded."""
|
|
|
|
# "Sharding" is a built-in custom call target that acts like an identity
|
|
|
|
# function, and is used to attach an OpSharding to.
|
2022-01-07 12:01:32 -08:00
|
|
|
def _create_custom_call(x):
|
|
|
|
# unspecified_dims indicate dimensions whose shardings are not specified and
|
|
|
|
# XLA sharding propagation can change them.
|
|
|
|
if unspecified_dims:
|
|
|
|
opaque = 'unspecified_dims=[' + ','.join(
|
|
|
|
[str(i) for i in unspecified_dims]) + ']'
|
|
|
|
opaque = bytes(opaque, 'utf-8')
|
|
|
|
return xops.CustomCall(
|
|
|
|
builder, b'Sharding', [x], builder.get_shape(x), opaque=opaque)
|
|
|
|
else:
|
|
|
|
return xops.CustomCall(builder, b'Sharding', [x], builder.get_shape(x))
|
|
|
|
|
|
|
|
return with_sharding_proto(builder, sharding_proto, _create_custom_call, op)
|
|
|
|
|
2021-11-30 14:24:02 -08:00
|
|
|
|
|
|
|
def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
|
|
|
|
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
|
|
|
builder.set_sharding(sharding_proto)
|
|
|
|
try:
|
|
|
|
return op_fn(*args, **kwargs)
|
|
|
|
finally:
|
|
|
|
builder.clear_sharding()
|
|
|
|
|
2022-01-13 10:34:45 -08:00
|
|
|
def set_sharding(builder, op, sharding: SpatialSharding, unspecified_dims=None):
|
2021-11-30 14:24:02 -08:00
|
|
|
"""Uses CustomCall to annotate a value as sharded."""
|
2022-01-13 10:34:45 -08:00
|
|
|
return set_sharding_proto(builder, op, sharding_to_proto(sharding),
|
|
|
|
unspecified_dims)
|
2021-11-30 14:24:02 -08:00
|
|
|
|
|
|
|
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
|
|
|
|
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
|
2021-12-02 16:24:02 -08:00
|
|
|
return with_sharding_proto(builder, sharding_to_proto(sharding), op_fn, *args,
|
|
|
|
**kwargs)
|
2021-11-30 14:24:02 -08:00
|
|
|
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
### handlers
|
|
|
|
|
2021-10-19 08:40:15 -07:00
|
|
|
# Numpy dtypes -> XLA primitive types
|
|
|
|
|
2021-10-19 06:48:34 -07:00
|
|
|
_dtype_to_primitive_type: Dict[np.dtype, xc.PrimitiveType] = {
|
|
|
|
np.dtype('bool'): xc.PrimitiveType.PRED,
|
|
|
|
np.dtype('int8'): xc.PrimitiveType.S8,
|
|
|
|
np.dtype('int16'): xc.PrimitiveType.S16,
|
|
|
|
np.dtype('int32'): xc.PrimitiveType.S32,
|
|
|
|
np.dtype('int64'): xc.PrimitiveType.S64,
|
|
|
|
np.dtype('uint8'): xc.PrimitiveType.U8,
|
|
|
|
np.dtype('uint16'): xc.PrimitiveType.U16,
|
|
|
|
np.dtype('uint32'): xc.PrimitiveType.U32,
|
|
|
|
np.dtype('uint64'): xc.PrimitiveType.U64,
|
|
|
|
np.dtype(dtypes.bfloat16): xc.PrimitiveType.BF16,
|
|
|
|
np.dtype('float16'): xc.PrimitiveType.F16,
|
|
|
|
np.dtype('float32'): xc.PrimitiveType.F32,
|
|
|
|
np.dtype('float64'): xc.PrimitiveType.F64,
|
|
|
|
np.dtype('complex64'): xc.PrimitiveType.C64,
|
|
|
|
np.dtype('complex128'): xc.PrimitiveType.C128,
|
|
|
|
}
|
|
|
|
|
|
|
|
def dtype_to_primitive_type(dtype: np.dtype) -> xc.PrimitiveType:
|
|
|
|
"""Converts a NumPy dtype into an XLA PrimitiveType."""
|
|
|
|
# Many things (e.g., strings, scalar types) can be compared with NumPy dtypes,
|
|
|
|
# but may not hash correctly. Make sure we have a true np.dtype.
|
|
|
|
assert isinstance(dtype, np.dtype), type(dtype)
|
|
|
|
try:
|
|
|
|
return _dtype_to_primitive_type[dtype]
|
|
|
|
except KeyError as err:
|
|
|
|
raise TypeError(f"No XLA lowering for NumPy dtype: {dtype}") from err
|
|
|
|
|
2021-10-19 08:40:15 -07:00
|
|
|
|
|
|
|
# JAX abstract values -> XLA shapes
|
2019-07-27 15:46:14 -07:00
|
|
|
|
2021-10-13 13:57:42 -04:00
|
|
|
def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[XlaShape]:
|
2019-07-27 15:46:14 -07:00
|
|
|
try:
|
|
|
|
return xla_shape_handlers[type(aval)](aval)
|
2020-03-09 22:06:12 +02:00
|
|
|
except KeyError as err:
|
2020-04-16 15:51:23 +01:00
|
|
|
raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err
|
2020-03-21 13:54:30 +01:00
|
|
|
|
2021-10-13 13:57:42 -04:00
|
|
|
xla_shape_handlers: Dict[Type[core.AbstractValue],
|
|
|
|
Callable[[Any], Sequence[XlaShape]]] = {
|
2020-09-24 16:29:57 +01:00
|
|
|
core.AbstractUnit: _make_unit_shape,
|
2020-04-16 15:51:23 +01:00
|
|
|
ShapedArray: _make_array_shape,
|
|
|
|
ConcreteArray: _make_array_shape,
|
|
|
|
}
|
2021-11-22 08:22:10 -08:00
|
|
|
xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
|
2019-07-27 15:46:14 -07:00
|
|
|
|
2021-10-19 08:40:15 -07:00
|
|
|
|
|
|
|
|
|
|
|
# IR constants
|
|
|
|
|
|
|
|
_constant_handlers: Dict[type, Callable] = {}
|
|
|
|
|
|
|
|
def pyval_to_ir_constants(builder, py_val, canonicalize_types=True):
|
|
|
|
"""Translate a general constant `py_val` to a constant, canonicalizing its dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
py_val: a Python value to be translated to a constant.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A representation of the constant as a list of xla ops.
|
|
|
|
"""
|
2021-12-14 15:35:43 -08:00
|
|
|
for t in type(py_val).__mro__:
|
2021-10-19 08:40:15 -07:00
|
|
|
handler = _constant_handlers.get(t)
|
|
|
|
if handler: return handler(builder, py_val, canonicalize_types)
|
|
|
|
if hasattr(py_val, '__jax_array__'):
|
|
|
|
return pyval_to_ir_constants(builder, py_val.__jax_array__(),
|
|
|
|
canonicalize_types)
|
|
|
|
raise TypeError("No constant handler for type: {}".format(type(py_val)))
|
|
|
|
|
|
|
|
def pyval_to_ir_constant(builder, py_val, canonicalize_types=True):
|
|
|
|
"""Translate constant `py_val` to a constant, canonicalizing its dtype.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
py_val: a Python value to be translated to a constant.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A representation of the constant, either a ComputationDataHandle or None
|
|
|
|
"""
|
|
|
|
const = pyval_to_ir_constants(builder, py_val, canonicalize_types=canonicalize_types)
|
|
|
|
assert len(const) == 1, f"Internal error: cannot create constant from object of type {type(py_val)}"
|
|
|
|
return const[0]
|
|
|
|
|
|
|
|
|
|
|
|
def register_constant_handler(type_, handler_fun):
|
|
|
|
_constant_handlers[type_] = handler_fun
|
|
|
|
|
|
|
|
register_constant_handler(core.Unit, lambda c, *_: _make_unit_constant(c))
|
|
|
|
|
|
|
|
|
|
|
|
# TODO(mattjj,frostig): try to remove this function
|
|
|
|
def _normalize_to_xla_dtypes(val):
|
|
|
|
"""Normalize dtypes in a value."""
|
|
|
|
if hasattr(val, '__array__') or np.isscalar(val):
|
|
|
|
return np.asarray(val, dtype=dtypes.canonicalize_dtype(dtypes.result_type(val)))
|
|
|
|
elif isinstance(val, (tuple, list)):
|
|
|
|
return tuple(_normalize_to_xla_dtypes(x) for x in val)
|
|
|
|
raise TypeError('Can\'t convert to XLA: {}'.format(val))
|
|
|
|
|
|
|
|
def _numpy_array_constant(builder, value, canonicalize_types=True):
|
|
|
|
if canonicalize_types:
|
|
|
|
value = _normalize_to_xla_dtypes(value)
|
|
|
|
return [xops.Constant(builder, value)]
|
|
|
|
|
|
|
|
|
|
|
|
def _ndarray_constant_handler(c, val, canonicalize_types=True):
|
|
|
|
"""Constant handler for ndarray literals, handling zero-size strides.
|
|
|
|
|
|
|
|
This function essentially calls _numpy_array_constant(val) except it has
|
|
|
|
special handling of arrays with any strides of size zero: for those, it
|
|
|
|
generates appropriate calls to NumpyArrayConstant, Broadcast, and Transpose
|
|
|
|
to avoid staging in large literals that might arise from np.zeros or np.ones
|
|
|
|
or the output of lax.broadcast (which uses np.broadcast_to which in turn
|
|
|
|
uses size-zero strides).
|
|
|
|
|
|
|
|
Args:
|
|
|
|
c: an XlaBuilder
|
|
|
|
val: an ndarray.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An XLA ComputationDataHandle / XlaOp representing the constant ndarray
|
|
|
|
staged into the XLA Computation.
|
|
|
|
"""
|
|
|
|
# TODO(mattjj): revise this to use xops.BroadcastInDim rather than Transpose
|
|
|
|
if dtypes.result_type(val) == dtypes.float0:
|
|
|
|
return _numpy_array_constant(c, np.zeros(val.shape, dtype=np.bool_))
|
|
|
|
elif np.any(np.equal(0, val.strides)) and val.size > 0:
|
|
|
|
zero_stride_axes, = np.where(np.equal(0, val.strides))
|
|
|
|
other_axes, = np.where(np.not_equal(0, val.strides))
|
|
|
|
collapsed_val = val[tuple(0 if ax in zero_stride_axes else slice(None)
|
|
|
|
for ax in range(val.ndim))]
|
|
|
|
xla_val = xops.Broadcast(
|
|
|
|
_numpy_array_constant(c, collapsed_val, canonicalize_types)[0],
|
|
|
|
np.take(val.shape, zero_stride_axes))
|
|
|
|
permutation = np.argsort(tuple(zero_stride_axes) + tuple(other_axes))
|
|
|
|
return [xops.Transpose(xla_val, permutation)]
|
|
|
|
else:
|
|
|
|
return _numpy_array_constant(c, val, canonicalize_types)
|
|
|
|
register_constant_handler(np.ndarray, _ndarray_constant_handler)
|
|
|
|
|
|
|
|
|
|
|
|
def _scalar_constant_handler(c, val, canonicalize_types=True):
|
|
|
|
return _numpy_array_constant(c, val, canonicalize_types)
|
|
|
|
|
|
|
|
for scalar_type in [np.int8, np.int16, np.int32, np.int64,
|
|
|
|
np.uint8, np.uint16, np.uint32, np.uint64,
|
|
|
|
np.float16, np.float32, np.float64,
|
|
|
|
np.bool_, np.longlong,
|
|
|
|
dtypes.bfloat16]:
|
|
|
|
register_constant_handler(scalar_type, _scalar_constant_handler)
|
|
|
|
|
|
|
|
# https://github.com/winpython/winpython/issues/613#issuecomment-380121523
|
|
|
|
if hasattr(np, "float128"):
|
|
|
|
register_constant_handler(np.float128, _scalar_constant_handler)
|
|
|
|
|
|
|
|
def _python_scalar_handler(dtype, c, val, canonicalize_dtypes=True):
|
|
|
|
return _numpy_array_constant(c, dtype.type(val))
|
|
|
|
|
|
|
|
for ptype, dtype in dtypes.python_scalar_dtypes.items():
|
|
|
|
register_constant_handler(ptype, partial(_python_scalar_handler, dtype))
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
def _device_array_constant_handler(c, val, canonicalize_types=True):
|
|
|
|
return pyval_to_ir_constants(c, val.device_buffer.to_py())
|
|
|
|
for t in device_array.device_array_types:
|
|
|
|
register_constant_handler(t, _device_array_constant_handler)
|
2021-10-19 08:40:15 -07:00
|
|
|
|
Change scalar promotion rules to prefer array types over scalar types. (#1709)
* Change scalar promotion rules to prefer array types over scalar types.
Currently JAX does not treat Python scalars specially during type promotion. This means that, for example:
`1. + np.array([...], np.float32)`
ends up as an array of type np.float64. The `1.` is promoted to a default type (here np.float64), and the type promotion of a np.float64 and an np.float32 is an np.float64. This is unlike classic NumPy, which treats scalars specially during type promotion, in particular, preferring the type of an array over the type of a scalar.
This change adds a notion of weak_type to JAX avals. During type promotion, we prefer non-weak types, i.e., the type of the array in the example above, ignoring the type of the scalar.
In contexts where a Python scalar is to be promoted to a NumPy value, a default type is used (e.g., `np.float_`). This change also makes it possible to use 32-bit default types that differ from NumPy's default types. The JAX test suite passes with 32-bit default types. However, we do not yet enable this change or expose it in the API.
2019-11-18 14:51:10 -05:00
|
|
|
|
2022-02-14 12:09:29 -05:00
|
|
|
register_constant_handler(core.Token, lambda c, _, __: [xops.CreateToken(c)])
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
# TODO(mattjj): try to remove this canonicalize_dtype stuff
|
|
|
|
def canonicalize_dtype(x):
|
2019-12-11 12:27:11 -05:00
|
|
|
typ = type(x)
|
|
|
|
handler = canonicalize_dtype_handlers.get(typ)
|
|
|
|
if handler: return handler(x)
|
2021-12-14 15:35:43 -08:00
|
|
|
for typ in typ.__mro__:
|
2019-12-11 12:27:11 -05:00
|
|
|
handler = canonicalize_dtype_handlers.get(typ)
|
|
|
|
if handler: return handler(x)
|
2021-02-05 20:30:14 -08:00
|
|
|
if hasattr(x, '__jax_array__'):
|
|
|
|
return canonicalize_dtype(x.__jax_array__())
|
2020-04-16 15:51:23 +01:00
|
|
|
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")
|
2019-12-11 12:27:11 -05:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def _canonicalize_ndarray_dtype(x):
|
2020-07-14 13:05:31 -07:00
|
|
|
return np.asarray(x, dtypes.canonicalize_dtype(dtypes.result_type(x)))
|
2020-04-16 15:51:23 +01:00
|
|
|
|
2019-12-11 12:27:11 -05:00
|
|
|
def _canonicalize_python_scalar_dtype(typ, x):
|
2020-07-14 13:05:31 -07:00
|
|
|
return np.asarray(
|
2021-03-29 09:26:19 -07:00
|
|
|
x, dtypes.canonicalize_dtype(dtypes._scalar_type_to_dtype(typ, x)))
|
2020-04-16 15:51:23 +01:00
|
|
|
|
|
|
|
canonicalize_dtype_handlers: Dict[Any, Callable] = {core.Unit: identity}
|
2021-11-22 08:22:10 -08:00
|
|
|
for t in device_array.device_array_types:
|
|
|
|
canonicalize_dtype_handlers[t] = lambda x: x
|
2020-04-16 15:51:23 +01:00
|
|
|
canonicalize_dtype_handlers.update(
|
|
|
|
(t, _canonicalize_ndarray_dtype) for t in array_types)
|
|
|
|
canonicalize_dtype_handlers.update(
|
|
|
|
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
|
2021-11-22 08:22:10 -08:00
|
|
|
canonicalize_dtype_handlers[core.Token] = lambda x: x
|
2019-07-27 15:46:14 -07:00
|
|
|
|
2020-03-21 13:54:30 +01:00
|
|
|
def abstractify(x) -> core.AbstractValue:
|
2019-12-11 12:27:11 -05:00
|
|
|
typ = type(x)
|
|
|
|
aval_fn = pytype_aval_mappings.get(typ)
|
|
|
|
if aval_fn: return aval_fn(x)
|
2021-12-14 15:35:43 -08:00
|
|
|
for typ in typ.__mro__:
|
2019-12-11 12:27:11 -05:00
|
|
|
aval_fn = pytype_aval_mappings.get(typ)
|
|
|
|
if aval_fn: return aval_fn(x)
|
2021-02-05 20:30:14 -08:00
|
|
|
if hasattr(x, '__jax_array__'):
|
|
|
|
return abstractify(x.__jax_array__())
|
2020-06-01 13:24:40 -07:00
|
|
|
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
|
2019-07-27 15:46:14 -07:00
|
|
|
|
2021-03-29 09:26:19 -07:00
|
|
|
def _make_abstract_python_scalar(typ, val):
|
|
|
|
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val), weak_type=True)
|
2019-12-11 12:27:11 -05:00
|
|
|
|
2020-04-16 15:51:23 +01:00
|
|
|
pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = {
|
|
|
|
core.Unit: lambda _: core.abstract_unit,
|
|
|
|
}
|
2021-11-22 08:22:10 -08:00
|
|
|
for t in device_array.device_array_types:
|
|
|
|
pytype_aval_mappings[t] = operator.attrgetter('aval')
|
|
|
|
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
|
2020-04-16 15:51:23 +01:00
|
|
|
pytype_aval_mappings.update((t, make_shaped_array) for t in array_types)
|
|
|
|
pytype_aval_mappings.update(
|
|
|
|
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
|
implement lazy sublanguage
Before this commit, this computation would avoid materializing the iota
array at trace time:
@jit
def f(x):
m, n = x.shape
return x + np.arange(n)
But this one would materialize the iota array at trace time and stage it
into the computation as a potentially large array constant:
@jit
def f(x):
m, n = x.shape
return x + np.arange(m)[:, None]
The difference is that previously operations like broadcasts,
transposes, and reshapes that add singleton dimensions (as above) would
force otherwise lazy values to be materialized, while after this commit
broadcasts, transposes, and reshapes are all lazy operations that only
update metadata on their input rather than compiling and executing XLA
computations and producing new buffers.
Also, np.eye and np.tri become lazy (in addition to np.zeros, np.ones, np.full).
This commit replaces the ad-hoc "lazy device constant" system, which was
used to get the simpler behavior in the first example above.
Incidentally fixes #1431
See https://github.com/google/jax/pull/1668 for more.
2020-01-03 15:46:19 -08:00
|
|
|
|
2019-12-17 17:49:06 -08:00
|
|
|
|
2021-11-30 05:34:00 -08:00
|
|
|
def primitive_subcomputation(platform: str, axis_env: 'AxisEnv',
|
|
|
|
prim: core.Primitive,
|
2021-11-15 18:26:05 -08:00
|
|
|
*avals: core.AbstractValue, **params):
|
2021-10-18 13:19:45 -04:00
|
|
|
c = xc.XlaBuilder(f"primitive_computation_{prim.name}")
|
2021-11-15 18:26:05 -08:00
|
|
|
f = lower_fun(prim.bind, multiple_results=prim.multiple_results,
|
2021-11-17 07:20:18 -08:00
|
|
|
new_style=True)
|
|
|
|
xla_args, _ = _xla_callable_args(c, avals, tuple_args=False,
|
|
|
|
filter_tokens=False)
|
2021-11-30 05:34:00 -08:00
|
|
|
ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env,
|
2021-10-28 11:06:58 -07:00
|
|
|
name_stack=new_name_stack())
|
2021-11-17 07:20:18 -08:00
|
|
|
ans = f(ctx.replace(builder=c), avals, None, *xla_args, **params)
|
|
|
|
if prim.multiple_results:
|
|
|
|
ans = xops.Tuple(c, ans)
|
|
|
|
else:
|
|
|
|
ans, = ans
|
2021-10-13 10:56:21 -04:00
|
|
|
return c.build(ans)
|
2020-07-24 11:05:40 -07:00
|
|
|
|
2019-12-12 05:14:57 -08:00
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
# Used within _xla_callable_args and _xla_param to distinguish between None (no
|
|
|
|
# sharding annotation set) and replicated.
|
|
|
|
_replicated_param = object()
|
2020-09-29 11:53:17 -07:00
|
|
|
|
2021-11-29 12:39:19 -08:00
|
|
|
def _token_param_shape():
|
|
|
|
"""Shape used in place of tokens as top-level computation arguments."""
|
|
|
|
return xc.Shape.array_shape(np.dtype(np.bool_), [])
|
|
|
|
|
|
|
|
def _make_token_return_value(c):
|
|
|
|
"""Value used in place of tokens as a top-level computation return value."""
|
|
|
|
return xops.Constant(c, np.zeros((), dtype=np.dtype(np.bool_)))
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
def _xla_callable_args(
|
|
|
|
c, avals, tuple_args, *,
|
|
|
|
replicated=None,
|
|
|
|
partitions=None,
|
|
|
|
partitions_proto: bool = False,
|
|
|
|
donated_invars=None,
|
|
|
|
filter_tokens=True):
|
|
|
|
assert partitions is None or len(partitions) == len(avals)
|
|
|
|
if not tuple_args:
|
|
|
|
if replicated is None:
|
|
|
|
replicated = [None] * len(avals)
|
|
|
|
if partitions is None:
|
|
|
|
parts: List[object] = [None] * len(avals)
|
|
|
|
elif partitions_proto:
|
|
|
|
parts = partitions
|
|
|
|
else:
|
|
|
|
parts = [_replicated_param if part is None else part
|
|
|
|
for part in partitions]
|
|
|
|
counts = it.count()
|
2021-11-29 12:39:19 -08:00
|
|
|
xla_args = [_xla_param(c, next(counts), xla_shape, r, p, partitions_proto,
|
|
|
|
filter_tokens)
|
2021-11-22 08:22:10 -08:00
|
|
|
for (a, r, p) in safe_zip(avals, replicated, parts)
|
|
|
|
for xla_shape in aval_to_xla_shapes(a)]
|
|
|
|
if donated_invars is not None:
|
|
|
|
donated_invars = [
|
|
|
|
d for (a, _, _, d) in zip(avals, replicated, parts, donated_invars)
|
|
|
|
for xla_shape in aval_to_xla_shapes(a)]
|
|
|
|
return xla_args, donated_invars
|
|
|
|
else:
|
|
|
|
if replicated is not None:
|
|
|
|
replicated = [r for a, r in zip(avals, replicated)
|
|
|
|
if a is not abstract_token]
|
|
|
|
if partitions is None:
|
|
|
|
tuple_parts = None
|
|
|
|
elif partitions_proto:
|
2021-11-30 14:24:02 -08:00
|
|
|
tuple_parts = tuple_sharding_proto(partitions)
|
2021-11-22 08:22:10 -08:00
|
|
|
else:
|
|
|
|
tuple_parts = tuple(partitions)
|
|
|
|
tuple_shape = xc.Shape.tuple_shape(
|
2021-11-29 12:39:19 -08:00
|
|
|
[shape if not (filter_tokens and a is abstract_token)
|
|
|
|
else _token_param_shape()
|
|
|
|
for a in avals for shape in aval_to_xla_shapes(a)])
|
|
|
|
tuple_param = _xla_param(c, 0, tuple_shape, replicated, tuple_parts,
|
|
|
|
partitions_proto, filter_tokens)
|
|
|
|
xla_args = [v if not (filter_tokens and a is abstract_token)
|
|
|
|
else xops.CreateToken(c)
|
|
|
|
for a, v in zip(avals, xla_destructure(c, tuple_param))]
|
2021-11-22 08:22:10 -08:00
|
|
|
return xla_args, donated_invars
|
2020-02-06 17:19:54 -08:00
|
|
|
|
2021-11-29 12:39:19 -08:00
|
|
|
def _xla_param(builder, param_num, xla_shape, replicated, partitions,
|
|
|
|
parts_proto, filter_tokens):
|
|
|
|
is_token = xla_shape.is_token()
|
|
|
|
if filter_tokens and is_token:
|
|
|
|
xla_shape = _token_param_shape()
|
2021-11-30 14:24:02 -08:00
|
|
|
make_param = partial(parameter, builder, param_num, xla_shape,
|
2021-11-22 08:22:10 -08:00
|
|
|
replicated=replicated)
|
2021-11-30 14:24:02 -08:00
|
|
|
with_sharding_fn = with_sharding_proto if parts_proto else with_sharding
|
2021-11-22 08:22:10 -08:00
|
|
|
if partitions is None:
|
2021-11-29 12:39:19 -08:00
|
|
|
out = make_param()
|
2021-11-22 08:22:10 -08:00
|
|
|
elif partitions is _replicated_param:
|
2021-11-30 14:24:02 -08:00
|
|
|
out = with_sharding_fn(builder, None, make_param)
|
2021-11-22 08:22:10 -08:00
|
|
|
else:
|
2021-11-30 14:24:02 -08:00
|
|
|
out = with_sharding_fn(builder, partitions, make_param)
|
2021-11-29 12:39:19 -08:00
|
|
|
if filter_tokens and is_token:
|
|
|
|
out = xops.CreateToken(builder)
|
|
|
|
return out
|
2019-04-24 21:31:15 -07:00
|
|
|
|
2020-12-18 16:26:31 +00:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
### compiling jaxprs
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
|
2020-09-29 11:53:17 -07:00
|
|
|
def _flatmap(func: Callable, vars: Sequence):
|
|
|
|
return list(it.chain.from_iterable(map(func, vars)))
|
|
|
|
|
|
|
|
def _partitionmap(func: Callable, vars: Sequence, nodes: Sequence):
|
2021-10-13 13:57:42 -04:00
|
|
|
return map(func, vars,
|
2021-11-24 12:52:08 -08:00
|
|
|
util.unflatten(nodes,
|
|
|
|
[len(aval_to_xla_shapes(v.aval)) for v in vars]))
|
2020-09-29 11:53:17 -07:00
|
|
|
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
class AxisEnv(NamedTuple):
|
|
|
|
"""Represents a pmap mesh (only along the replica axes)."""
|
|
|
|
nreps: int
|
|
|
|
names: Tuple[Any, ...]
|
|
|
|
sizes: Tuple[int, ...]
|
|
|
|
|
|
|
|
@dataclasses.dataclass
|
|
|
|
class TranslationContext:
|
|
|
|
builder: xc.XlaBuilder
|
|
|
|
# TODO(phawkins): make platform non-optional. We should always be translating
|
|
|
|
# with a specific platform in mind.
|
|
|
|
platform: Optional[str]
|
|
|
|
axis_env: AxisEnv
|
2021-10-28 11:06:58 -07:00
|
|
|
name_stack: Union[str, source_info_util.NameStack]
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
|
|
|
|
def replace(self, **kw): return dataclasses.replace(self, **kw)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
|
|
|
|
def jaxpr_subcomp(ctx: TranslationContext, jaxpr: core.Jaxpr,
|
|
|
|
consts: Sequence[XlaOp], *args: XlaOp) -> Sequence[XlaOp]:
|
2021-11-15 18:26:05 -08:00
|
|
|
assert ctx.platform is not None
|
2018-11-17 18:03:33 -08:00
|
|
|
def read(v):
|
2019-05-13 08:48:13 -07:00
|
|
|
if type(v) is Literal:
|
2021-10-19 08:40:15 -07:00
|
|
|
return pyval_to_ir_constants(ctx.builder, canonicalize_dtype(v.val))
|
2019-05-13 08:48:13 -07:00
|
|
|
else:
|
|
|
|
return env[v]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
change the xla representation of JAX's unit (#2416)
* change the xla representation of JAX's unit
Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.
This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:
```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```
The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.
The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.
In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).
* add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
|
|
|
def aval(v):
|
|
|
|
if type(v) is Literal:
|
|
|
|
return abstractify(v.val)
|
|
|
|
else:
|
|
|
|
return v.aval
|
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def write(v, node):
|
|
|
|
assert node is not None
|
|
|
|
env[v] = node
|
|
|
|
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
env: Dict[core.Var, Sequence[XlaOp]] = {}
|
2021-10-19 08:40:15 -07:00
|
|
|
_partitionmap(write, [core.unitvar],
|
|
|
|
pyval_to_ir_constants(ctx.builder, core.unit))
|
2020-09-29 11:53:17 -07:00
|
|
|
_partitionmap(write, jaxpr.constvars, consts)
|
|
|
|
_partitionmap(write, jaxpr.invars, args)
|
2018-11-17 18:03:33 -08:00
|
|
|
for eqn in jaxpr.eqns:
|
2021-10-28 11:06:58 -07:00
|
|
|
if config.jax_experimental_name_stack:
|
|
|
|
assert isinstance(ctx.name_stack, source_info_util.NameStack), type(ctx.name_stack)
|
|
|
|
source_info = eqn.source_info.replace(
|
|
|
|
name_stack=ctx.name_stack + eqn.source_info.name_stack)
|
|
|
|
else:
|
|
|
|
source_info = eqn.source_info
|
[jax2tf] Add support for generating HLO OpMetadata in the TF graph
The goal is to ensure that the HLO that
jax2tf->TF/XLA generates has the same metadata as what JAX generates.
This includes `op_type`, `op_name`, and source information, which are
used for debugging and profiling.
In order to ensure that this metadata is carried from the JAX tracing
time to TF/XLA, we save the metadata in custom TF op attributes. These
attributes are automatically preserved through SavedModel. This relies
on a separate change in TF/XLA to look for these custom attributes
and override its default.
For the source information, we use pretty much the same code that
xla.py uses. HLO OpMetadata has room for only one source location.
JAX (xla.py) picks the top-most user frame, which is obtained by
filtering out the stack frames in the JAX source tree. When used
with jax2tf we also need to filter out stack frames in the
TensorFlow source tree.
The hardest part is to generate the `op_name`, which is a hierarchical
name with components separated by '/', e.g., `jax2tf(top_func)/while/cond/le`.
We carry the current `name_stack` in thread-local state. Unfortunately, there
is no easy way to share the exact code that achieves this in xla.py. At the
same time it is not crucial that we have exactly identical name stacks as in
JAX.
I attempted to also carry this state in the JAX `MainTrace`, but could not
fully control the name stack. E.g., when calling a jitted-function we
have to reuse the current `MainTrace` although we want to push an element
on the name stack.
For now this option is not yet enabled until we make the necessary
changes in TensorFlow.
2021-05-25 13:33:35 +02:00
|
|
|
op_metadata = make_op_metadata(
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
eqn.primitive, eqn.params, name_stack=ctx.name_stack,
|
2021-10-28 11:06:58 -07:00
|
|
|
source_info=source_info)
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
ctx.builder.set_op_metadata(op_metadata)
|
2020-09-29 11:53:17 -07:00
|
|
|
in_nodes = _flatmap(read, eqn.invars)
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
if (ctx.platform is not None and
|
|
|
|
eqn.primitive in _backend_specific_translations[ctx.platform]):
|
|
|
|
rule = _backend_specific_translations[ctx.platform][eqn.primitive]
|
|
|
|
elif eqn.primitive in _translations:
|
|
|
|
rule = _translations[eqn.primitive]
|
2020-09-29 11:53:17 -07:00
|
|
|
else:
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
raise NotImplementedError(
|
|
|
|
f"XLA translation rule for primitive '{eqn.primitive.name}' not found")
|
|
|
|
|
2021-10-29 15:49:31 -07:00
|
|
|
with source_info_util.user_context(eqn.source_info.traceback):
|
2021-10-28 11:06:58 -07:00
|
|
|
eqn_ctx = (ctx.replace(name_stack=source_info.name_stack) if
|
|
|
|
config.jax_experimental_name_stack else ctx)
|
|
|
|
ans = rule(eqn_ctx, map(aval, eqn.invars), map(aval, eqn.outvars),
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
*in_nodes, **eqn.params)
|
|
|
|
|
2021-10-18 18:06:48 -07:00
|
|
|
assert isinstance(ans, collections.abc.Sequence), (ans, eqn)
|
|
|
|
assert all(isinstance(x, xe.XlaOp) for x in ans), (ans, eqn)
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
map(ctx.builder.get_shape, ans) # force xla to do shape error checking
|
|
|
|
ctx.builder.clear_op_metadata()
|
|
|
|
_partitionmap(write, eqn.outvars, ans)
|
2020-09-29 11:53:17 -07:00
|
|
|
return _flatmap(read, jaxpr.outvars)
|
|
|
|
|
2019-04-15 07:45:10 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
def xla_destructure(c, ans):
|
2020-05-11 17:43:55 -04:00
|
|
|
num_elements = len(c.get_shape(ans).tuple_shapes())
|
2020-04-23 18:30:47 -04:00
|
|
|
return [xops.GetTupleElement(ans, i) for i in range(num_elements)]
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
def check_backend_matches(inner_backend, outer_backend):
|
2019-08-29 20:25:02 -07:00
|
|
|
# For nested calls, the outermost call sets the backend for all inner calls;
|
|
|
|
# it's an error if the inner call has a conflicting explicit backend spec.
|
|
|
|
if inner_backend and inner_backend != outer_backend:
|
2020-04-16 15:51:23 +01:00
|
|
|
raise ValueError(
|
|
|
|
f"Outer-jit backend specification {outer_backend} must match explicit "
|
|
|
|
f"inner-jit backend specification {inner_backend}.")
|
2019-08-29 20:25:02 -07:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-11-19 11:36:35 +00:00
|
|
|
def extend_axis_env(env: AxisEnv, name, size: int):
|
|
|
|
return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,))
|
2018-11-17 18:03:33 -08:00
|
|
|
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
def axis_read(axis_env, axis_name):
|
2020-05-19 15:51:07 -07:00
|
|
|
try:
|
|
|
|
return max(i for i, name in enumerate(axis_env.names) if name == axis_name)
|
|
|
|
except ValueError:
|
2020-09-15 12:36:53 -07:00
|
|
|
raise NameError("unbound axis name: {}".format(axis_name)) from None
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
|
2021-11-30 05:34:00 -08:00
|
|
|
def axis_groups(axis_env: AxisEnv, name) -> Tuple[Tuple[int, ...]]:
|
2020-11-19 11:36:35 +00:00
|
|
|
if not isinstance(name, (list, tuple)):
|
|
|
|
name = (name,)
|
|
|
|
mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
|
|
|
|
trailing_size, ragged = divmod(axis_env.nreps, prod(axis_env.sizes))
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
assert not ragged
|
2020-11-19 11:36:35 +00:00
|
|
|
mesh_spec = axis_env.sizes + (trailing_size,)
|
|
|
|
return _axis_groups(mesh_spec, mesh_axes)
|
|
|
|
|
|
|
|
def _axis_groups(mesh_spec, mesh_axes):
|
|
|
|
"""Computes replica group ids for a collective performed over a subset of the mesh.
|
|
|
|
|
2021-01-15 11:49:19 +11:00
|
|
|
Args:
|
2020-11-19 11:36:35 +00:00
|
|
|
mesh_spec: A sequence of integers representing the mesh shape.
|
|
|
|
mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive)
|
|
|
|
indicating over which axes the collective is performed.
|
|
|
|
Returns:
|
|
|
|
A tuple of replica groups (i.e. tuples containing replica ids).
|
|
|
|
"""
|
|
|
|
iota = np.arange(prod(mesh_spec)).reshape(mesh_spec)
|
2020-07-14 13:05:31 -07:00
|
|
|
groups = np.reshape(
|
|
|
|
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
|
2020-11-19 11:36:35 +00:00
|
|
|
(prod(np.take(mesh_spec, mesh_axes)), -1))
|
2020-06-23 09:39:45 -07:00
|
|
|
return tuple(unsafe_map(tuple, groups.T))
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
|
2020-02-06 17:19:54 -08:00
|
|
|
|
2019-12-06 10:23:17 -08:00
|
|
|
# TODO(mattjj,skyewm): the functions here are utilities for checking if
|
|
|
|
# not-yet-supported features are used with multi-host programming
|
|
|
|
|
enable jit+pmap by merging pxla.py and xla.py
This change is essentially de-duplicating the XLA lowering logic between
xla.py and pxla.py. Only the latter was capable of handling collectives
(aka pmap primitives), which meant that these didn't work:
1. some compositions of jit and pmap, like jit-of-pmap
2. collectives inside initial-style control flow like scan
3. jax.xla_computation on a function involving collectives
By merging the logic into xla.py, now all the lowering machinery works
with everything. Woo!
The pxla.py file still exists and contains mostly dynamic/runtime
components for pmap and functions used only by pmap and collectives
translations. In particular, pxla.py has
* the pmap impl, particularly the dispatching logic for top-level pmaps,
including argument sharding and lazy sharded result persistence
* the ShardedDeviceArray / ShardedDeviceTuple classes
* the dynamic (trace-time) axis environment data structures and logic
and the special axis_index primitive
* the split-axis transformation for soft_pmap
* the PmapPrimitive (just a tagged version of Primitive)
* the static sharding/unsharding logic for pmap-inside-jit/pmap
These things moved over to xla.py
* the logic for lowering pmap primitives, especially the static axis
environment used during xla lowering
This change refactors the translation rule tables a bit. Instead of just
having one table, there are now four, and they contain rules with
slightly different type signatures:
* the `translations` table has rules with the same signatures as always,
i.e. `CompBuilder -> [XlaOperands] -> ParamsDict -> XlaOperandOut`
* the `backend_specific_translations` table is keyed by platform name
strings and has dict values that each have the same type as `translations`
* the `parallel_translations` table is used for primitives modeling
parallel collectives, and so it has rules with signature
`CompBuilder -> [XlaOperands] -> ReplicaGroups -> ParamsDict -> XlaOpOut`
* the `initial_style_translations` table is for the initial-style
control flow primitives (like `scan`), for which the translation rules
themselves lower jaxprs to XLA computations and thus require the static axis
env to be passed in; the rules there have signature
`CompBuilder -> AxisEnv -> [XlaOperands] -> ParamsDict -> XlaOpOut`
* the `call_translations` table is sued for `xla_call` and `xla_pmap`,
i.e. the primitives underlying `jit` and `pmap` respectively, and has
rules with signature
`CompBuilder -> Jaxpr -> AxisEnv -> [XlaOp] -> [XlaOp] -> ParamsDict -> XlaOp`
Having these as separate tables is an uninteresting implementation
detail. The lowering function `_jaxpr_computation` just does a case analysis
on whether the primitive being translated has an entry in any table
(where the `backend_specific_translations` table must be checked before
the `translations` table, since some primitives may be entered in both).
This change fixes #804 also addresses #852, in that the lax control flow
impls for those primitives are now based on Python-level jaxpr
interpreters rather than XLA compilation, but we should probably wait to
close the latter issue until we benchmark and improve things more. This
change at least seems not to be a performance regression: on my machine
the lax control flow tests go from running in ~20s to running in ~14s.
This change also adds a docstring for `jax.xla_computation` and some
basic tests.
2019-07-02 13:17:31 -07:00
|
|
|
|
2019-12-06 10:23:17 -08:00
|
|
|
def jaxpr_collectives(jaxpr):
|
2020-02-05 15:38:25 +01:00
|
|
|
"""Generates all the collective primitives anywhere inside a Jaxpr."""
|
|
|
|
for eqn in jaxpr.eqns:
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
if eqn.primitive in _collective_primitives:
|
2019-12-06 10:23:17 -08:00
|
|
|
yield eqn.primitive
|
2022-03-30 17:52:55 -07:00
|
|
|
for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_collectives(subjaxpr)
|
2019-12-06 10:23:17 -08:00
|
|
|
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
### xla_call underlying jit
|
2019-02-13 14:28:30 -08:00
|
|
|
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
def flatten_shape(s: XlaShape) -> Sequence[Tuple[Sequence[int], XlaShape]]:
|
|
|
|
"""Expands a given shape tree into a flat list of indices to arrays.
|
|
|
|
|
|
|
|
Given the following computation:
|
|
|
|
|
|
|
|
>>> c = xc.XlaBuilder("example")
|
2021-11-30 14:24:02 -08:00
|
|
|
>>> p0 = parameter(c, 1, xc.shape_from_pyval(jnp.ones([1])))
|
|
|
|
>>> p1 = parameter(c, 2, xc.shape_from_pyval(jnp.ones([2])))
|
|
|
|
>>> p2 = parameter(c, 3, xc.shape_from_pyval(jnp.ones([3])))
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
>>> o = xops.Tuple(c, [p0, p1, p2])
|
|
|
|
|
|
|
|
We can query the arrays in the output tuple:
|
|
|
|
|
|
|
|
>>> flatten_shape(c.GetShape(o))
|
2021-04-02 16:43:10 -07:00
|
|
|
[((0,), f32[1]{0}), ((1,), f32[2]{0}), ((2,), f32[3]{0})]
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
|
|
|
|
Or the arrays in one of the parameters (which is itself an array):
|
|
|
|
|
|
|
|
>>> flatten_shape(c.GetShape(p0))
|
2021-04-02 16:43:10 -07:00
|
|
|
[((), f32[1]{0})]
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
|
|
|
|
Args
|
|
|
|
s: The input shape.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An iterable of pairs of indices and shapes for each array within the shape
|
|
|
|
tree.
|
|
|
|
"""
|
2021-03-18 18:56:47 -07:00
|
|
|
results: List[Tuple[Tuple[int, ...], XlaShape]] = []
|
|
|
|
_flatten_shape(s, (), results)
|
|
|
|
return results
|
|
|
|
|
|
|
|
def _flatten_shape(s: XlaShape, index: Tuple[int, ...],
|
|
|
|
results: List[Tuple[Tuple[int, ...], XlaShape]]) -> None:
|
|
|
|
if s.is_array() or s.is_token():
|
|
|
|
results.append((index, s))
|
|
|
|
else:
|
|
|
|
assert s.is_tuple()
|
|
|
|
for i, sub in enumerate(s.tuple_shapes()):
|
|
|
|
_flatten_shape(sub, index + (i,), results)
|
|
|
|
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
|
2020-06-11 17:15:23 -07:00
|
|
|
def _xla_consts(c, consts):
|
|
|
|
unique_consts = {id(const): const for const in consts}
|
|
|
|
xla_consts = {
|
2021-10-19 08:40:15 -07:00
|
|
|
id_: pyval_to_ir_constants(c, const) for id_, const in unique_consts.items()}
|
2021-05-06 09:44:01 -07:00
|
|
|
return [c for const in consts for c in xla_consts[id(const)]]
|
2020-06-11 17:15:23 -07:00
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
2021-10-13 10:45:11 -07:00
|
|
|
|
|
|
|
|
2021-10-14 11:27:39 -07:00
|
|
|
def set_up_aliases(c, xla_args, out_shape: XlaShape, donated_args, tuple_args):
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
"""Configures input/output "must" aliasing based on `donated_args`."""
|
|
|
|
# First for every input array add it to `donations` iff it is a member of
|
|
|
|
# `donated_args`.
|
2021-10-14 11:27:39 -07:00
|
|
|
donations: Dict[Tuple[Tuple[int, ...], Any], Deque]
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
donations = defaultdict(deque)
|
|
|
|
for arg_index, arg in enumerate(xla_args):
|
|
|
|
if donated_args[arg_index]:
|
|
|
|
for param_index, element in flatten_shape(c.GetShape(arg)):
|
2021-03-10 10:18:38 -05:00
|
|
|
key = (element.dimensions(), element.xla_element_type())
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
if tuple_args:
|
|
|
|
param_number = 0
|
|
|
|
param_index = (arg_index,) + tuple(param_index)
|
|
|
|
donations[key].append((param_number, param_index, arg_index))
|
|
|
|
else:
|
|
|
|
param_number = arg_index
|
|
|
|
donations[key].append((param_number, param_index, arg_index))
|
|
|
|
|
|
|
|
# Consume donations for outputs.
|
|
|
|
out_donated_args = list(donated_args)
|
2021-10-14 11:27:39 -07:00
|
|
|
for output_index, element in flatten_shape(out_shape):
|
2021-03-10 10:18:38 -05:00
|
|
|
key = (element.dimensions(), element.xla_element_type())
|
Add support for buffer donation in `jit` and `pmap`. (#2936)
For a computation of the form:
>>> f = lambda x: x ** 2
>>> f = jax.jit(f)
>>> while run:
... x = f(x)
JAX must currently always have two copies of `x` in device memory since there
is no reliable way in Python to determine whether there will be future uses of
`x`. This causes two classes of problem:
1. Users at the limit of available device are constrained by the additional
copy of their parameters and other state while they typically only require
one copy. This typically frees 100M+ of device memory and is a critical
optimization for larger models to match state of the art performance in
other frameworks.
2. This constant alloc/free of the input/output buffers can cause memory
fragmentation on some platforms (although having a reusing allocator and
limiting run-ahead may be a better solution for this problem).
We propose fixing this by using input/output aliasing as supported by XLA. We
will support this in JAX by allowing certain arguments of jit/pmap decorated
functions to be donated and reused as outputs:
>>> f = lambda x: x ** 2
>>> f = jit(f, donate_argnums=0)
>>> while run:
... x = f(x)
JAX will determine that the donated input `x` can alias with the output of the
function and it will instruct XLA it _must_ write the result to this buffer.
If a user tries to reuse a buffer after it has been donated they get an error
that the buffer is invalid:
>>> y = f(x)
>>> jax.device_get(x)
...
RuntimeError: Invalid argument: CopyToHostAsync() called on invalid buffer.
The semantics of `donate_argnums` follows that of `static_argnums`, namely that
it identifies positional arguments to the computation that are to be donated
to the computation and used as part of the output.
One feature that is also enabled by this is invalidating buffers that should
only be used once, for example PRNGKeys:
>>> @partial(jit, donate_argnums=0)
... def move(x):
... # Do something complex enough for JAX to just optimize it away.
... return tree_map(lambda x: x + x - x, x)
>>> def safe_eager_uniform(key, *a, **k):
... assert hasattr(key, 'device_buffer'), "random must run eagerly"
... key = move(key)
... return jax.random.uniform(key, *a, **k)
This is not a complete answer to random safety since it is still possible to
reuse a key as part of a traced computation, however it can be used to support
this feature (somewhat inefficiently) in eager mode.
2020-05-31 23:00:16 +01:00
|
|
|
if donations.get(key, ()):
|
|
|
|
param_number, param_index, arg_index = donations[key].popleft()
|
|
|
|
out_donated_args[arg_index] = False
|
|
|
|
c.setup_alias(output_index, param_number, param_index)
|
|
|
|
|
|
|
|
return tuple(out_donated_args)
|
|
|
|
|
2021-05-03 11:40:59 -07:00
|
|
|
|
2021-03-02 09:42:12 -08:00
|
|
|
xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
|
2020-06-23 09:39:45 -07:00
|
|
|
xla_call = xla_call_p.bind
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2022-02-06 17:21:31 -08:00
|
|
|
def _xla_call_partial_eval_update_params(params, kept_inputs, num_new_inputs):
|
2020-06-23 09:39:45 -07:00
|
|
|
donated_invars = params['donated_invars']
|
2022-02-06 17:21:31 -08:00
|
|
|
if not kept_inputs and donated_invars:
|
2020-06-23 09:39:45 -07:00
|
|
|
# JaxprTrace.post_process_call creates a call with no input tracers
|
2022-02-06 17:21:31 -08:00
|
|
|
donated_invars = (False,) * num_new_inputs
|
2020-06-23 09:39:45 -07:00
|
|
|
else:
|
2022-02-06 17:21:31 -08:00
|
|
|
assert len(kept_inputs) == len(donated_invars)
|
2020-06-23 09:39:45 -07:00
|
|
|
# JaxprTrace.process_call drops known input tracers
|
2022-02-06 17:21:31 -08:00
|
|
|
donated_invars = [d for d, kept in zip(donated_invars, kept_inputs) if kept]
|
|
|
|
# Any new inputs are prepended to the left, so mark those as not donated.
|
|
|
|
donated_invars = [False] * num_new_inputs + donated_invars
|
|
|
|
return dict(params, donated_invars=tuple(donated_invars))
|
2020-06-23 09:39:45 -07:00
|
|
|
pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params
|
|
|
|
|
2022-03-28 15:40:49 -07:00
|
|
|
def _xla_call_jvp_update_params(params, nz_tangents):
|
2020-06-23 09:39:45 -07:00
|
|
|
donated_invars = params['donated_invars']
|
|
|
|
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
|
|
|
|
new_donated_invars = (*donated_invars, *donated_tangents)
|
|
|
|
return dict(params, donated_invars=new_donated_invars)
|
|
|
|
ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params
|
|
|
|
|
|
|
|
def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
|
|
|
|
donated_invars = params['donated_invars']
|
|
|
|
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
|
|
|
|
donated_cotangents = [False for nz in nonzero_cts if nz]
|
|
|
|
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
|
|
|
|
ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
|
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
def _xla_call_translation_rule(ctx, avals_in, avals_out, *in_nodes, name,
|
|
|
|
backend=None, call_jaxpr, donated_invars,
|
|
|
|
inline=None, device=None):
|
2021-05-03 21:40:50 -07:00
|
|
|
del device, donated_invars, inline # Ignored.
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
c = ctx.builder
|
|
|
|
check_backend_matches(backend, ctx.platform)
|
2021-10-18 13:19:45 -04:00
|
|
|
subc = xc.XlaBuilder(f"jit_{name}")
|
2021-11-30 14:24:02 -08:00
|
|
|
args = [parameter(subc, i, c.get_shape(n)) for i, n in enumerate(in_nodes)]
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
sub_ctx = ctx.replace(
|
|
|
|
builder=subc,
|
|
|
|
name_stack=extend_name_stack(ctx.name_stack, wrap_name(name, 'jit')))
|
|
|
|
out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
|
2022-01-04 10:17:20 -08:00
|
|
|
|
|
|
|
if len(out_nodes) == 1:
|
|
|
|
subc = subc.Build(out_nodes[0])
|
|
|
|
return [xops.Call(c, subc, list(in_nodes))]
|
|
|
|
else:
|
|
|
|
subc = subc.Build(xops.Tuple(subc, out_nodes))
|
|
|
|
return xla_destructure(c, xops.Call(c, subc, list(in_nodes)))
|
2019-07-27 15:46:14 -07:00
|
|
|
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
|
|
|
|
2021-08-06 11:09:29 -07:00
|
|
|
def _xla_call_partial_eval_custom_params_updater(
|
2022-02-15 03:43:40 -08:00
|
|
|
unks_in: Sequence[bool],
|
|
|
|
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
|
|
|
|
num_res: int, params_known: dict, params_staged: dict
|
2021-08-06 11:09:29 -07:00
|
|
|
) -> Tuple[dict, dict]:
|
2021-08-25 20:46:11 -07:00
|
|
|
# pruned inputs to jaxpr_known according to unks_in, so prune donated_invars
|
|
|
|
donated_invars_known, _ = partition_list(unks_in, params_known['donated_invars'])
|
|
|
|
new_params_known = dict(params_known, donated_invars=tuple(donated_invars_known))
|
|
|
|
# added num_res new inputs to jaxpr_staged, so extend donated_invars
|
|
|
|
donated_invars_staged = [*([False] * num_res), *params_staged['donated_invars']]
|
|
|
|
new_params_staged = dict(params_staged, donated_invars=tuple(donated_invars_staged))
|
|
|
|
return new_params_known, new_params_staged
|
2021-08-06 11:09:29 -07:00
|
|
|
pe.partial_eval_jaxpr_custom_rules[xla_call_p] = \
|
2021-10-12 20:06:38 -07:00
|
|
|
partial(pe.call_partial_eval_custom_rule, 'call_jaxpr',
|
2021-08-06 11:09:29 -07:00
|
|
|
_xla_call_partial_eval_custom_params_updater)
|
|
|
|
pe.dce_rules[xla_call_p] = pe.dce_jaxpr_call_rule
|
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
pe.padding_rules[xla_call_p] = partial(pe.call_padding_rule, xla_call_p)
|
|
|
|
|
2021-08-06 11:09:29 -07:00
|
|
|
|
2022-03-09 12:20:28 -08:00
|
|
|
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext,
|
|
|
|
settings: core.JaxprPpSettings,
|
2021-11-23 15:51:49 -08:00
|
|
|
) -> List[pp.Doc]:
|
|
|
|
printed_params = {k:v for k, v in eqn.params.items() if
|
|
|
|
k == 'call_jaxpr' or k == 'name' or
|
|
|
|
k == 'backend' and v is not None or
|
|
|
|
k == 'device' and v is not None or
|
|
|
|
k == 'donated_invars' and any(v)}
|
|
|
|
return [pp.text(eqn.primitive.name),
|
2022-03-09 12:20:28 -08:00
|
|
|
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
|
2021-11-23 15:51:49 -08:00
|
|
|
pp.text(" ") + core.pp_vars(eqn.invars, context)]
|
|
|
|
core.pp_eqn_rules[xla_call_p] = _pp_xla_call
|
|
|
|
|
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
### translation tables
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
MYPY = False
|
|
|
|
if not MYPY:
|
|
|
|
class TranslationRule(Protocol):
|
|
|
|
def __call__(self, ctx: TranslationContext,
|
|
|
|
avals_in: Sequence[core.AbstractValue],
|
|
|
|
avals_out: Sequence[core.AbstractValue],
|
|
|
|
*args: XlaOp, **kw
|
|
|
|
) -> Sequence[XlaOp]:
|
|
|
|
"""A translation rule lowers a primitive invocation into an XLA HLO."""
|
|
|
|
else:
|
|
|
|
TranslationRule = Any
|
|
|
|
|
|
|
|
_translations: Dict[core.Primitive, TranslationRule] = {}
|
|
|
|
_backend_specific_translations: Dict[str, Dict[core.Primitive, TranslationRule]]
|
|
|
|
_backend_specific_translations = defaultdict(dict)
|
|
|
|
|
|
|
|
_collective_primitives: Set[core.Primitive] = set()
|
|
|
|
_initial_style_primitives: Set[core.Primitive] = set()
|
|
|
|
|
2022-04-13 07:25:51 -07:00
|
|
|
def register_initial_style_primitive(prim: core.Primitive):
|
|
|
|
_initial_style_primitives.add(prim)
|
|
|
|
|
|
|
|
def register_collective_primitive(prim: core.Primitive):
|
|
|
|
_collective_primitives.add(prim)
|
|
|
|
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
def register_translation(prim: core.Primitive, rule: TranslationRule, *,
|
2022-04-13 07:25:51 -07:00
|
|
|
platform: Optional[str] = None) -> None:
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
ts = (_translations if platform is None
|
|
|
|
else _backend_specific_translations[platform])
|
|
|
|
ts[prim] = rule
|
|
|
|
|
|
|
|
# As a temporary backward compatibility measure, we use an adapter class to
|
|
|
|
# convert from the old styles of translation rules to the newer ones.
|
|
|
|
# TODO(phawkins): update users of the older translation rule styles and remove
|
|
|
|
# the adapters.
|
|
|
|
class _TranslationRuleAdapter:
|
|
|
|
def __init__(self, translations,
|
|
|
|
wrap_fn: Callable[[core.Primitive, Callable], TranslationRule]):
|
|
|
|
self._translations = translations
|
|
|
|
self._wrap_fn = wrap_fn
|
|
|
|
|
|
|
|
def __setitem__(self, key: core.Primitive, value: Callable):
|
|
|
|
self._translations[key] = self._wrap_fn(key, value)
|
|
|
|
|
|
|
|
|
|
|
|
def _wrap_old_translation(prim: core.Primitive, f: Callable) -> TranslationRule:
|
|
|
|
@functools.wraps(f)
|
|
|
|
def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
|
|
|
|
avals_out: Sequence[core.AbstractValue],
|
|
|
|
*args: XlaOp, **kw) -> Sequence[XlaOp]:
|
|
|
|
ans = f(ctx.builder, *args, **kw)
|
|
|
|
if (prim.multiple_results or
|
|
|
|
any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)):
|
|
|
|
return xla_destructure(ctx.builder, ans)
|
|
|
|
else:
|
|
|
|
return [ans]
|
|
|
|
return wrapped
|
|
|
|
|
|
|
|
|
|
|
|
def _wrap_old_call_translation(prim: core.Primitive,
|
|
|
|
f: Callable) -> TranslationRule:
|
|
|
|
@functools.wraps(f)
|
|
|
|
def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
|
|
|
|
avals_out: Sequence[core.AbstractValue],
|
|
|
|
*args: XlaOp, **kw) -> Sequence[XlaOp]:
|
|
|
|
platform = kw.pop("backend", None)
|
|
|
|
check_backend_matches(platform, ctx.platform)
|
|
|
|
ans = f(ctx.builder, ctx.axis_env, args, ctx.name_stack,
|
|
|
|
backend=ctx.platform, **kw)
|
|
|
|
if (prim.multiple_results or
|
|
|
|
any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)):
|
|
|
|
return xla_destructure(ctx.builder, ans)
|
|
|
|
else:
|
|
|
|
return [ans]
|
|
|
|
return wrapped
|
|
|
|
|
|
|
|
translations : _TranslationRuleAdapter
|
|
|
|
translations = _TranslationRuleAdapter(_translations, _wrap_old_translation)
|
|
|
|
|
|
|
|
class _BackendSpecificTranslationsAdapter(defaultdict):
|
|
|
|
def __missing__(self, key):
|
|
|
|
ret = self[key] = _TranslationRuleAdapter(
|
|
|
|
_backend_specific_translations[key], _wrap_old_translation)
|
|
|
|
return ret
|
2018-11-17 18:03:33 -08:00
|
|
|
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
backend_specific_translations: Dict[str, _TranslationRuleAdapter]
|
|
|
|
backend_specific_translations = _BackendSpecificTranslationsAdapter()
|
|
|
|
call_translations : _TranslationRuleAdapter
|
|
|
|
call_translations = _TranslationRuleAdapter(
|
|
|
|
_translations, _wrap_old_call_translation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
register_translation(xla_call_p, _xla_call_translation_rule)
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def zeros_like_translation_rule(c, x):
|
2020-05-11 17:43:55 -04:00
|
|
|
shape = c.get_shape(x)
|
change the xla representation of JAX's unit (#2416)
* change the xla representation of JAX's unit
Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.
This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:
```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```
The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.
The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.
In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).
* add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
|
|
|
assert not shape.is_tuple()
|
2021-10-19 08:40:15 -07:00
|
|
|
zero = xops.Constant(c, np.array(0, shape.element_type()))
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Broadcast(zero, shape.dimensions())
|
2019-07-27 15:46:14 -07:00
|
|
|
translations[ad_util.zeros_like_p] = zeros_like_translation_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2019-07-27 15:46:14 -07:00
|
|
|
def add_jaxvals_translation_rule(c, x, y):
|
2020-05-11 17:43:55 -04:00
|
|
|
shape = c.get_shape(x)
|
change the xla representation of JAX's unit (#2416)
* change the xla representation of JAX's unit
Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.
This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:
```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```
The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.
The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.
In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).
* add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
|
|
|
assert not shape.is_tuple()
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Add(x, y)
|
2019-07-27 15:46:14 -07:00
|
|
|
translations[ad_util.add_jaxvals_p] = add_jaxvals_translation_rule
|
2018-11-17 18:03:33 -08:00
|
|
|
|
2020-07-30 12:59:36 -07:00
|
|
|
translations[ad_util.stop_gradient_p] = lambda c, x: x
|
|
|
|
|
|
|
|
|
2020-04-06 11:22:01 -04:00
|
|
|
@lu.transformation
|
|
|
|
def _tuple_output(*args, **kwargs):
|
|
|
|
ans = yield args, kwargs
|
|
|
|
yield (ans,)
|
|
|
|
|
2021-11-18 12:44:27 -08:00
|
|
|
def lower_fun(fun: Callable, *, multiple_results: bool, backend=None,
|
|
|
|
new_style: bool = False) -> Callable:
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
if new_style:
|
2021-10-19 11:10:07 -07:00
|
|
|
def f_new(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
|
2021-12-02 06:11:41 -08:00
|
|
|
avals_out: Optional[Sequence[core.AbstractValue]],
|
2021-11-15 18:26:05 -08:00
|
|
|
*xla_args: xc.XlaOp,
|
2021-10-19 11:10:07 -07:00
|
|
|
**params) -> Sequence[xc.XlaOp]:
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
wrapped_fun = lu.wrap_init(fun, params)
|
|
|
|
if not multiple_results:
|
|
|
|
wrapped_fun = _tuple_output(wrapped_fun)
|
2021-11-18 12:44:27 -08:00
|
|
|
with core.extend_axis_env_nd(zip(ctx.axis_env.names, ctx.axis_env.sizes)):
|
2021-10-18 18:06:48 -07:00
|
|
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals_in)
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
return jaxpr_subcomp(ctx, jaxpr, _xla_consts(ctx.builder, consts),
|
|
|
|
*xla_args)
|
|
|
|
return f_new
|
|
|
|
|
|
|
|
# TODO(phawkins): migrate dependent code & always use new_style=True.
|
2021-11-15 18:26:05 -08:00
|
|
|
|
|
|
|
if backend is None:
|
|
|
|
# The user didn't specify a backend. This isn't possible with the new style
|
|
|
|
# API.
|
|
|
|
backend = "backend_not_specified"
|
|
|
|
|
2020-03-16 12:13:25 -07:00
|
|
|
def f(c, *xla_args, **params):
|
2020-05-11 17:43:55 -04:00
|
|
|
avals = [_array_aval_from_xla_shape(c.get_shape(x)) for x in xla_args]
|
2021-01-05 13:16:59 -08:00
|
|
|
return f_with_avals(c, avals, xla_args, params)
|
|
|
|
|
|
|
|
def f_with_avals(c, avals, xla_args, params):
|
2021-11-18 12:44:27 -08:00
|
|
|
# parallelism is only supported via the new-style API.
|
|
|
|
axis_env = AxisEnv(1, (), ())
|
2020-04-06 11:22:01 -04:00
|
|
|
wrapped_fun = lu.wrap_init(fun, params)
|
|
|
|
if not multiple_results:
|
|
|
|
wrapped_fun = _tuple_output(wrapped_fun)
|
2021-07-21 21:14:40 -07:00
|
|
|
with core.extend_axis_env_nd(zip(axis_env.names, axis_env.sizes)):
|
|
|
|
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals)
|
2021-10-28 11:06:58 -07:00
|
|
|
ctx = TranslationContext(c, backend, axis_env, new_name_stack())
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
outs = jaxpr_subcomp(ctx, jaxpr, _xla_consts(c, consts), *xla_args)
|
2021-10-13 13:57:42 -04:00
|
|
|
if (multiple_results or
|
|
|
|
any(len(aval_to_xla_shapes(v.aval)) > 1 for v in jaxpr.outvars)):
|
2020-04-23 18:30:47 -04:00
|
|
|
return xops.Tuple(c, outs)
|
2020-04-06 11:22:01 -04:00
|
|
|
else:
|
|
|
|
assert len(outs) == 1, outs
|
|
|
|
return outs[0]
|
2021-01-05 13:16:59 -08:00
|
|
|
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
return f
|
2018-11-17 18:03:33 -08:00
|
|
|
|
change the xla representation of JAX's unit (#2416)
* change the xla representation of JAX's unit
Previously the representation of JAX's unit value (a sentinel /
placeholder) was an empty tuple, but by changing the representation to
something else we can further reduce our dependence on runtime tuples.
This commit makes the representation fairly easy to change. There are
three functions in xla.py that define the representation. Here are
versions that would keep the old XLA representation as an empty tuple:
```
def _make_unit(c): return c.Tuple()
def _make_abstract_unit(_): return xc.Shape.tuple_shape(())
def _device_put_unit(_, device):
return xc.Buffer.make_tuple((), device, backend=xb.get_device_backend(device))
```
The new representation is as a trivial array. An alternative
representation would be nothing at all: we don't need to generate XLA
computations that have representations of JAX units. While that
alterntaive is probably the best choice, it seemed like it would require
a bit more refactoring/bookkeeping (e.g. to allow XLA computations to
have a smaller number of outputs than the corresponding JAX function),
and would also mean the XLA representation would be a step further
removed from the jaxpr representation. So I stuck with a trivial array
for now.
The mapping from JAX types to XLA types need not be invertible. However,
XLA translation rules currently don't take as arguments the
corresponding JAX types (abstract values), and there were a few cases
where we relied on checking whether an argument's XLA type was that of
an empty tuple so as to determine if we were effectively operating on a
JAX unit.
In particular, the AD-related primitive add_jaxvals_p could in principle
add two units, and get lowered to an XLA addition on the unit
representation. Previously, the translation rule for add_jaxvals_p
checked the XLA type so that adding two empty tuples didn't produce any
XLA operation; now it adds its inputs, and so if unit is represented as
a trivial array we could be inserting trivial scalar adds where we had
none before. However, if that case is ever possible, it doesn't come up
in our tests (which I checked by keeping the representation as an empty
tuple and then asserting an XLA tuple type is never seen by that
translation rule).
* add comment about JAX<->XLA array types assumption
2020-03-14 12:33:14 -07:00
|
|
|
def _array_aval_from_xla_shape(xla_shape):
|
|
|
|
# This function instantiates the assumption that we can map fro XLA array
|
|
|
|
# types to JAX array types.
|
|
|
|
# TODO(mattjj): remove assumption can map XLA array types to JAX array types
|
|
|
|
assert not xla_shape.is_tuple()
|
|
|
|
return ShapedArray(xla_shape.dimensions(), xla_shape.numpy_dtype())
|
|
|
|
|
2021-12-01 08:42:35 -08:00
|
|
|
|
2020-11-04 21:01:42 -08:00
|
|
|
ad.primitive_transposes[core.named_call_p] = partial(ad.call_transpose,
|
|
|
|
core.named_call_p)
|
|
|
|
|
|
|
|
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
def _named_call_translation_rule(ctx, avals_in, avals_out, *in_nodes,
|
|
|
|
name="core_call", backend=None, call_jaxpr):
|
|
|
|
check_backend_matches(backend, ctx.platform)
|
|
|
|
c = ctx.builder
|
2021-10-18 13:19:45 -04:00
|
|
|
subc = xc.XlaBuilder(name)
|
2021-11-30 14:24:02 -08:00
|
|
|
args = [parameter(subc, i, c.GetShape(n)) for i, n in enumerate(in_nodes)]
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
sub_ctx = ctx.replace(builder=subc,
|
|
|
|
name_stack=extend_name_stack(ctx.name_stack, name))
|
|
|
|
out_nodes = jaxpr_subcomp(sub_ctx, call_jaxpr, (), *args)
|
2022-01-04 10:17:20 -08:00
|
|
|
if len(out_nodes) == 1:
|
|
|
|
subc = subc.Build(out_nodes[0])
|
|
|
|
return [xops.Call(c, subc, list(in_nodes))]
|
|
|
|
else:
|
|
|
|
subc = subc.Build(xops.Tuple(subc, out_nodes))
|
|
|
|
return xla_destructure(c, xops.Call(c, subc, list(in_nodes)))
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
register_translation(core.named_call_p, _named_call_translation_rule)
|
2020-11-04 21:01:42 -08:00
|
|
|
|
|
|
|
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
def _call_translation_rule(ctx, avals_in, avals_out, *in_nodes, backend=None,
|
2020-11-04 21:01:42 -08:00
|
|
|
call_jaxpr):
|
|
|
|
return _named_call_translation_rule(
|
Cleanup internal representation of XLA translation rules.
Over time JAX has sprouted many variants of XLA translation rules, each with slightly different but overlapping arguments. This change consolidates them into a new xla.TranslationRule signature:
rule(ctx, avals_in, avals_out, *args, **params)
where ctx contains the parts of the other signatures that were typically not specific to a particular equation.
Since there are many JAX rules to migrate, and even a number of translation rules belonging to projects downstream of JAX, we leave backwards compatibility shims around `xla.translations`, `xla.backend_specific_translations`, and `xla.call_translations` which seem to be the only ones used outside JAX itself.
In passing, this change alters the semantics of `backend` arguments to nested `jit` blocks. We now always canonicalize the backend to a specific backend at the outermost `jit`, and do not complain if an inner `jit` has an explicit `backend` that matches the current default choice.
PiperOrigin-RevId: 403607667
2021-10-16 07:52:57 -07:00
|
|
|
ctx, avals_in, avals_out, *in_nodes, name="core_call", backend=backend,
|
|
|
|
call_jaxpr=call_jaxpr)
|
|
|
|
register_translation(core.call_p, _call_translation_rule)
|