Move jax.interpreters.xla to jax._src.interpreters.xla.

Replace jax.interpreters.xla with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 507895040
This commit is contained in:
Peter Hawkins 2023-02-07 15:00:56 -08:00 committed by jax authors
parent 9c827fbd9a
commit 6860cb8d2a
10 changed files with 656 additions and 597 deletions

View File

@ -30,8 +30,9 @@ from jax._src.util import prod, safe_zip, use_cpp_class, use_cpp_method
from jax._src.lib import xla_client as xc
from jax._src import api
from jax._src.typing import ArrayLike
from jax.interpreters import mlir
from jax._src.interpreters import pxla
from jax.interpreters import xla, mlir
from jax._src.interpreters import xla
from jax._src.sharding import (
Sharding, SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
device_replica_id_map)

View File

@ -36,7 +36,7 @@ from jax.errors import UnexpectedTracerError
from jax.monitoring import record_event_duration_secs
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
import jax.interpreters.xla as xla
import jax._src.interpreters.xla as xla
from jax.interpreters import pxla
import jax.interpreters.partial_eval as pe
@ -123,11 +123,6 @@ def apply_primitive(prim, *args, **params):
**params)
return compiled_fun(*args)
# TODO(phawkins,frostig,mattjj): update code referring to
# xla.apply_primitive to point here, or use simple_impl if that's why
# it is using apply_primitive to begin with
xla.apply_primitive = apply_primitive
def simple_impl(prim):
prim.def_impl(partial(apply_primitive, prim))
@ -646,7 +641,7 @@ def eqn_replicas(eqn):
call_jaxpr = eqn.params.get("call_jaxpr")
if call_jaxpr:
return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr)
elif eqn.primitive in xla._initial_style_primitives:
elif eqn.primitive in xla.initial_style_primitives:
return initial_style_primitive_replicas(eqn.params)
else:
return 1
@ -1030,9 +1025,6 @@ def backend_compile(backend, built_c, options, host_callbacks):
# to take in `host_callbacks`
return backend.compile(built_c, compile_options=options)
# TODO(phawkins): update users.
xla.backend_compile = backend_compile
_ir_dump_counter = itertools.count()
def _make_string_safe_for_filename(s: str) -> str:
@ -1263,9 +1255,6 @@ def device_put(x, device: Optional[Device] = None) -> Tuple[Any, ...]:
except KeyError as err:
raise TypeError(f"No device_put handler for type: {type(x)}") from err
# TODO(phawkins): update users.
xla.device_put = device_put
def _device_put_masked_array(x, device: Optional[Device]):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")

View File

@ -50,7 +50,6 @@ from jax.errors import JAXTypeError
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_map
from jax._src import abstract_arrays
@ -71,6 +70,7 @@ from jax._src.config import config
from jax._src.config import flags
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import xla
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version

View File

