2022-09-22 12:26:48 -07:00
|
|
|
|
# Copyright 2018 The JAX Authors.
|
2021-11-22 08:22:10 -08:00
|
|
|
|
#
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
#
|
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
#
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
# Primitive dispatch and jit dispatch.
|
2022-04-09 10:56:14 -07:00
|
|
|
|
from __future__ import annotations
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
2022-04-14 14:18:31 -07:00
|
|
|
|
import atexit
|
2022-08-18 11:40:58 -07:00
|
|
|
|
import collections
|
2021-12-13 21:51:08 -08:00
|
|
|
|
import contextlib
|
2021-11-22 08:22:10 -08:00
|
|
|
|
from functools import partial
|
|
|
|
|
import itertools
|
2021-12-13 21:51:08 -08:00
|
|
|
|
import time
|
2021-11-22 08:22:10 -08:00
|
|
|
|
from typing import (
|
2022-08-19 10:03:43 -07:00
|
|
|
|
Any, Callable, Dict, Optional, Sequence, Set, Tuple, List, Type, Union,
|
|
|
|
|
TYPE_CHECKING)
|
2021-11-29 12:39:19 -08:00
|
|
|
|
from typing_extensions import Protocol
|
2021-12-14 17:43:40 -08:00
|
|
|
|
import os
|
|
|
|
|
import re
|
2022-04-14 14:18:31 -07:00
|
|
|
|
import threading
|
2021-11-22 08:22:10 -08:00
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
|
|
from absl import logging
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
|
import jax
|
2021-11-22 08:22:10 -08:00
|
|
|
|
from jax import core
|
|
|
|
|
from jax import linear_util as lu
|
2022-03-22 12:16:03 -07:00
|
|
|
|
from jax.errors import UnexpectedTracerError
|
2021-11-22 08:22:10 -08:00
|
|
|
|
import jax.interpreters.ad as ad
|
|
|
|
|
import jax.interpreters.batching as batching
|
2021-11-29 12:39:19 -08:00
|
|
|
|
import jax.interpreters.mlir as mlir
|
2021-11-22 08:22:10 -08:00
|
|
|
|
import jax.interpreters.xla as xla
|
|
|
|
|
import jax.interpreters.partial_eval as pe
|
|
|
|
|
from jax._src import device_array
|
|
|
|
|
from jax._src import dtypes
|
2021-12-06 15:13:01 -08:00
|
|
|
|
from jax._src import profiler
|
2022-03-22 12:16:03 -07:00
|
|
|
|
from jax._src import stages
|
|
|
|
|
from jax._src import traceback_util
|
|
|
|
|
from jax._src.abstract_arrays import array_types
|
|
|
|
|
from jax._src.config import config, flags
|
2021-12-10 14:56:10 -08:00
|
|
|
|
from jax._src.lib.mlir import ir
|
2022-08-09 14:34:30 -07:00
|
|
|
|
from jax._src.lib import can_execute_with_token
|
2021-11-22 08:22:10 -08:00
|
|
|
|
from jax._src.lib import xla_bridge as xb
|
|
|
|
|
from jax._src.lib import xla_client as xc
|
|
|
|
|
import jax._src.util as util
|
2022-04-09 10:56:14 -07:00
|
|
|
|
from jax._src.util import flatten, unflatten
|
2022-05-31 12:46:54 -07:00
|
|
|
|
from etils import epath
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
2022-08-19 10:03:43 -07:00
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from jax.experimental.array import Array
|
|
|
|
|
|
2021-12-14 17:43:40 -08:00
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
|
|
|
|
flags.DEFINE_string(
|
|
|
|
|
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
|
|
|
|
|
help="Path to which HLO/MHLO IR that is emitted by JAX as input to the "
|
|
|
|
|
"compiler should be dumped as text files. Optional. If omitted, JAX "
|
|
|
|
|
"will not dump IR.")
|
|
|
|
|
|
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
traceback_util.register_exclusion(__file__)
|
|
|
|
|
|
2021-11-29 12:39:19 -08:00
|
|
|
|
MYPY = False # Are we currently type checking with mypy?
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
xe = xc._xla
|
|
|
|
|
|
|
|
|
|
Backend = xe.Client
|
|
|
|
|
Device = xc.Device
|
|
|
|
|
Buffer = xe.Buffer
|
|
|
|
|
|
|
|
|
|
XlaExecutable = xc.Executable
|
|
|
|
|
|
|
|
|
|
map, unsafe_map = util.safe_map, map
|
|
|
|
|
zip, unsafe_zip = util.safe_zip, zip
|
|
|
|
|
|
|
|
|
|
# This flag is set on exit; no logging should be attempted
|
|
|
|
|
_on_exit = False
|
|
|
|
|
|
|
|
|
|
### op-by-op execution
|
|
|
|
|
|
|
|
|
|
ArgSpec = Tuple[core.AbstractValue, Optional[Device]]
|
|
|
|
|
|
|
|
|
|
def arg_spec(x: Any) -> ArgSpec:
|
2022-08-18 11:40:58 -07:00
|
|
|
|
from jax.experimental.sharding import PmapSharding
|
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
aval = xla.abstractify(x)
|
|
|
|
|
try:
|
2022-08-17 12:25:14 -07:00
|
|
|
|
if config.jax_array:
|
2022-08-30 10:45:29 -07:00
|
|
|
|
if isinstance(x.sharding, PmapSharding):
|
|
|
|
|
return aval, None
|
2022-08-18 11:40:58 -07:00
|
|
|
|
return aval, (x.sharding if x._committed else None)
|
|
|
|
|
else:
|
|
|
|
|
return aval, x._device
|
2021-11-22 08:22:10 -08:00
|
|
|
|
except:
|
|
|
|
|
return aval, None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def apply_primitive(prim, *args, **params):
|
|
|
|
|
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
|
|
|
|
|
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
|
|
|
|
|
**params)
|
|
|
|
|
return compiled_fun(*args)
|
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
|
# TODO(phawkins,frostig,mattjj): update code referring to
|
|
|
|
|
# xla.apply_primitive to point here, or use simple_impl if that's why
|
|
|
|
|
# it is using apply_primitive to begin with
|
2021-11-22 08:22:10 -08:00
|
|
|
|
xla.apply_primitive = apply_primitive
|
|
|
|
|
|
2022-08-22 13:56:50 -07:00
|
|
|
|
def simple_impl(prim):
|
|
|
|
|
prim.def_impl(partial(apply_primitive, prim))
|
|
|
|
|
|
2022-04-14 14:18:31 -07:00
|
|
|
|
RuntimeToken = Any
|
|
|
|
|
|
|
|
|
|
class RuntimeTokenSet(threading.local):
|
|
|
|
|
tokens: Dict[core.Effect, Tuple[RuntimeToken, Device]]
|
2022-05-16 18:55:52 -07:00
|
|
|
|
output_tokens: Dict[Device, RuntimeToken]
|
2022-08-04 13:23:02 -07:00
|
|
|
|
output_runtime_tokens: Dict[Device, RuntimeToken]
|
2022-04-14 14:18:31 -07:00
|
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self.tokens = {}
|
2022-08-04 13:23:02 -07:00
|
|
|
|
# TODO(sharadmv): remove redundant output token dictionary when minimum
|
|
|
|
|
# jaxlib version is bumped to 0.3.16.
|
2022-05-16 18:55:52 -07:00
|
|
|
|
self.output_tokens = {}
|
2022-08-04 13:23:02 -07:00
|
|
|
|
self.output_runtime_tokens = {}
|
2022-04-14 14:18:31 -07:00
|
|
|
|
|
|
|
|
|
def get_token(self, eff: core.Effect, device: Device) -> RuntimeToken:
|
|
|
|
|
if eff not in self.tokens:
|
|
|
|
|
self.tokens[eff] = device_put(np.zeros(0, np.bool_), device), device
|
|
|
|
|
elif self.tokens[eff][1] != device:
|
|
|
|
|
(old_token,), _ = self.tokens[eff]
|
|
|
|
|
old_token.aval = core.ShapedArray((0,), np.bool_)
|
|
|
|
|
self.tokens[eff] = device_put(old_token, device), device
|
|
|
|
|
return self.tokens[eff][0]
|
|
|
|
|
|
|
|
|
|
def update_token(self, eff: core.Effect, token: RuntimeToken):
|
|
|
|
|
self.tokens[eff] = token, self.tokens[eff][1]
|
|
|
|
|
|
2022-05-16 18:55:52 -07:00
|
|
|
|
def set_output_token(self, device: Device, token: RuntimeToken):
|
|
|
|
|
# We're free to clobber the previous output token because on each
|
|
|
|
|
# device we have a total ordering of computations. Only the token
|
|
|
|
|
# from the latest computation matters. If this weren't the case
|
|
|
|
|
# we'd need to store a set of output tokens.
|
|
|
|
|
self.output_tokens[device] = token
|
|
|
|
|
|
2022-08-04 13:23:02 -07:00
|
|
|
|
def set_output_runtime_token(self, device: Device, token: RuntimeToken):
|
|
|
|
|
# TODO(sharadmv): remove this method when minimum jaxlib version is bumped
|
|
|
|
|
self.output_runtime_tokens[device] = token
|
|
|
|
|
|
2022-04-14 14:18:31 -07:00
|
|
|
|
def clear(self):
|
|
|
|
|
self.tokens = {}
|
2022-05-16 18:55:52 -07:00
|
|
|
|
self.output_tokens = {}
|
2022-08-04 13:23:02 -07:00
|
|
|
|
self.output_runtime_tokens = {}
|
2022-04-14 14:18:31 -07:00
|
|
|
|
|
|
|
|
|
def block_until_ready(self):
|
2022-08-04 13:23:02 -07:00
|
|
|
|
for token, _ in self.tokens.values():
|
|
|
|
|
token[0].block_until_ready()
|
|
|
|
|
for token in self.output_tokens.values():
|
|
|
|
|
token[0].block_until_ready()
|
|
|
|
|
for token in self.output_runtime_tokens.values():
|
|
|
|
|
token.block_until_ready()
|
2022-08-17 10:43:50 -07:00
|
|
|
|
self.clear()
|
2022-04-14 14:18:31 -07:00
|
|
|
|
|
|
|
|
|
runtime_tokens: RuntimeTokenSet = RuntimeTokenSet()
|
|
|
|
|
|
|
|
|
|
@atexit.register
|
|
|
|
|
def wait_for_tokens():
|
|
|
|
|
runtime_tokens.block_until_ready()
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
@util.cache()
|
|
|
|
|
def xla_primitive_callable(prim, *arg_specs: ArgSpec, **params):
|
2022-08-18 11:40:58 -07:00
|
|
|
|
_, arg_devices = util.unzip2(arg_specs)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
donated_invars = (False,) * len(arg_specs)
|
2022-08-18 11:40:58 -07:00
|
|
|
|
if config.jax_array:
|
2022-08-29 22:02:32 -07:00
|
|
|
|
# This will be resolved in sharded_lowering.
|
2022-08-18 11:40:58 -07:00
|
|
|
|
device = None
|
|
|
|
|
else:
|
|
|
|
|
device = _device_from_arg_devices(arg_devices)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
def prim_fun(*args):
|
|
|
|
|
out = prim.bind(*args, **params)
|
|
|
|
|
if prim.multiple_results:
|
|
|
|
|
return out
|
|
|
|
|
else:
|
|
|
|
|
return out,
|
|
|
|
|
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
|
2022-05-04 01:21:39 -07:00
|
|
|
|
prim.name, donated_invars, False, *arg_specs)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
if not prim.multiple_results:
|
|
|
|
|
return lambda *args, **kw: compiled(*args, **kw)[0]
|
|
|
|
|
else:
|
|
|
|
|
return compiled
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[Device]:
|
|
|
|
|
"""Given devices of inputs, determine where to perform a computation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
devices: list where each element is a either a `Device` instance or `None`.
|
|
|
|
|
Returns:
|
|
|
|
|
A `Device` instance or None.
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError if input devices are inconsistent.
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
device, = {d for d in devices if d is not None} or (None,)
|
|
|
|
|
return device
|
|
|
|
|
except ValueError as err:
|
|
|
|
|
msg = "primitive arguments must be colocated on the same device, got {}"
|
|
|
|
|
raise ValueError(msg.format(", ".join(map(str, devices)))) from err
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# JIT execution
|
|
|
|
|
|
|
|
|
|
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
|
2022-05-04 01:21:39 -07:00
|
|
|
|
donated_invars, inline, keep_unused: bool):
|
2021-11-22 08:22:10 -08:00
|
|
|
|
del inline # Only used at tracing time
|
2022-06-29 13:55:30 -07:00
|
|
|
|
if fun.in_type is None:
|
|
|
|
|
arg_specs = unsafe_map(arg_spec, args)
|
|
|
|
|
else:
|
2022-08-18 11:40:58 -07:00
|
|
|
|
# fun.in_type is used for dynamic shapes.
|
|
|
|
|
if config.jax_array:
|
|
|
|
|
raise NotImplementedError('Dynamic shapes do not work with Array.')
|
2022-06-29 13:55:30 -07:00
|
|
|
|
arg_specs = [(None, getattr(x, '_device', None)) for x in args]
|
2022-06-10 18:36:18 -07:00
|
|
|
|
compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
|
|
|
|
|
keep_unused, *arg_specs)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
try:
|
2022-04-12 14:40:19 -07:00
|
|
|
|
return compiled_fun(*args)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
except FloatingPointError:
|
|
|
|
|
assert config.jax_debug_nans or config.jax_debug_infs # compiled_fun can only raise in this case
|
2022-04-12 14:40:19 -07:00
|
|
|
|
print("Invalid value encountered in the output of a jit-decorated function. "
|
2021-11-22 08:22:10 -08:00
|
|
|
|
"Calling the de-optimized version.")
|
2022-06-10 18:36:18 -07:00
|
|
|
|
# We want to run the wrapped function again (after xla_callable already ran
|
2021-11-22 08:22:10 -08:00
|
|
|
|
# it), but linear_util.WrappedFun instances are meant to be run only once.
|
|
|
|
|
# In addition to re-executing the Python code, which is usually undesirable
|
2022-04-12 14:40:19 -07:00
|
|
|
|
# but which config.jax_debug_nans is meant to opt into, we'll be
|
|
|
|
|
# re-executing any linear_util.py-style side effects, i.e. re-populating
|
|
|
|
|
# Stores created by any transformation_with_aux's applied to fun. Since this
|
|
|
|
|
# is intentional here, to avoid "Store occupied" errors we clone the
|
|
|
|
|
# WrappedFun with empty stores.
|
2021-11-22 08:22:10 -08:00
|
|
|
|
stores = [lu.Store() for _ in fun.stores]
|
2022-05-18 17:26:10 -07:00
|
|
|
|
clone = lu.WrappedFun(fun.f, fun.transforms, stores, fun.params,
|
|
|
|
|
fun.in_type)
|
2022-04-12 14:40:19 -07:00
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
with core.new_sublevel():
|
2022-04-12 14:40:19 -07:00
|
|
|
|
_ = clone.call_wrapped(*args) # may raise, not return
|
|
|
|
|
|
|
|
|
|
# If control reaches this line, we got a NaN on the output of `compiled_fun`
|
|
|
|
|
# but not `clone.call_wrapped` on the same arguments. Let's tell the user.
|
|
|
|
|
fun_info = pe.fun_sourceinfo(fun.f)
|
|
|
|
|
msg = ("An invalid value was encountered in the output of the "
|
|
|
|
|
f"`jit`-decorated function {fun_info}. Because "
|
|
|
|
|
"config.jax_debug_nans and/or config.jax_debug_infs is set, the "
|
|
|
|
|
"de-optimized function (i.e., the function as if the `jit` "
|
|
|
|
|
"decorator were removed) was called in an attempt to get a more "
|
|
|
|
|
"precise error message. However, the de-optimized function did not "
|
|
|
|
|
"produce invalid values during its execution. This behavior can "
|
|
|
|
|
"result from `jit` optimizations causing the invalud value to be "
|
|
|
|
|
"produced. It may also arise from having nan/inf constants as "
|
|
|
|
|
"outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. "
|
|
|
|
|
"\n\n"
|
|
|
|
|
"It may be possible to avoid the invalid value by removing the "
|
|
|
|
|
"`jit` decorator, at the cost of losing optimizations. "
|
|
|
|
|
"\n\n"
|
|
|
|
|
"If you see this error, consider opening a bug report at "
|
|
|
|
|
"https://github.com/google/jax.")
|
|
|
|
|
raise FloatingPointError(msg)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
xla.xla_call_p.def_impl(_xla_call_impl)
|
|
|
|
|
|
|
|
|
|
|
2022-09-09 16:56:10 -07:00
|
|
|
|
# TODO(yashkatariya,mattjj): Try to handle this in api.py via a device_put and
|
|
|
|
|
# don't pass the device and backend argument to `_xla_callable_uncached`.
|
|
|
|
|
def not_none_device_or_backend_on_jit(backend, device, num_ins):
|
|
|
|
|
"""This is to support the backend and device argument on jit. It's a feature
|
|
|
|
|
that's deprecated but needs to be supported for feature parity and so that we
|
|
|
|
|
can delete the non-Array paths when Array is switched on.
|
|
|
|
|
"""
|
|
|
|
|
# TODO(yashkatariya): Remove this entire function when backend and device are
|
|
|
|
|
# removed as arguments on jit.
|
|
|
|
|
|
|
|
|
|
from jax.experimental import sharding
|
|
|
|
|
|
|
|
|
|
if device is not None and backend is not None:
|
|
|
|
|
raise ValueError("can't specify both a device and a backend for jit, "
|
|
|
|
|
"got device={} and backend={}".format(device, backend))
|
|
|
|
|
|
|
|
|
|
if backend is not None:
|
|
|
|
|
da = [xb.get_backend(backend).get_default_device_assignment(1)[0]]
|
|
|
|
|
else:
|
|
|
|
|
assert device is not None
|
|
|
|
|
da = [device]
|
|
|
|
|
|
|
|
|
|
assert len(da) == 1
|
|
|
|
|
# Set committed to True for this path because it simulates a device_put on
|
|
|
|
|
# behalf of a user.
|
|
|
|
|
committed = True
|
|
|
|
|
# in_shardings will be marked as replicated regardless of whatever the input
|
|
|
|
|
# had. Given that only a single device is allowed above, this is correct.
|
|
|
|
|
in_shardings = [sharding.OpShardingSharding.get_replicated(da)] * num_ins
|
|
|
|
|
return committed, da, in_shardings
|
|
|
|
|
|
|
|
|
|
|
2022-09-22 11:13:04 -07:00
|
|
|
|
def sharded_lowering(fun, device, backend, name, donated_invars, always_lower,
|
|
|
|
|
keep_unused, *arg_specs):
|
2022-08-18 11:40:58 -07:00
|
|
|
|
# TODO(yashkatariya): Remove the local imports from here when the functions
|
|
|
|
|
# in pxla.py move to dispatch.py or a utils file.
|
|
|
|
|
from jax.interpreters import pxla
|
|
|
|
|
from jax.experimental import pjit, sharding
|
|
|
|
|
|
|
|
|
|
in_avals, in_shardings = util.unzip2(arg_specs)
|
|
|
|
|
|
2022-09-09 16:56:10 -07:00
|
|
|
|
if backend is not None or device is not None:
|
|
|
|
|
committed, da, in_shardings = not_none_device_or_backend_on_jit(
|
|
|
|
|
backend, device, len(in_shardings))
|
2022-09-08 22:13:15 -07:00
|
|
|
|
else:
|
|
|
|
|
committed = any(i is not None for i in in_shardings)
|
|
|
|
|
da = pjit._get_and_check_device_assignment(
|
|
|
|
|
(i for i in in_shardings if i is not None), pxla.EMPTY_ENV.physical_mesh)
|
|
|
|
|
in_shardings = [sharding.OpShardingSharding.get_replicated(da) if i is None else i
|
|
|
|
|
for i in in_shardings]
|
2022-09-07 10:29:34 -07:00
|
|
|
|
|
|
|
|
|
process_index = xb.process_index()
|
|
|
|
|
local_da = [d for d in da if d.process_index == process_index]
|
|
|
|
|
if len(local_da) != len(da):
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"Running operations on `Array`s that are not fully addressable by this "
|
|
|
|
|
"process (i.e. `Array`s with data sharded across multiple devices and "
|
|
|
|
|
"processes.) is dangerous. It’s very important that all processes run "
|
|
|
|
|
"the same cross-process computations in the same order otherwise it "
|
|
|
|
|
"can lead to hangs.\n"
|
|
|
|
|
"If you’re not already familiar with JAX’s multi-process "
|
|
|
|
|
"programming model, please read "
|
|
|
|
|
"https://jax.readthedocs.io/en/latest/multi_process.html.")
|
|
|
|
|
|
2022-09-08 08:49:12 -07:00
|
|
|
|
if not in_shardings:
|
|
|
|
|
inp_device_assignment = da
|
|
|
|
|
else:
|
|
|
|
|
inp_device_assignment = None
|
|
|
|
|
|
2022-08-18 11:40:58 -07:00
|
|
|
|
# Pass in a singleton `_UNSPECIFIED` for out_shardings because we don't know
|
|
|
|
|
# the number of output avals at this stage. lower_sharding_computation will
|
|
|
|
|
# apply it to all out_avals.
|
|
|
|
|
return pxla.lower_sharding_computation(
|
2022-08-30 10:45:29 -07:00
|
|
|
|
fun, 'jit', name, in_shardings, pjit._UNSPECIFIED,
|
2022-08-18 11:40:58 -07:00
|
|
|
|
donated_invars, in_avals,
|
2022-08-29 22:02:32 -07:00
|
|
|
|
in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused,
|
2022-09-22 11:13:04 -07:00
|
|
|
|
committed=committed, always_lower=always_lower,
|
|
|
|
|
inp_device_assignment=inp_device_assignment)
|
2022-08-18 11:40:58 -07:00
|
|
|
|
|
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
|
2022-05-04 01:21:39 -07:00
|
|
|
|
donated_invars, keep_unused, *arg_specs):
|
2022-09-08 08:49:12 -07:00
|
|
|
|
if config.jax_array:
|
2022-09-22 11:13:04 -07:00
|
|
|
|
computation = sharded_lowering(fun, device, backend, name, donated_invars,
|
|
|
|
|
False, keep_unused, *arg_specs)
|
|
|
|
|
return computation.compile(_allow_propagation_to_outputs=True).unsafe_call
|
2022-08-18 11:40:58 -07:00
|
|
|
|
else:
|
|
|
|
|
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
|
|
|
|
|
keep_unused, *arg_specs).compile().unsafe_call
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
2022-06-10 18:36:18 -07:00
|
|
|
|
xla_callable = lu.cache(_xla_callable_uncached)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
|
2022-09-08 08:49:12 -07:00
|
|
|
|
def is_single_device_sharding(sharding) -> bool:
|
|
|
|
|
from jax.experimental.sharding import PmapSharding
|
|
|
|
|
# Special case PmapSharding here because PmapSharding maps away an axis
|
|
|
|
|
# and needs to be handled separately.
|
|
|
|
|
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
|
|
|
|
|
|
|
|
|
|
|
2021-12-13 21:51:08 -08:00
|
|
|
|
@contextlib.contextmanager
|
|
|
|
|
def log_elapsed_time(fmt: str):
|
|
|
|
|
if _on_exit:
|
|
|
|
|
yield
|
|
|
|
|
else:
|
|
|
|
|
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
yield
|
|
|
|
|
elapsed_time = time.time() - start_time
|
|
|
|
|
logging.log(log_priority, fmt.format(elapsed_time=elapsed_time))
|
|
|
|
|
|
|
|
|
|
|
2022-08-19 04:57:07 -07:00
|
|
|
|
def should_tuple_args(num_args: int, platform: str):
|
2022-09-22 01:28:45 -07:00
|
|
|
|
# CPU does not need a tuple as it uses a buffer table
|
|
|
|
|
# TPU only needs a tuple for very long lists
|
|
|
|
|
if platform == "cpu":
|
|
|
|
|
return False
|
|
|
|
|
elif platform == "tpu":
|
2022-08-19 04:57:07 -07:00
|
|
|
|
return num_args > 2000
|
|
|
|
|
else:
|
|
|
|
|
return num_args > 100
|
|
|
|
|
|
|
|
|
|
|
2022-08-30 10:45:29 -07:00
|
|
|
|
def raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr):
|
|
|
|
|
if nreps > 1:
|
|
|
|
|
warnings.warn(
|
|
|
|
|
f"The jitted function {name} includes a pmap. Using "
|
|
|
|
|
"jit-of-pmap can lead to inefficient data movement, as the outer jit "
|
|
|
|
|
"does not preserve sharded data representations and instead collects "
|
|
|
|
|
"input and output arrays onto a single device. "
|
|
|
|
|
"Consider removing the outer jit unless you know what you're doing. "
|
|
|
|
|
"See https://github.com/google/jax/issues/2926.")
|
|
|
|
|
|
|
|
|
|
if nreps > xb.device_count(backend):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"compiling computation `{name}` that requires {nreps} replicas, but "
|
|
|
|
|
f"only {xb.device_count(backend)} XLA devices are available.")
|
|
|
|
|
|
|
|
|
|
if xb.process_count() > 1 and (nreps > 1 or
|
|
|
|
|
jaxpr_has_primitive(jaxpr, "xla_pmap")):
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"jit of multi-host pmap not implemented (and jit-of-pmap can cause "
|
|
|
|
|
"extra data movement anyway, so maybe you don't want it after all).")
|
|
|
|
|
|
|
|
|
|
|
2021-12-06 15:13:01 -08:00
|
|
|
|
@profiler.annotate_function
|
2022-08-29 22:02:32 -07:00
|
|
|
|
def lower_xla_callable(
|
|
|
|
|
fun: lu.WrappedFun, device, backend, name, donated_invars,
|
2022-08-30 10:45:29 -07:00
|
|
|
|
always_lower: bool, keep_unused: bool, *arg_specs):
|
2022-05-04 01:21:39 -07:00
|
|
|
|
"""Lower into XLA.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
always_lower: If `True`, even trivial programs (not doing any computation
|
|
|
|
|
such as lambda x: x) will be lowered into an XLA program.
|
|
|
|
|
keep_unused: If `False` (the default), arguments that JAX determines to be
|
|
|
|
|
unused by `fun` *may* be dropped from resulting compiled XLA executables.
|
|
|
|
|
Such arguments will not be transferred to the device nor provided to the
|
|
|
|
|
underlying executable. If `True`, unused arguments will not be pruned.
|
|
|
|
|
"""
|
2021-11-22 08:22:10 -08:00
|
|
|
|
if device is not None and backend is not None:
|
|
|
|
|
raise ValueError("can't specify both a device and a backend for jit, "
|
|
|
|
|
"got device={} and backend={}".format(device, backend))
|
|
|
|
|
abstract_args, arg_devices = util.unzip2(arg_specs)
|
2022-05-26 23:21:09 -07:00
|
|
|
|
if fun.in_type is None:
|
|
|
|
|
# Add an annotation inferred from the arguments; no dynamic axes here.
|
|
|
|
|
in_type = tuple(unsafe_zip(abstract_args, itertools.repeat(True)))
|
|
|
|
|
fun = lu.annotate(fun, in_type)
|
2022-03-30 17:52:55 -07:00
|
|
|
|
else:
|
2022-06-29 13:55:30 -07:00
|
|
|
|
assert abstract_args == (None,) * len(abstract_args)
|
|
|
|
|
abstract_args = [aval for aval, _ in fun.in_type]
|
2022-08-29 22:02:32 -07:00
|
|
|
|
|
2022-08-30 10:45:29 -07:00
|
|
|
|
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
|
|
|
|
"for jit in {elapsed_time} sec"):
|
|
|
|
|
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
|
|
|
|
|
fun, pe.debug_info_final(fun, "jit"))
|
|
|
|
|
out_avals, kept_outputs = util.unzip2(out_type)
|
2022-08-29 22:02:32 -07:00
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
if any(isinstance(c, core.Tracer) for c in consts):
|
|
|
|
|
raise UnexpectedTracerError("Encountered an unexpected tracer.")
|
2022-05-18 17:26:10 -07:00
|
|
|
|
|
2022-05-26 23:21:09 -07:00
|
|
|
|
if config.jax_dynamic_shapes:
|
2022-05-18 17:26:10 -07:00
|
|
|
|
keep_unused = True
|
2022-06-17 15:53:53 -07:00
|
|
|
|
has_outfeed = False
|
2022-06-29 13:55:30 -07:00
|
|
|
|
donated_invars = [False] * len(fun.in_type)
|
2022-05-18 17:26:10 -07:00
|
|
|
|
else:
|
2022-06-17 15:53:53 -07:00
|
|
|
|
has_outfeed = core.jaxpr_uses_outfeed(jaxpr)
|
2022-05-18 17:26:10 -07:00
|
|
|
|
jaxpr = apply_outfeed_rewriter(jaxpr)
|
|
|
|
|
|
|
|
|
|
if not keep_unused:
|
2022-03-30 17:52:55 -07:00
|
|
|
|
jaxpr, kept_const_idx, kept_var_idx = _prune_unused_inputs(jaxpr)
|
|
|
|
|
consts = [c for i, c in enumerate(consts) if i in kept_const_idx]
|
|
|
|
|
abstract_args, arg_devices = util.unzip2(
|
|
|
|
|
[a for i, a in enumerate(arg_specs) if i in kept_var_idx])
|
2022-05-18 17:26:10 -07:00
|
|
|
|
donated_invars = [x for i, x in enumerate(donated_invars)
|
|
|
|
|
if i in kept_var_idx]
|
2022-03-30 17:52:55 -07:00
|
|
|
|
del kept_const_idx
|
|
|
|
|
else:
|
2022-06-29 13:55:30 -07:00
|
|
|
|
kept_var_idx = set(range(len(fun.in_type)))
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
nreps = jaxpr_replicas(jaxpr)
|
|
|
|
|
device = _xla_callable_device(nreps, backend, device, arg_devices)
|
|
|
|
|
backend = xb.get_device_backend(device) if device else xb.get_backend(backend)
|
|
|
|
|
|
2022-06-29 13:55:30 -07:00
|
|
|
|
if config.jax_dynamic_shapes and jaxpr_has_bints(jaxpr):
|
2022-03-30 17:52:55 -07:00
|
|
|
|
jaxpr, consts = pe.pad_jaxpr(jaxpr, consts)
|
|
|
|
|
|
2022-05-18 17:26:10 -07:00
|
|
|
|
map(prefetch, itertools.chain(consts, jaxpr_literals(jaxpr)))
|
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
# Computations that only produce constants and/or only rearrange their inputs,
|
|
|
|
|
# which are often produced from partial evaluation, don't need compilation,
|
|
|
|
|
# and don't need to evaluate their arguments.
|
2022-06-17 15:53:53 -07:00
|
|
|
|
if (not always_lower and not (jaxpr.effects or has_outfeed) and
|
|
|
|
|
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars)):
|
2021-11-22 08:22:10 -08:00
|
|
|
|
return XlaComputation(
|
2022-05-18 17:26:10 -07:00
|
|
|
|
name, None, True, None, None, None, jaxpr=jaxpr, consts=consts,
|
|
|
|
|
device=device, in_avals=abstract_args, out_avals=out_avals,
|
2022-05-16 18:55:52 -07:00
|
|
|
|
has_unordered_effects=False, ordered_effects=[],
|
2022-07-06 20:52:08 -07:00
|
|
|
|
kept_var_idx=kept_var_idx, keepalive=None, host_callbacks=[])
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
if not _on_exit:
|
|
|
|
|
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
|
2021-12-13 21:51:08 -08:00
|
|
|
|
if len(abstract_args) > 10:
|
|
|
|
|
msg = f"Compiling {fun.__name__} ({id(fun)}) for {len(abstract_args)} args."
|
|
|
|
|
else:
|
|
|
|
|
msg = f"Compiling {fun.__name__} ({id(fun)} for args {abstract_args}."
|
|
|
|
|
logging.log(log_priority, msg)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
2022-08-30 10:45:29 -07:00
|
|
|
|
raise_warnings_or_errors_for_jit_of_pmap(nreps, backend, name, jaxpr)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
2021-11-29 12:39:19 -08:00
|
|
|
|
# pass long arg lists as tuple for TPU
|
2022-08-19 04:57:07 -07:00
|
|
|
|
tuple_args = should_tuple_args(len(abstract_args), backend.platform)
|
2021-11-29 12:39:19 -08:00
|
|
|
|
axis_env = xla.AxisEnv(nreps, (), ())
|
2022-04-14 15:22:58 -07:00
|
|
|
|
name_stack = util.new_name_stack(util.wrap_name(name, 'jit'))
|
2021-11-30 14:24:02 -08:00
|
|
|
|
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
2021-12-10 14:56:10 -08:00
|
|
|
|
module_name = f"jit_{fun.__name__}"
|
2022-05-16 18:55:52 -07:00
|
|
|
|
unordered_effects = [eff for eff in closed_jaxpr.effects
|
|
|
|
|
if eff not in core.ordered_effects]
|
|
|
|
|
ordered_effects = [eff for eff in closed_jaxpr.effects
|
|
|
|
|
if eff in core.ordered_effects]
|
2022-07-06 20:52:08 -07:00
|
|
|
|
lowering_result = mlir.lower_jaxpr_to_module(
|
2022-08-16 14:25:10 -07:00
|
|
|
|
module_name, closed_jaxpr, unordered_effects,
|
|
|
|
|
ordered_effects, backend, backend.platform,
|
2022-07-06 20:52:08 -07:00
|
|
|
|
mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
|
|
|
|
|
module, keepalive, host_callbacks = (
|
|
|
|
|
lowering_result.module, lowering_result.keepalive,
|
|
|
|
|
lowering_result.host_callbacks)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
return XlaComputation(
|
2022-05-18 17:26:10 -07:00
|
|
|
|
name, module, False, donated_invars, fun.in_type, out_type, nreps=nreps,
|
2022-03-30 17:52:55 -07:00
|
|
|
|
device=device, backend=backend, tuple_args=tuple_args,
|
2022-05-16 18:55:52 -07:00
|
|
|
|
in_avals=abstract_args, out_avals=out_avals,
|
|
|
|
|
has_unordered_effects=bool(unordered_effects),
|
|
|
|
|
ordered_effects=ordered_effects, kept_var_idx=kept_var_idx,
|
2022-07-06 20:52:08 -07:00
|
|
|
|
keepalive=keepalive, host_callbacks=host_callbacks)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
|
2022-04-13 13:44:42 -07:00
|
|
|
|
def _backend_supports_unbounded_dynamic_shapes(backend: Backend) -> bool:
|
|
|
|
|
return backend.platform == 'iree'
|
|
|
|
|
|
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
def prefetch(x):
|
|
|
|
|
if isinstance(x, device_array.DeviceArray):
|
|
|
|
|
x.copy_to_host_async()
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def jaxpr_literals(jaxpr):
|
|
|
|
|
"""Generates all the literals inside a jaxpr, including nested subjaxprs."""
|
|
|
|
|
for eqn in jaxpr.eqns:
|
|
|
|
|
for v in eqn.invars:
|
|
|
|
|
if type(v) is core.Literal:
|
|
|
|
|
yield v.val
|
|
|
|
|
for subjaxpr in core.subjaxprs(jaxpr):
|
|
|
|
|
yield from jaxpr_literals(subjaxpr)
|
|
|
|
|
|
|
|
|
|
|
2022-07-11 13:23:44 -07:00
|
|
|
|
def jaxpr_has_primitive(jaxpr, prim_name: str):
|
|
|
|
|
"""Whether there is a primitive given by user anywhere inside a Jaxpr."""
|
2021-11-22 08:22:10 -08:00
|
|
|
|
for eqn in jaxpr.eqns:
|
2022-07-11 13:23:44 -07:00
|
|
|
|
if prim_name in eqn.primitive.name:
|
2021-11-22 08:22:10 -08:00
|
|
|
|
return True
|
|
|
|
|
for subjaxpr in core.subjaxprs(jaxpr):
|
2022-07-11 13:23:44 -07:00
|
|
|
|
if jaxpr_has_primitive(subjaxpr, prim_name):
|
2021-11-22 08:22:10 -08:00
|
|
|
|
return True
|
|
|
|
|
return False
|
|
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
|
def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:
|
2022-06-29 13:55:30 -07:00
|
|
|
|
return (any(type(v.aval) is core.AbstractBInt for v in jaxpr.invars) or
|
|
|
|
|
any(type(v.aval) is core.AbstractBInt
|
2022-03-30 17:52:55 -07:00
|
|
|
|
for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr))
|
2022-06-29 13:55:30 -07:00
|
|
|
|
for e in j.eqns for v in e.outvars))
|
2022-03-30 17:52:55 -07:00
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
def _prune_unused_inputs(
|
|
|
|
|
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
|
2022-09-08 08:49:12 -07:00
|
|
|
|
used_outputs = [True] * len(jaxpr.outvars)
|
|
|
|
|
new_jaxpr, used_consts, used_inputs = pe.dce_jaxpr_consts(jaxpr, used_outputs)
|
|
|
|
|
kept_const_idx = {i for i, b in enumerate(used_consts) if b}
|
|
|
|
|
kept_var_idx = {i for i, b in enumerate(used_inputs) if b}
|
|
|
|
|
return new_jaxpr, kept_const_idx, kept_var_idx
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# We can optionally set a Jaxpr rewriter that can be applied just before
|
|
|
|
|
# compilation. This mechanism is used for compiling id_tap, we can
|
|
|
|
|
# remove it once we bring the id_tap implementation into the core.
|
|
|
|
|
outfeed_rewriter: Optional[Callable[[core.Jaxpr], core.Jaxpr]] = None
|
|
|
|
|
def apply_outfeed_rewriter(jaxpr: core.Jaxpr) -> core.Jaxpr:
|
|
|
|
|
if outfeed_rewriter is not None:
|
|
|
|
|
return outfeed_rewriter(jaxpr)
|
|
|
|
|
else:
|
|
|
|
|
return jaxpr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def jaxpr_replicas(jaxpr) -> int:
|
|
|
|
|
"""The number of replicas needed for a jaxpr.
|
|
|
|
|
|
|
|
|
|
For a eqn, multiply the `axis_size` with the `jaxpr_replicas` of the
|
|
|
|
|
subjaxprs. For a list of eqns, take the maximum number of replicas.
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(jaxpr, core.ClosedJaxpr):
|
|
|
|
|
jaxpr = jaxpr.jaxpr
|
|
|
|
|
return max(unsafe_map(eqn_replicas, jaxpr.eqns), default=1)
|
|
|
|
|
|
2022-08-30 14:59:34 -07:00
|
|
|
|
# TODO(mattjj): this function assumes that only pmap has a parameter named
|
|
|
|
|
# axis_size, and that it corresponds to cross-replica mapping
|
2021-11-22 08:22:10 -08:00
|
|
|
|
def eqn_replicas(eqn):
|
|
|
|
|
call_jaxpr = eqn.params.get("call_jaxpr")
|
2022-08-30 14:59:34 -07:00
|
|
|
|
if call_jaxpr:
|
2021-11-22 08:22:10 -08:00
|
|
|
|
return eqn.params.get('axis_size', 1) * jaxpr_replicas(call_jaxpr)
|
|
|
|
|
elif eqn.primitive in xla._initial_style_primitives:
|
|
|
|
|
return initial_style_primitive_replicas(eqn.params)
|
|
|
|
|
else:
|
|
|
|
|
return 1
|
|
|
|
|
|
|
|
|
|
def initial_style_primitive_replicas(params):
|
|
|
|
|
return max(core.traverse_jaxpr_params(jaxpr_replicas, params).values(), default=1)
|
|
|
|
|
|
|
|
|
|
|
2022-06-02 10:33:53 -07:00
|
|
|
|
def _xla_callable_device(nreps, backend, device,
|
|
|
|
|
arg_devices) -> Optional[Device]:
|
2021-11-22 08:22:10 -08:00
|
|
|
|
if nreps > 1:
|
|
|
|
|
if device is not None or backend is not None:
|
|
|
|
|
raise ValueError(f"can't specify device or backend for jit-of-pmap, "
|
|
|
|
|
f"got device={device} and backend={backend}")
|
|
|
|
|
return None
|
|
|
|
|
else:
|
2021-12-21 20:55:03 +00:00
|
|
|
|
# TODO(skye): dedup with C++ jit logic for determining jit device?
|
|
|
|
|
if device is not None:
|
|
|
|
|
assert backend is None
|
2021-11-22 08:22:10 -08:00
|
|
|
|
return device
|
2021-12-21 20:55:03 +00:00
|
|
|
|
|
|
|
|
|
if backend is not None:
|
2021-11-22 08:22:10 -08:00
|
|
|
|
return xb.get_backend(backend).get_default_device_assignment(1)[0]
|
2021-12-21 20:55:03 +00:00
|
|
|
|
|
|
|
|
|
arg_device = _device_from_arg_devices(arg_devices)
|
|
|
|
|
if arg_device is not None:
|
|
|
|
|
return arg_device
|
|
|
|
|
|
|
|
|
|
return config.jax_default_device
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
|
2021-11-29 12:39:19 -08:00
|
|
|
|
# Argument and result handlers
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
2021-11-29 12:39:19 -08:00
|
|
|
|
num_buffers_handlers: Dict[Type[core.AbstractValue],
|
|
|
|
|
Callable[[core.AbstractValue], int]] = {}
|
|
|
|
|
|
|
|
|
|
def aval_to_num_buffers(aval: core.AbstractValue) -> int:
|
|
|
|
|
"""Returns the number of buffers in the runtime representation of `aval`.
|
|
|
|
|
|
|
|
|
|
In general this may differ from the number of buffers in the compiler-IR
|
|
|
|
|
representation of the same value.
|
|
|
|
|
"""
|
2021-11-22 08:22:10 -08:00
|
|
|
|
try:
|
2021-11-29 12:39:19 -08:00
|
|
|
|
return num_buffers_handlers[type(aval)](aval)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
except KeyError as err:
|
2021-11-29 12:39:19 -08:00
|
|
|
|
raise TypeError(f"No num_buffers handler for type: {type(aval)}") from err
|
|
|
|
|
|
|
|
|
|
num_buffers_handlers[core.AbstractToken] = lambda _: 1
|
|
|
|
|
num_buffers_handlers[core.ShapedArray] = lambda _: 1
|
2022-05-18 17:26:10 -07:00
|
|
|
|
num_buffers_handlers[core.DShapedArray] = lambda _: 1
|
2021-11-29 12:39:19 -08:00
|
|
|
|
num_buffers_handlers[core.ConcreteArray] = lambda _: 1
|
2022-06-29 13:55:30 -07:00
|
|
|
|
num_buffers_handlers[core.AbstractBInt] = lambda _: 1
|
2021-11-29 12:39:19 -08:00
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
2022-04-13 13:44:42 -07:00
|
|
|
|
def _input_handler(backend: Backend,
|
2022-05-18 17:26:10 -07:00
|
|
|
|
in_type: Optional[pe.InputType],
|
|
|
|
|
out_type: Optional[pe.OutputType],
|
2022-03-30 17:52:55 -07:00
|
|
|
|
) -> Optional[Callable]:
|
2022-05-18 17:26:10 -07:00
|
|
|
|
if in_type is None:
|
|
|
|
|
assert out_type is None
|
|
|
|
|
return None
|
|
|
|
|
in_avals, which_explicit = util.unzip2(in_type)
|
|
|
|
|
# Check whether we actually need an input_handler.
|
2022-03-30 17:52:55 -07:00
|
|
|
|
needs_implicit = which_explicit and not all(which_explicit)
|
2022-06-29 13:55:30 -07:00
|
|
|
|
needs_out_handling = any(type(d) is core.InDBIdx for a, _ in out_type or []
|
2022-05-18 17:26:10 -07:00
|
|
|
|
if type(a) is core.DShapedArray for d in a.shape)
|
2022-03-30 17:52:55 -07:00
|
|
|
|
|
2022-05-18 17:26:10 -07:00
|
|
|
|
if not needs_implicit and not needs_out_handling:
|
2022-03-30 17:52:55 -07:00
|
|
|
|
return None
|
|
|
|
|
assert config.jax_dynamic_shapes
|
|
|
|
|
|
|
|
|
|
# Precompute how to grab implicit inputs from explicit inputs' axis sizes.
|
|
|
|
|
which_explicit = which_explicit or [True] * len(in_avals)
|
|
|
|
|
implicit_idxs = {i for i, ex in enumerate(which_explicit) if not ex}
|
|
|
|
|
implicit_args_from_axes: List[Tuple[int, int, int]] = []
|
|
|
|
|
for arg_idx, aval in enumerate(in_avals):
|
|
|
|
|
if isinstance(aval, core.DShapedArray):
|
|
|
|
|
for axis_idx, d in enumerate(aval.shape):
|
2022-06-11 15:46:05 -07:00
|
|
|
|
if isinstance(d, core.DBIdx) and d.val in implicit_idxs:
|
2022-03-30 17:52:55 -07:00
|
|
|
|
implicit_args_from_axes.append((d.val, arg_idx, axis_idx))
|
|
|
|
|
assert {i for i, _, _ in implicit_args_from_axes} == implicit_idxs
|
|
|
|
|
|
2022-05-18 17:26:10 -07:00
|
|
|
|
# Precompute which input values are needed for output types.
|
|
|
|
|
inputs_needed_for_out_types = out_type and [
|
2022-06-29 13:55:30 -07:00
|
|
|
|
d.val for aval, _ in out_type if type(aval) is core.DShapedArray # type: ignore
|
2022-06-11 15:46:05 -07:00
|
|
|
|
for d in aval.shape if type(d) is core.InDBIdx]
|
2022-05-18 17:26:10 -07:00
|
|
|
|
|
|
|
|
|
def elaborate(explicit_args: Sequence[Any]) -> Tuple[Tuple, Optional[Tuple]]:
|
|
|
|
|
if needs_implicit:
|
|
|
|
|
# Build full argument list, leaving Nones for implicit arguments.
|
|
|
|
|
explicit_args_ = iter(explicit_args)
|
|
|
|
|
args = [next(explicit_args_) if ex else None for ex in which_explicit]
|
|
|
|
|
assert next(explicit_args_, None) is None
|
|
|
|
|
# Populate implicit arguments.
|
|
|
|
|
for i, j, k in implicit_args_from_axes:
|
|
|
|
|
if args[i] is None:
|
|
|
|
|
args[i] = args[j].shape[k] # type: ignore
|
|
|
|
|
else:
|
|
|
|
|
if args[i] != args[j].shape[k]:
|
|
|
|
|
raise Exception("inconsistent argument axis sizes for type")
|
|
|
|
|
else:
|
|
|
|
|
args = list(explicit_args)
|
|
|
|
|
|
|
|
|
|
if needs_out_handling:
|
|
|
|
|
# Make a list of inputs needed by output types, leaving unneeded as None.
|
|
|
|
|
out_type_env = [None] * len(args)
|
|
|
|
|
for i in inputs_needed_for_out_types or []:
|
|
|
|
|
out_type_env[i] = args[i]
|
|
|
|
|
else:
|
|
|
|
|
out_type_env = None # type: ignore
|
2022-03-30 17:52:55 -07:00
|
|
|
|
|
2022-05-18 17:26:10 -07:00
|
|
|
|
return tuple(args), out_type_env and tuple(out_type_env) # type: ignore
|
|
|
|
|
return elaborate
|
2022-03-30 17:52:55 -07:00
|
|
|
|
|
2022-05-26 23:21:09 -07:00
|
|
|
|
def _result_handler(backend: Backend,
|
|
|
|
|
sticky_device: Optional[Device],
|
|
|
|
|
out_type: Optional[pe.OutputType]
|
|
|
|
|
) -> Callable:
|
|
|
|
|
out_avals, kept_outputs = util.unzip2(out_type)
|
|
|
|
|
handlers = map(partial(aval_to_result_handler, sticky_device), out_avals)
|
|
|
|
|
dyn_outs = any(type(aval) is core.DShapedArray and
|
2022-06-11 15:46:05 -07:00
|
|
|
|
any(type(d) in (core.InDBIdx, core.OutDBIdx) for d in aval.shape)
|
2022-05-26 23:21:09 -07:00
|
|
|
|
for aval in out_avals)
|
|
|
|
|
if not dyn_outs:
|
|
|
|
|
return SimpleResultHandler(handlers)
|
|
|
|
|
assert config.jax_dynamic_shapes
|
|
|
|
|
|
|
|
|
|
def result_handler(input_env, lists_of_bufs):
|
|
|
|
|
results = []
|
|
|
|
|
for handler, bufs in unsafe_zip(handlers, lists_of_bufs):
|
|
|
|
|
results.append(handler((input_env, results), *bufs))
|
|
|
|
|
return [r for r, keep in unsafe_zip(results, kept_outputs) if keep]
|
|
|
|
|
return result_handler
|
|
|
|
|
|
|
|
|
|
class SimpleResultHandler:
|
|
|
|
|
handlers: Sequence[ResultHandler]
|
|
|
|
|
def __init__(self, handlers): self.handlers = handlers
|
|
|
|
|
def __iter__(self): return iter(self.handlers)
|
|
|
|
|
def __len__(self): return len(self.handlers)
|
|
|
|
|
def __call__(self, env, lists_of_bufs):
|
|
|
|
|
return tuple(h(env, *bs) for h, bs in zip(self.handlers, lists_of_bufs))
|
|
|
|
|
|
|
|
|
|
|
2022-08-18 15:58:40 -07:00
|
|
|
|
def maybe_create_array_from_da(buf, aval, device):
|
2022-06-24 10:04:31 -07:00
|
|
|
|
if config.jax_array:
|
|
|
|
|
from jax.experimental.array import Array
|
|
|
|
|
from jax.experimental.sharding import SingleDeviceSharding
|
2022-08-17 12:25:14 -07:00
|
|
|
|
return Array(aval, SingleDeviceSharding(buf.device()), [buf],
|
2022-08-23 10:19:59 -07:00
|
|
|
|
committed=(device is not None), _skip_checks=True)
|
2022-06-24 10:04:31 -07:00
|
|
|
|
else:
|
|
|
|
|
return device_array.make_device_array(aval, device, buf)
|
|
|
|
|
|
|
|
|
|
|
2021-11-29 12:39:19 -08:00
|
|
|
|
if MYPY:
|
|
|
|
|
ResultHandler = Any
|
|
|
|
|
else:
|
|
|
|
|
class ResultHandler(Protocol):
|
2022-05-18 17:26:10 -07:00
|
|
|
|
def __call__(self, env: Optional[Sequence[Any]], *args: xla.Buffer) -> Any:
|
2021-11-29 12:39:19 -08:00
|
|
|
|
"""Boxes raw buffers into their user-facing representation."""
|
|
|
|
|
|
|
|
|
|
def aval_to_result_handler(sticky_device: Optional[Device],
|
|
|
|
|
aval: core.AbstractValue) -> ResultHandler:
|
|
|
|
|
try:
|
|
|
|
|
return result_handlers[type(aval)](sticky_device, aval)
|
|
|
|
|
except KeyError as err:
|
|
|
|
|
raise TypeError(f"No result handler for type: {type(aval)}") from err
|
|
|
|
|
|
|
|
|
|
def array_result_handler(sticky_device: Optional[Device],
|
2022-05-18 17:26:10 -07:00
|
|
|
|
aval: core.ShapedArray):
|
2022-06-23 11:46:20 -07:00
|
|
|
|
if aval.dtype == dtypes.float0:
|
2022-05-18 17:26:10 -07:00
|
|
|
|
return lambda _, __: np.zeros(aval.shape, dtypes.float0)
|
|
|
|
|
aval = core.raise_to_shaped(aval)
|
2022-08-30 14:47:15 -07:00
|
|
|
|
if core.is_opaque_dtype(aval.dtype):
|
2022-08-30 13:25:49 -07:00
|
|
|
|
return aval.dtype._rules.result_handler(sticky_device, aval)
|
2022-08-18 15:58:40 -07:00
|
|
|
|
handler = lambda _, b: maybe_create_array_from_da(b, aval, sticky_device)
|
2022-05-18 17:26:10 -07:00
|
|
|
|
handler.args = aval, sticky_device # for C++ dispatch path in api.py
|
|
|
|
|
return handler
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
|
def dynamic_array_result_handler(sticky_device: Optional[Device],
|
|
|
|
|
aval: core.DShapedArray):
|
2022-06-23 11:46:20 -07:00
|
|
|
|
if aval.dtype == dtypes.float0:
|
2022-03-30 17:52:55 -07:00
|
|
|
|
return lambda _: np.zeros(aval.shape, dtypes.float0) # type: ignore
|
|
|
|
|
else:
|
2022-05-18 17:26:10 -07:00
|
|
|
|
return partial(_dynamic_array_result_handler, sticky_device, aval)
|
|
|
|
|
|
|
|
|
|
def _dynamic_array_result_handler(sticky_device, aval, env, buf):
|
2022-06-29 13:55:30 -07:00
|
|
|
|
in_env, out_env = env or (None, None)
|
|
|
|
|
shape = [in_env[d.val] if type(d) is core.InDBIdx else
|
|
|
|
|
out_env[d.val] if type(d) is core.OutDBIdx else d
|
|
|
|
|
for d in aval.shape]
|
|
|
|
|
if all(type(d) is int for d in shape):
|
|
|
|
|
aval = core.ShapedArray(tuple(shape), aval.dtype)
|
2022-08-18 15:58:40 -07:00
|
|
|
|
return maybe_create_array_from_da(buf, aval, sticky_device)
|
2022-06-29 13:55:30 -07:00
|
|
|
|
elif any(type(d) is core.BInt for d in shape):
|
|
|
|
|
padded_shape = [d.bound if type(d) is core.BInt else d for d in shape]
|
|
|
|
|
buf_aval = core.ShapedArray(tuple(padded_shape), aval.dtype, aval.weak_type)
|
2022-08-18 15:58:40 -07:00
|
|
|
|
data = maybe_create_array_from_da(buf, buf_aval, sticky_device)
|
2022-06-29 13:55:30 -07:00
|
|
|
|
return core.PaddedArray(aval.update(shape=tuple(shape)), data)
|
2022-05-18 17:26:10 -07:00
|
|
|
|
else:
|
2022-05-26 23:21:09 -07:00
|
|
|
|
aval = core.ShapedArray(tuple(shape), aval.dtype)
|
2022-08-18 15:58:40 -07:00
|
|
|
|
return maybe_create_array_from_da(buf, aval, sticky_device)
|
2022-03-30 17:52:55 -07:00
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
2022-06-29 13:55:30 -07:00
|
|
|
|
|
2021-11-29 12:39:19 -08:00
|
|
|
|
result_handlers: Dict[
|
|
|
|
|
Type[core.AbstractValue],
|
|
|
|
|
Callable[[Optional[Device], Any], ResultHandler]] = {}
|
2022-05-18 17:26:10 -07:00
|
|
|
|
result_handlers[core.AbstractToken] = lambda _, __: lambda _, __: core.token
|
2021-11-29 12:39:19 -08:00
|
|
|
|
result_handlers[core.ShapedArray] = array_result_handler
|
2022-03-30 17:52:55 -07:00
|
|
|
|
result_handlers[core.DShapedArray] = dynamic_array_result_handler
|
2021-11-29 12:39:19 -08:00
|
|
|
|
result_handlers[core.ConcreteArray] = array_result_handler
|
2022-06-29 13:55:30 -07:00
|
|
|
|
result_handlers[core.AbstractBInt] = \
|
|
|
|
|
lambda _, a: lambda _, b: core.BInt(int(b), a.bound)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def needs_check_special():
|
|
|
|
|
return config.jax_debug_infs or config.jax_debug_nans
|
|
|
|
|
|
|
|
|
|
def check_special(name, bufs):
|
|
|
|
|
if needs_check_special():
|
|
|
|
|
for buf in bufs:
|
|
|
|
|
_check_special(name, buf.xla_shape(), buf)
|
|
|
|
|
|
|
|
|
|
def _check_special(name, xla_shape, buf):
|
|
|
|
|
assert not xla_shape.is_tuple()
|
|
|
|
|
if dtypes.issubdtype(xla_shape.element_type(), np.inexact):
|
2022-08-25 07:27:54 -07:00
|
|
|
|
if config.jax_debug_nans and np.any(np.isnan(np.asarray(buf))):
|
2021-11-22 08:22:10 -08:00
|
|
|
|
raise FloatingPointError(f"invalid value (nan) encountered in {name}")
|
2022-08-25 07:27:54 -07:00
|
|
|
|
if config.jax_debug_infs and np.any(np.isinf(np.asarray(buf))):
|
2021-11-22 08:22:10 -08:00
|
|
|
|
raise FloatingPointError(f"invalid value (inf) encountered in {name}")
|
|
|
|
|
|
2022-05-16 18:55:52 -07:00
|
|
|
|
def _add_tokens(has_unordered_effects: bool, ordered_effects: List[core.Effect],
|
2022-08-17 10:43:50 -07:00
|
|
|
|
has_host_callbacks: bool, device: Device, input_bufs):
|
2022-05-16 18:55:52 -07:00
|
|
|
|
tokens = [runtime_tokens.get_token(eff, device) for eff in ordered_effects]
|
2022-04-14 14:18:31 -07:00
|
|
|
|
tokens_flat = flatten(tokens)
|
|
|
|
|
input_bufs = [*tokens_flat, *input_bufs]
|
2022-08-04 13:23:02 -07:00
|
|
|
|
def _remove_tokens(output_bufs, runtime_token):
|
|
|
|
|
# TODO(sharadmv): simplify when minimum jaxlib version is bumped
|
2022-08-09 14:34:30 -07:00
|
|
|
|
num_output_tokens = len(ordered_effects) + (not can_execute_with_token and
|
2022-08-04 13:23:02 -07:00
|
|
|
|
has_unordered_effects)
|
|
|
|
|
token_bufs, output_bufs = util.split_list(output_bufs, [num_output_tokens])
|
2022-08-17 10:43:50 -07:00
|
|
|
|
if has_unordered_effects or has_host_callbacks:
|
2022-08-09 14:34:30 -07:00
|
|
|
|
if can_execute_with_token:
|
2022-08-04 13:23:02 -07:00
|
|
|
|
runtime_tokens.set_output_runtime_token(device, runtime_token)
|
|
|
|
|
else:
|
|
|
|
|
output_token_buf, *token_bufs = token_bufs
|
|
|
|
|
runtime_tokens.set_output_token(device, output_token_buf)
|
2022-05-16 18:55:52 -07:00
|
|
|
|
for eff, token_buf in zip(ordered_effects, token_bufs):
|
2022-04-14 14:18:31 -07:00
|
|
|
|
runtime_tokens.update_token(eff, token_buf)
|
|
|
|
|
return output_bufs
|
|
|
|
|
return input_bufs, _remove_tokens
|
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
def _execute_compiled(name: str, compiled: XlaExecutable,
|
2022-03-30 17:52:55 -07:00
|
|
|
|
input_handler: Optional[Callable],
|
2022-05-26 23:21:09 -07:00
|
|
|
|
output_buffer_counts: Sequence[int],
|
|
|
|
|
result_handler: Callable,
|
2022-05-16 18:55:52 -07:00
|
|
|
|
has_unordered_effects: bool,
|
|
|
|
|
ordered_effects: List[core.Effect],
|
2022-08-17 10:43:50 -07:00
|
|
|
|
kept_var_idx, has_host_callbacks: bool, *args):
|
2021-11-22 08:22:10 -08:00
|
|
|
|
device, = compiled.local_devices()
|
2022-05-18 17:26:10 -07:00
|
|
|
|
args, env = input_handler(args) if input_handler else (args, None)
|
2022-05-26 23:21:09 -07:00
|
|
|
|
in_flat = flatten(device_put(x, device) for i, x in enumerate(args)
|
|
|
|
|
if i in kept_var_idx)
|
2022-08-17 10:43:50 -07:00
|
|
|
|
if has_unordered_effects or ordered_effects or has_host_callbacks:
|
2022-08-04 13:23:02 -07:00
|
|
|
|
in_flat, token_handler = _add_tokens(
|
2022-08-17 10:43:50 -07:00
|
|
|
|
has_unordered_effects, ordered_effects, has_host_callbacks, device,
|
|
|
|
|
in_flat)
|
2022-08-09 14:34:30 -07:00
|
|
|
|
if can_execute_with_token:
|
2022-08-04 13:23:02 -07:00
|
|
|
|
out_flat, runtime_token = compiled.execute_with_token(in_flat)
|
|
|
|
|
else:
|
|
|
|
|
out_flat = compiled.execute(in_flat)
|
|
|
|
|
runtime_token = None
|
|
|
|
|
else:
|
|
|
|
|
out_flat = compiled.execute(in_flat)
|
2022-05-26 23:21:09 -07:00
|
|
|
|
check_special(name, out_flat)
|
|
|
|
|
out_bufs = unflatten(out_flat, output_buffer_counts)
|
2022-08-17 10:43:50 -07:00
|
|
|
|
if ordered_effects or has_unordered_effects or has_host_callbacks:
|
2022-08-04 13:23:02 -07:00
|
|
|
|
out_bufs = token_handler(out_bufs, runtime_token)
|
2022-05-26 23:21:09 -07:00
|
|
|
|
return result_handler(env, out_bufs)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _execute_replicated(name: str, compiled: XlaExecutable,
|
2022-03-30 17:52:55 -07:00
|
|
|
|
input_handler: Optional[Callable],
|
2022-05-26 23:21:09 -07:00
|
|
|
|
output_buffer_counts: Sequence[int],
|
|
|
|
|
result_handler: Callable,
|
2022-05-16 18:55:52 -07:00
|
|
|
|
has_unordered_effects: bool,
|
|
|
|
|
ordered_effects: List[core.Effect],
|
2022-08-30 10:45:29 -07:00
|
|
|
|
kept_var_idx, has_host_callbacks: bool,
|
|
|
|
|
*args, from_lower_sharding_computation: bool = False):
|
2022-05-16 18:55:52 -07:00
|
|
|
|
if has_unordered_effects or ordered_effects:
|
|
|
|
|
# TODO(sharadmv): support jit-of-pmap with effects
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"Cannot execute replicated computation with effects.")
|
2022-03-30 17:52:55 -07:00
|
|
|
|
if input_handler: raise NotImplementedError # TODO(mattjj, dougalm)
|
2022-04-09 10:56:14 -07:00
|
|
|
|
input_bufs = [flatten(device_put(x, device) for i, x in enumerate(args)
|
|
|
|
|
if i in kept_var_idx)
|
|
|
|
|
for device in compiled.local_devices()]
|
|
|
|
|
input_bufs_flip = list(unsafe_zip(*input_bufs))
|
|
|
|
|
out_bufs_flat_rep = compiled.execute_sharded_on_local_devices(input_bufs_flip)
|
2022-05-26 23:21:09 -07:00
|
|
|
|
out_flat = [bufs[0] for bufs in out_bufs_flat_rep]
|
|
|
|
|
check_special(name, out_flat)
|
|
|
|
|
out_bufs = unflatten(out_flat, output_buffer_counts)
|
2022-08-30 10:45:29 -07:00
|
|
|
|
if from_lower_sharding_computation:
|
|
|
|
|
return result_handler(out_bufs)
|
2022-05-26 23:21:09 -07:00
|
|
|
|
return result_handler(None, out_bufs)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers,
|
2022-05-16 18:55:52 -07:00
|
|
|
|
has_unordered_effects: bool,
|
2022-07-06 20:52:08 -07:00
|
|
|
|
ordered_effects: List[core.Effect], kept_var_idx,
|
|
|
|
|
host_callbacks, *args):
|
2022-05-02 17:11:44 -07:00
|
|
|
|
env: Dict[core.Var, Any] = {}
|
2021-11-22 08:22:10 -08:00
|
|
|
|
pruned_args = (x for i, x in enumerate(args) if i in kept_var_idx)
|
|
|
|
|
map(env.setdefault, jaxpr.invars, pruned_args)
|
|
|
|
|
map(env.setdefault, jaxpr.constvars, consts)
|
|
|
|
|
outs = [xla.canonicalize_dtype(v.val) if type(v) is core.Literal else env[v]
|
|
|
|
|
for v in jaxpr.outvars]
|
|
|
|
|
return [_copy_device_array_to_device(x, device) if device_array.type_is_device_array(x)
|
2022-05-18 17:26:10 -07:00
|
|
|
|
else h(None, *device_put(x, device)) for h, x in zip(handlers, outs)]
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
|
2022-07-01 17:35:17 -07:00
|
|
|
|
class XlaComputation(stages.XlaLowering):
|
2021-11-22 08:22:10 -08:00
|
|
|
|
name: str
|
|
|
|
|
_is_trivial: bool
|
2022-03-30 17:52:55 -07:00
|
|
|
|
_executable: Optional[XlaCompiledComputation]
|
2021-11-22 08:22:10 -08:00
|
|
|
|
_donated_invars: Optional[Sequence[bool]]
|
|
|
|
|
|
|
|
|
|
def __init__(self, name: str, hlo, is_trivial: bool,
|
2021-12-10 13:01:51 -08:00
|
|
|
|
donated_invars: Optional[Sequence[bool]],
|
2022-05-18 17:26:10 -07:00
|
|
|
|
in_type: Optional[pe.InputType],
|
|
|
|
|
out_type: Optional[pe.OutputType],
|
2021-12-10 13:01:51 -08:00
|
|
|
|
**compile_args):
|
2021-11-22 08:22:10 -08:00
|
|
|
|
self.name = name
|
|
|
|
|
self._hlo = hlo
|
|
|
|
|
self._is_trivial = is_trivial
|
|
|
|
|
self._donated_invars = donated_invars
|
2022-05-18 17:26:10 -07:00
|
|
|
|
self._in_type = in_type
|
|
|
|
|
self._out_type = out_type
|
2021-11-22 08:22:10 -08:00
|
|
|
|
self._executable = None
|
|
|
|
|
self.compile_args = compile_args
|
|
|
|
|
|
|
|
|
|
def is_trivial(self):
|
|
|
|
|
return self._is_trivial
|
|
|
|
|
|
2022-07-01 17:35:17 -07:00
|
|
|
|
# -- stages.XlaLowering overrides
|
|
|
|
|
|
2021-12-10 13:01:51 -08:00
|
|
|
|
def hlo(self) -> xc.XlaComputation:
|
2021-11-22 08:22:10 -08:00
|
|
|
|
if self.is_trivial():
|
|
|
|
|
raise ValueError("A trivial computation has no HLO")
|
2021-12-10 14:56:10 -08:00
|
|
|
|
if isinstance(self._hlo, xc.XlaComputation):
|
|
|
|
|
return self._hlo
|
|
|
|
|
return xe.mlir.mlir_module_to_xla_computation(
|
|
|
|
|
mlir.module_to_string(self._hlo),
|
|
|
|
|
use_tuple_args=self.compile_args["tuple_args"])
|
2021-12-10 13:01:51 -08:00
|
|
|
|
|
[JAX] Change signature of .mhlo() method on compiler IR objects to return an ir.Module object instead of its string representation.
It isn't free to pretty-print IR, so it's best to avoid it unless necessary. In addition, by returning an IR object, the user is now free to, say, print it with different options.
For example, one can now write things like:
```
In [1]: import numpy as np, jax, jax.numpy as jnp
In [2]: m = jax.jit(lambda x: x + jnp.array(np.arange(1000))).lower(7.).compiler_ir(dialect='mhlo')
In [3]: m.operation.print(large_elements_limit=10)
module @jit__lambda_.4 {
func public @main(%arg0: tensor<f32>) -> tensor<1000xf32> {
%0 = mhlo.constant opaque<"_", "0xDEADBEEF"> : tensor<1000xi32>
%1 = "mhlo.convert"(%0) : (tensor<1000xi32>) -> tensor<1000xf32>
%2 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1000xf32>
%3 = mhlo.add %2, %1 : tensor<1000xf32>
return %3 : tensor<1000xf32>
}
}
```
Fixes https://github.com/google/jax/issues/9226
PiperOrigin-RevId: 422855649
2022-01-19 11:01:03 -08:00
|
|
|
|
def mhlo(self) -> ir.Module:
|
2021-12-10 13:01:51 -08:00
|
|
|
|
if self.is_trivial():
|
|
|
|
|
raise ValueError("A trivial computation has no MHLO")
|
|
|
|
|
if isinstance(self._hlo, xc.XlaComputation):
|
[JAX] Change signature of .mhlo() method on compiler IR objects to return an ir.Module object instead of its string representation.
It isn't free to pretty-print IR, so it's best to avoid it unless necessary. In addition, by returning an IR object, the user is now free to, say, print it with different options.
For example, one can now write things like:
```
In [1]: import numpy as np, jax, jax.numpy as jnp
In [2]: m = jax.jit(lambda x: x + jnp.array(np.arange(1000))).lower(7.).compiler_ir(dialect='mhlo')
In [3]: m.operation.print(large_elements_limit=10)
module @jit__lambda_.4 {
func public @main(%arg0: tensor<f32>) -> tensor<1000xf32> {
%0 = mhlo.constant opaque<"_", "0xDEADBEEF"> : tensor<1000xi32>
%1 = "mhlo.convert"(%0) : (tensor<1000xi32>) -> tensor<1000xf32>
%2 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1000xf32>
%3 = mhlo.add %2, %1 : tensor<1000xf32>
return %3 : tensor<1000xf32>
}
}
```
Fixes https://github.com/google/jax/issues/9226
PiperOrigin-RevId: 422855649
2022-01-19 11:01:03 -08:00
|
|
|
|
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
|
|
|
|
|
with mlir.make_ir_context():
|
|
|
|
|
return ir.Module.parse(module_str)
|
|
|
|
|
return self._hlo
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
2022-03-30 17:52:55 -07:00
|
|
|
|
def compile(self) -> XlaCompiledComputation:
|
2021-11-22 08:22:10 -08:00
|
|
|
|
if self._executable is None:
|
|
|
|
|
if self.is_trivial():
|
|
|
|
|
self._executable = XlaCompiledComputation.from_trivial_jaxpr(
|
2021-12-10 13:01:51 -08:00
|
|
|
|
**self.compile_args)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
else:
|
|
|
|
|
self._executable = XlaCompiledComputation.from_xla_computation(
|
2022-05-18 17:26:10 -07:00
|
|
|
|
self.name, self._hlo, self._in_type, self._out_type,
|
|
|
|
|
**self.compile_args)
|
2021-12-10 13:01:51 -08:00
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
return self._executable
|
|
|
|
|
|
2021-12-06 15:13:01 -08:00
|
|
|
|
@profiler.annotate_function
|
2022-07-06 20:52:08 -07:00
|
|
|
|
def backend_compile(backend, built_c, options, host_callbacks):
|
2021-11-22 08:22:10 -08:00
|
|
|
|
# we use a separate function call to ensure that XLA compilation appears
|
|
|
|
|
# separately in Python profiling results
|
2022-07-06 20:52:08 -07:00
|
|
|
|
if host_callbacks:
|
|
|
|
|
return backend.compile(built_c, compile_options=options,
|
|
|
|
|
host_callbacks=host_callbacks)
|
|
|
|
|
# Some backends don't have `host_callbacks` option yet
|
|
|
|
|
# TODO(sharadmv): remove this fallback when all backends allow `compile`
|
|
|
|
|
# to take in `host_callbacks`
|
2021-11-22 08:22:10 -08:00
|
|
|
|
return backend.compile(built_c, compile_options=options)
|
|
|
|
|
|
|
|
|
|
# TODO(phawkins): update users.
|
|
|
|
|
xla.backend_compile = backend_compile
|
|
|
|
|
|
2021-12-14 17:43:40 -08:00
|
|
|
|
_ir_dump_counter = itertools.count()
|
|
|
|
|
|
|
|
|
|
def _make_string_safe_for_filename(s: str) -> str:
|
|
|
|
|
return re.sub(r'[^\w.)( -]', '', s)
|
|
|
|
|
|
|
|
|
|
def _dump_ir_to_file(name: str, ir: str):
|
|
|
|
|
id = next(_ir_dump_counter)
|
|
|
|
|
name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir"
|
2022-05-31 12:46:54 -07:00
|
|
|
|
name = epath.Path(FLAGS.jax_dump_ir_to) / name
|
|
|
|
|
name.write_text(ir)
|
2021-12-14 17:43:40 -08:00
|
|
|
|
|
|
|
|
|
|
2022-09-07 16:31:28 -04:00
|
|
|
|
def compile_or_get_cached(backend, computation: ir.Module, compile_options,
|
2022-07-06 20:52:08 -07:00
|
|
|
|
host_callbacks):
|
2021-12-10 14:56:10 -08:00
|
|
|
|
# Avoid import cycle between jax and jax.experimental
|
|
|
|
|
from jax.experimental.compilation_cache import compilation_cache as cc
|
|
|
|
|
|
2022-09-07 16:31:28 -04:00
|
|
|
|
sym_name = computation.operation.attributes['sym_name']
|
|
|
|
|
module_name = ir.StringAttr(sym_name).value
|
|
|
|
|
# Convert ir.Module to a string representation, unless the
|
|
|
|
|
# back-end expliclity flags the ability to handle a module directly
|
|
|
|
|
# (avoiding the overhead of back and forth conversions)
|
|
|
|
|
serialized_computation: Union[str, bytes, ir.Module]
|
|
|
|
|
if getattr(backend, "needs_str_ir", True):
|
|
|
|
|
if xc.mlir_api_version >= 34:
|
|
|
|
|
serialized_computation = mlir.module_to_bytecode(computation)
|
|
|
|
|
else:
|
|
|
|
|
serialized_computation = mlir.module_to_string(computation)
|
2021-12-10 14:56:10 -08:00
|
|
|
|
else:
|
2022-09-07 16:31:28 -04:00
|
|
|
|
serialized_computation = computation
|
2021-12-10 14:56:10 -08:00
|
|
|
|
|
|
|
|
|
# Persistent compilation cache only implemented on TPU.
|
|
|
|
|
# TODO(skye): add warning when initializing cache on unsupported default platform
|
2022-09-08 11:50:01 -07:00
|
|
|
|
supported_platforms = ["tpu"]
|
|
|
|
|
# GPU caching can be enabled if JitRt is enabled.
|
|
|
|
|
# TODO(b/232263664): Remove check when JitRt is enabled by default.
|
|
|
|
|
if "--xla_gpu_enable_xla_runtime_executable=true" in os.environ.get("XLA_FLAGS", ""):
|
|
|
|
|
supported_platforms.append("gpu")
|
|
|
|
|
if cc.is_initialized() and backend.platform in supported_platforms:
|
2022-09-07 16:31:28 -04:00
|
|
|
|
cached_executable = cc.get_executable(serialized_computation,
|
|
|
|
|
compile_options, backend)
|
2021-12-14 17:43:40 -08:00
|
|
|
|
if cached_executable is not None:
|
|
|
|
|
logging.info('Persistent compilation cache hit for %s.', module_name)
|
|
|
|
|
return cached_executable
|
|
|
|
|
else:
|
2022-09-07 16:31:28 -04:00
|
|
|
|
compiled = backend_compile(backend, serialized_computation,
|
|
|
|
|
compile_options, host_callbacks)
|
|
|
|
|
cc.put_executable(module_name, serialized_computation, compile_options,
|
|
|
|
|
compiled, backend)
|
2021-12-14 17:43:40 -08:00
|
|
|
|
return compiled
|
|
|
|
|
|
|
|
|
|
if FLAGS.jax_dump_ir_to:
|
2022-09-07 16:31:28 -04:00
|
|
|
|
_dump_ir_to_file(module_name, mlir.module_to_string(computation))
|
|
|
|
|
return backend_compile(backend, serialized_computation, compile_options,
|
|
|
|
|
host_callbacks)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
|
2022-08-30 10:45:29 -07:00
|
|
|
|
def get_buffer_counts(out_avals, ordered_effects, has_unordered_effects):
|
|
|
|
|
buffer_counts = [aval_to_num_buffers(aval) for aval in out_avals]
|
|
|
|
|
if ordered_effects or has_unordered_effects:
|
|
|
|
|
num_output_tokens = len(ordered_effects)
|
|
|
|
|
# TODO(sharadmv): remove check when minimum jaxlib version is bumped
|
|
|
|
|
if not can_execute_with_token:
|
|
|
|
|
num_output_tokens += has_unordered_effects
|
|
|
|
|
buffer_counts = ([1] * num_output_tokens) + buffer_counts
|
|
|
|
|
return buffer_counts
|
|
|
|
|
|
|
|
|
|
|
2022-07-01 17:35:17 -07:00
|
|
|
|
class XlaCompiledComputation(stages.XlaExecutable):
|
2022-04-14 14:18:31 -07:00
|
|
|
|
def __init__(self, xla_executable, in_avals, kept_var_idx, unsafe_call,
|
|
|
|
|
keepalive: Any):
|
2021-11-22 08:22:10 -08:00
|
|
|
|
self._xla_executable = xla_executable
|
|
|
|
|
self.in_avals = in_avals
|
|
|
|
|
self._kept_var_idx = kept_var_idx
|
|
|
|
|
self.unsafe_call = unsafe_call
|
2022-04-14 14:18:31 -07:00
|
|
|
|
# Only the `unsafe_call` function is cached, so to avoid the `keepalive`
|
|
|
|
|
# being garbage collected we attach it to `unsafe_call`.
|
|
|
|
|
self.unsafe_call.keepalive = keepalive
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
@staticmethod
|
2022-07-06 20:52:08 -07:00
|
|
|
|
def from_xla_computation(name: str, xla_computation: Optional[ir.Module],
|
|
|
|
|
in_type: Optional[pe.InputType],
|
|
|
|
|
out_type: Optional[pe.OutputType], nreps: int,
|
|
|
|
|
device: Optional[Device], backend: Backend,
|
|
|
|
|
tuple_args: bool,
|
|
|
|
|
in_avals: Sequence[core.AbstractValue],
|
|
|
|
|
out_avals: Sequence[core.AbstractValue],
|
|
|
|
|
has_unordered_effects: bool,
|
|
|
|
|
ordered_effects: List[core.Effect],
|
|
|
|
|
kept_var_idx: Set[int], keepalive: Optional[Any],
|
|
|
|
|
host_callbacks: List[Any]) -> XlaCompiledComputation:
|
2021-11-29 12:39:19 -08:00
|
|
|
|
sticky_device = device
|
2022-05-18 17:26:10 -07:00
|
|
|
|
input_handler = _input_handler(backend, in_type, out_type)
|
2022-05-26 23:21:09 -07:00
|
|
|
|
result_handler = _result_handler(backend, sticky_device, out_type)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
options = xb.get_compile_options(
|
2022-04-09 10:56:14 -07:00
|
|
|
|
num_replicas=nreps, num_partitions=1,
|
2022-01-25 16:27:09 -08:00
|
|
|
|
device_assignment=(sticky_device,) if sticky_device else None)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
options.parameter_is_tupled_arguments = tuple_args
|
2021-12-13 21:51:08 -08:00
|
|
|
|
with log_elapsed_time(f"Finished XLA compilation of {name} "
|
|
|
|
|
"in {elapsed_time} sec"):
|
2022-07-06 20:52:08 -07:00
|
|
|
|
compiled = compile_or_get_cached(backend, xla_computation, options,
|
|
|
|
|
host_callbacks)
|
2022-08-30 10:45:29 -07:00
|
|
|
|
buffer_counts = get_buffer_counts(out_avals, ordered_effects,
|
|
|
|
|
has_unordered_effects)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
execute = _execute_compiled if nreps == 1 else _execute_replicated
|
2022-05-16 18:55:52 -07:00
|
|
|
|
unsafe_call = partial(execute, name, compiled, input_handler, buffer_counts, # type: ignore # noqa: F811
|
2022-05-26 23:21:09 -07:00
|
|
|
|
result_handler, has_unordered_effects,
|
2022-08-17 10:43:50 -07:00
|
|
|
|
ordered_effects, kept_var_idx, bool(host_callbacks))
|
2022-04-14 14:18:31 -07:00
|
|
|
|
return XlaCompiledComputation(compiled, in_avals, kept_var_idx, unsafe_call,
|
|
|
|
|
keepalive)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
def is_trivial(self):
|
|
|
|
|
return self._xla_executable == None
|
|
|
|
|
|
2022-01-13 15:42:17 -08:00
|
|
|
|
@property
|
2021-11-22 08:22:10 -08:00
|
|
|
|
def xla_executable(self):
|
2022-03-22 12:16:03 -07:00
|
|
|
|
# TODO(frostig): remove in favor of runtime_executable?
|
2021-11-22 08:22:10 -08:00
|
|
|
|
if self.is_trivial():
|
|
|
|
|
raise ValueError("A trivial compiled computation has no XLA executable")
|
|
|
|
|
return self._xla_executable
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
2022-07-06 20:52:08 -07:00
|
|
|
|
def from_trivial_jaxpr(jaxpr, consts, device, in_avals, out_avals,
|
|
|
|
|
has_unordered_effects, ordered_effects, kept_var_idx,
|
|
|
|
|
keepalive: Optional[Any],
|
|
|
|
|
host_callbacks: List[Any]) -> XlaCompiledComputation:
|
2022-04-14 14:18:31 -07:00
|
|
|
|
assert keepalive is None
|
2021-11-22 08:22:10 -08:00
|
|
|
|
result_handlers = map(partial(aval_to_result_handler, device), out_avals)
|
2022-07-06 20:52:08 -07:00
|
|
|
|
unsafe_call = partial(_execute_trivial, jaxpr, device, consts, out_avals,
|
|
|
|
|
result_handlers, has_unordered_effects,
|
2022-08-17 10:43:50 -07:00
|
|
|
|
ordered_effects, kept_var_idx, bool(host_callbacks))
|
2022-04-14 14:18:31 -07:00
|
|
|
|
return XlaCompiledComputation(None, in_avals, kept_var_idx, unsafe_call,
|
|
|
|
|
keepalive)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
2022-07-01 17:35:17 -07:00
|
|
|
|
# -- stages.XlaExecutable overrides
|
2022-03-22 12:16:03 -07:00
|
|
|
|
|
2022-07-01 17:35:17 -07:00
|
|
|
|
def xla_extension_executable(self):
|
2022-03-22 12:16:03 -07:00
|
|
|
|
return self.xla_executable
|
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
def call(self, *args):
|
|
|
|
|
arg_specs = unsafe_map(arg_spec, args)
|
|
|
|
|
arg_avals = [spec[0] for i, spec in enumerate(arg_specs)
|
|
|
|
|
if i in self._kept_var_idx]
|
|
|
|
|
check_arg_avals_for_call(self.in_avals, arg_avals)
|
|
|
|
|
return self.unsafe_call(*args)
|
|
|
|
|
|
|
|
|
|
def check_arg_avals_for_call(ref_avals, arg_avals):
|
|
|
|
|
if len(ref_avals) != len(arg_avals):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f"Computation compiled for {len(ref_avals)} inputs "
|
|
|
|
|
f"but called with {len(arg_avals)}")
|
|
|
|
|
for ref_aval, arg_aval in zip(ref_avals, arg_avals):
|
|
|
|
|
if not core.typematch(ref_aval, arg_aval):
|
|
|
|
|
ref_avals_fmt = ', '.join(str(a) for a in ref_avals)
|
|
|
|
|
arg_avals_fmt = ', '.join(str(a) for a in arg_avals)
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f"Computation compiled for input types:\n {ref_avals_fmt}\n"
|
|
|
|
|
f"called with:\n {arg_avals_fmt}")
|
|
|
|
|
|
|
|
|
|
|
2022-08-15 19:12:50 -04:00
|
|
|
|
def device_put(x, device: Optional[Device] = None) -> Tuple[Any, ...]:
|
2021-11-22 08:22:10 -08:00
|
|
|
|
x = xla.canonicalize_dtype(x)
|
|
|
|
|
try:
|
|
|
|
|
return device_put_handlers[type(x)](x, device)
|
|
|
|
|
except KeyError as err:
|
|
|
|
|
raise TypeError(f"No device_put handler for type: {type(x)}") from err
|
|
|
|
|
|
|
|
|
|
# TODO(phawkins): update users.
|
|
|
|
|
xla.device_put = device_put
|
|
|
|
|
|
|
|
|
|
def _device_put_array(x, device: Optional[Device]):
|
|
|
|
|
backend = xb.get_device_backend(device)
|
2022-06-23 11:46:20 -07:00
|
|
|
|
if x.dtype == dtypes.float0:
|
2021-11-22 08:22:10 -08:00
|
|
|
|
x = np.zeros(x.shape, dtype=np.dtype(bool))
|
|
|
|
|
return (backend.buffer_from_pyval(x, device),)
|
|
|
|
|
|
|
|
|
|
def _device_put_scalar(x, device):
|
|
|
|
|
return _device_put_array(dtypes.coerce_to_array(x), device)
|
|
|
|
|
|
2021-11-29 12:39:19 -08:00
|
|
|
|
def _device_put_token(_, device):
|
|
|
|
|
backend = xb.get_device_backend(device)
|
|
|
|
|
return (backend.buffer_from_pyval(np.zeros((), dtype=np.dtype(np.bool_)),
|
2021-11-22 08:22:10 -08:00
|
|
|
|
device),)
|
|
|
|
|
|
|
|
|
|
_scalar_types = dtypes.python_scalar_dtypes.keys()
|
|
|
|
|
|
2022-08-15 19:12:50 -04:00
|
|
|
|
device_put_handlers: Dict[Any, Callable[[Any, Optional[Device]],
|
|
|
|
|
Tuple[Any, ...]]] = {}
|
2021-11-22 08:22:10 -08:00
|
|
|
|
device_put_handlers.update((t, _device_put_array) for t in array_types)
|
|
|
|
|
device_put_handlers.update((t, _device_put_scalar) for t in _scalar_types)
|
2021-11-29 12:39:19 -08:00
|
|
|
|
device_put_handlers[core.Token] = _device_put_token
|
2022-06-29 13:55:30 -07:00
|
|
|
|
device_put_handlers[core.BInt] = lambda x, d: _device_put_scalar(x.val, d)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _device_put_device_array(x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray], device: Optional[Device]):
|
|
|
|
|
x = _copy_device_array_to_device(x, device)
|
|
|
|
|
return (x.device_buffer,)
|
|
|
|
|
for t in device_array.device_array_types:
|
|
|
|
|
device_put_handlers[t] = _device_put_device_array
|
2022-06-29 13:55:30 -07:00
|
|
|
|
device_put_handlers[core.PaddedArray] = lambda x, d: device_put(x._data, d)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
2022-05-26 23:21:09 -07:00
|
|
|
|
def _copy_device_array_to_device(
|
|
|
|
|
x: Union[device_array.DeviceArrayProtocol, device_array._DeviceArray],
|
|
|
|
|
device: Optional[xc.Device]
|
|
|
|
|
) -> Union[device_array.DeviceArrayProtocol, device_array._DeviceArray]:
|
2021-11-22 08:22:10 -08:00
|
|
|
|
if device is None:
|
|
|
|
|
# no copying to be done because there's no target specified
|
|
|
|
|
return x
|
|
|
|
|
elif xb.get_device_backend(device).platform == x.device_buffer.platform():
|
|
|
|
|
# source and target platforms are the same
|
|
|
|
|
if x.device_buffer.device() == device:
|
|
|
|
|
# no copying to be done because source equals target
|
|
|
|
|
if x._device == device:
|
|
|
|
|
return x
|
|
|
|
|
else:
|
|
|
|
|
moved_buf = x.device_buffer # We need to change stickyness
|
|
|
|
|
else:
|
|
|
|
|
# move the buffer with a device-to-device copy
|
|
|
|
|
moved_buf = x.device_buffer.copy_to_device(device)
|
|
|
|
|
else:
|
|
|
|
|
# buffers from different XLA backends are passed through the host.
|
|
|
|
|
backend = xb.get_device_backend(device)
|
2022-08-25 07:27:54 -07:00
|
|
|
|
moved_buf = backend.buffer_from_pyval(np.asarray(x.device_buffer), device)
|
2021-11-22 08:22:10 -08:00
|
|
|
|
return device_array.make_device_array(x.aval, device, moved_buf)
|
|
|
|
|
|
|
|
|
|
|
2022-08-19 10:03:43 -07:00
|
|
|
|
def _copy_array_to_device(x: Array, device: Optional[xc.Device]) -> Array:
|
|
|
|
|
"""Copies `Array`s with SingleDeviceSharding to a different device."""
|
|
|
|
|
from jax.experimental import array, sharding
|
|
|
|
|
|
|
|
|
|
if device is None:
|
|
|
|
|
# no copying to be done because there's no target specified
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
buf = x._arrays[0]
|
|
|
|
|
if xb.get_device_backend(device).platform == buf.platform():
|
|
|
|
|
# source and target platforms are the same
|
|
|
|
|
if x.device() == device:
|
|
|
|
|
# no copying to be done because source equals target
|
|
|
|
|
if x._committed:
|
|
|
|
|
return x
|
|
|
|
|
else:
|
|
|
|
|
moved_buf = buf # We need to change stickyness
|
|
|
|
|
else:
|
|
|
|
|
# move the buffer with a device-to-device copy
|
|
|
|
|
moved_buf = buf.copy_to_device(device)
|
|
|
|
|
else:
|
|
|
|
|
# buffers from different XLA backends are passed through the host.
|
|
|
|
|
backend = xb.get_device_backend(device)
|
2022-08-25 07:27:54 -07:00
|
|
|
|
moved_buf = backend.buffer_from_pyval(np.asarray(buf), device)
|
2022-08-19 10:03:43 -07:00
|
|
|
|
return array.Array(
|
|
|
|
|
x.aval, sharding.SingleDeviceSharding(moved_buf.device()), [moved_buf],
|
|
|
|
|
committed=(device is not None))
|
|
|
|
|
|
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
def _device_put_impl(x, device: Optional[Device] = None):
|
2022-08-19 10:03:43 -07:00
|
|
|
|
from jax.experimental import array, sharding
|
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
if device_array.type_is_device_array(x):
|
|
|
|
|
return _copy_device_array_to_device(x, device)
|
|
|
|
|
|
2022-08-19 10:03:43 -07:00
|
|
|
|
if type(x) is array.Array and isinstance(x.sharding, sharding.SingleDeviceSharding):
|
|
|
|
|
return _copy_array_to_device(x, device)
|
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
try:
|
|
|
|
|
a = xla.abstractify(x)
|
|
|
|
|
except TypeError as err:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
|
2022-06-24 10:04:31 -07:00
|
|
|
|
return aval_to_result_handler(device, a)(None, *device_put(x, device))
|
|
|
|
|
|
2021-11-22 08:22:10 -08:00
|
|
|
|
|
|
|
|
|
device_put_p = core.Primitive('device_put')
|
|
|
|
|
device_put_p.def_impl(_device_put_impl)
|
|
|
|
|
device_put_p.def_abstract_eval(lambda x, device=None: x)
|
|
|
|
|
ad.deflinear2(device_put_p, lambda cotangent, _, **kwargs: [cotangent])
|
|
|
|
|
batching.defvectorized(device_put_p)
|
2021-11-23 18:57:45 -08:00
|
|
|
|
|
[MLIR] Change signature of lowering rules.
Refactoring only, no functional changes intended.
Previously the MLIR lowering rule signature was
```
def rule(ctx, avals_in, avals_out, *args, **jaxpr_params):
```
where `ctx` was a module-wide context.
Change it to
```
def rule(ctx, *args, **jaxpr_params)
```
where `ctx` is a per-rule context object. The previous parameters are now available as `ctx.module_context`, `ctx.avals_in`, and `ctx.avals_out`.
This change makes it easier to add new per-rule context information without having to refactor all of the lowering rules to accept a new argument. One example is a shape environment for dynamic shapes. Another example, which motivated this work, is that I want to include the primitive name as part of the rule context.
PiperOrigin-RevId: 416698663
2021-12-15 19:06:26 -08:00
|
|
|
|
def _device_put_lowering(ctx, x, *, device):
|
2021-11-23 18:57:45 -08:00
|
|
|
|
return [x]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mlir.register_lowering(device_put_p, _device_put_lowering)
|