factor AOT types out to a stages module

This commit is contained in:
Roy Frostig 2022-03-14 19:38:23 -07:00
parent 5354a016e6
commit 047488446b
6 changed files with 227 additions and 173 deletions

View File

@ -62,7 +62,6 @@ from jax._src.api import (
checkpoint as checkpoint,
checkpoint_policies as checkpoint_policies,
closure_convert as closure_convert,
Compiled as Compiled,
curry, # TODO(phawkins): update users to avoid this.
custom_ivjp as custom_ivjp,
custom_gradient as custom_gradient,
@ -92,7 +91,6 @@ from jax._src.api import (
jvp as jvp,
local_device_count as local_device_count,
local_devices as local_devices,
Lowered as Lowered,
linearize as linearize,
linear_transpose as linear_transpose,
make_jaxpr as make_jaxpr,
@ -139,6 +137,7 @@ from jax import numpy as numpy
from jax import ops as ops
from jax import profiler as profiler
from jax import random as random
from jax import stages as stages
from jax import tree_util as tree_util
from jax import util as util

View File

@ -47,21 +47,22 @@ from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_multimap, treedef_is_leaf, treedef_children,
Partial, PyTreeDef, all_leaves, treedef_tuple)
from jax._src import device_array
from jax._src import dispatch
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import stages
from jax._src import traceback_util
from jax._src.api_util import (
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, _ensure_str_tuple, argnames_partial_except)
from jax._src import device_array
from jax._src import dispatch
from jax._src import dtypes
from jax._src.lax import lax as lax_internal
from jax._src.lib import jax_jit
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import broadcast_prefix
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
@ -493,161 +494,6 @@ def _cpp_jit(
return f_jitted
class Lowered:
"""Lowering of a function specialized to argument types and values.
A lowering is a computation ready for compilation. This class
carries a lowering together with the remaining information needed to
later compile and execute it. It also provides a common API for
querying properties of lowered computations across JAX's various
lowering paths (``jit``, ``pmap``, etc.).
"""
__slots__ = [
"in_tree", "in_avals", "out_tree", "donate_argnums", "_lowering",
"_no_kwargs"
]
# The PyTreeDef of the (positional arguments, keyword arguments).
#
# To get the individual PyTreeDef for the positional an keyword arguments,
# use `in_tree.children() which will return you a sequence of 2 PyTreeDef.
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
in_avals: Any
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_lowering: Union[dispatch.XlaComputation,
pxla.MeshComputation,
pxla.PmapComputation]
_no_kwargs: bool
def __init__(self,
lowering,
in_tree: PyTreeDef,
in_avals,
out_tree: PyTreeDef,
donate_argnums: Tuple[int],
no_kwargs: bool = False):
"""Initializer.
Args:
in_tree: The `PyTreeDef` of (args, kwargs).
out_tree: The `PyTreeDef` of the outputs.
no_kwargs: If `True` the transformation, and the `Compiled` returned from
this object will not support keyword arguments (an error will be raised
if some are provided).
"""
self._lowering = lowering
self.in_tree = in_tree
self.in_avals = in_avals
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs
def compile(self) -> 'Compiled':
return Compiled(
self._lowering.compile(), self.in_tree, self.in_avals,
self.out_tree, self.donate_argnums, self._no_kwargs)
def compiler_ir(self, dialect: Optional[str] = None):
if dialect is None or dialect == "mhlo":
return self._lowering.mhlo()
elif dialect == "hlo":
return self._lowering.hlo()
else:
raise ValueError(f"Unknown dialect {dialect}")
# TODO(frostig): remove this in favor of `compiler_ir`
def _xla_computation(self):
return self._lowering.hlo()
class Compiled:
"""Compiled representation of a function specialized to types/values.
A compiled computation is associated with an executable and the
remaining information needed to execute it. It also provides a
common API for querying properties of compiled computations across
JAX's various compilation paths and backends.
"""
__slots__ = [
"in_tree", "in_avals", "out_tree", "donate_argnums", "_executable",
"_no_kwargs"
]
# The PyTreeDef of the (positional arguments, keyword arguments).
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
in_avals: Any
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_executable: Union[dispatch.XlaCompiledComputation,
pxla.MeshExecutable,
pxla.PmapExecutable]
_no_kwargs: bool
def __init__(self, executable, in_tree, in_avals, out_tree, donate_argnums,
no_kwargs=False):
self._executable = executable
self.in_tree = in_tree
self.in_avals = in_avals
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs
def compiler_ir(self):
"""Post-compilation IR.
Compilation typically involves code transformation and
optimization. This method exists to reflect the compiler's
representation of the program after such passes, whenever
possible.
"""
return self._executable.xla_executable.hlo_modules()
def runtime_executable(self):
return self._executable.xla_executable
def _xla_executable(self):
# TODO(frostig): finalize API. For now, return the underlying
# executable directly via this method.
return self._executable.xla_executable
def __call__(self, *args, **kwargs):
if self._no_kwargs and kwargs:
kws = ', '.join(kwargs.keys())
raise NotImplementedError(
'function was compiled by a transformation that does not support '
f"keyword arguments, but called with keyword arguments: {kws}")
args_flat, in_tree = tree_flatten((args, kwargs))
if in_tree != self.in_tree:
# TODO(frostig): provide more info about the source function
# and transformation
raise TypeError(
f'function compiled for {self.in_tree}, called with {in_tree}')
try:
out_flat = self._executable.call(*args_flat)
except TypeError as e:
# We can't transform ahead-of-time compiled calls, since we've
# lowered and compiled for a fixed function signature, and JAX
# transformations change signatures. We interpret a Tracer
# argument as an indication of a transformation attempt. We
# could check this before the executable call, but we'd rather
# avoid isinstance checks on the call path. Seeing a TypeError
# might mean that arguments have JAX-invalid types, which in
# turn might mean some are Tracers.
for arg in args_flat:
if isinstance(arg, core.Tracer):
raise TypeError(
'Cannot apply JAX transformations to a function lowered and '
'compiled for a particular signature. Detected argument of '
f'Tracer type {type(arg)}.')
else:
raise
return tree_unflatten(self.out_tree, out_flat)
def _jit_lower(fun, static_argnums, static_argnames, device, backend,
donate_argnums, inline):
"""Make a ``lower`` method for jitted functions."""
@ -664,7 +510,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
return aval, None
@api_boundary
def lower(*args, **kwargs) -> Lowered:
def lower(*args, **kwargs) -> stages.Lowered:
"""Lower this function for the given arguments.
A lowered function is staged out of Python and translated to a
@ -687,8 +533,8 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
computation = dispatch.lower_xla_callable(flat_fun, device, backend, name,
donated_invars,
*arg_specs_and_device)
return Lowered(computation, in_tree, in_tree.unflatten(arg_specs),
out_tree(), donate_argnums)
return stages.Lowered(computation, in_tree, in_tree.unflatten(arg_specs),
out_tree(), donate_argnums)
return lower
@ -2182,7 +2028,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
# this might naturally be a method, with ``fun`` as a ``self`` and
# all the other arguments stored as attributes.
@api_boundary
def lower(*args, **kwargs) -> Lowered:
def lower(*args, **kwargs) -> stages.Lowered:
"""Lower a parallel-mapped form of this function for the given arguments.
A parallel-mapped and lowered function is staged out of Python and
@ -2208,8 +2054,9 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
donated_invars=p.donated_invars,
global_arg_shapes=p.global_arg_shapes_flat,
avals=abstract_args)
return Lowered(computation, p.in_tree, p.in_tree.unflatten(abstract_args),
p.out_tree(), donate_tuple)
return stages.Lowered(
computation, p.in_tree, p.in_tree.unflatten(abstract_args),
p.out_tree(), donate_tuple)
return lower

187
jax/_src/stages.py Normal file
View File

@ -0,0 +1,187 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Tuple, Union
from jax import core
from jax.interpreters import pxla
from jax.tree_util import PyTreeDef, tree_flatten, tree_unflatten
from jax._src import dispatch
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
class Lowered:
"""Lowering of a function specialized to argument types and values.
A lowering is a computation ready for compilation. This class
carries a lowering together with the remaining information needed to
later compile and execute it. It also provides a common API for
querying properties of lowered computations across JAX's various
lowering paths (``jit``, ``pmap``, etc.).
"""
__slots__ = [
"in_tree", "in_avals", "out_tree", "donate_argnums", "_lowering",
"_no_kwargs"
]
# The PyTreeDef of the (positional arguments, keyword arguments).
#
# To get the individual PyTreeDef for the positional an keyword arguments,
# use `in_tree.children() which will return you a sequence of 2 PyTreeDef.
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
in_avals: Any
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_lowering: Union[dispatch.XlaComputation,
pxla.MeshComputation,
pxla.PmapComputation]
_no_kwargs: bool
def __init__(self,
lowering,
in_tree: PyTreeDef,
in_avals,
out_tree: PyTreeDef,
donate_argnums: Tuple[int],
no_kwargs: bool = False):
"""Initializer.
Args:
in_tree: The `PyTreeDef` of (args, kwargs).
out_tree: The `PyTreeDef` of the outputs.
no_kwargs: If `True` the transformation, and the `Compiled` returned from
this object will not support keyword arguments (an error will be raised
if some are provided).
"""
self._lowering = lowering
self.in_tree = in_tree
self.in_avals = in_avals
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs
def compile(self) -> 'Compiled':
return Compiled(
self._lowering.compile(), self.in_tree, self.in_avals,
self.out_tree, self.donate_argnums, self._no_kwargs)
def compiler_ir(self, dialect: Optional[str] = None):
if dialect is None or dialect == "mhlo":
return self._lowering.mhlo()
elif dialect == "hlo":
return self._lowering.hlo()
else:
raise ValueError(f"Unknown dialect {dialect}")
# TODO(frostig): remove this in favor of `compiler_ir`
def _xla_computation(self):
return self._lowering.hlo()
class Compiled:
"""Compiled representation of a function specialized to types/values.
A compiled computation is associated with an executable and the
remaining information needed to execute it. It also provides a
common API for querying properties of compiled computations across
JAX's various compilation paths and backends.
"""
__slots__ = [
"in_tree", "in_avals", "out_tree", "donate_argnums", "_executable",
"_no_kwargs"
]
# The PyTreeDef of the (positional arguments, keyword arguments).
in_tree: PyTreeDef
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
in_avals: Any
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_executable: Union[dispatch.XlaCompiledComputation,
pxla.MeshExecutable,
pxla.PmapExecutable]
_no_kwargs: bool
def __init__(self, executable, in_tree, in_avals, out_tree, donate_argnums,
no_kwargs=False):
self._executable = executable
self.in_tree = in_tree
self.in_avals = in_avals
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs
def compiler_ir(self):
"""Post-compilation IR.
Compilation typically involves code transformation and
optimization. This method exists to reflect the compiler's
representation of the program after such passes, whenever
possible.
"""
return self._executable.xla_executable.hlo_modules()
def runtime_executable(self):
return self._executable.xla_executable
def _xla_executable(self):
# TODO(frostig): finalize API. For now, return the underlying
# executable directly via this method.
return self._executable.xla_executable
def __call__(self, *args, **kwargs):
if self._no_kwargs and kwargs:
kws = ', '.join(kwargs.keys())
raise NotImplementedError(
'function was compiled by a transformation that does not support '
f"keyword arguments, but called with keyword arguments: {kws}")
args_flat, in_tree = tree_flatten((args, kwargs))
if in_tree != self.in_tree:
# TODO(frostig): provide more info about the source function
# and transformation
raise TypeError(
f'function compiled for {self.in_tree}, called with {in_tree}')
try:
out_flat = self._executable.call(*args_flat)
except TypeError:
# We can't transform ahead-of-time compiled calls, since we've
# lowered and compiled for a fixed function signature, and JAX
# transformations change signatures. We interpret a Tracer
# argument as an indication of a transformation attempt. We
# could check this before the executable call, but we'd rather
# avoid isinstance checks on the call path. Seeing a TypeError
# might mean that arguments have JAX-invalid types, which in
# turn might mean some are Tracers.
for arg in args_flat:
if isinstance(arg, core.Tracer):
raise TypeError(
'Cannot apply JAX transformations to a function lowered and '
'compiled for a particular signature. Detected argument of '
f'Tracer type {type(arg)}.')
else:
raise
return tree_unflatten(self.out_tree, out_flat)

View File

@ -26,7 +26,8 @@ from enum import Enum
from jax import numpy as jnp
from jax import core
from jax import linear_util as lu
from jax._src.api import Lowered, _check_callable, _check_arg
from jax import stages
from jax._src.api import _check_callable, _check_arg
from jax._src import dispatch
from jax.tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map,
tree_leaves, treedef_tuple)
@ -662,7 +663,7 @@ def xmap(fun: Callable,
in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
in_avals = in_tree.unflatten(avals_flat)
return Lowered(
return stages.Lowered(
computation, in_tree, in_avals, out_tree(), donate_argnums,
no_kwargs=True)

View File

@ -23,7 +23,8 @@ from jax.experimental import maps
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
from jax import core
from jax import linear_util as lu
from jax._src.api import _check_callable, _check_arg, Lowered
from jax import stages
from jax._src.api import _check_callable, _check_arg
from jax._src import dispatch
from jax._src import source_info_util
from jax._src.lib import xla_extension_version
@ -282,8 +283,8 @@ def pjit(fun: Callable,
args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)
return Lowered(lowering, args_kwargs_in_tree, local_in_avals, out_tree,
donate_argnums, no_kwargs=True)
return stages.Lowered(lowering, args_kwargs_in_tree, local_in_avals,
out_tree, donate_argnums, no_kwargs=True)
wrapped.lower = lower
return wrapped

19
jax/stages.py Normal file
View File

@ -0,0 +1,19 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa: F401
from jax._src.stages import (
Compiled as Compiled,
Lowered as Lowered,
)