@ -0,0 +1,586 @@
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lowering of jaxprs into XLA (HLO) computations.
from collections import defaultdict
import dataclasses
import functools
from functools import partial
import itertools as it
import operator
import re
from typing import (Any, Callable, Dict, List, NamedTuple, Optional,
Protocol, Sequence, Set, Type, Tuple, Union)
import numpy as np
from jax.config import config
from jax.interpreters import partial_eval as pe
from jax._src import core
from jax._src import device_array
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ConcreteArray, ShapedArray, str_eqn_compact
from jax._src.interpreters import ad
from jax._src.util import (prod, new_name_stack, safe_zip, safe_map,
partition_list)
# TODO: update callers to refer to new location.
from jax._src.util import extend_name_stack as extend_name_stack # noqa: F401
from jax._src.util import wrap_name as wrap_name # noqa: F401
from jax._src.typing import Shape
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
xe = xc._xla
xops = xc._xla.ops
# Types
Backend = xe.Client
Device = xc.Device
Buffer = xe.Buffer
XlaOp = xc.XlaOp
XlaShape = xc.Shape
XlaBuilder = xc.XlaBuilder
XlaLoadedExecutable = Any
XlaLoadedExecutable = xc.LoadedExecutable # type:ignore
# TODO(phawkins): update code to point to new locations.
DeviceArray = device_array.DeviceArray
_DeviceArray = device_array._DeviceArray
_CppDeviceArray = xe.Buffer
make_device_array = device_array.make_device_array
def identity(x): return x
_scalar_types = dtypes.python_scalar_dtypes.keys()
def _make_array_shape(a: ShapedArray) -> Sequence[XlaShape]:
if a.dtype == dtypes.float0:
return (xc.Shape.array_shape(np.dtype('bool'), a.shape),)
else:
return (xc.Shape.array_shape(a.dtype, a.shape),)
def get_canonical_source_file(frame: source_info_util.Frame):
source_file = frame.file_name
if config.jax_hlo_source_file_canonicalization_regex:
source_file = re.sub(config.jax_hlo_source_file_canonicalization_regex,
'', source_file)
return source_file
tracebacks = {}
def make_op_metadata(primitive: core.Primitive,
params: Dict, *,
source_info: source_info_util.SourceInfo,
name_stack: Union[str, source_info_util.NameStack] = "",
) -> xc.OpMetadata:
eqn_str = (str(source_info.name_stack) + '/'
+ str_eqn_compact(primitive.name, params))
tracebacks[eqn_str] = source_info.traceback
frame = source_info_util.user_frame(source_info)
return xc.OpMetadata(
op_type=primitive.name,
op_name=eqn_str,
source_file=get_canonical_source_file(frame) if frame else None,
source_line=frame.start_line if frame else None)
# Utilities
def parameter(builder, num, shape, name=None, replicated=None):
if name is None:
name = ''
if replicated is None:
replicated = []
elif isinstance(replicated, bool):
replicated = [replicated] * shape.leaf_count()
return xops.Parameter(builder, num,
shape.with_major_to_minor_layout_if_absent(), name,
replicated)
# HLO instructions optionally can be annotated to say how the output should be
# spatially partitioned (represented in XLA as OpSharding protos, see
# sharding_to_proto). For array outputs, the annotation is either an int per
# dimension specifying the number of ways that dimension divided (i.e. the total
# number of shards is the product), or None to indicate the array should be
# replicated. Tuple outputs are represented as tuples thereof. XLA supports
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
# checkers don't support recursive types), so we only represent one level of
# nesting in this type definition.
SpatialSharding = Union[Shape,
None,
Tuple[Optional[Shape], ...]]
def sharding_to_proto(sharding: SpatialSharding):
"""Converts a SpatialSharding to an OpSharding.
See
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
for details on the OpSharding proto.
"""
proto = xc.OpSharding()
if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
assert all(s is None or isinstance(s, tuple) for s in sharding)
return tuple_sharding_proto(list(map(sharding_to_proto, sharding))) # type: ignore
if sharding is None:
proto.type = xc.OpSharding.Type.REPLICATED
else:
proto.type = xc.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = list(sharding) # type: ignore
proto.tile_assignment_devices = list(range(np.product(sharding))) # type: ignore
return proto
def tuple_sharding_proto(elems):
proto = xc.OpSharding()
assert all(isinstance(e, type(proto)) for e in elems)
proto.type = xc.OpSharding.Type.TUPLE
proto.tuple_shardings = elems
return proto
def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
builder.set_sharding(sharding_proto)
try:
return op_fn(*args, **kwargs)
finally:
builder.clear_sharding()
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
return with_sharding_proto(builder, sharding_to_proto(sharding), op_fn, *args,
**kwargs)
### handlers
# Numpy dtypes -> XLA primitive types
_dtype_to_primitive_type: Dict[np.dtype, xc.PrimitiveType] = {
np.dtype('bool'): xc.PrimitiveType.PRED,
np.dtype('int8'): xc.PrimitiveType.S8,
np.dtype('int16'): xc.PrimitiveType.S16,
np.dtype('int32'): xc.PrimitiveType.S32,
np.dtype('int64'): xc.PrimitiveType.S64,
np.dtype('uint8'): xc.PrimitiveType.U8,
np.dtype('uint16'): xc.PrimitiveType.U16,
np.dtype('uint32'): xc.PrimitiveType.U32,
np.dtype('uint64'): xc.PrimitiveType.U64,
np.dtype(dtypes.bfloat16): xc.PrimitiveType.BF16,
np.dtype('float16'): xc.PrimitiveType.F16,
np.dtype('float32'): xc.PrimitiveType.F32,
np.dtype('float64'): xc.PrimitiveType.F64,
np.dtype('complex64'): xc.PrimitiveType.C64,
np.dtype('complex128'): xc.PrimitiveType.C128,
}
def dtype_to_primitive_type(dtype: np.dtype) -> xc.PrimitiveType:
"""Converts a NumPy dtype into an XLA PrimitiveType."""
# Many things (e.g., strings, scalar types) can be compared with NumPy dtypes,
# but may not hash correctly. Make sure we have a true np.dtype.
assert isinstance(dtype, np.dtype), type(dtype)
try:
return _dtype_to_primitive_type[dtype]
except KeyError as err:
raise TypeError(f"No XLA lowering for NumPy dtype: {dtype}") from err
# JAX abstract values -> XLA shapes
def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[XlaShape]:
try:
return xla_shape_handlers[type(aval)](aval)
except KeyError as err:
raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err
xla_shape_handlers: Dict[Type[core.AbstractValue],
Callable[[Any], Sequence[XlaShape]]] = {
ShapedArray: _make_array_shape,
ConcreteArray: _make_array_shape,
}
xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
# IR constants
# TODO(mattjj): try to remove this canonicalize_dtype stuff
def canonicalize_dtype(x):
typ = type(x)
handler = canonicalize_dtype_handlers.get(typ)
if handler: return handler(x)
for typ in typ.__mro__:
handler = canonicalize_dtype_handlers.get(typ)
if handler: return handler(x)
if hasattr(x, '__jax_array__'):
return canonicalize_dtype(x.__jax_array__())
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")
def _canonicalize_masked_array_dtype(x):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
def _canonicalize_ndarray_dtype(x):
return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
def _canonicalize_python_scalar_dtype(typ, x):
return np.asarray(
x, dtypes.canonicalize_dtype(dtypes._scalar_type_to_dtype(typ, x)))
canonicalize_dtype_handlers: Dict[Any, Callable] = {}
for t in device_array.device_array_types:
canonicalize_dtype_handlers[t] = identity
canonicalize_dtype_handlers.update(
(t, _canonicalize_ndarray_dtype) for t in numpy_scalar_types)
canonicalize_dtype_handlers[np.ndarray] = _canonicalize_ndarray_dtype
canonicalize_dtype_handlers[np.ma.MaskedArray] = _canonicalize_masked_array_dtype
canonicalize_dtype_handlers.update(
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
canonicalize_dtype_handlers[core.Token] = identity
canonicalize_dtype_handlers[core.DArray] = identity
def abstractify(x) -> core.AbstractValue:
typ = type(x)
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
for typ in typ.__mro__:
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
if hasattr(x, '__jax_array__'):
return abstractify(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
def _make_abstract_python_scalar(typ, val):
# Note: all python scalar types are weak except bool, because bool only
# comes in a single width.
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
weak_type=typ is not bool)
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = {}
for t in device_array.device_array_types:
pytype_aval_mappings[t] = operator.attrgetter('aval')
pytype_aval_mappings[core.DArray] = operator.attrgetter('_aval')
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
pytype_aval_mappings.update((t, _make_shaped_array_for_numpy_scalar)
for t in numpy_scalar_types)
pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
pytype_aval_mappings.update(
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
def primitive_subcomputation(platform: str, axis_env: 'AxisEnv',
prim: core.Primitive,
avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
**params):
c = xc.XlaBuilder(f"primitive_computation_{prim.name}")
counts = it.count()
xla_args = [parameter(c, next(counts), xla_shape)
for a in avals_in for xla_shape in aval_to_xla_shapes(a)]
if (platform is not None and
prim in _backend_specific_translations[platform]):
rule = _backend_specific_translations[platform][prim]
elif prim in _translations:
rule = _translations[prim]
ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env,
name_stack=new_name_stack())
ans = rule(ctx, avals_in, avals_out, *xla_args, **params)
if prim.multiple_results:
return c.build(xops.Tuple(c, ans))
else:
x, = ans
return c.build(x)
### compiling jaxprs
class AxisEnv(NamedTuple):
"""Represents a pmap mesh (only along the replica axes)."""
nreps: int
names: Tuple[Any, ...]
sizes: Tuple[int, ...]
@dataclasses.dataclass
class TranslationContext:
builder: xc.XlaBuilder
# TODO(phawkins): make platform non-optional. We should always be translating
# with a specific platform in mind.
platform: Optional[str]
axis_env: AxisEnv
name_stack: Union[str, source_info_util.NameStack]
def replace(self, **kw): return dataclasses.replace(self, **kw)
def xla_destructure(c, ans):
num_elements = len(c.get_shape(ans).tuple_shapes())
return [xops.GetTupleElement(ans, i) for i in range(num_elements)]
def check_backend_matches(inner_backend, outer_backend):
# For nested calls, the outermost call sets the backend for all inner calls;
# it's an error if the inner call has a conflicting explicit backend spec.
if inner_backend is None:
return
if (inner_backend != outer_backend and
outer_backend not in xb.expand_platform_alias(inner_backend)):
raise ValueError(
f"Outer-jit backend specification {outer_backend} must match explicit "
f"inner-jit backend specification {inner_backend}.")
def extend_axis_env(env: AxisEnv, name, size: int):
return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,))
def axis_read(axis_env, axis_name):
try:
return max(i for i, name in enumerate(axis_env.names) if name == axis_name)
except ValueError:
raise NameError(f"unbound axis name: {axis_name}") from None
def axis_groups(axis_env: AxisEnv, name) -> Tuple[Tuple[int, ...]]:
if not isinstance(name, (list, tuple)):
name = (name,)
mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
trailing_size, ragged = divmod(axis_env.nreps, prod(axis_env.sizes))
assert not ragged
mesh_spec = axis_env.sizes + (trailing_size,)
return _axis_groups(mesh_spec, mesh_axes)
def _axis_groups(mesh_spec, mesh_axes):
"""Computes replica group ids for a collective performed over a subset of the mesh.
Args:
mesh_spec: A sequence of integers representing the mesh shape.
mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive)
indicating over which axes the collective is performed.
Returns:
A tuple of replica groups (i.e. tuples containing replica ids).
"""
iota = np.arange(prod(mesh_spec)).reshape(mesh_spec)
groups = np.reshape(
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
(prod(np.take(mesh_spec, mesh_axes)), -1))
return tuple(unsafe_map(tuple, groups.T))
# TODO(mattjj,skyewm): the functions here are utilities for checking if
# not-yet-supported features are used with multi-host programming
def jaxpr_collectives(jaxpr):
"""Generates all the collective primitives anywhere inside a Jaxpr."""
for eqn in jaxpr.eqns:
if eqn.primitive in _collective_primitives:
yield eqn.primitive
for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_collectives(subjaxpr)
### xla_call underlying jit
xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
xla_call = xla_call_p.bind
def _xla_call_partial_eval_update_params(
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
) -> core.ParamDict:
donated_invars = params['donated_invars']
if not kept_inputs and donated_invars:
# JaxprTrace.post_process_call creates a call with no input tracers
donated_invars = (False,) * num_new_inputs
else:
assert len(kept_inputs) == len(donated_invars)
# JaxprTrace.process_call drops known input tracers
donated_invars = [d for d, kept in zip(donated_invars, kept_inputs) if kept]
# Any new inputs are prepended to the left, so mark those as not donated.
donated_invars = [False] * num_new_inputs + donated_invars
return dict(params, donated_invars=tuple(donated_invars))
pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params
def _xla_call_jvp_update_params(params, nz_tangents):
donated_invars = params['donated_invars']
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
new_donated_invars = (*donated_invars, *donated_tangents)
return dict(params, donated_invars=new_donated_invars)
ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params
def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
donated_invars = params['donated_invars']
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
donated_cotangents = [False for nz in nonzero_cts if nz]
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
def _xla_call_partial_eval_custom_params_updater(
unks_in: Sequence[bool], inst_in: Sequence[bool],
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
num_res: int, params_known: dict, params_staged: dict
) -> Tuple[dict, dict]:
# pruned inputs to jaxpr_known according to unks_in, so prune donated_invars
donated_known, _ = partition_list(unks_in, params_known['donated_invars'])
new_params_known = dict(params_known, donated_invars=tuple(donated_known))
# added num_res new inputs to jaxpr_staged, so extend donated_invars
_, donated_staged_ = partition_list(inst_in, params_staged['donated_invars'])
donated_staged = [False] * num_res + donated_staged_
new_params_staged = dict(params_staged, donated_invars=tuple(donated_staged))
return new_params_known, new_params_staged
pe.partial_eval_jaxpr_custom_rules[xla_call_p] = \
partial(pe.call_partial_eval_custom_rule, 'call_jaxpr',
_xla_call_partial_eval_custom_params_updater)
pe.dce_rules[xla_call_p] = pe.dce_jaxpr_call_rule
pe.padding_rules[xla_call_p] = partial(pe.call_padding_rule, xla_call_p)
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext,
settings: core.JaxprPpSettings,
) -> List[pp.Doc]:
printed_params = {k:v for k, v in eqn.params.items() if
k == 'call_jaxpr' or k == 'name' or
k == 'backend' and v is not None or
k == 'device' and v is not None or
k == 'donated_invars' and any(v)}
annotation = (source_info_util.summarize(eqn.source_info)
if settings.source_info else None)
lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
rhs = [pp.text(eqn.primitive.name),
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
pp.text(" ") + core.pp_vars(eqn.invars, context)]
return [lhs, pp.text(" = ", annotation=annotation), *rhs]
core.pp_eqn_rules[xla_call_p] = _pp_xla_call
### translation tables
MYPY = False
if not MYPY:
class TranslationRule(Protocol):
def __call__(self, ctx: TranslationContext,
avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
*args: XlaOp, **kw
) -> Sequence[XlaOp]:
"""A translation rule lowers a primitive invocation into an XLA HLO."""
else:
TranslationRule = Any
_translations: Dict[core.Primitive, TranslationRule] = {}
_backend_specific_translations: Dict[str, Dict[core.Primitive, TranslationRule]]
_backend_specific_translations = defaultdict(dict)
_collective_primitives: Set[core.Primitive] = set()
initial_style_primitives: Set[core.Primitive] = set()
def register_initial_style_primitive(prim: core.Primitive):
initial_style_primitives.add(prim)
def register_collective_primitive(prim: core.Primitive):
_collective_primitives.add(prim)
def register_translation(prim: core.Primitive, rule: TranslationRule, *,
platform: Optional[str] = None) -> None:
if platform is None:
_translations[prim] = rule
else:
# For backward compatibility reasons, we allow rules to be registered
# under "gpu" even though the platforms are now called "cuda" and "rocm".
# TODO(phawkins): fix up users to specify either "cuda" or "rocm" and remove
# this expansion.
for p in xb.expand_platform_alias(platform):
_backend_specific_translations[p][prim] = rule
# As a temporary backward compatibility measure, we use an adapter class to
# convert from the old styles of translation rules to the newer ones.
# TODO(phawkins): update users of the older translation rule styles and remove
# the adapters.
class _TranslationRuleAdapter:
def __init__(self, translations,
wrap_fn: Callable[[core.Primitive, Callable], TranslationRule]):
self._translations = translations
self._wrap_fn = wrap_fn
def __setitem__(self, key: core.Primitive, value: Callable):
wrapped = self._wrap_fn(key, value)
for translations in self._translations:
translations[key] = wrapped
def _wrap_old_translation(prim: core.Primitive, f: Callable) -> TranslationRule:
@functools.wraps(f)
def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
*args: XlaOp, **kw) -> Sequence[XlaOp]:
ans = f(ctx.builder, *args, **kw)
if (prim.multiple_results or
any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)):
return xla_destructure(ctx.builder, ans)
else:
return [ans]
return wrapped
translations : _TranslationRuleAdapter
translations = _TranslationRuleAdapter([_translations], _wrap_old_translation)
class _BackendSpecificTranslationsAdapter(defaultdict):
def __missing__(self, key):
translation_tables = [_backend_specific_translations[p]
for p in xb.expand_platform_alias(key)]
ret = self[key] = _TranslationRuleAdapter(
translation_tables, _wrap_old_translation)
return ret
backend_specific_translations: Dict[str, _TranslationRuleAdapter]
backend_specific_translations = _BackendSpecificTranslationsAdapter()
# TODO(phawkins): remove lower_fun completely after updating users.
def lower_fun(fun: Callable, *, multiple_results: bool, backend=None,
new_style: bool = False) -> Callable:
def f(*args, **kw):
raise RuntimeError("XLA translation rules are deprecated and "
"jax.interpreters.xla.lower_fun is no longer supported. "
"Add an MLIR lowering via jax.interpreters.mlir "
"instead.")
return f

