mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 20:06:05 +00:00

As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically: 1) MLIR lowerings now produce StableHLO ops instead of MHLO ops. 2) Fallback lowerings now produce StableHLO ops as well. 3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs). From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO): a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing. b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath. c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo". d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues. PiperOrigin-RevId: 497978733
1432 lines
56 KiB
Python
1432 lines
56 KiB
Python
# Copyright 2018 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Primitive dispatch and jit dispatch.
|
|
from __future__ import annotations
|
|
|
|
import atexit
|
|
import contextlib
|
|
from functools import partial
|
|
import itertools
|
|
import time
|
|
from typing import (
|
|
Any, Callable, Dict, Iterable, Iterator, Optional, Protocol,
|
|
Sequence, Set, Tuple, List, Type, Union)
|
|
import logging
|
|
import os
|
|
import re
|
|
import threading
|
|
import warnings
|
|
|
|
import numpy as np
|
|
|
|
import jax
|
|
from jax._src import linear_util as lu
|
|
from jax.errors import UnexpectedTracerError
|
|
from jax.monitoring import record_event_duration_secs
|
|
import jax.interpreters.ad as ad
|
|
import jax.interpreters.batching as batching
|
|
import jax.interpreters.mlir as mlir
|
|
import jax.interpreters.xla as xla
|
|
from jax.interpreters import pxla
|
|
import jax.interpreters.partial_eval as pe
|
|
from jax._src import array
|
|
from jax._src import core
|
|
from jax._src import device_array
|
|
from jax._src import dtypes
|
|
from jax._src import profiler
|
|
from jax._src import stages
|
|
from jax._src import traceback_util
|
|
from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
|
|
OpShardingSharding, Sharding)
|
|
from jax._src.abstract_arrays import array_types
|
|
from jax._src.config import config, flags
|
|
from jax._src.lib.mlir import ir
|
|
from jax._src.lib.mlir.dialects import use_stablehlo
|
|
from jax._src.lib import xla_bridge as xb
|
|
from jax._src.lib import xla_client as xc
|
|
import jax._src.util as util
|
|
from jax._src.util import flatten, unflatten
|
|
from jax._src import path
|
|
|
|
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
|
|
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
flags.DEFINE_string(
|
|
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
|
|
help="Path to which the IR that is emitted by JAX as input to the "
|
|
"compiler should be dumped as text files. Optional. If omitted, JAX "
|
|
"will not dump IR.")
|
|
|
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
MYPY = False # Are we currently type checking with mypy?
|
|
|
|
xe = xc._xla
|
|
|
|
Backend = xe.Client
|
|
Device = xc.Device
|
|
Buffer = xe.Buffer
|
|
|
|
XlaLoadedExecutable = xla.XlaLoadedExecutable
|
|
CompileOptions = xc.CompileOptions
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
zip, unsafe_zip = util.safe_zip, zip
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# This flag is set on exit; no logging should be attempted
|
|
_on_exit = False
|
|
|
|
### op-by-op execution
|
|
|
|
ArgSpec = Tuple[core.AbstractValue, Optional[Device]]
|
|
|
|
def arg_spec(x: Any) -> ArgSpec:
|
|
from jax.experimental import pjit
|
|
|
|
aval = xla.abstractify(x)
|
|
try:
|
|
if config.jax_array:
|
|
if isinstance(x.sharding, PmapSharding):
|
|
return aval, None
|
|
return aval, (pjit.to_op_sharding_sharding(x.sharding, x.ndim) # type: ignore
|
|
if x._committed else None)
|
|
else:
|
|
return aval, x._device
|
|
except:
|
|
return aval, None
|
|
|
|
|
|
def apply_primitive(prim, *args, **params):
|
|
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
|
|
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
|
|
**params)
|
|
return compiled_fun(*args)
|
|
|
|
# TODO(phawkins,frostig,mattjj): update code referring to
|
|
# xla.apply_primitive to point here, or use simple_impl if that's why
|
|
# it is using apply_primitive to begin with
|
|
xla.apply_primitive = apply_primitive
|
|
|
|
def simple_impl(prim):
|
|
prim.def_impl(partial(apply_primitive, prim))
|
|
|
|
RuntimeToken = Any
|
|
|
|
class RuntimeTokenSet(threading.local):
|
|
tokens: Dict[core.Effect, Tuple[RuntimeToken, Device]]
|
|
output_tokens: Dict[Device, RuntimeToken]
|
|
output_runtime_tokens: Dict[Device, RuntimeToken]
|
|
|
|
def __init__(self):
|
|
self.tokens = {}
|
|
# TODO(sharadmv): remove redundant output token dictionary when minimum
|
|
# jaxlib version is bumped to 0.3.16.
|
|
self.output_tokens = {}
|
|
self.output_runtime_tokens = {}
|
|
|
|
def get_token(self, eff: core.Effect, device: Device) -> RuntimeToken:
|
|
if eff not in self.tokens:
|
|
self.tokens[eff] = device_put(np.zeros(0, np.bool_), device), device
|
|
elif self.tokens[eff][1] != device:
|
|
(old_token,), _ = self.tokens[eff]
|
|
old_token.aval = core.ShapedArray((0,), np.bool_)
|
|
self.tokens[eff] = device_put(old_token, device), device
|
|
return self.tokens[eff][0]
|
|
|
|
def update_token(self, eff: core.Effect, token: RuntimeToken):
|
|
self.tokens[eff] = token, self.tokens[eff][1]
|
|
|
|
def set_output_token(self, device: Device, token: RuntimeToken):
|
|
# We're free to clobber the previous output token because on each
|
|
# device we have a total ordering of computations. Only the token
|
|
# from the latest computation matters. If this weren't the case
|
|
# we'd need to store a set of output tokens.
|
|
self.output_tokens[device] = token
|
|
|
|
def set_output_runtime_token(self, device: Device, token: RuntimeToken):
|
|
# TODO(sharadmv): remove this method when minimum jaxlib version is bumped
|
|
self.output_runtime_tokens[device] = token
|
|
|
|
def clear(self):
|
|
self.tokens = {}
|
|
self.output_tokens = {}
|
|
self.output_runtime_tokens = {}
|
|
|
|
def block_until_ready(self):
|
|
for token, _ in self.tokens.values():
|
|
token[0].block_until_ready()
|
|
for token in self.output_tokens.values():
|
|
token[0].block_until_ready()
|
|
for token in self.output_runtime_tokens.values():
|
|
token.block_until_ready()
|
|
self.clear()
|
|
|
|
runtime_tokens: RuntimeTokenSet = RuntimeTokenSet()
|
|
|
|
@atexit.register
|
|
def wait_for_tokens():
|
|
runtime_tokens.block_until_ready()
|
|
|
|
@util.cache()
|
|
def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params):
|
|
_, arg_devices = util.unzip2(arg_specs)
|
|
donated_invars = (False,) * len(arg_specs)
|
|
if config.jax_array:
|
|
# This will be resolved in sharded_lowering.
|
|
device = None
|
|
else:
|
|
device = _device_from_arg_devices(arg_devices)
|
|
def prim_fun(*args):
|
|
out = prim.bind(*args, **params)
|
|
if prim.multiple_results:
|
|
return out
|
|
else:
|
|
return out,
|
|
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
|
|
prim.name, donated_invars, False, *arg_specs)
|
|
if not prim.multiple_results:
|
|
return lambda *args, **kw: compiled(*args, **kw)[0]
|
|
else:
|
|
return compiled
|
|
|
|
|
|
def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[Device]:
|
|
"""Given devices of inputs, determine where to perform a computation.
|
|
|
|
Args:
|
|
devices: list where each element is a either a `Device` instance or `None`.
|
|
Returns:
|
|
A `Device` instance or None.
|
|
Raises:
|
|
ValueError if input devices are inconsistent.
|
|
"""
|
|
try:
|
|
device, = {d for d in devices if d is not None} or (None,)
|
|
return device
|
|
except ValueError as err:
|
|
msg = "primitive arguments must be colocated on the same device, got {}"
|
|
raise ValueError(msg.format(", ".join(map(str, devices)))) from err
|
|
|
|
|
|
# JIT execution
|
|
|
|
|
|
def _xla_call_impl_lazy(fun: lu.WrappedFun, *args, device, backend, name,
|
|
donated_invars, inline, keep_unused: bool):
|
|
del inline # Only used at tracing time
|
|
if fun.in_type is None:
|
|
arg_specs: Iterable[Any] = unsafe_map(arg_spec, args)
|
|
else:
|
|
# fun.in_type is used for dynamic shapes.
|
|
if config.jax_array:
|
|
raise NotImplementedError('Dynamic shapes do not work with Array.')
|
|
arg_specs = [(None, getattr(x, '_device', None)) for x in args]
|
|
return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
|
|
*arg_specs)
|
|
|
|
|
|
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
|
|
donated_invars, inline, keep_unused: bool):
|
|
compiled_fun = _xla_call_impl_lazy(
|
|
fun,
|
|
*args,
|
|
device=device,
|
|
backend=backend,
|
|
name=name,
|
|
donated_invars=donated_invars,
|
|
inline=inline,
|
|
keep_unused=keep_unused)
|
|
# TODO(parkers): Maybe apply this more generally in the case of the c++
|
|
# fallback.
|
|
try:
|
|
return compiled_fun(*args)
|
|
except FloatingPointError:
|
|
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
|
|
print("Invalid value encountered in the output of a jit-decorated function. "
|
|
"Calling the de-optimized version.")
|
|
# We want to run the wrapped function again (after xla_callable already ran
|
|
# it), but linear_util.WrappedFun instances are meant to be run only once.
|
|
# In addition to re-executing the Python code, which is usually undesirable
|
|
# but which config.jax_debug_nans is meant to opt into, we'll be
|
|
# re-executing any linear_util.py-style side effects, i.e. re-populating
|
|
# Stores created by any transformation_with_aux's applied to fun. Since this
|
|
# is intentional here, to avoid "Store occupied" errors we clone the
|
|
# WrappedFun with empty stores.
|
|
stores = [lu.Store() for _ in fun.stores]
|
|
clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params,
|
|
fun.in_type)
|
|
|
|
with core.new_sublevel():
|
|
_ = clone.call_wrapped(*args) # may raise, not return
|
|
|
|
# If control reaches this line, we got a NaN on the output of `compiled_fun`
|
|
# but not `clone.call_wrapped` on the same arguments. Let's tell the user.
|
|
fun_info = pe.fun_sourceinfo(fun.f)
|
|
msg = ("An invalid value was encountered in the output of the "
|
|
f"`jit`-decorated function {fun_info}. Because "
|
|
"config.jax_debug_nans and/or config.jax_debug_infs is set, the "
|
|
"de-optimized function (i.e., the function as if the `jit` "
|
|
"decorator were removed) was called in an attempt to get a more "
|
|
"precise error message. However, the de-optimized function did not "
|
|
"produce invalid values during its execution. This behavior can "
|
|
"result from `jit` optimizations causing the invalud value to be "
|
|
"produced. It may also arise from having nan/inf constants as "
|
|
"outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
|
|
"\n\n"
|
|
"It may be possible to avoid the invalid value by removing the "
|
|
"`jit` decorator, at the cost of losing optimizations. "
|
|
"\n\n"
|
|
"If you see this error, consider opening a bug report at "
|
|
"https://github.com/google/jax.")
|
|
raise FloatingPointError(msg)
|
|
|
|
xla.xla_call_p.def_impl(_xla_call_impl)
|
|
|
|
|
|
# TODO(yashkatariya,mattjj): Try to handle this in api.py via a device_put and
|
|
# don't pass the device and backend argument to `_xla_callable_uncached`.
|
|
def not_none_device_or_backend_on_jit(backend, device, num_ins):
|
|
"""This is to support the backend and device argument on jit. It's a feature
|
|
that's deprecated but needs to be supported for feature parity and so that we
|
|
can delete the non-Array paths when Array is switched on.
|
|
"""
|
|
# TODO(yashkatariya): Remove this entire function when backend and device are
|
|
# removed as arguments on jit.
|
|
if device is not None and backend is not None:
|
|
raise ValueError("can't specify both a device and a backend for jit, "
|
|
"got device={} and backend={}".format(device, backend))
|
|
|
|
if backend is not None:
|
|
da = [xb.get_backend(backend).get_default_device_assignment(1)[0]]
|
|
else:
|
|
assert device is not None
|
|
da = [device]
|
|
|
|
assert len(da) == 1
|
|
# in_shardings will be marked as replicated regardless of whatever the input
|
|
# had. Given that only a single device is allowed above, this is correct.
|
|
in_shardings = [OpShardingSharding.get_replicated(da)] * num_ins
|
|
return da, in_shardings
|
|
|
|
|
|
def sharded_lowering(fun, device, backend, name, donated_invars, always_lower,
|
|
keep_unused, *arg_specs):
|
|
# TODO(yashkatariya): Remove the local imports from here when the functions
|
|
# in pxla.py move to dispatch.py or a utils file.
|
|
from jax.interpreters import pxla
|
|
from jax.experimental import pjit
|
|
|
|
in_avals, in_shardings = util.unzip2(arg_specs)
|
|
|
|
da = None
|
|
if backend is not None or device is not None:
|
|
da, in_shardings = not_none_device_or_backend_on_jit(
|
|
backend, device, len(in_shardings))
|
|
|
|
in_shardings = [pxla._UNSPECIFIED if i is None else i for i in in_shardings]
|
|
|
|
# Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know
|
|
# the number of output avals at this stage. lower_sharding_computation will
|
|
# apply it to all out_avals.
|
|
return pxla.lower_sharding_computation(
|
|
fun, 'jit', name, in_shardings, pjit._UNSPECIFIED, donated_invars,
|
|
in_avals, in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused,
|
|
always_lower=always_lower, devices_from_context=da)
|
|
|
|
|
|
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
|
|
donated_invars, keep_unused, *arg_specs):
|
|
if config.jax_array:
|
|
computation = sharded_lowering(fun, device, backend, name, donated_invars,
|
|
False, keep_unused, *arg_specs)
|
|
return computation.compile(_allow_propagation_to_outputs=True).unsafe_call
|
|
else:
|
|
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
|
|
keep_unused, *arg_specs).compile().unsafe_call
|
|
|
|
xla_callable = lu.cache(_xla_callable_uncached)
|
|
|
|
|
|
def is_single_device_sharding(sharding) -> bool:
|
|
# Special case PmapSharding here because PmapSharding maps away an axis
|
|
# and needs to be handled separately.test_pjit_single_device_sharding_add
|
|
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def log_elapsed_time(fmt: str, event: Optional[str] = None):
|
|
if _on_exit:
|
|
yield
|
|
else:
|
|
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
|
start_time = time.time()
|
|
yield
|
|
elapsed_time = time.time() - start_time
|
|
logger.log(log_priority, fmt.format(elapsed_time=elapsed_time))
|
|
if event is not None:
|
|
record_event_duration_secs(event, elapsed_time)
|
|
|
|
|
|
def should_tuple_args(num_args: int, platform: str):
|
|
# CPU and GPU do not need tuples as they use host-side data structures that
|
|
# do not have small bounds.
|
|
# TPU only needs a tuple for very long lists
|
|
if platform == "tpu":
|
|
return num_args > 2000
|
|
else:
|
|
return False
|
|
|
|
|
|
def raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr):
|
|
if nreps > 1:
|
|
warnings.warn(
|
|
f"The jitted function {name} includes a pmap. Using "
|
|
"jit-of-pmap can lead to inefficient data movement, as the outer jit "
|
|
"does not preserve sharded data representations and instead collects "
|
|
"input and output arrays onto a single device. "
|
|
"Consider removing the outer jit unless you know what you're doing. "
|
|
"See https://github.com/google/jax/issues/2926.")
|
|
|
|
if nreps > xb.device_count(backend):
|
|
raise ValueError(
|
|
f"compiling computation `{name}` that requires {nreps} replicas, but "
|
|
f"only {xb.device_count(backend)} XLA devices are available.")
|
|
|
|
if xb.process_count() > 1 and (nreps > 1 or
|
|
jaxpr_has_primitive(jaxpr, "xla_pmap")):
|
|
raise NotImplementedError(
|
|
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
|
|
"extra data movement anyway, so maybe you don't want it after all).")
|
|
|
|
|
|
@profiler.annotate_function
|
|
def lower_xla_callable(
|
|
fun: lu.WrappedFun, device, backend, name, donated_invars,
|
|
always_lower: bool, keep_unused: bool, *arg_specs):
|
|
"""Lower into XLA.
|
|
|
|
Args:
|
|
always_lower: If `True`, even trivial programs (not doing any computation
|
|
such as lambda x: x) will be lowered into an XLA program.
|
|
keep_unused: If `False` (the default), arguments that JAX determines to be
|
|
unused by `fun` *may* be dropped from resulting compiled XLA executables.
|
|
Such arguments will not be transferred to the device nor provided to the
|
|
underlying executable. If `True`, unused arguments will not be pruned.
|
|
"""
|
|
if device is not None and backend is not None:
|
|
raise ValueError("can't specify both a device and a backend for jit, "
|
|
"got device={} and backend={}".format(device, backend))
|
|
abstract_args, arg_devices = util.unzip2(arg_specs)
|
|
if fun.in_type is None:
|
|
# Add an annotation inferred from the arguments; no dynamic axes here.
|
|
in_type = tuple(unsafe_zip(abstract_args, itertools.repeat(True)))
|
|
fun = lu.annotate(fun, in_type)
|
|
else:
|
|
assert abstract_args == (None,) * len(abstract_args)
|
|
abstract_args = tuple(aval for aval, _ in fun.in_type)
|
|
|
|
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
|
"for jit in {elapsed_time} sec",
|
|
event=JAXPR_TRACE_EVENT):
|
|
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
|
|
fun, pe.debug_info_final(fun, "jit"))
|
|
out_avals, kept_outputs = util.unzip2(out_type)
|
|
|
|
if any(isinstance(c, core.Tracer) for c in consts):
|
|
raise UnexpectedTracerError("Encountered an unexpected tracer.")
|
|
|
|
if config.jax_dynamic_shapes:
|
|
keep_unused = True
|
|
has_outfeed = False
|
|
donated_invars = [False] * len(fun.in_type)
|
|
else:
|
|
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
|
|
jaxpr = apply_outfeed_rewriter(jaxpr)
|
|
|
|
if not keep_unused:
|
|
jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
|
|
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
|
|
abstract_args, arg_devices = util.unzip2(
|
|
[a for i, a in enumerate(arg_specs) if i in kept_var_idx])
|
|
donated_invars = [x for i, x in enumerate(donated_invars)
|
|
if i in kept_var_idx]
|
|
del kept_const_idx
|
|
else:
|
|
kept_var_idx = set(range(len(fun.in_type)))
|
|
|
|
nreps = jaxpr_replicas(jaxpr)
|
|
device = _xla_callable_device(nreps, backend, device, arg_devices)
|
|
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)
|
|
|
|
if config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr):
|
|
jaxpr, consts = pe.pad_jaxpr(jaxpr, consts)
|
|
|
|
map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr)))
|
|
|
|
# Computations that only produce constants and/or only rearrange their inputs,
|
|
# which are often produced from partial evaluation, don't need compilation,
|
|
# and don't need to evaluate their arguments.
|
|
if (not always_lower and not (jaxpr.effects or has_outfeed) and
|
|
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars)):
|
|
return XlaComputation(
|
|
name, None, True, None, None, None, jaxpr=jaxpr, consts=consts,
|
|
device=device, in_avals=abstract_args, out_avals=out_avals,
|
|
has_unordered_effects=False, ordered_effects=[],
|
|
kept_var_idx=kept_var_idx, keepalive=None, host_callbacks=[])
|
|
|
|
if not _on_exit:
|
|
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
|
if len(abstract_args) > 10:
|
|
msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args."
|
|
else:
|
|
msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}."
|
|
logger.log(log_priority, msg)
|
|
|
|
raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr)
|
|
|
|
# pass long arg lists as tuple for TPU
|
|
tuple_args = should_tuple_args(len(abstract_args), backend.platform)
|
|
axis_env = xla.AxisEnv(nreps, (), ())
|
|
name_stack = util.new_name_stack(util.wrap_name(name, 'jit'))
|
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
|
closed_out_type = [
|
|
(a.update(shape=tuple(pe.InDBIdx(d.val - len(consts))
|
|
if type(d) is pe.InDBIdx else d for d in a.shape))
|
|
if type(a) is core.DShapedArray else a, b) for a, b in out_type]
|
|
module_name = f"jit_{fun.__name__}"
|
|
unordered_effects = [eff for eff in closed_jaxpr.effects
|
|
if eff not in core.ordered_effects]
|
|
ordered_effects = [eff for eff in closed_jaxpr.effects
|
|
if eff in core.ordered_effects]
|
|
lowering_result = mlir.lower_jaxpr_to_module(
|
|
module_name, closed_jaxpr, unordered_effects,
|
|
ordered_effects, backend, backend.platform,
|
|
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
|
|
module, keepalive, host_callbacks = (
|
|
lowering_result.module, lowering_result.keepalive,
|
|
lowering_result.host_callbacks)
|
|
return XlaComputation(
|
|
name, module, False, donated_invars, fun.in_type, tuple(closed_out_type),
|
|
nreps=nreps, device=device, backend=backend, tuple_args=tuple_args,
|
|
in_avals=abstract_args, out_avals=out_avals,
|
|
has_unordered_effects=bool(unordered_effects),
|
|
ordered_effects=ordered_effects, kept_var_idx=kept_var_idx,
|
|
keepalive=keepalive, host_callbacks=host_callbacks)
|
|
|
|
|
|
def _backend_supports_unbounded_dynamic_shapes(backend: Backend) -> bool:
|
|
return backend.platform == 'iree'
|
|
|
|
|
|
def prefetch(x):
|
|
if isinstance(x, device_array.DeviceArray):
|
|
x.copy_to_host_async()
|
|
return x
|
|
|
|
|
|
def jaxpr_literals(jaxpr):
|
|
"""Generates all the literals inside a jaxpr, including nested subjaxprs."""
|
|
for eqn in jaxpr.eqns:
|
|
for v in eqn.invars:
|
|
if type(v) is core.Literal:
|
|
yield v.val
|
|
for subjaxpr in core.subjaxprs(jaxpr):
|
|
yield from jaxpr_literals(subjaxpr)
|
|
|
|
|
|
def jaxpr_has_primitive(jaxpr, prim_name: str):
|
|
"""Whether there is a primitive given by user anywhere inside a Jaxpr."""
|
|
for eqn in jaxpr.eqns:
|
|
if prim_name in eqn.primitive.name:
|
|
return True
|
|
for subjaxpr in core.subjaxprs(jaxpr):
|
|
if jaxpr_has_primitive(subjaxpr, prim_name):
|
|
return True
|
|
return False
|
|
|
|
|
|
def jaxpr_shardings(jaxpr) -> Iterator[jax.sharding.XLACompatibleSharding]:
|
|
from jax.experimental import pjit
|
|
|
|
for eqn in jaxpr.eqns:
|
|
if eqn.primitive is pjit.sharding_constraint_p:
|
|
yield eqn.params['sharding']
|
|
elif eqn.primitive is pjit.pjit_p:
|
|
yield from eqn.params['in_shardings']
|
|
yield from eqn.params['out_shardings']
|
|
for subjaxpr in core.subjaxprs(jaxpr):
|
|
yield from jaxpr_shardings(subjaxpr)
|
|
|
|
|
|
def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:
|
|
return (any(type(v.aval.dtype) is core.bint for v in jaxpr.invars
|
|
if isinstance(v.aval, core.UnshapedArray)) or
|
|
any(_is_bint_axis_size(d)
|
|
for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr))
|
|
for e in j.eqns for v in e.outvars
|
|
if isinstance(v.aval, core.DShapedArray) for d in v.aval.shape))
|
|
|
|
def _is_bint_axis_size(d: core.AxisSize) -> bool:
|
|
if isinstance(d, core.DArray):
|
|
assert not d.shape
|
|
return type(d.dtype) is core.bint
|
|
elif isinstance(d, core.Var):
|
|
return (isinstance(d.aval, core.DShapedArray) and
|
|
type(d.aval.dtype) is core.bint)
|
|
return False
|
|
|
|
def _prune_unused_inputs(
|
|
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
|
|
used_outputs = [True] * len(jaxpr.outvars)
|
|
new_jaxpr, used_consts, used_inputs = pe.dce_jaxpr_consts(jaxpr, used_outputs)
|
|
kept_const_idx = {i for i, b in enumerate(used_consts) if b}
|
|
kept_var_idx = {i for i, b in enumerate(used_inputs) if b}
|
|
return new_jaxpr, kept_const_idx, kept_var_idx
|
|
|
|
|
|
# We can optionally set a Jaxpr rewriter that can be applied just before
|
|
# compilation. This mechanism is used for compiling id_tap, we can
|
|
# remove it once we bring the id_tap implementation into the core.
|
|
outfeed_rewriter: Optional[Callable[[core.Jaxpr], core.Jaxpr]] = None
|
|
def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
|
if outfeed_rewriter is not None:
|
|
return outfeed_rewriter(jaxpr)
|
|
else:
|
|
return jaxpr
|
|
|
|
|
|
# TODO(mattjj,necula): this duplicates code in core.valid_jaxtype, but one
|
|
# internal user relies on it for duck-typing. must fix downstream user!
|
|
def _valid_jaxtype(arg):
|
|
try:
|
|
xla.abstractify(arg) # faster than core.get_aval
|
|
except TypeError:
|
|
return core.valid_jaxtype(arg)
|
|
else:
|
|
return True
|
|
|
|
def check_arg(arg):
|
|
if not (isinstance(arg, core.Tracer) or _valid_jaxtype(arg)):
|
|
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid "
|
|
"JAX type.")
|
|
|
|
|
|
def jaxpr_replicas(jaxpr) -> int:
|
|
"""The number of replicas needed for a jaxpr.
|
|
|
|
For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the
|
|
subjaxprs. For a list of eqns, take the maximum number of replicas.
|
|
"""
|
|
if isinstance(jaxpr, core.ClosedJaxpr):
|
|
jaxpr = jaxpr.jaxpr
|
|
return max(unsafe_map(eqn_replicas, jaxpr.eqns), default=1)
|
|
|
|
# TODO(mattjj): this function assumes that only pmap has a parameter named
|
|
# axis_size, and that it corresponds to cross-replica mapping
|
|
def eqn_replicas(eqn):
|
|
call_jaxpr = eqn.params.get("call_jaxpr")
|
|
if call_jaxpr:
|
|
return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr)
|
|
elif eqn.primitive in xla._initial_style_primitives:
|
|
return initial_style_primitive_replicas(eqn.params)
|
|
else:
|
|
return 1
|
|
|
|
def initial_style_primitive_replicas(params):
|
|
return max(core.traverse_jaxpr_params(jaxpr_replicas, params).values(), default=1)
|
|
|
|
|
|
def _xla_callable_device(nreps, backend, device,
|
|
arg_devices) -> Optional[Device]:
|
|
if nreps > 1:
|
|
if device is not None or backend is not None:
|
|
raise ValueError(f"can't specify device or backend for jit-of-pmap, "
|
|
f"got {device=} and {backend=}")
|
|
return None
|
|
else:
|
|
# TODO(skye): dedup with C++ jit logic for determining jit device?
|
|
if device is not None:
|
|
assert backend is None
|
|
return device
|
|
|
|
if backend is not None:
|
|
return xb.get_backend(backend).get_default_device_assignment(1)[0]
|
|
|
|
arg_device = _device_from_arg_devices(arg_devices)
|
|
if arg_device is not None:
|
|
return arg_device
|
|
|
|
return config.jax_default_device
|
|
|
|
|
|
# Argument and result handlers
|
|
|
|
num_buffers_handlers: Dict[Type[core.AbstractValue],
|
|
Callable[[core.AbstractValue], int]] = {}
|
|
|
|
def aval_to_num_buffers(aval: core.AbstractValue) -> int:
|
|
"""Returns the number of buffers in the runtime representation of `aval`.
|
|
|
|
In general this may differ from the number of buffers in the compiler-IR
|
|
representation of the same value.
|
|
"""
|
|
try:
|
|
return num_buffers_handlers[type(aval)](aval)
|
|
except KeyError as err:
|
|
raise TypeError(f"No num_buffers handler for type: {type(aval)}") from err
|
|
|
|
num_buffers_handlers[core.AbstractToken] = lambda _: 1
|
|
num_buffers_handlers[core.ShapedArray] = lambda _: 1
|
|
num_buffers_handlers[core.DShapedArray] = lambda _: 1
|
|
num_buffers_handlers[core.ConcreteArray] = lambda _: 1
|
|
|
|
|
|
def _input_handler(backend: Backend,
|
|
in_type: Optional[pe.InputType],
|
|
out_type: Optional[pe.OutputType],
|
|
) -> Optional[Callable]:
|
|
if in_type is None:
|
|
assert out_type is None
|
|
return None
|
|
in_avals, which_explicit = util.unzip2(in_type)
|
|
# Check whether we actually need an input_handler.
|
|
needs_implicit = which_explicit and not all(which_explicit)
|
|
needs_out_handling = any(type(d) is core.InDBIdx for a, _ in out_type or []
|
|
if type(a) is core.DShapedArray for d in a.shape)
|
|
|
|
if not needs_implicit and not needs_out_handling:
|
|
return None
|
|
assert config.jax_dynamic_shapes
|
|
|
|
# Precompute how to grab implicit inputs from explicit inputs' axis sizes.
|
|
which_explicit = which_explicit or (True,) * len(in_avals)
|
|
implicit_idxs = {i for i, ex in enumerate(which_explicit) if not ex}
|
|
implicit_args_from_axes: List[Tuple[int, int, int]] = []
|
|
for arg_idx, aval in enumerate(in_avals):
|
|
if isinstance(aval, core.DShapedArray):
|
|
for axis_idx, d in enumerate(aval.shape):
|
|
if isinstance(d, core.DBIdx) and d.val in implicit_idxs:
|
|
implicit_args_from_axes.append((d.val, arg_idx, axis_idx))
|
|
assert {i for i, _, _ in implicit_args_from_axes} == implicit_idxs
|
|
|
|
# Precompute which input values are needed for output types.
|
|
inputs_needed_for_out_types = out_type and [
|
|
d.val for aval, _ in out_type if type(aval) is core.DShapedArray # type: ignore
|
|
for d in aval.shape if type(d) is core.InDBIdx]
|
|
|
|
def elaborate(explicit_args: Sequence[Any]) -> Tuple[Tuple, Optional[Tuple]]:
|
|
if needs_implicit:
|
|
# Build full argument list, leaving Nones for implicit arguments.
|
|
explicit_args_ = iter(explicit_args)
|
|
args = [next(explicit_args_) if ex else None for ex in which_explicit]
|
|
assert next(explicit_args_, None) is None
|
|
# Populate implicit arguments.
|
|
for i, j, k in implicit_args_from_axes:
|
|
if args[i] is None:
|
|
args[i] = args[j].shape[k] # type: ignore
|
|
else:
|
|
if args[i] != args[j].shape[k]:
|
|
raise Exception("inconsistent argument axis sizes for type")
|
|
else:
|
|
args = list(explicit_args)
|
|
|
|
if needs_out_handling:
|
|
# Make a list of inputs needed by output types, leaving unneeded as None.
|
|
out_type_env = [None] * len(args)
|
|
for i in inputs_needed_for_out_types or []:
|
|
out_type_env[i] = args[i]
|
|
else:
|
|
out_type_env = None # type: ignore
|
|
|
|
return tuple(args), out_type_env and tuple(out_type_env) # type: ignore
|
|
return elaborate
|
|
|
|
def _result_handler(backend: Backend,
|
|
sticky_device: Optional[Device],
|
|
out_type: pe.OutputType,
|
|
) -> Callable:
|
|
out_avals, kept_outputs = util.unzip2(out_type)
|
|
handlers = map(partial(aval_to_result_handler, sticky_device), out_avals)
|
|
dyn_outs = any(type(aval) is core.DShapedArray and
|
|
any(type(d) in (core.InDBIdx, core.OutDBIdx) for d in aval.shape)
|
|
for aval in out_avals)
|
|
if not dyn_outs:
|
|
return SimpleResultHandler(handlers)
|
|
assert config.jax_dynamic_shapes
|
|
|
|
def result_handler(input_env, lists_of_bufs):
|
|
results = []
|
|
for handler, bufs in unsafe_zip(handlers, lists_of_bufs):
|
|
results.append(handler((input_env, results), *bufs))
|
|
return [r for r, keep in unsafe_zip(results, kept_outputs) if keep]
|
|
return result_handler
|
|
|
|
class SimpleResultHandler:
|
|
handlers: Sequence[ResultHandler]
|
|
def __init__(self, handlers): self.handlers = handlers
|
|
def __iter__(self): return iter(self.handlers)
|
|
def __len__(self): return len(self.handlers)
|
|
def __call__(self, env, lists_of_bufs):
|
|
return tuple(h(env, *bs) for h, bs in zip(self.handlers, lists_of_bufs))
|
|
|
|
|
|
def maybe_create_array_from_da(buf, aval, device):
|
|
if config.jax_array:
|
|
return array.ArrayImpl(aval, SingleDeviceSharding(buf.device()), [buf],
|
|
committed=(device is not None), _skip_checks=True)
|
|
else:
|
|
return device_array.make_device_array(aval, device, buf)
|
|
|
|
|
|
if MYPY:
|
|
ResultHandler = Any
|
|
else:
|
|
class ResultHandler(Protocol):
|
|
def __call__(self, env: Optional[Sequence[Any]], *args: xla.Buffer) -> Any:
|
|
"""Boxes raw buffers into their user-facing representation."""
|
|
|
|
def aval_to_result_handler(sticky_device: Optional[Device],
|
|
aval: core.AbstractValue) -> ResultHandler:
|
|
try:
|
|
return result_handlers[type(aval)](sticky_device, aval)
|
|
except KeyError as err:
|
|
raise TypeError(f"No result handler for type: {type(aval)}") from err
|
|
|
|
def array_result_handler(sticky_device: Optional[Device],
|
|
aval: core.ShapedArray):
|
|
if not core.is_opaque_dtype(aval.dtype) and aval.dtype == dtypes.float0:
|
|
return lambda _, __: np.zeros(aval.shape, dtypes.float0)
|
|
aval = core.raise_to_shaped(aval)
|
|
if core.is_opaque_dtype(aval.dtype):
|
|
return aval.dtype._rules.result_handler(sticky_device, aval)
|
|
handler = lambda _, b: maybe_create_array_from_da(b, aval, sticky_device)
|
|
handler.args = aval, sticky_device # for C++ dispatch path in api.py
|
|
return handler
|
|
|
|
def dynamic_array_result_handler(sticky_device: Optional[Device],
|
|
aval: core.DShapedArray):
|
|
if not core.is_opaque_dtype(aval.dtype) and aval.dtype == dtypes.float0:
|
|
return lambda _: np.zeros(aval.shape, dtypes.float0) # type: ignore
|
|
else:
|
|
return partial(_dynamic_array_result_handler, sticky_device, aval)
|
|
|
|
def _dynamic_array_result_handler(sticky_device, aval, env, buf):
|
|
in_env, out_env = env or (None, None)
|
|
shape = [in_env[d.val] if type(d) is core.InDBIdx else
|
|
out_env[d.val] if type(d) is core.OutDBIdx else d
|
|
for d in aval.shape]
|
|
if all(type(d) is int for d in shape) and type(aval.dtype) is not core.bint:
|
|
aval = core.ShapedArray(tuple(shape), buf.dtype)
|
|
return maybe_create_array_from_da(buf, aval, sticky_device)
|
|
else:
|
|
pad_shape = [d.dtype.bound if _is_bint_axis_size(d) else d for d in shape]
|
|
buf_aval = core.ShapedArray(tuple(pad_shape), buf.dtype, aval.weak_type)
|
|
data = maybe_create_array_from_da(buf, buf_aval, sticky_device)
|
|
return core.DArray(aval.update(shape=tuple(shape)), data)
|
|
|
|
|
|
|
|
result_handlers: Dict[
|
|
Type[core.AbstractValue],
|
|
Callable[[Optional[Device], Any], ResultHandler]] = {}
|
|
result_handlers[core.AbstractToken] = lambda _, __: lambda _, __: core.token
|
|
result_handlers[core.ShapedArray] = array_result_handler
|
|
result_handlers[core.DShapedArray] = dynamic_array_result_handler
|
|
result_handlers[core.ConcreteArray] = array_result_handler
|
|
|
|
|
|
def needs_check_special():
|
|
return config.jax_debug_infs or config.jax_debug_nans
|
|
|
|
def check_special(name, bufs):
|
|
if needs_check_special():
|
|
for buf in bufs:
|
|
_check_special(name, buf.dtype, buf)
|
|
|
|
def _check_special(name, dtype, buf):
|
|
if dtypes.issubdtype(dtype, np.inexact):
|
|
if config.jax_debug_nans and np.any(np.isnan(np.asarray(buf))):
|
|
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
|
|
if config.jax_debug_infs and np.any(np.isinf(np.asarray(buf))):
|
|
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
|
|
|
def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
|
|
has_host_callbacks: bool, device: Device, input_bufs):
|
|
tokens = [runtime_tokens.get_token(eff, device) for eff in ordered_effects]
|
|
tokens_flat = flatten(tokens)
|
|
input_bufs = [*tokens_flat, *input_bufs]
|
|
def _remove_tokens(output_bufs, runtime_token):
|
|
num_output_tokens = len(ordered_effects)
|
|
token_bufs, output_bufs = util.split_list(output_bufs, [num_output_tokens])
|
|
if has_unordered_effects or has_host_callbacks:
|
|
runtime_tokens.set_output_runtime_token(device, runtime_token)
|
|
for eff, token_buf in zip(ordered_effects, token_bufs):
|
|
runtime_tokens.update_token(eff, token_buf)
|
|
return output_bufs
|
|
return input_bufs, _remove_tokens
|
|
|
|
|
|
def _execute_compiled(name: str, compiled: XlaLoadedExecutable,
|
|
input_handler: Optional[Callable],
|
|
output_buffer_counts: Sequence[int],
|
|
result_handler: Callable, has_unordered_effects: bool,
|
|
ordered_effects: List[core.Effect], kept_var_idx,
|
|
has_host_callbacks: bool, *args):
|
|
device, = compiled.local_devices()
|
|
args, env = input_handler(args) if input_handler else (args, None)
|
|
in_flat = flatten(device_put(x, device) for i, x in enumerate(args)
|
|
if i in kept_var_idx)
|
|
if has_unordered_effects or ordered_effects or has_host_callbacks:
|
|
in_flat, token_handler = _add_tokens(
|
|
has_unordered_effects, ordered_effects, has_host_callbacks, device,
|
|
in_flat)
|
|
out_flat, runtime_token = compiled.execute_with_token(in_flat)
|
|
else:
|
|
out_flat = compiled.execute(in_flat)
|
|
check_special(name, out_flat)
|
|
out_bufs = unflatten(out_flat, output_buffer_counts)
|
|
if ordered_effects or has_unordered_effects or has_host_callbacks:
|
|
out_bufs = token_handler(out_bufs, runtime_token)
|
|
return result_handler(env, out_bufs)
|
|
|
|
|
|
def _execute_replicated(name: str,
|
|
compiled: XlaLoadedExecutable,
|
|
input_handler: Optional[Callable],
|
|
output_buffer_counts: Sequence[int],
|
|
result_handler: Callable,
|
|
has_unordered_effects: bool,
|
|
ordered_effects: List[core.Effect],
|
|
kept_var_idx,
|
|
has_host_callbacks: bool,
|
|
*args,
|
|
from_lower_sharding_computation: bool = False):
|
|
if has_unordered_effects or ordered_effects:
|
|
# TODO(sharadmv): support jit-of-pmap with effects
|
|
raise NotImplementedError(
|
|
"Cannot execute replicated computation with effects.")
|
|
if input_handler: raise NotImplementedError # TODO(mattjj, dougalm)
|
|
input_bufs = [flatten(device_put(x, device) for i, x in enumerate(args)
|
|
if i in kept_var_idx)
|
|
for device in compiled.local_devices()]
|
|
input_bufs_flip = list(unsafe_zip(*input_bufs))
|
|
out_bufs_flat_rep = compiled.execute_sharded_on_local_devices(input_bufs_flip)
|
|
out_flat = [bufs[0] for bufs in out_bufs_flat_rep]
|
|
check_special(name, out_flat)
|
|
out_bufs = unflatten(out_flat, output_buffer_counts)
|
|
if from_lower_sharding_computation:
|
|
return result_handler(out_bufs)
|
|
return result_handler(None, out_bufs)
|
|
|
|
|
|
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
|
|
has_unordered_effects: bool,
|
|
ordered_effects: List[core.Effect], kept_var_idx,
|
|
host_callbacks, *args):
|
|
env: Dict[core.Var, Any] = {}
|
|
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
|
|
map(env.setdefault, jaxpr.invars, pruned_args)
|
|
map(env.setdefault, jaxpr.constvars, consts)
|
|
outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
|
|
for v in jaxpr.outvars]
|
|
return [_copy_device_array_to_device(x, device) if device_array.type_is_device_array(x)
|
|
else h(None, *device_put(x, device)) for h, x in zip(handlers, outs)]
|
|
|
|
|
|
class XlaComputation(stages.XlaLowering):
|
|
name: str
|
|
_is_trivial: bool
|
|
_executable: Optional[XlaCompiledComputation]
|
|
_donated_invars: Optional[Sequence[bool]]
|
|
|
|
def __init__(self, name: str, hlo: Optional[ir.Module], is_trivial: bool,
|
|
donated_invars: Optional[Sequence[bool]],
|
|
in_type: Optional[pe.InputType],
|
|
out_type: Optional[pe.OutputType],
|
|
**compile_args):
|
|
self.name = name
|
|
self._hlo = hlo
|
|
self._is_trivial = is_trivial
|
|
self._donated_invars = donated_invars
|
|
self._in_type = in_type
|
|
self._out_type = out_type
|
|
self._executable = None
|
|
self.compile_args = compile_args
|
|
|
|
def is_trivial(self):
|
|
return self._is_trivial
|
|
|
|
# -- stages.XlaLowering overrides
|
|
|
|
def hlo(self) -> xc.XlaComputation:
|
|
if self.is_trivial():
|
|
raise ValueError("A trivial computation has no HLO")
|
|
return xe.mlir.mlir_module_to_xla_computation(
|
|
mlir.module_to_string(self._hlo),
|
|
use_tuple_args=self.compile_args["tuple_args"])
|
|
|
|
def mhlo(self) -> ir.Module:
|
|
if use_stablehlo:
|
|
return super().mhlo()
|
|
else:
|
|
if self.is_trivial():
|
|
raise ValueError("A trivial computation has no MHLO")
|
|
return self._hlo
|
|
|
|
def stablehlo(self) -> ir.Module:
|
|
if use_stablehlo:
|
|
if self.is_trivial():
|
|
raise ValueError("A trivial computation has no StableHLO")
|
|
return self._hlo
|
|
else:
|
|
return super().stablehlo()
|
|
|
|
def compile(self) -> XlaCompiledComputation:
|
|
if self._executable is None:
|
|
if self.is_trivial():
|
|
self._executable = XlaCompiledComputation.from_trivial_jaxpr(
|
|
**self.compile_args)
|
|
else:
|
|
assert self._out_type is not None
|
|
self._executable = XlaCompiledComputation.from_xla_computation(
|
|
self.name, self._hlo, self._in_type, self._out_type,
|
|
**self.compile_args)
|
|
|
|
return self._executable
|
|
|
|
@profiler.annotate_function
|
|
def backend_compile(backend, built_c, options, host_callbacks):
|
|
# we use a separate function call to ensure that XLA compilation appears
|
|
# separately in Python profiling results
|
|
if host_callbacks:
|
|
return backend.compile(built_c, compile_options=options,
|
|
host_callbacks=host_callbacks)
|
|
# Some backends don't have `host_callbacks` option yet
|
|
# TODO(sharadmv): remove this fallback when all backends allow `compile`
|
|
# to take in `host_callbacks`
|
|
return backend.compile(built_c, compile_options=options)
|
|
|
|
# TODO(phawkins): update users.
|
|
xla.backend_compile = backend_compile
|
|
|
|
_ir_dump_counter = itertools.count()
|
|
|
|
def _make_string_safe_for_filename(s: str) -> str:
|
|
return re.sub(r'[^\w.)( -]', '', s)
|
|
|
|
def _dump_ir_to_file(name: str, ir: str):
|
|
id = next(_ir_dump_counter)
|
|
name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir"
|
|
name = path.Path(FLAGS.jax_dump_ir_to) / name
|
|
name.write_text(ir)
|
|
|
|
|
|
def compile_or_get_cached(backend, computation: ir.Module, compile_options,
|
|
host_callbacks):
|
|
# Avoid import cycle between jax and jax.experimental
|
|
from jax.experimental.compilation_cache import compilation_cache as cc
|
|
|
|
sym_name = computation.operation.attributes['sym_name']
|
|
module_name = ir.StringAttr(sym_name).value
|
|
|
|
if FLAGS.jax_dump_ir_to:
|
|
_dump_ir_to_file(module_name, mlir.module_to_string(computation))
|
|
|
|
# Convert ir.Module to a string representation, unless the
|
|
# back-end expliclity flags the ability to handle a module directly
|
|
# (avoiding the overhead of back and forth conversions)
|
|
serialized_computation: Union[str, bytes, ir.Module]
|
|
if getattr(backend, "needs_str_ir", True):
|
|
serialized_computation = mlir.module_to_bytecode(computation)
|
|
else:
|
|
serialized_computation = computation
|
|
|
|
# Persistent compilation cache only implemented on TPU.
|
|
# TODO(skye): add warning when initializing cache on unsupported default platform
|
|
supported_platforms = ["tpu"]
|
|
# GPU caching can be enabled if JitRt is enabled.
|
|
# TODO(b/232263664): Remove check when JitRt is enabled by default.
|
|
if "--xla_gpu_enable_xla_runtime_executable=true" in os.environ.get("XLA_FLAGS", ""):
|
|
supported_platforms.append("gpu")
|
|
# (b/233850967) CPU caching can be enabled if XLA Runtime is enabled.
|
|
if "--xla_cpu_use_xla_runtime=true" in os.environ.get("XLA_FLAGS", ""):
|
|
supported_platforms.append("cpu")
|
|
if cc.is_initialized() and backend.platform in supported_platforms:
|
|
cached_executable = _cache_read(serialized_computation, module_name,
|
|
compile_options, backend)
|
|
if cached_executable is not None:
|
|
logger.info("Persistent compilation cache hit for '%s'", module_name)
|
|
return cached_executable
|
|
else:
|
|
start_time = time.monotonic()
|
|
compiled = backend_compile(backend, serialized_computation,
|
|
compile_options, host_callbacks)
|
|
compile_time = time.monotonic() - start_time
|
|
_cache_write(serialized_computation, compile_time, module_name,
|
|
compile_options, backend, compiled)
|
|
return compiled
|
|
|
|
return backend_compile(backend, serialized_computation, compile_options,
|
|
host_callbacks)
|
|
|
|
|
|
def _cache_read(computation: Union[str, bytes, ir.Module], module_name: str,
|
|
compile_options: CompileOptions,
|
|
backend: Backend) -> Optional[XlaLoadedExecutable]:
|
|
"""Looks up `computation` in the persisent compilation cache."""
|
|
# Avoid import cycle between jax and jax.experimental
|
|
from jax.experimental.compilation_cache import compilation_cache as cc
|
|
|
|
try:
|
|
return cc.get_executable(computation, compile_options, backend)
|
|
except Exception as ex:
|
|
if config.jax_raise_persistent_cache_errors:
|
|
raise
|
|
warnings.warn(
|
|
f"Error reading persistent compilation cache entry for "
|
|
f"'{module_name}': {type(ex).__name__}: {ex}")
|
|
return None
|
|
|
|
|
|
def _cache_write(serialized_computation: Union[str, bytes, ir.Module],
|
|
compile_time_secs: float,
|
|
module_name: str, compile_options: CompileOptions,
|
|
backend: Backend, compiled: XlaLoadedExecutable):
|
|
"""Writes `serialized_computation` to the persistent compilation cache."""
|
|
# Avoid import cycle between jax and jax.experimental
|
|
from jax.experimental.compilation_cache import compilation_cache as cc
|
|
|
|
min_compile_time = config.jax_persistent_cache_min_compile_time_secs
|
|
if min_compile_time:
|
|
if compile_time_secs < min_compile_time:
|
|
logger.info(
|
|
"Not writing persistent cache entry for '%s' because it took < %.2f "
|
|
"seconds to compile (%.2fs)", module_name, min_compile_time,
|
|
compile_time_secs)
|
|
return
|
|
else:
|
|
logger.info(
|
|
"'%s' took at least %.2f seconds to compile (%.2fs), writing "
|
|
"persistent cache entry", module_name, min_compile_time,
|
|
compile_time_secs)
|
|
|
|
try:
|
|
cc.put_executable(module_name, serialized_computation, compile_options,
|
|
compiled, backend)
|
|
except Exception as ex:
|
|
if config.jax_raise_persistent_cache_errors:
|
|
raise
|
|
warnings.warn(
|
|
f"Error writing persistent compilation cache entry for "
|
|
f"'{module_name}': {type(ex).__name__}: {ex}")
|
|
|
|
|
|
def _instruction_count(module: ir.Module, max_count: Optional[int] = None):
|
|
|
|
def _blocks_count(blocks, count):
|
|
for block in blocks:
|
|
for op in block.operations:
|
|
count += 1
|
|
# Untested premature performance optimization
|
|
if max_count is not None and count >= max_count:
|
|
return max_count
|
|
for region in op.regions:
|
|
count = _blocks_count(region.blocks, count)
|
|
return count
|
|
|
|
count = 0
|
|
for func in module.body.operations:
|
|
count = _blocks_count(func.body.blocks, count)
|
|
return count
|
|
|
|
|
|
def get_buffer_counts(out_avals, ordered_effects, has_unordered_effects):
|
|
buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals]
|
|
if ordered_effects or has_unordered_effects:
|
|
num_output_tokens = len(ordered_effects)
|
|
# TODO(sharadmv): remove check when minimum jaxlib version is bumped
|
|
buffer_counts = ([1] * num_output_tokens) + buffer_counts
|
|
return buffer_counts
|
|
|
|
|
|
class XlaCompiledComputation(stages.XlaExecutable):
|
|
def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call,
|
|
keepalive: Any):
|
|
self._xla_executable = xla_executable
|
|
self.in_avals = in_avals
|
|
self._kept_var_idx = kept_var_idx
|
|
self.unsafe_call = unsafe_call
|
|
# Only the `unsafe_call` function is cached, so to avoid the `keepalive`
|
|
# being garbage collected we attach it to `unsafe_call`.
|
|
self.unsafe_call.keepalive = keepalive
|
|
|
|
@staticmethod
|
|
def from_xla_computation(name: str, xla_computation: Optional[ir.Module],
|
|
in_type: Optional[pe.InputType],
|
|
out_type: pe.OutputType, nreps: int,
|
|
device: Optional[Device], backend: Backend,
|
|
tuple_args: bool,
|
|
in_avals: Sequence[core.AbstractValue],
|
|
out_avals: Sequence[core.AbstractValue],
|
|
has_unordered_effects: bool,
|
|
ordered_effects: List[core.Effect],
|
|
kept_var_idx: Set[int], keepalive: Optional[Any],
|
|
host_callbacks: List[Any]) -> XlaCompiledComputation:
|
|
sticky_device = device
|
|
input_handler = _input_handler(backend, in_type, out_type)
|
|
result_handler = _result_handler(backend, sticky_device, out_type)
|
|
options = xb.get_compile_options(
|
|
num_replicas=nreps, num_partitions=1,
|
|
device_assignment=(sticky_device,) if sticky_device else None)
|
|
options.parameter_is_tupled_arguments = tuple_args
|
|
with log_elapsed_time(f"Finished XLA compilation of {name} "
|
|
"in {elapsed_time} sec",
|
|
event=BACKEND_COMPILE_EVENT):
|
|
compiled = compile_or_get_cached(backend, xla_computation, options,
|
|
host_callbacks)
|
|
buffer_counts = get_buffer_counts(out_avals, ordered_effects,
|
|
has_unordered_effects)
|
|
execute = _execute_compiled if nreps == 1 else _execute_replicated
|
|
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts, # type: ignore # noqa: F811
|
|
result_handler, has_unordered_effects,
|
|
ordered_effects, kept_var_idx, bool(host_callbacks))
|
|
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call,
|
|
keepalive)
|
|
|
|
def is_trivial(self):
|
|
return self._xla_executable == None
|
|
|
|
@property
|
|
def xla_executable(self):
|
|
# TODO(frostig): remove in favor of runtime_executable?
|
|
if self.is_trivial():
|
|
raise ValueError("A trivial compiled computation has no XLA executable")
|
|
return self._xla_executable
|
|
|
|
@staticmethod
|
|
def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals,
|
|
has_unordered_effects, ordered_effects, kept_var_idx,
|
|
keepalive: Optional[Any],
|
|
host_callbacks: List[Any]) -> XlaCompiledComputation:
|
|
assert keepalive is None
|
|
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
|
|
unsafe_call = partial(_execute_trivial, jaxpr, device, consts, out_avals,
|
|
result_handlers, has_unordered_effects,
|
|
ordered_effects, kept_var_idx, bool(host_callbacks))
|
|
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call,
|
|
keepalive)
|
|
|
|
# -- stages.XlaExecutable overrides
|
|
|
|
def xla_extension_executable(self):
|
|
return self.xla_executable
|
|
|
|
def call(self, *args):
|
|
arg_specs = unsafe_map(arg_spec, args)
|
|
arg_avals = [spec[0] for i, spec in enumerate(arg_specs)
|
|
if i in self._kept_var_idx]
|
|
check_arg_avals_for_call(self.in_avals, arg_avals)
|
|
return self.unsafe_call(*args)
|
|
|
|
def check_arg_avals_for_call(ref_avals, arg_avals):
|
|
if len(ref_avals) != len(arg_avals):
|
|
raise TypeError(
|
|
f"Computation compiled for {len(ref_avals)} inputs "
|
|
f"but called with {len(arg_avals)}")
|
|
for ref_aval, arg_aval in zip(ref_avals, arg_avals):
|
|
if not core.typematch(ref_aval, arg_aval):
|
|
raise TypeError(
|
|
"Computation was compiled for different input types and called with "
|
|
"different types. One of the mismatches is:\n"
|
|
f"Compiled with:\n {ref_aval}\n"
|
|
f"called with:\n {arg_aval}")
|
|
|
|
|
|
def _set_aval(val):
|
|
if val.aval is None:
|
|
val.aval = core.ShapedArray(val.shape, val.dtype)
|
|
return val
|
|
|
|
|
|
def device_put(x, device: Optional[Device] = None) -> Tuple[Any, ...]:
|
|
x = xla.canonicalize_dtype(x)
|
|
try:
|
|
return device_put_handlers[type(x)](x, device)
|
|
except KeyError as err:
|
|
raise TypeError(f"No device_put handler for type: {type(x)}") from err
|
|
|
|
# TODO(phawkins): update users.
|
|
xla.device_put = device_put
|
|
|
|
def _device_put_array(x, device: Optional[Device]):
|
|
backend = xb.get_device_backend(device)
|
|
if x.dtype == dtypes.float0:
|
|
x = np.zeros(x.shape, dtype=np.dtype(bool))
|
|
return (backend.buffer_from_pyval(x, device),)
|
|
|
|
def _device_put_scalar(x, device):
|
|
return _device_put_array(dtypes.coerce_to_array(x), device)
|
|
|
|
def _device_put_token(_, device):
|
|
backend = xb.get_device_backend(device)
|
|
return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype(np.bool_)),
|
|
device),)
|
|
|
|
_scalar_types = dtypes.python_scalar_dtypes.keys()
|
|
|
|
device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]],
|
|
Tuple[Any, ...]]] = {}
|
|
device_put_handlers.update((t, _device_put_array) for t in array_types)
|
|
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)
|
|
device_put_handlers[core.Token] = _device_put_token
|
|
|
|
|
|
def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[Device]):
|
|
x = _copy_device_array_to_device(x, device)
|
|
return (x.device_buffer,)
|
|
for t in device_array.device_array_types:
|
|
device_put_handlers[t] = _device_put_device_array
|
|
|
|
def _device_put_jax_array(x, device: Optional[Device]):
|
|
if is_single_device_sharding(x.sharding):
|
|
x = _copy_device_array_to_device(_set_aval(x._arrays[0]), device)
|
|
return (x,)
|
|
else:
|
|
# Round trip via host if x is sharded. SDA also does a round trip via host.
|
|
return _device_put_array(x._value, device)
|
|
device_put_handlers[array.ArrayImpl] = _device_put_jax_array
|
|
|
|
device_put_handlers[pxla._ShardedDeviceArray] = _device_put_array
|
|
device_put_handlers[pxla.pmap_lib.ShardedDeviceArray] = _device_put_array
|
|
|
|
device_put_handlers[core.DArray] = lambda x, d: device_put(x._data, d)
|
|
|
|
def _copy_device_array_to_device(
|
|
x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray],
|
|
device: Optional[xc.Device]
|
|
) -> Union[device_array.DeviceArrayProtocol, device_array._DeviceArray]:
|
|
if device is None:
|
|
# no copying to be done because there's no target specified
|
|
return x
|
|
elif xb.get_device_backend(device).platform == x.device_buffer.platform():
|
|
# source and target platforms are the same
|
|
if x.device_buffer.device() == device:
|
|
# no copying to be done because source equals target
|
|
if x._device == device:
|
|
return x
|
|
else:
|
|
moved_buf = x.device_buffer # We need to change stickyness
|
|
else:
|
|
# move the buffer with a device-to-device copy
|
|
moved_buf = x.device_buffer.copy_to_device(device)
|
|
else:
|
|
# buffers from different XLA backends are passed through the host.
|
|
backend = xb.get_device_backend(device)
|
|
moved_buf = backend.buffer_from_pyval(np.asarray(x.device_buffer), device)
|
|
return device_array.make_device_array(x.aval, device, moved_buf)
|
|
|
|
|
|
def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Array:
|
|
"""Copies `Array`s with SingleDeviceSharding to a different device."""
|
|
if device is None:
|
|
# no copying to be done because there's no target specified
|
|
return x
|
|
|
|
buf = x._arrays[0]
|
|
if xb.get_device_backend(device).platform == buf.platform():
|
|
# source and target platforms are the same
|
|
if x.device() == device:
|
|
# no copying to be done because source equals target
|
|
if x._committed:
|
|
return x
|
|
else:
|
|
moved_buf = buf # We need to change stickyness
|
|
else:
|
|
# move the buffer with a device-to-device copy
|
|
moved_buf = buf.copy_to_device(device)
|
|
else:
|
|
# buffers from different XLA backends are passed through the host.
|
|
backend = xb.get_device_backend(device)
|
|
moved_buf = backend.buffer_from_pyval(np.asarray(buf), device)
|
|
return array.ArrayImpl(
|
|
x.aval, SingleDeviceSharding(moved_buf.device()), [moved_buf],
|
|
committed=(device is not None))
|
|
|
|
|
|
def _device_put_impl(
|
|
x, device: Optional[Union[Device, jax.sharding.Sharding]] = None):
|
|
from jax.interpreters import pxla
|
|
|
|
try:
|
|
a = xla.abstractify(x)
|
|
except TypeError as err:
|
|
raise TypeError(
|
|
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
|
|
|
|
if isinstance(device, Sharding):
|
|
if not jax.config.jax_array:
|
|
raise RuntimeError(
|
|
"Please enable `jax_array` to use device_put with a `Sharding`. "
|
|
"You can use jax.config.update('jax_array', True) or set JAX_ARRAY=1 "
|
|
"environment variable or set the `jax_array` boolean flag to "
|
|
"something true-like.")
|
|
s = device
|
|
if not s.is_fully_addressable: # type: ignore
|
|
raise ValueError(
|
|
"device_put's second argument must be a Device or a Sharding which "
|
|
f"represents addressable devices, but got {s}")
|
|
if getattr(x, 'sharding', None) == s:
|
|
return x
|
|
# TODO(mattjj,yashkatariya,phawkins): more runtime fast resharding here?
|
|
arg_handler = pxla.shard_arg_handlers[type(x)]
|
|
result_handler = pxla.global_aval_to_result_handler(a, s, True, False)
|
|
map_ = s.devices_indices_map(x.shape) # type: ignore
|
|
return result_handler(arg_handler(x, list(map_), list(map_.values()),
|
|
pxla.InputsHandlerMode.pjit_or_xmap))
|
|
|
|
# Only `Device` exists below. `Sharding` instance is handled above.
|
|
if isinstance(x, array.ArrayImpl):
|
|
if not x.is_fully_addressable:
|
|
raise ValueError(
|
|
"device_put's first argument must be a fully addressable array, but "
|
|
f"got value with devices {x.devices()}")
|
|
if device is None:
|
|
return x
|
|
elif is_single_device_sharding(x.sharding):
|
|
return _copy_array_to_device(x, device)
|
|
|
|
if device_array.type_is_device_array(x):
|
|
return _copy_device_array_to_device(x, device)
|
|
|
|
return aval_to_result_handler(device, a)(None, *device_put(x, device))
|
|
|
|
|
|
device_put_p = core.Primitive('device_put')
|
|
device_put_p.def_impl(_device_put_impl)
|
|
device_put_p.def_abstract_eval(lambda x, device=None: x)
|
|
ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent])
|
|
batching.defvectorized(device_put_p)
|
|
|
|
def _device_put_lowering(ctx, x, *, device):
|
|
return [x]
|
|
|
|
|
|
mlir.register_lowering(device_put_p, _device_put_lowering)
|