# Copyright 2018 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Primitive dispatch and jit dispatch. from __future__ import annotations import atexit import collections import contextlib from functools import partial import itertools import time from typing import ( Any, Callable, Dict, Optional, Sequence, Set, Tuple, List, Type, Union, TYPE_CHECKING) from typing_extensions import Protocol import os import re import threading import warnings from absl import logging import numpy as np import jax from jax import core from jax import linear_util as lu from jax.errors import UnexpectedTracerError import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir import jax.interpreters.xla as xla import jax.interpreters.partial_eval as pe from jax._src import device_array from jax._src import dtypes from jax._src import profiler from jax._src import stages from jax._src import traceback_util from jax._src.abstract_arrays import array_types from jax._src.config import config, flags from jax._src.lib.mlir import ir from jax._src.lib import can_execute_with_token from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc import jax._src.util as util from jax._src.util import flatten, unflatten from etils import epath if TYPE_CHECKING: from jax.experimental.array import Array FLAGS = flags.FLAGS flags.DEFINE_string( 'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''), help="Path to which HLO/MHLO IR that is emitted by JAX as input to the " "compiler should be dumped as text files. Optional. If omitted, JAX " "will not dump IR.") traceback_util.register_exclusion(__file__) MYPY = False # Are we currently type checking with mypy? xe = xc._xla Backend = xe.Client Device = xc.Device Buffer = xe.Buffer XlaExecutable = xc.Executable map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip # This flag is set on exit; no logging should be attempted _on_exit = False ### op-by-op execution ArgSpec = Tuple[core.AbstractValue, Optional[Device]] def arg_spec(x: Any) -> ArgSpec: from jax.experimental.sharding import PmapSharding aval = xla.abstractify(x) try: if config.jax_array: if isinstance(x.sharding, PmapSharding): return aval, None return aval, (x.sharding if x._committed else None) else: return aval, x._device except: return aval, None def apply_primitive(prim, *args, **params): """Impl rule that compiles and runs a single primitive 'prim' using XLA.""" compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params) return compiled_fun(*args) # TODO(phawkins,frostig,mattjj): update code referring to # xla.apply_primitive to point here, or use simple_impl if that's why # it is using apply_primitive to begin with xla.apply_primitive = apply_primitive def simple_impl(prim): prim.def_impl(partial(apply_primitive, prim)) RuntimeToken = Any class RuntimeTokenSet(threading.local): tokens: Dict[core.Effect, Tuple[RuntimeToken, Device]] output_tokens: Dict[Device, RuntimeToken] output_runtime_tokens: Dict[Device, RuntimeToken] def __init__(self): self.tokens = {} # TODO(sharadmv): remove redundant output token dictionary when minimum # jaxlib version is bumped to 0.3.16. self.output_tokens = {} self.output_runtime_tokens = {} def get_token(self, eff: core.Effect, device: Device) -> RuntimeToken: if eff not in self.tokens: self.tokens[eff] = device_put(np.zeros(0, np.bool_), device), device elif self.tokens[eff][1] != device: (old_token,), _ = self.tokens[eff] old_token.aval = core.ShapedArray((0,), np.bool_) self.tokens[eff] = device_put(old_token, device), device return self.tokens[eff][0] def update_token(self, eff: core.Effect, token: RuntimeToken): self.tokens[eff] = token, self.tokens[eff][1] def set_output_token(self, device: Device, token: RuntimeToken): # We're free to clobber the previous output token because on each # device we have a total ordering of computations. Only the token # from the latest computation matters. If this weren't the case # we'd need to store a set of output tokens. self.output_tokens[device] = token def set_output_runtime_token(self, device: Device, token: RuntimeToken): # TODO(sharadmv): remove this method when minimum jaxlib version is bumped self.output_runtime_tokens[device] = token def clear(self): self.tokens = {} self.output_tokens = {} self.output_runtime_tokens = {} def block_until_ready(self): for token, _ in self.tokens.values(): token[0].block_until_ready() for token in self.output_tokens.values(): token[0].block_until_ready() for token in self.output_runtime_tokens.values(): token.block_until_ready() self.clear() runtime_tokens: RuntimeTokenSet = RuntimeTokenSet() @atexit.register def wait_for_tokens(): runtime_tokens.block_until_ready() @util.cache() def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params): _, arg_devices = util.unzip2(arg_specs) donated_invars = (False,) * len(arg_specs) if config.jax_array: # This will be resolved in sharded_lowering. device = None else: device = _device_from_arg_devices(arg_devices) def prim_fun(*args): out = prim.bind(*args, **params) if prim.multiple_results: return out else: return out, compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None, prim.name, donated_invars, False, *arg_specs) if not prim.multiple_results: return lambda *args, **kw: compiled(*args, **kw)[0] else: return compiled def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[Device]: """Given devices of inputs, determine where to perform a computation. Args: devices: list where each element is a either a `Device` instance or `None`. Returns: A `Device` instance or None. Raises: ValueError if input devices are inconsistent. """ try: device, = {d for d in devices if d is not None} or (None,) return device except ValueError as err: msg = "primitive arguments must be colocated on the same device, got {}" raise ValueError(msg.format(", ".join(map(str, devices)))) from err # JIT execution def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars, inline, keep_unused: bool): del inline # Only used at tracing time if fun.in_type is None: arg_specs = unsafe_map(arg_spec, args) else: # fun.in_type is used for dynamic shapes. if config.jax_array: raise NotImplementedError('Dynamic shapes do not work with Array.') arg_specs = [(None, getattr(x, '_device', None)) for x in args] compiled_fun = xla_callable(fun, device, backend, name, donated_invars, keep_unused, *arg_specs) try: return compiled_fun(*args) except FloatingPointError: assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case print("Invalid value encountered in the output of a jit-decorated function. " "Calling the de-optimized version.") # We want to run the wrapped function again (after xla_callable already ran # it), but linear_util.WrappedFun instances are meant to be run only once. # In addition to re-executing the Python code, which is usually undesirable # but which config.jax_debug_nans is meant to opt into, we'll be # re-executing any linear_util.py-style side effects, i.e. re-populating # Stores created by any transformation_with_aux's applied to fun. Since this # is intentional here, to avoid "Store occupied" errors we clone the # WrappedFun with empty stores. stores = [lu.Store() for _ in fun.stores] clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params, fun.in_type) with core.new_sublevel(): _ = clone.call_wrapped(*args) # may raise, not return # If control reaches this line, we got a NaN on the output of `compiled_fun` # but not `clone.call_wrapped` on the same arguments. Let's tell the user. fun_info = pe.fun_sourceinfo(fun.f) msg = ("An invalid value was encountered in the output of the " f"`jit`-decorated function {fun_info}. Because " "config.jax_debug_nans and/or config.jax_debug_infs is set, the " "de-optimized function (i.e., the function as if the `jit` " "decorator were removed) was called in an attempt to get a more " "precise error message. However, the de-optimized function did not " "produce invalid values during its execution. This behavior can " "result from `jit` optimizations causing the invalud value to be " "produced. It may also arise from having nan/inf constants as " "outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. " "\n\n" "It may be possible to avoid the invalid value by removing the " "`jit` decorator, at the cost of losing optimizations. " "\n\n" "If you see this error, consider opening a bug report at " "https://github.com/google/jax.") raise FloatingPointError(msg) xla.xla_call_p.def_impl(_xla_call_impl) # TODO(yashkatariya,mattjj): Try to handle this in api.py via a device_put and # don't pass the device and backend argument to `_xla_callable_uncached`. def not_none_device_or_backend_on_jit(backend, device, num_ins): """This is to support the backend and device argument on jit. It's a feature that's deprecated but needs to be supported for feature parity and so that we can delete the non-Array paths when Array is switched on. """ # TODO(yashkatariya): Remove this entire function when backend and device are # removed as arguments on jit. from jax.experimental import sharding if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " "got device={} and backend={}".format(device, backend)) if backend is not None: da = [xb.get_backend(backend).get_default_device_assignment(1)[0]] else: assert device is not None da = [device] assert len(da) == 1 # Set committed to True for this path because it simulates a device_put on # behalf of a user. committed = True # in_shardings will be marked as replicated regardless of whatever the input # had. Given that only a single device is allowed above, this is correct. in_shardings = [sharding.OpShardingSharding.get_replicated(da)] * num_ins return committed, da, in_shardings def sharded_lowering(fun, device, backend, name, donated_invars, always_lower, keep_unused, *arg_specs): # TODO(yashkatariya): Remove the local imports from here when the functions # in pxla.py move to dispatch.py or a utils file. from jax.interpreters import pxla from jax.experimental import pjit, sharding in_avals, in_shardings = util.unzip2(arg_specs) if backend is not None or device is not None: committed, da, in_shardings = not_none_device_or_backend_on_jit( backend, device, len(in_shardings)) else: committed = any(i is not None for i in in_shardings) da = pjit._get_and_check_device_assignment( (i for i in in_shardings if i is not None), pxla.EMPTY_ENV.physical_mesh) in_shardings = [sharding.OpShardingSharding.get_replicated(da) if i is None else i for i in in_shardings] process_index = xb.process_index() local_da = [d for d in da if d.process_index == process_index] if len(local_da) != len(da): warnings.warn( "Running operations on `Array`s that are not fully addressable by this " "process (i.e. `Array`s with data sharded across multiple devices and " "processes.) is dangerous. It’s very important that all processes run " "the same cross-process computations in the same order otherwise it " "can lead to hangs.\n" "If you’re not already familiar with JAX’s multi-process " "programming model, please read " "https://jax.readthedocs.io/en/latest/multi_process.html.") if not in_shardings: inp_device_assignment = da else: inp_device_assignment = None # Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know # the number of output avals at this stage. lower_sharding_computation will # apply it to all out_avals. return pxla.lower_sharding_computation( fun, 'jit', name, in_shardings, pjit._UNSPECIFIED, donated_invars, in_avals, in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused, committed=committed, always_lower=always_lower, inp_device_assignment=inp_device_assignment) def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name, donated_invars, keep_unused, *arg_specs): if config.jax_array: computation = sharded_lowering(fun, device, backend, name, donated_invars, False, keep_unused, *arg_specs) return computation.compile(_allow_propagation_to_outputs=True).unsafe_call else: return lower_xla_callable(fun, device, backend, name, donated_invars, False, keep_unused, *arg_specs).compile().unsafe_call xla_callable = lu.cache(_xla_callable_uncached) def is_single_device_sharding(sharding) -> bool: from jax.experimental.sharding import PmapSharding # Special case PmapSharding here because PmapSharding maps away an axis # and needs to be handled separately. return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding) @contextlib.contextmanager def log_elapsed_time(fmt: str): if _on_exit: yield else: log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG start_time = time.time() yield elapsed_time = time.time() - start_time logging.log(log_priority, fmt.format(elapsed_time=elapsed_time)) def should_tuple_args(num_args: int, platform: str): # CPU does not need a tuple as it uses a buffer table # TPU only needs a tuple for very long lists if platform == "cpu": return False elif platform == "tpu": return num_args > 2000 else: return num_args > 100 def raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr): if nreps > 1: warnings.warn( f"The jitted function {name} includes a pmap. Using " "jit-of-pmap can lead to inefficient data movement, as the outer jit " "does not preserve sharded data representations and instead collects " "input and output arrays onto a single device. " "Consider removing the outer jit unless you know what you're doing. " "See https://github.com/google/jax/issues/2926.") if nreps > xb.device_count(backend): raise ValueError( f"compiling computation `{name}` that requires {nreps} replicas, but " f"only {xb.device_count(backend)} XLA devices are available.") if xb.process_count() > 1 and (nreps > 1 or jaxpr_has_primitive(jaxpr, "xla_pmap")): raise NotImplementedError( "jit of multi-host pmap not implemented (and jit-of-pmap can cause " "extra data movement anyway, so maybe you don't want it after all).") @profiler.annotate_function def lower_xla_callable( fun: lu.WrappedFun, device, backend, name, donated_invars, always_lower: bool, keep_unused: bool, *arg_specs): """Lower into XLA. Args: always_lower: If `True`, even trivial programs (not doing any computation such as lambda x: x) will be lowered into an XLA program. keep_unused: If `False` (the default), arguments that JAX determines to be unused by `fun` *may* be dropped from resulting compiled XLA executables. Such arguments will not be transferred to the device nor provided to the underlying executable. If `True`, unused arguments will not be pruned. """ if device is not None and backend is not None: raise ValueError("can't specify both a device and a backend for jit, " "got device={} and backend={}".format(device, backend)) abstract_args, arg_devices = util.unzip2(arg_specs) if fun.in_type is None: # Add an annotation inferred from the arguments; no dynamic axes here. in_type = tuple(unsafe_zip(abstract_args, itertools.repeat(True))) fun = lu.annotate(fun, in_type) else: assert abstract_args == (None,) * len(abstract_args) abstract_args = [aval for aval, _ in fun.in_type] with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " "for jit in {elapsed_time} sec"): jaxpr, out_type, consts = pe.trace_to_jaxpr_final2( fun, pe.debug_info_final(fun, "jit")) out_avals, kept_outputs = util.unzip2(out_type) if any(isinstance(c, core.Tracer) for c in consts): raise UnexpectedTracerError("Encountered an unexpected tracer.") if config.jax_dynamic_shapes: keep_unused = True has_outfeed = False donated_invars = [False] * len(fun.in_type) else: has_outfeed = core.jaxpr_uses_outfeed(jaxpr) jaxpr = apply_outfeed_rewriter(jaxpr) if not keep_unused: jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr) consts = [c for i, c in enumerate(consts) if i in kept_const_idx] abstract_args, arg_devices = util.unzip2( [a for i, a in enumerate(arg_specs) if i in kept_var_idx]) donated_invars = [x for i, x in enumerate(donated_invars) if i in kept_var_idx] del kept_const_idx else: kept_var_idx = set(range(len(fun.in_type))) nreps = jaxpr_replicas(jaxpr) device = _xla_callable_device(nreps, backend, device, arg_devices) backend = xb.get_device_backend(device) if device else xb.get_backend(backend) if config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr): jaxpr, consts = pe.pad_jaxpr(jaxpr, consts) map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr))) # Computations that only produce constants and/or only rearrange their inputs, # which are often produced from partial evaluation, don't need compilation, # and don't need to evaluate their arguments. if (not always_lower and not (jaxpr.effects or has_outfeed) and (not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars)): return XlaComputation( name, None, True, None, None, None, jaxpr=jaxpr, consts=consts, device=device, in_avals=abstract_args, out_avals=out_avals, has_unordered_effects=False, ordered_effects=[], kept_var_idx=kept_var_idx, keepalive=None, host_callbacks=[]) if not _on_exit: log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG if len(abstract_args) > 10: msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args." else: msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}." logging.log(log_priority, msg) raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr) # pass long arg lists as tuple for TPU tuple_args = should_tuple_args(len(abstract_args), backend.platform) axis_env = xla.AxisEnv(nreps, (), ()) name_stack = util.new_name_stack(util.wrap_name(name, 'jit')) closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) module_name = f"jit_{fun.__name__}" unordered_effects = [eff for eff in closed_jaxpr.effects if eff not in core.ordered_effects] ordered_effects = [eff for eff in closed_jaxpr.effects if eff in core.ordered_effects] lowering_result = mlir.lower_jaxpr_to_module( module_name, closed_jaxpr, unordered_effects, ordered_effects, backend, backend.platform, mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars) module, keepalive, host_callbacks = ( lowering_result.module, lowering_result.keepalive, lowering_result.host_callbacks) return XlaComputation( name, module, False, donated_invars, fun.in_type, out_type, nreps=nreps, device=device, backend=backend, tuple_args=tuple_args, in_avals=abstract_args, out_avals=out_avals, has_unordered_effects=bool(unordered_effects), ordered_effects=ordered_effects, kept_var_idx=kept_var_idx, keepalive=keepalive, host_callbacks=host_callbacks) def _backend_supports_unbounded_dynamic_shapes(backend: Backend) -> bool: return backend.platform == 'iree' def prefetch(x): if isinstance(x, device_array.DeviceArray): x.copy_to_host_async() return x def jaxpr_literals(jaxpr): """Generates all the literals inside a jaxpr, including nested subjaxprs.""" for eqn in jaxpr.eqns: for v in eqn.invars: if type(v) is core.Literal: yield v.val for subjaxpr in core.subjaxprs(jaxpr): yield from jaxpr_literals(subjaxpr) def jaxpr_has_primitive(jaxpr, prim_name: str): """Whether there is a primitive given by user anywhere inside a Jaxpr.""" for eqn in jaxpr.eqns: if prim_name in eqn.primitive.name: return True for subjaxpr in core.subjaxprs(jaxpr): if jaxpr_has_primitive(subjaxpr, prim_name): return True return False def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool: return (any(type(v.aval) is core.AbstractBInt for v in jaxpr.invars) or any(type(v.aval) is core.AbstractBInt for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr)) for e in j.eqns for v in e.outvars)) def _prune_unused_inputs( jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]: used_outputs = [True] * len(jaxpr.outvars) new_jaxpr, used_consts, used_inputs = pe.dce_jaxpr_consts(jaxpr, used_outputs) kept_const_idx = {i for i, b in enumerate(used_consts) if b} kept_var_idx = {i for i, b in enumerate(used_inputs) if b} return new_jaxpr, kept_const_idx, kept_var_idx # We can optionally set a Jaxpr rewriter that can be applied just before # compilation. This mechanism is used for compiling id_tap, we can # remove it once we bring the id_tap implementation into the core. outfeed_rewriter: Optional[Callable[[core.Jaxpr], core.Jaxpr]] = None def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr: if outfeed_rewriter is not None: return outfeed_rewriter(jaxpr) else: return jaxpr def jaxpr_replicas(jaxpr) -> int: """The number of replicas needed for a jaxpr. For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the subjaxprs. For a list of eqns, take the maximum number of replicas. """ if isinstance(jaxpr, core.ClosedJaxpr): jaxpr = jaxpr.jaxpr return max(unsafe_map(eqn_replicas, jaxpr.eqns), default=1) # TODO(mattjj): this function assumes that only pmap has a parameter named # axis_size, and that it corresponds to cross-replica mapping def eqn_replicas(eqn): call_jaxpr = eqn.params.get("call_jaxpr") if call_jaxpr: return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr) elif eqn.primitive in xla._initial_style_primitives: return initial_style_primitive_replicas(eqn.params) else: return 1 def initial_style_primitive_replicas(params): return max(core.traverse_jaxpr_params(jaxpr_replicas, params).values(), default=1) def _xla_callable_device(nreps, backend, device, arg_devices) -> Optional[Device]: if nreps > 1: if device is not None or backend is not None: raise ValueError(f"can't specify device or backend for jit-of-pmap, " f"got device={device} and backend={backend}") return None else: # TODO(skye): dedup with C++ jit logic for determining jit device? if device is not None: assert backend is None return device if backend is not None: return xb.get_backend(backend).get_default_device_assignment(1)[0] arg_device = _device_from_arg_devices(arg_devices) if arg_device is not None: return arg_device return config.jax_default_device # Argument and result handlers num_buffers_handlers: Dict[Type[core.AbstractValue], Callable[[core.AbstractValue], int]] = {} def aval_to_num_buffers(aval: core.AbstractValue) -> int: """Returns the number of buffers in the runtime representation of `aval`. In general this may differ from the number of buffers in the compiler-IR representation of the same value. """ try: return num_buffers_handlers[type(aval)](aval) except KeyError as err: raise TypeError(f"No num_buffers handler for type: {type(aval)}") from err num_buffers_handlers[core.AbstractToken] = lambda _: 1 num_buffers_handlers[core.ShapedArray] = lambda _: 1 num_buffers_handlers[core.DShapedArray] = lambda _: 1 num_buffers_handlers[core.ConcreteArray] = lambda _: 1 num_buffers_handlers[core.AbstractBInt] = lambda _: 1 def _input_handler(backend: Backend, in_type: Optional[pe.InputType], out_type: Optional[pe.OutputType], ) -> Optional[Callable]: if in_type is None: assert out_type is None return None in_avals, which_explicit = util.unzip2(in_type) # Check whether we actually need an input_handler. needs_implicit = which_explicit and not all(which_explicit) needs_out_handling = any(type(d) is core.InDBIdx for a, _ in out_type or [] if type(a) is core.DShapedArray for d in a.shape) if not needs_implicit and not needs_out_handling: return None assert config.jax_dynamic_shapes # Precompute how to grab implicit inputs from explicit inputs' axis sizes. which_explicit = which_explicit or [True] * len(in_avals) implicit_idxs = {i for i, ex in enumerate(which_explicit) if not ex} implicit_args_from_axes: List[Tuple[int, int, int]] = [] for arg_idx, aval in enumerate(in_avals): if isinstance(aval, core.DShapedArray): for axis_idx, d in enumerate(aval.shape): if isinstance(d, core.DBIdx) and d.val in implicit_idxs: implicit_args_from_axes.append((d.val, arg_idx, axis_idx)) assert {i for i, _, _ in implicit_args_from_axes} == implicit_idxs # Precompute which input values are needed for output types. inputs_needed_for_out_types = out_type and [ d.val for aval, _ in out_type if type(aval) is core.DShapedArray # type: ignore for d in aval.shape if type(d) is core.InDBIdx] def elaborate(explicit_args: Sequence[Any]) -> Tuple[Tuple, Optional[Tuple]]: if needs_implicit: # Build full argument list, leaving Nones for implicit arguments. explicit_args_ = iter(explicit_args) args = [next(explicit_args_) if ex else None for ex in which_explicit] assert next(explicit_args_, None) is None # Populate implicit arguments. for i, j, k in implicit_args_from_axes: if args[i] is None: args[i] = args[j].shape[k] # type: ignore else: if args[i] != args[j].shape[k]: raise Exception("inconsistent argument axis sizes for type") else: args = list(explicit_args) if needs_out_handling: # Make a list of inputs needed by output types, leaving unneeded as None. out_type_env = [None] * len(args) for i in inputs_needed_for_out_types or []: out_type_env[i] = args[i] else: out_type_env = None # type: ignore return tuple(args), out_type_env and tuple(out_type_env) # type: ignore return elaborate def _result_handler(backend: Backend, sticky_device: Optional[Device], out_type: Optional[pe.OutputType] ) -> Callable: out_avals, kept_outputs = util.unzip2(out_type) handlers = map(partial(aval_to_result_handler, sticky_device), out_avals) dyn_outs = any(type(aval) is core.DShapedArray and any(type(d) in (core.InDBIdx, core.OutDBIdx) for d in aval.shape) for aval in out_avals) if not dyn_outs: return SimpleResultHandler(handlers) assert config.jax_dynamic_shapes def result_handler(input_env, lists_of_bufs): results = [] for handler, bufs in unsafe_zip(handlers, lists_of_bufs): results.append(handler((input_env, results), *bufs)) return [r for r, keep in unsafe_zip(results, kept_outputs) if keep] return result_handler class SimpleResultHandler: handlers: Sequence[ResultHandler] def __init__(self, handlers): self.handlers = handlers def __iter__(self): return iter(self.handlers) def __len__(self): return len(self.handlers) def __call__(self, env, lists_of_bufs): return tuple(h(env, *bs) for h, bs in zip(self.handlers, lists_of_bufs)) def maybe_create_array_from_da(buf, aval, device): if config.jax_array: from jax.experimental.array import Array from jax.experimental.sharding import SingleDeviceSharding return Array(aval, SingleDeviceSharding(buf.device()), [buf], committed=(device is not None), _skip_checks=True) else: return device_array.make_device_array(aval, device, buf) if MYPY: ResultHandler = Any else: class ResultHandler(Protocol): def __call__(self, env: Optional[Sequence[Any]], *args: xla.Buffer) -> Any: """Boxes raw buffers into their user-facing representation.""" def aval_to_result_handler(sticky_device: Optional[Device], aval: core.AbstractValue) -> ResultHandler: try: return result_handlers[type(aval)](sticky_device, aval) except KeyError as err: raise TypeError(f"No result handler for type: {type(aval)}") from err def array_result_handler(sticky_device: Optional[Device], aval: core.ShapedArray): if aval.dtype == dtypes.float0: return lambda _, __: np.zeros(aval.shape, dtypes.float0) aval = core.raise_to_shaped(aval) if core.is_opaque_dtype(aval.dtype): return aval.dtype._rules.result_handler(sticky_device, aval) handler = lambda _, b: maybe_create_array_from_da(b, aval, sticky_device) handler.args = aval, sticky_device # for C++ dispatch path in api.py return handler def dynamic_array_result_handler(sticky_device: Optional[Device], aval: core.DShapedArray): if aval.dtype == dtypes.float0: return lambda _: np.zeros(aval.shape, dtypes.float0) # type: ignore else: return partial(_dynamic_array_result_handler, sticky_device, aval) def _dynamic_array_result_handler(sticky_device, aval, env, buf): in_env, out_env = env or (None, None) shape = [in_env[d.val] if type(d) is core.InDBIdx else out_env[d.val] if type(d) is core.OutDBIdx else d for d in aval.shape] if all(type(d) is int for d in shape): aval = core.ShapedArray(tuple(shape), aval.dtype) return maybe_create_array_from_da(buf, aval, sticky_device) elif any(type(d) is core.BInt for d in shape): padded_shape = [d.bound if type(d) is core.BInt else d for d in shape] buf_aval = core.ShapedArray(tuple(padded_shape), aval.dtype, aval.weak_type) data = maybe_create_array_from_da(buf, buf_aval, sticky_device) return core.PaddedArray(aval.update(shape=tuple(shape)), data) else: aval = core.ShapedArray(tuple(shape), aval.dtype) return maybe_create_array_from_da(buf, aval, sticky_device) result_handlers: Dict[ Type[core.AbstractValue], Callable[[Optional[Device], Any], ResultHandler]] = {} result_handlers[core.AbstractToken] = lambda _, __: lambda _, __: core.token result_handlers[core.ShapedArray] = array_result_handler result_handlers[core.DShapedArray] = dynamic_array_result_handler result_handlers[core.ConcreteArray] = array_result_handler result_handlers[core.AbstractBInt] = \ lambda _, a: lambda _, b: core.BInt(int(b), a.bound) def needs_check_special(): return config.jax_debug_infs or config.jax_debug_nans def check_special(name, bufs): if needs_check_special(): for buf in bufs: _check_special(name, buf.xla_shape(), buf) def _check_special(name, xla_shape, buf): assert not xla_shape.is_tuple() if dtypes.issubdtype(xla_shape.element_type(), np.inexact): if config.jax_debug_nans and np.any(np.isnan(np.asarray(buf))): raise FloatingPointError(f"invalid value (nan) encountered in {name}") if config.jax_debug_infs and np.any(np.isinf(np.asarray(buf))): raise FloatingPointError(f"invalid value (inf) encountered in {name}") def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect], has_host_callbacks: bool, device: Device, input_bufs): tokens = [runtime_tokens.get_token(eff, device) for eff in ordered_effects] tokens_flat = flatten(tokens) input_bufs = [*tokens_flat, *input_bufs] def _remove_tokens(output_bufs, runtime_token): # TODO(sharadmv): simplify when minimum jaxlib version is bumped num_output_tokens = len(ordered_effects) + (not can_execute_with_token and has_unordered_effects) token_bufs, output_bufs = util.split_list(output_bufs, [num_output_tokens]) if has_unordered_effects or has_host_callbacks: if can_execute_with_token: runtime_tokens.set_output_runtime_token(device, runtime_token) else: output_token_buf, *token_bufs = token_bufs runtime_tokens.set_output_token(device, output_token_buf) for eff, token_buf in zip(ordered_effects, token_bufs): runtime_tokens.update_token(eff, token_buf) return output_bufs return input_bufs, _remove_tokens def _execute_compiled(name: str, compiled: XlaExecutable, input_handler: Optional[Callable], output_buffer_counts: Sequence[int], result_handler: Callable, has_unordered_effects: bool, ordered_effects: List[core.Effect], kept_var_idx, has_host_callbacks: bool, *args): device, = compiled.local_devices() args, env = input_handler(args) if input_handler else (args, None) in_flat = flatten(device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx) if has_unordered_effects or ordered_effects or has_host_callbacks: in_flat, token_handler = _add_tokens( has_unordered_effects, ordered_effects, has_host_callbacks, device, in_flat) if can_execute_with_token: out_flat, runtime_token = compiled.execute_with_token(in_flat) else: out_flat = compiled.execute(in_flat) runtime_token = None else: out_flat = compiled.execute(in_flat) check_special(name, out_flat) out_bufs = unflatten(out_flat, output_buffer_counts) if ordered_effects or has_unordered_effects or has_host_callbacks: out_bufs = token_handler(out_bufs, runtime_token) return result_handler(env, out_bufs) def _execute_replicated(name: str, compiled: XlaExecutable, input_handler: Optional[Callable], output_buffer_counts: Sequence[int], result_handler: Callable, has_unordered_effects: bool, ordered_effects: List[core.Effect], kept_var_idx, has_host_callbacks: bool, *args, from_lower_sharding_computation: bool = False): if has_unordered_effects or ordered_effects: # TODO(sharadmv): support jit-of-pmap with effects raise NotImplementedError( "Cannot execute replicated computation with effects.") if input_handler: raise NotImplementedError # TODO(mattjj, dougalm) input_bufs = [flatten(device_put(x, device) for i, x in enumerate(args) if i in kept_var_idx) for device in compiled.local_devices()] input_bufs_flip = list(unsafe_zip(*input_bufs)) out_bufs_flat_rep = compiled.execute_sharded_on_local_devices(input_bufs_flip) out_flat = [bufs[0] for bufs in out_bufs_flat_rep] check_special(name, out_flat) out_bufs = unflatten(out_flat, output_buffer_counts) if from_lower_sharding_computation: return result_handler(out_bufs) return result_handler(None, out_bufs) def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers, has_unordered_effects: bool, ordered_effects: List[core.Effect], kept_var_idx, host_callbacks, *args): env: Dict[core.Var, Any] = {} pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx) map(env.setdefault, jaxpr.invars, pruned_args) map(env.setdefault, jaxpr.constvars, consts) outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v] for v in jaxpr.outvars] return [_copy_device_array_to_device(x, device) if device_array.type_is_device_array(x) else h(None, *device_put(x, device)) for h, x in zip(handlers, outs)] class XlaComputation(stages.XlaLowering): name: str _is_trivial: bool _executable: Optional[XlaCompiledComputation] _donated_invars: Optional[Sequence[bool]] def __init__(self, name: str, hlo, is_trivial: bool, donated_invars: Optional[Sequence[bool]], in_type: Optional[pe.InputType], out_type: Optional[pe.OutputType], **compile_args): self.name = name self._hlo = hlo self._is_trivial = is_trivial self._donated_invars = donated_invars self._in_type = in_type self._out_type = out_type self._executable = None self.compile_args = compile_args def is_trivial(self): return self._is_trivial # -- stages.XlaLowering overrides def hlo(self) -> xc.XlaComputation: if self.is_trivial(): raise ValueError("A trivial computation has no HLO") if isinstance(self._hlo, xc.XlaComputation): return self._hlo return xe.mlir.mlir_module_to_xla_computation( mlir.module_to_string(self._hlo), use_tuple_args=self.compile_args["tuple_args"]) def mhlo(self) -> ir.Module: if self.is_trivial(): raise ValueError("A trivial computation has no MHLO") if isinstance(self._hlo, xc.XlaComputation): module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo) with mlir.make_ir_context(): return ir.Module.parse(module_str) return self._hlo def compile(self) -> XlaCompiledComputation: if self._executable is None: if self.is_trivial(): self._executable = XlaCompiledComputation.from_trivial_jaxpr( **self.compile_args) else: self._executable = XlaCompiledComputation.from_xla_computation( self.name, self._hlo, self._in_type, self._out_type, **self.compile_args) return self._executable @profiler.annotate_function def backend_compile(backend, built_c, options, host_callbacks): # we use a separate function call to ensure that XLA compilation appears # separately in Python profiling results if host_callbacks: return backend.compile(built_c, compile_options=options, host_callbacks=host_callbacks) # Some backends don't have `host_callbacks` option yet # TODO(sharadmv): remove this fallback when all backends allow `compile` # to take in `host_callbacks` return backend.compile(built_c, compile_options=options) # TODO(phawkins): update users. xla.backend_compile = backend_compile _ir_dump_counter = itertools.count() def _make_string_safe_for_filename(s: str) -> str: return re.sub(r'[^\w.)( -]', '', s) def _dump_ir_to_file(name: str, ir: str): id = next(_ir_dump_counter) name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir" name = epath.Path(FLAGS.jax_dump_ir_to) / name name.write_text(ir) def compile_or_get_cached(backend, computation: ir.Module, compile_options, host_callbacks): # Avoid import cycle between jax and jax.experimental from jax.experimental.compilation_cache import compilation_cache as cc sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value # Convert ir.Module to a string representation, unless the # back-end expliclity flags the ability to handle a module directly # (avoiding the overhead of back and forth conversions) serialized_computation: Union[str, bytes, ir.Module] if getattr(backend, "needs_str_ir", True): if xc.mlir_api_version >= 34: serialized_computation = mlir.module_to_bytecode(computation) else: serialized_computation = mlir.module_to_string(computation) else: serialized_computation = computation # Persistent compilation cache only implemented on TPU. # TODO(skye): add warning when initializing cache on unsupported default platform supported_platforms = ["tpu"] # GPU caching can be enabled if JitRt is enabled. # TODO(b/232263664): Remove check when JitRt is enabled by default. if "--xla_gpu_enable_xla_runtime_executable=true" in os.environ.get("XLA_FLAGS", ""): supported_platforms.append("gpu") if cc.is_initialized() and backend.platform in supported_platforms: cached_executable = cc.get_executable(serialized_computation, compile_options, backend) if cached_executable is not None: logging.info('Persistent compilation cache hit for %s.', module_name) return cached_executable else: compiled = backend_compile(backend, serialized_computation, compile_options, host_callbacks) cc.put_executable(module_name, serialized_computation, compile_options, compiled, backend) return compiled if FLAGS.jax_dump_ir_to: _dump_ir_to_file(module_name, mlir.module_to_string(computation)) return backend_compile(backend, serialized_computation, compile_options, host_callbacks) def get_buffer_counts(out_avals, ordered_effects, has_unordered_effects): buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals] if ordered_effects or has_unordered_effects: num_output_tokens = len(ordered_effects) # TODO(sharadmv): remove check when minimum jaxlib version is bumped if not can_execute_with_token: num_output_tokens += has_unordered_effects buffer_counts = ([1] * num_output_tokens) + buffer_counts return buffer_counts class XlaCompiledComputation(stages.XlaExecutable): def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call, keepalive: Any): self._xla_executable = xla_executable self.in_avals = in_avals self._kept_var_idx = kept_var_idx self.unsafe_call = unsafe_call # Only the `unsafe_call` function is cached, so to avoid the `keepalive` # being garbage collected we attach it to `unsafe_call`. self.unsafe_call.keepalive = keepalive @staticmethod def from_xla_computation(name: str, xla_computation: Optional[ir.Module], in_type: Optional[pe.InputType], out_type: Optional[pe.OutputType], nreps: int, device: Optional[Device], backend: Backend, tuple_args: bool, in_avals: Sequence[core.AbstractValue], out_avals: Sequence[core.AbstractValue], has_unordered_effects: bool, ordered_effects: List[core.Effect], kept_var_idx: Set[int], keepalive: Optional[Any], host_callbacks: List[Any]) -> XlaCompiledComputation: sticky_device = device input_handler = _input_handler(backend, in_type, out_type) result_handler = _result_handler(backend, sticky_device, out_type) options = xb.get_compile_options( num_replicas=nreps, num_partitions=1, device_assignment=(sticky_device,) if sticky_device else None) options.parameter_is_tupled_arguments = tuple_args with log_elapsed_time(f"Finished XLA compilation of {name} " "in {elapsed_time} sec"): compiled = compile_or_get_cached(backend, xla_computation, options, host_callbacks) buffer_counts = get_buffer_counts(out_avals, ordered_effects, has_unordered_effects) execute = _execute_compiled if nreps == 1 else _execute_replicated unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts, # type: ignore # noqa: F811 result_handler, has_unordered_effects, ordered_effects, kept_var_idx, bool(host_callbacks)) return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call, keepalive) def is_trivial(self): return self._xla_executable == None @property def xla_executable(self): # TODO(frostig): remove in favor of runtime_executable? if self.is_trivial(): raise ValueError("A trivial compiled computation has no XLA executable") return self._xla_executable @staticmethod def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals, has_unordered_effects, ordered_effects, kept_var_idx, keepalive: Optional[Any], host_callbacks: List[Any]) -> XlaCompiledComputation: assert keepalive is None result_handlers = map(partial(aval_to_result_handler, device), out_avals) unsafe_call = partial(_execute_trivial, jaxpr, device, consts, out_avals, result_handlers, has_unordered_effects, ordered_effects, kept_var_idx, bool(host_callbacks)) return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call, keepalive) # -- stages.XlaExecutable overrides def xla_extension_executable(self): return self.xla_executable def call(self, *args): arg_specs = unsafe_map(arg_spec, args) arg_avals = [spec[0] for i, spec in enumerate(arg_specs) if i in self._kept_var_idx] check_arg_avals_for_call(self.in_avals, arg_avals) return self.unsafe_call(*args) def check_arg_avals_for_call(ref_avals, arg_avals): if len(ref_avals) != len(arg_avals): raise TypeError( f"Computation compiled for {len(ref_avals)} inputs " f"but called with {len(arg_avals)}") for ref_aval, arg_aval in zip(ref_avals, arg_avals): if not core.typematch(ref_aval, arg_aval): ref_avals_fmt = ', '.join(str(a) for a in ref_avals) arg_avals_fmt = ', '.join(str(a) for a in arg_avals) raise TypeError( f"Computation compiled for input types:\n {ref_avals_fmt}\n" f"called with:\n {arg_avals_fmt}") def device_put(x, device: Optional[Device] = None) -> Tuple[Any, ...]: x = xla.canonicalize_dtype(x) try: return device_put_handlers[type(x)](x, device) except KeyError as err: raise TypeError(f"No device_put handler for type: {type(x)}") from err # TODO(phawkins): update users. xla.device_put = device_put def _device_put_array(x, device: Optional[Device]): backend = xb.get_device_backend(device) if x.dtype == dtypes.float0: x = np.zeros(x.shape, dtype=np.dtype(bool)) return (backend.buffer_from_pyval(x, device),) def _device_put_scalar(x, device): return _device_put_array(dtypes.coerce_to_array(x), device) def _device_put_token(_, device): backend = xb.get_device_backend(device) return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype(np.bool_)), device),) _scalar_types = dtypes.python_scalar_dtypes.keys() device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]], Tuple[Any, ...]]] = {} device_put_handlers.update((t, _device_put_array) for t in array_types) device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types) device_put_handlers[core.Token] = _device_put_token device_put_handlers[core.BInt] = lambda x, d: _device_put_scalar(x.val, d) def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[Device]): x = _copy_device_array_to_device(x, device) return (x.device_buffer,) for t in device_array.device_array_types: device_put_handlers[t] = _device_put_device_array device_put_handlers[core.PaddedArray] = lambda x, d: device_put(x._data, d) def _copy_device_array_to_device( x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[xc.Device] ) -> Union[device_array.DeviceArrayProtocol, device_array._DeviceArray]: if device is None: # no copying to be done because there's no target specified return x elif xb.get_device_backend(device).platform == x.device_buffer.platform(): # source and target platforms are the same if x.device_buffer.device() == device: # no copying to be done because source equals target if x._device == device: return x else: moved_buf = x.device_buffer # We need to change stickyness else: # move the buffer with a device-to-device copy moved_buf = x.device_buffer.copy_to_device(device) else: # buffers from different XLA backends are passed through the host. backend = xb.get_device_backend(device) moved_buf = backend.buffer_from_pyval(np.asarray(x.device_buffer), device) return device_array.make_device_array(x.aval, device, moved_buf) def _copy_array_to_device(x: Array, device: Optional[xc.Device]) -> Array: """Copies `Array`s with SingleDeviceSharding to a different device.""" from jax.experimental import array, sharding if device is None: # no copying to be done because there's no target specified return x buf = x._arrays[0] if xb.get_device_backend(device).platform == buf.platform(): # source and target platforms are the same if x.device() == device: # no copying to be done because source equals target if x._committed: return x else: moved_buf = buf # We need to change stickyness else: # move the buffer with a device-to-device copy moved_buf = buf.copy_to_device(device) else: # buffers from different XLA backends are passed through the host. backend = xb.get_device_backend(device) moved_buf = backend.buffer_from_pyval(np.asarray(buf), device) return array.Array( x.aval, sharding.SingleDeviceSharding(moved_buf.device()), [moved_buf], committed=(device is not None)) def _device_put_impl(x, device: Optional[Device] = None): from jax.experimental import array, sharding if device_array.type_is_device_array(x): return _copy_device_array_to_device(x, device) if type(x) is array.Array and isinstance(x.sharding, sharding.SingleDeviceSharding): return _copy_array_to_device(x, device) try: a = xla.abstractify(x) except TypeError as err: raise TypeError( f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err return aval_to_result_handler(device, a)(None, *device_put(x, device)) device_put_p = core.Primitive('device_put') device_put_p.def_impl(_device_put_impl) device_put_p.def_abstract_eval(lambda x, device=None: x) ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent]) batching.defvectorized(device_put_p) def _device_put_lowering(ctx, x, *, device): return [x] mlir.register_lowering(device_put_p, _device_put_lowering)