View File

@ -46,7 +46,7 @@ from jax._src import util
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import use_stablehlo
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.interpreters import xla
source_info_util.register_exclusion(__file__)

View File

@ -511,9 +511,10 @@ from jax import custom_derivatives
from jax._src import dtypes
from jax import lax
from jax.experimental import pjit
from jax.interpreters import ad, xla, batching, pxla
from jax.interpreters import ad, batching, pxla
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax._src.interpreters import xla
from jax._src import ad_checkpoint
from jax._src import dispatch
from jax._src import pretty_printer as pp

View File

@ -67,7 +67,7 @@ from jax._src import test_util as jtu
from jax.config import config
from jax.experimental import jax2tf
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.interpreters import xla
import numpy as np
import tensorflow as tf # type: ignore[import]

View File

@ -34,7 +34,7 @@ from jax._src import linear_util as lu
from jax.config import config
from jax._src.interpreters import ad
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src.interpreters import xla
from jax._src import ad_util
from jax._src import core
from jax._src import device_array
@ -328,7 +328,7 @@ def _source_info_to_location(
if frame is None:
loc = ir.Location.unknown()
else:
loc = ir.Location.file(xla._get_canonical_source_file(frame),
loc = ir.Location.file(xla.get_canonical_source_file(frame),
frame.start_line, frame.start_column)
loc = ir.Location.name(eqn_str, childLoc=loc)
# TODO(phawkins): also include primitive.name as the operator type.

View File

@ -12,580 +12,61 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lowering of jaxprs into XLA (HLO) computations.
from collections import defaultdict
import dataclasses
import functools
from functools import partial
import itertools as it
import operator
import re
from typing import (Any, Callable, Dict, List, NamedTuple, Optional,
Protocol, Sequence, Set, Type, Tuple, Union)
import numpy as np
from jax.config import config
from jax.interpreters import partial_eval as pe
from jax._src import core
from jax._src import device_array
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ConcreteArray, ShapedArray, str_eqn_compact
from jax._src.interpreters import ad
from jax._src.util import (prod, new_name_stack, safe_zip, safe_map,
partition_list)
# TODO: update callers to refer to new location.
from jax._src.util import extend_name_stack as extend_name_stack # noqa: F401
from jax._src.util import wrap_name as wrap_name # noqa: F401
from jax._src.typing import Shape
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
xe = xc._xla
xops = xc._xla.ops
# Types
Backend = xe.Client
Device = xc.Device
Buffer = xe.Buffer
XlaOp = xc.XlaOp
XlaShape = xc.Shape
XlaBuilder = xc.XlaBuilder
XlaLoadedExecutable = Any
XlaLoadedExecutable = xc.LoadedExecutable # type:ignore
# apply_primitive is defined in jax._src.dispatch.
apply_primitive: Callable
backend_compile: Callable
device_put: Callable
# TODO(phawkins): update code to point to new locations.
DeviceArray = device_array.DeviceArray
_DeviceArray = device_array._DeviceArray
_CppDeviceArray = xe.Buffer
make_device_array = device_array.make_device_array
def identity(x): return x
_scalar_types = dtypes.python_scalar_dtypes.keys()
def _make_array_shape(a: ShapedArray) -> Sequence[XlaShape]:
if a.dtype == dtypes.float0:
return (xc.Shape.array_shape(np.dtype('bool'), a.shape),)
else:
return (xc.Shape.array_shape(a.dtype, a.shape),)
def _get_canonical_source_file(frame: source_info_util.Frame):
source_file = frame.file_name
if config.jax_hlo_source_file_canonicalization_regex:
source_file = re.sub(config.jax_hlo_source_file_canonicalization_regex,
'', source_file)
return source_file
tracebacks = {}
def make_op_metadata(primitive: core.Primitive,
params: Dict, *,
source_info: source_info_util.SourceInfo,
name_stack: Union[str, source_info_util.NameStack] = "",
) -> xc.OpMetadata:
eqn_str = (str(source_info.name_stack) + '/'
+ str_eqn_compact(primitive.name, params))
tracebacks[eqn_str] = source_info.traceback
frame = source_info_util.user_frame(source_info)
return xc.OpMetadata(
op_type=primitive.name,
op_name=eqn_str,
source_file=_get_canonical_source_file(frame) if frame else None,
source_line=frame.start_line if frame else None)
# Utilities
def parameter(builder, num, shape, name=None, replicated=None):
if name is None:
name = ''
if replicated is None:
replicated = []
elif isinstance(replicated, bool):
replicated = [replicated] * shape.leaf_count()
return xops.Parameter(builder, num,
shape.with_major_to_minor_layout_if_absent(), name,
replicated)
# HLO instructions optionally can be annotated to say how the output should be
# spatially partitioned (represented in XLA as OpSharding protos, see
# sharding_to_proto). For array outputs, the annotation is either an int per
# dimension specifying the number of ways that dimension divided (i.e. the total
# number of shards is the product), or None to indicate the array should be
# replicated. Tuple outputs are represented as tuples thereof. XLA supports
# arbitrary tuple nesting, but JAX only uses one level of tupling (and our type
# checkers don't support recursive types), so we only represent one level of
# nesting in this type definition.
SpatialSharding = Union[Shape,
None,
Tuple[Optional[Shape], ...]]
def sharding_to_proto(sharding: SpatialSharding):
"""Converts a SpatialSharding to an OpSharding.
See
https://github.com/tensorflow/tensorflow/blob/main/tensorflow/compiler/xla/xla_data.proto#L601
for details on the OpSharding proto.
"""
proto = xc.OpSharding()
if isinstance(sharding, tuple) and not isinstance(sharding[0], int):
assert all(s is None or isinstance(s, tuple) for s in sharding)
return tuple_sharding_proto(list(map(sharding_to_proto, sharding))) # type: ignore
if sharding is None:
proto.type = xc.OpSharding.Type.REPLICATED
else:
proto.type = xc.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = list(sharding) # type: ignore
proto.tile_assignment_devices = list(range(np.product(sharding))) # type: ignore
return proto
def tuple_sharding_proto(elems):
proto = xc.OpSharding()
assert all(isinstance(e, type(proto)) for e in elems)
proto.type = xc.OpSharding.Type.TUPLE
proto.tuple_shardings = elems
return proto
def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
builder.set_sharding(sharding_proto)
try:
return op_fn(*args, **kwargs)
finally:
builder.clear_sharding()
def with_sharding(builder, sharding: SpatialSharding, op_fn, *args, **kwargs):
"""Builds op_fn(*args, **kwargs) with sharding annotation."""
return with_sharding_proto(builder, sharding_to_proto(sharding), op_fn, *args,
**kwargs)
### handlers
# Numpy dtypes -> XLA primitive types
_dtype_to_primitive_type: Dict[np.dtype, xc.PrimitiveType] = {
np.dtype('bool'): xc.PrimitiveType.PRED,
np.dtype('int8'): xc.PrimitiveType.S8,
np.dtype('int16'): xc.PrimitiveType.S16,
np.dtype('int32'): xc.PrimitiveType.S32,
np.dtype('int64'): xc.PrimitiveType.S64,
np.dtype('uint8'): xc.PrimitiveType.U8,
np.dtype('uint16'): xc.PrimitiveType.U16,
np.dtype('uint32'): xc.PrimitiveType.U32,
np.dtype('uint64'): xc.PrimitiveType.U64,
np.dtype(dtypes.bfloat16): xc.PrimitiveType.BF16,
np.dtype('float16'): xc.PrimitiveType.F16,
np.dtype('float32'): xc.PrimitiveType.F32,
np.dtype('float64'): xc.PrimitiveType.F64,
np.dtype('complex64'): xc.PrimitiveType.C64,
np.dtype('complex128'): xc.PrimitiveType.C128,
}
def dtype_to_primitive_type(dtype: np.dtype) -> xc.PrimitiveType:
"""Converts a NumPy dtype into an XLA PrimitiveType."""
# Many things (e.g., strings, scalar types) can be compared with NumPy dtypes,
# but may not hash correctly. Make sure we have a true np.dtype.
assert isinstance(dtype, np.dtype), type(dtype)
try:
return _dtype_to_primitive_type[dtype]
except KeyError as err:
raise TypeError(f"No XLA lowering for NumPy dtype: {dtype}") from err
# JAX abstract values -> XLA shapes
def aval_to_xla_shapes(aval: core.AbstractValue) -> Sequence[XlaShape]:
try:
return xla_shape_handlers[type(aval)](aval)
except KeyError as err:
raise TypeError(f"No xla_shape_handler for type: {type(aval)}") from err
xla_shape_handlers: Dict[Type[core.AbstractValue],
Callable[[Any], Sequence[XlaShape]]] = {
ShapedArray: _make_array_shape,
ConcreteArray: _make_array_shape,
}
xla_shape_handlers[core.AbstractToken] = lambda _: (xc.Shape.token_shape(),)
# IR constants
# TODO(mattjj): try to remove this canonicalize_dtype stuff
def canonicalize_dtype(x):
typ = type(x)
handler = canonicalize_dtype_handlers.get(typ)
if handler: return handler(x)
for typ in typ.__mro__:
handler = canonicalize_dtype_handlers.get(typ)
if handler: return handler(x)
if hasattr(x, '__jax_array__'):
return canonicalize_dtype(x.__jax_array__())
raise TypeError(f"No canonicalize_dtype handler for type: {type(x)}")
def _canonicalize_masked_array_dtype(x):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")
def _canonicalize_ndarray_dtype(x):
return np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
def _canonicalize_python_scalar_dtype(typ, x):
return np.asarray(
x, dtypes.canonicalize_dtype(dtypes._scalar_type_to_dtype(typ, x)))
canonicalize_dtype_handlers: Dict[Any, Callable] = {}
for t in device_array.device_array_types:
canonicalize_dtype_handlers[t] = identity
canonicalize_dtype_handlers.update(
(t, _canonicalize_ndarray_dtype) for t in numpy_scalar_types)
canonicalize_dtype_handlers[np.ndarray] = _canonicalize_ndarray_dtype
canonicalize_dtype_handlers[np.ma.MaskedArray] = _canonicalize_masked_array_dtype
canonicalize_dtype_handlers.update(
(t, partial(_canonicalize_python_scalar_dtype, t)) for t in _scalar_types)
canonicalize_dtype_handlers[core.Token] = identity
canonicalize_dtype_handlers[core.DArray] = identity
def abstractify(x) -> core.AbstractValue:
typ = type(x)
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
for typ in typ.__mro__:
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
if hasattr(x, '__jax_array__'):
return abstractify(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
def _make_abstract_python_scalar(typ, val):
# Note: all python scalar types are weak except bool, because bool only
# comes in a single width.
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
weak_type=typ is not bool)
def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))
def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))
pytype_aval_mappings: Dict[Any, Callable[[Any], core.AbstractValue]] = {}
for t in device_array.device_array_types:
pytype_aval_mappings[t] = operator.attrgetter('aval')
pytype_aval_mappings[core.DArray] = operator.attrgetter('_aval')
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
pytype_aval_mappings.update((t, _make_shaped_array_for_numpy_scalar)
for t in numpy_scalar_types)
pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
pytype_aval_mappings.update(
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
def primitive_subcomputation(platform: str, axis_env: 'AxisEnv',
prim: core.Primitive,
avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
**params):
c = xc.XlaBuilder(f"primitive_computation_{prim.name}")
counts = it.count()
xla_args = [parameter(c, next(counts), xla_shape)
for a in avals_in for xla_shape in aval_to_xla_shapes(a)]
if (platform is not None and
prim in _backend_specific_translations[platform]):
rule = _backend_specific_translations[platform][prim]
elif prim in _translations:
rule = _translations[prim]
ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env,
name_stack=new_name_stack())
ans = rule(ctx, avals_in, avals_out, *xla_args, **params)
if prim.multiple_results:
return c.build(xops.Tuple(c, ans))
else:
x, = ans
return c.build(x)
### compiling jaxprs
class AxisEnv(NamedTuple):
"""Represents a pmap mesh (only along the replica axes)."""
nreps: int
names: Tuple[Any, ...]
sizes: Tuple[int, ...]
@dataclasses.dataclass
class TranslationContext:
builder: xc.XlaBuilder
# TODO(phawkins): make platform non-optional. We should always be translating
# with a specific platform in mind.
platform: Optional[str]
axis_env: AxisEnv
name_stack: Union[str, source_info_util.NameStack]
def replace(self, **kw): return dataclasses.replace(self, **kw)
def xla_destructure(c, ans):
num_elements = len(c.get_shape(ans).tuple_shapes())
return [xops.GetTupleElement(ans, i) for i in range(num_elements)]
def check_backend_matches(inner_backend, outer_backend):
# For nested calls, the outermost call sets the backend for all inner calls;
# it's an error if the inner call has a conflicting explicit backend spec.
if inner_backend is None:
return
if (inner_backend != outer_backend and
outer_backend not in xb.expand_platform_alias(inner_backend)):
raise ValueError(
f"Outer-jit backend specification {outer_backend} must match explicit "
f"inner-jit backend specification {inner_backend}.")
def extend_axis_env(env: AxisEnv, name, size: int):
return AxisEnv(env.nreps, env.names + (name,), env.sizes + (size,))
def axis_read(axis_env, axis_name):
try:
return max(i for i, name in enumerate(axis_env.names) if name == axis_name)
except ValueError:
raise NameError(f"unbound axis name: {axis_name}") from None
def axis_groups(axis_env: AxisEnv, name) -> Tuple[Tuple[int, ...]]:
if not isinstance(name, (list, tuple)):
name = (name,)
mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
trailing_size, ragged = divmod(axis_env.nreps, prod(axis_env.sizes))
assert not ragged
mesh_spec = axis_env.sizes + (trailing_size,)
return _axis_groups(mesh_spec, mesh_axes)
def _axis_groups(mesh_spec, mesh_axes):
"""Computes replica group ids for a collective performed over a subset of the mesh.
Args:
mesh_spec: A sequence of integers representing the mesh shape.
mesh_axes: A sequence of integers between 0 and `len(mesh_spec)` (exclusive)
indicating over which axes the collective is performed.
Returns:
A tuple of replica groups (i.e. tuples containing replica ids).
"""
iota = np.arange(prod(mesh_spec)).reshape(mesh_spec)
groups = np.reshape(
np.moveaxis(iota, mesh_axes, np.arange(len(mesh_axes))),
(prod(np.take(mesh_spec, mesh_axes)), -1))
return tuple(unsafe_map(tuple, groups.T))
# TODO(mattjj,skyewm): the functions here are utilities for checking if
# not-yet-supported features are used with multi-host programming
def jaxpr_collectives(jaxpr):
"""Generates all the collective primitives anywhere inside a Jaxpr."""
for eqn in jaxpr.eqns:
if eqn.primitive in _collective_primitives:
yield eqn.primitive
for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_collectives(subjaxpr)
### xla_call underlying jit
xla_call_p: core.CallPrimitive = core.CallPrimitive('xla_call')
xla_call = xla_call_p.bind
def _xla_call_partial_eval_update_params(
params: core.ParamDict, kept_inputs: Sequence[bool], num_new_inputs: int
) -> core.ParamDict:
donated_invars = params['donated_invars']
if not kept_inputs and donated_invars:
# JaxprTrace.post_process_call creates a call with no input tracers
donated_invars = (False,) * num_new_inputs
else:
assert len(kept_inputs) == len(donated_invars)
# JaxprTrace.process_call drops known input tracers
donated_invars = [d for d, kept in zip(donated_invars, kept_inputs) if kept]
# Any new inputs are prepended to the left, so mark those as not donated.
donated_invars = [False] * num_new_inputs + donated_invars
return dict(params, donated_invars=tuple(donated_invars))
pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params
def _xla_call_jvp_update_params(params, nz_tangents):
donated_invars = params['donated_invars']
donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]
new_donated_invars = (*donated_invars, *donated_tangents)
return dict(params, donated_invars=new_donated_invars)
ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params
def _xla_call_transpose_update_params(params, undef_primals, nonzero_cts):
donated_invars = params['donated_invars']
donated_primals = [d for d, u in zip(donated_invars, undef_primals) if not u]
donated_cotangents = [False for nz in nonzero_cts if nz]
return dict(params, donated_invars=(*donated_primals, *donated_cotangents))
ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
def _xla_call_partial_eval_custom_params_updater(
unks_in: Sequence[bool], inst_in: Sequence[bool],
kept_outs_known: Sequence[bool], kept_outs_staged: Sequence[bool],
num_res: int, params_known: dict, params_staged: dict
) -> Tuple[dict, dict]:
# pruned inputs to jaxpr_known according to unks_in, so prune donated_invars
donated_known, _ = partition_list(unks_in, params_known['donated_invars'])
new_params_known = dict(params_known, donated_invars=tuple(donated_known))
# added num_res new inputs to jaxpr_staged, so extend donated_invars
_, donated_staged_ = partition_list(inst_in, params_staged['donated_invars'])
donated_staged = [False] * num_res + donated_staged_
new_params_staged = dict(params_staged, donated_invars=tuple(donated_staged))
return new_params_known, new_params_staged
pe.partial_eval_jaxpr_custom_rules[xla_call_p] = \
partial(pe.call_partial_eval_custom_rule, 'call_jaxpr',
_xla_call_partial_eval_custom_params_updater)
pe.dce_rules[xla_call_p] = pe.dce_jaxpr_call_rule
pe.padding_rules[xla_call_p] = partial(pe.call_padding_rule, xla_call_p)
def _pp_xla_call(eqn: core.JaxprEqn, context: core.JaxprPpContext,
settings: core.JaxprPpSettings,
) -> List[pp.Doc]:
printed_params = {k:v for k, v in eqn.params.items() if
k == 'call_jaxpr' or k == 'name' or
k == 'backend' and v is not None or
k == 'device' and v is not None or
k == 'donated_invars' and any(v)}
annotation = (source_info_util.summarize(eqn.source_info)
if settings.source_info else None)
lhs = core.pp_vars(eqn.outvars, context, print_shapes=settings.print_shapes)
rhs = [pp.text(eqn.primitive.name),
core.pp_kv_pairs(sorted(printed_params.items()), context, settings),
pp.text(" ") + core.pp_vars(eqn.invars, context)]
return [lhs, pp.text(" = ", annotation=annotation), *rhs]
core.pp_eqn_rules[xla_call_p] = _pp_xla_call
### translation tables
MYPY = False
if not MYPY:
class TranslationRule(Protocol):
def __call__(self, ctx: TranslationContext,
avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
*args: XlaOp, **kw
) -> Sequence[XlaOp]:
"""A translation rule lowers a primitive invocation into an XLA HLO."""
else:
TranslationRule = Any
_translations: Dict[core.Primitive, TranslationRule] = {}
_backend_specific_translations: Dict[str, Dict[core.Primitive, TranslationRule]]
_backend_specific_translations = defaultdict(dict)
_collective_primitives: Set[core.Primitive] = set()
_initial_style_primitives: Set[core.Primitive] = set()
def register_initial_style_primitive(prim: core.Primitive):
_initial_style_primitives.add(prim)
def register_collective_primitive(prim: core.Primitive):
_collective_primitives.add(prim)
def register_translation(prim: core.Primitive, rule: TranslationRule, *,
platform: Optional[str] = None) -> None:
if platform is None:
_translations[prim] = rule
else:
# For backward compatibility reasons, we allow rules to be registered
# under "gpu" even though the platforms are now called "cuda" and "rocm".
# TODO(phawkins): fix up users to specify either "cuda" or "rocm" and remove
# this expansion.
for p in xb.expand_platform_alias(platform):
_backend_specific_translations[p][prim] = rule
# As a temporary backward compatibility measure, we use an adapter class to
# convert from the old styles of translation rules to the newer ones.
# TODO(phawkins): update users of the older translation rule styles and remove
# the adapters.
class _TranslationRuleAdapter:
def __init__(self, translations,
wrap_fn: Callable[[core.Primitive, Callable], TranslationRule]):
self._translations = translations
self._wrap_fn = wrap_fn
def __setitem__(self, key: core.Primitive, value: Callable):
wrapped = self._wrap_fn(key, value)
for translations in self._translations:
translations[key] = wrapped
def _wrap_old_translation(prim: core.Primitive, f: Callable) -> TranslationRule:
@functools.wraps(f)
def wrapped(ctx: TranslationContext, avals_in: Sequence[core.AbstractValue],
avals_out: Sequence[core.AbstractValue],
*args: XlaOp, **kw) -> Sequence[XlaOp]:
ans = f(ctx.builder, *args, **kw)
if (prim.multiple_results or
any(len(aval_to_xla_shapes(aval)) > 1 for aval in avals_out)):
return xla_destructure(ctx.builder, ans)
else:
return [ans]
return wrapped
translations : _TranslationRuleAdapter
translations = _TranslationRuleAdapter([_translations], _wrap_old_translation)
class _BackendSpecificTranslationsAdapter(defaultdict):
def __missing__(self, key):
translation_tables = [_backend_specific_translations[p]
for p in xb.expand_platform_alias(key)]
ret = self[key] = _TranslationRuleAdapter(
translation_tables, _wrap_old_translation)
return ret
backend_specific_translations: Dict[str, _TranslationRuleAdapter]
backend_specific_translations = _BackendSpecificTranslationsAdapter()
# TODO(phawkins): remove lower_fun completely after updating users.
def lower_fun(fun: Callable, *, multiple_results: bool, backend=None,
new_style: bool = False) -> Callable:
def f(*args, **kw):
raise RuntimeError("XLA translation rules are deprecated and "
"jax.interpreters.xla.lower_fun is no longer supported. "
"Add an MLIR lowering via jax.interpreters.mlir "
"instead.")
return f
from jax._src.interpreters.xla import (
AxisEnv as AxisEnv,
Backend as Backend,
Buffer as Buffer,
ConcreteArray as ConcreteArray,
Device as Device,
DeviceArray as DeviceArray,
Shape as Shape,
ShapedArray as ShapedArray,
SpatialSharding as SpatialSharding,
TranslationContext as TranslationContext,
TranslationRule as TranslationRule,
XlaBuilder as XlaBuilder,
XlaLoadedExecutable as XlaLoadedExecutable,
XlaOp as XlaOp,
XlaShape as XlaShape,
_CppDeviceArray as _CppDeviceArray,
_DeviceArray as _DeviceArray,
abstractify as abstractify,
aval_to_xla_shapes as aval_to_xla_shapes,
axis_groups as axis_groups,
axis_read as axis_read,
backend_specific_translations as backend_specific_translations,
canonicalize_dtype as canonicalize_dtype,
canonicalize_dtype_handlers as canonicalize_dtype_handlers,
check_backend_matches as check_backend_matches,
dtype_to_primitive_type as dtype_to_primitive_type,
extend_axis_env as extend_axis_env,
extend_name_stack as extend_name_stack,
jaxpr_collectives as jaxpr_collectives,
lower_fun as lower_fun,
make_device_array as make_device_array,
make_op_metadata as make_op_metadata,
new_name_stack as new_name_stack,
parameter as parameter,
partition_list as partition_list,
primitive_subcomputation as primitive_subcomputation,
pytype_aval_mappings as pytype_aval_mappings,
register_collective_primitive as register_collective_primitive,
register_initial_style_primitive as register_initial_style_primitive,
register_translation as register_translation,
sharding_to_proto as sharding_to_proto,
translations as translations,
xb as xb,
xc as xc,
xe as xe,
xla_call as xla_call,
xla_call_p as xla_call_p,
xla_destructure as xla_destructure,
xla_shape_handlers as xla_shape_handlers,
)
# TODO(phawkins): update users.
from jax._src.dispatch import (
apply_primitive as apply_primitive,
backend_compile as backend_compile,
device_put as device_put,
)

View File

@ -37,6 +37,7 @@ per-file-ignores =
jax/flatten_util.py:F401
jax/interpreters/ad.py:F401
jax/interpreters/pxla.py:F401
jax/interpreters/xla.py:F401
jax/linear_util.py:F401
jax/prng.py:F401
jax/profiler.py:F401