mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
factor AOT types out to a stages
module
This commit is contained in:
parent
5354a016e6
commit
047488446b
@ -62,7 +62,6 @@ from jax._src.api import (
|
||||
checkpoint as checkpoint,
|
||||
checkpoint_policies as checkpoint_policies,
|
||||
closure_convert as closure_convert,
|
||||
Compiled as Compiled,
|
||||
curry, # TODO(phawkins): update users to avoid this.
|
||||
custom_ivjp as custom_ivjp,
|
||||
custom_gradient as custom_gradient,
|
||||
@ -92,7 +91,6 @@ from jax._src.api import (
|
||||
jvp as jvp,
|
||||
local_device_count as local_device_count,
|
||||
local_devices as local_devices,
|
||||
Lowered as Lowered,
|
||||
linearize as linearize,
|
||||
linear_transpose as linear_transpose,
|
||||
make_jaxpr as make_jaxpr,
|
||||
@ -139,6 +137,7 @@ from jax import numpy as numpy
|
||||
from jax import ops as ops
|
||||
from jax import profiler as profiler
|
||||
from jax import random as random
|
||||
from jax import stages as stages
|
||||
from jax import tree_util as tree_util
|
||||
from jax import util as util
|
||||
|
||||
|
179
jax/_src/api.py
179
jax/_src/api.py
@ -47,21 +47,22 @@ from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
tree_multimap, treedef_is_leaf, treedef_children,
|
||||
Partial, PyTreeDef, all_leaves, treedef_tuple)
|
||||
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src import stages
|
||||
from jax._src import traceback_util
|
||||
from jax._src.api_util import (
|
||||
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
|
||||
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
|
||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||
shaped_abstractify, _ensure_str_tuple, argnames_partial_except)
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import broadcast_prefix
|
||||
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
|
||||
@ -493,161 +494,6 @@ def _cpp_jit(
|
||||
return f_jitted
|
||||
|
||||
|
||||
class Lowered:
|
||||
"""Lowering of a function specialized to argument types and values.
|
||||
|
||||
A lowering is a computation ready for compilation. This class
|
||||
carries a lowering together with the remaining information needed to
|
||||
later compile and execute it. It also provides a common API for
|
||||
querying properties of lowered computations across JAX's various
|
||||
lowering paths (``jit``, ``pmap``, etc.).
|
||||
"""
|
||||
__slots__ = [
|
||||
"in_tree", "in_avals", "out_tree", "donate_argnums", "_lowering",
|
||||
"_no_kwargs"
|
||||
]
|
||||
|
||||
# The PyTreeDef of the (positional arguments, keyword arguments).
|
||||
#
|
||||
# To get the individual PyTreeDef for the positional an keyword arguments,
|
||||
# use `in_tree.children() which will return you a sequence of 2 PyTreeDef.
|
||||
in_tree: PyTreeDef
|
||||
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
|
||||
in_avals: Any
|
||||
out_tree: PyTreeDef
|
||||
donate_argnums: Tuple[int]
|
||||
_lowering: Union[dispatch.XlaComputation,
|
||||
pxla.MeshComputation,
|
||||
pxla.PmapComputation]
|
||||
_no_kwargs: bool
|
||||
|
||||
def __init__(self,
|
||||
lowering,
|
||||
in_tree: PyTreeDef,
|
||||
in_avals,
|
||||
out_tree: PyTreeDef,
|
||||
donate_argnums: Tuple[int],
|
||||
no_kwargs: bool = False):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
in_tree: The `PyTreeDef` of (args, kwargs).
|
||||
out_tree: The `PyTreeDef` of the outputs.
|
||||
no_kwargs: If `True` the transformation, and the `Compiled` returned from
|
||||
this object will not support keyword arguments (an error will be raised
|
||||
if some are provided).
|
||||
"""
|
||||
self._lowering = lowering
|
||||
self.in_tree = in_tree
|
||||
self.in_avals = in_avals
|
||||
self.out_tree = out_tree
|
||||
self.donate_argnums = donate_argnums
|
||||
self._no_kwargs = no_kwargs
|
||||
|
||||
def compile(self) -> 'Compiled':
|
||||
return Compiled(
|
||||
self._lowering.compile(), self.in_tree, self.in_avals,
|
||||
self.out_tree, self.donate_argnums, self._no_kwargs)
|
||||
|
||||
def compiler_ir(self, dialect: Optional[str] = None):
|
||||
if dialect is None or dialect == "mhlo":
|
||||
return self._lowering.mhlo()
|
||||
elif dialect == "hlo":
|
||||
return self._lowering.hlo()
|
||||
else:
|
||||
raise ValueError(f"Unknown dialect {dialect}")
|
||||
|
||||
# TODO(frostig): remove this in favor of `compiler_ir`
|
||||
def _xla_computation(self):
|
||||
return self._lowering.hlo()
|
||||
|
||||
|
||||
class Compiled:
|
||||
"""Compiled representation of a function specialized to types/values.
|
||||
|
||||
A compiled computation is associated with an executable and the
|
||||
remaining information needed to execute it. It also provides a
|
||||
common API for querying properties of compiled computations across
|
||||
JAX's various compilation paths and backends.
|
||||
"""
|
||||
__slots__ = [
|
||||
"in_tree", "in_avals", "out_tree", "donate_argnums", "_executable",
|
||||
"_no_kwargs"
|
||||
]
|
||||
|
||||
|
||||
# The PyTreeDef of the (positional arguments, keyword arguments).
|
||||
in_tree: PyTreeDef
|
||||
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
|
||||
in_avals: Any
|
||||
out_tree: PyTreeDef
|
||||
donate_argnums: Tuple[int]
|
||||
_executable: Union[dispatch.XlaCompiledComputation,
|
||||
pxla.MeshExecutable,
|
||||
pxla.PmapExecutable]
|
||||
_no_kwargs: bool
|
||||
|
||||
def __init__(self, executable, in_tree, in_avals, out_tree, donate_argnums,
|
||||
no_kwargs=False):
|
||||
self._executable = executable
|
||||
self.in_tree = in_tree
|
||||
self.in_avals = in_avals
|
||||
self.out_tree = out_tree
|
||||
self.donate_argnums = donate_argnums
|
||||
self._no_kwargs = no_kwargs
|
||||
|
||||
def compiler_ir(self):
|
||||
"""Post-compilation IR.
|
||||
|
||||
Compilation typically involves code transformation and
|
||||
optimization. This method exists to reflect the compiler's
|
||||
representation of the program after such passes, whenever
|
||||
possible.
|
||||
"""
|
||||
return self._executable.xla_executable.hlo_modules()
|
||||
|
||||
def runtime_executable(self):
|
||||
return self._executable.xla_executable
|
||||
|
||||
def _xla_executable(self):
|
||||
# TODO(frostig): finalize API. For now, return the underlying
|
||||
# executable directly via this method.
|
||||
return self._executable.xla_executable
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self._no_kwargs and kwargs:
|
||||
kws = ', '.join(kwargs.keys())
|
||||
raise NotImplementedError(
|
||||
'function was compiled by a transformation that does not support '
|
||||
f"keyword arguments, but called with keyword arguments: {kws}")
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
if in_tree != self.in_tree:
|
||||
# TODO(frostig): provide more info about the source function
|
||||
# and transformation
|
||||
raise TypeError(
|
||||
f'function compiled for {self.in_tree}, called with {in_tree}')
|
||||
try:
|
||||
out_flat = self._executable.call(*args_flat)
|
||||
except TypeError as e:
|
||||
# We can't transform ahead-of-time compiled calls, since we've
|
||||
# lowered and compiled for a fixed function signature, and JAX
|
||||
# transformations change signatures. We interpret a Tracer
|
||||
# argument as an indication of a transformation attempt. We
|
||||
# could check this before the executable call, but we'd rather
|
||||
# avoid isinstance checks on the call path. Seeing a TypeError
|
||||
# might mean that arguments have JAX-invalid types, which in
|
||||
# turn might mean some are Tracers.
|
||||
for arg in args_flat:
|
||||
if isinstance(arg, core.Tracer):
|
||||
raise TypeError(
|
||||
'Cannot apply JAX transformations to a function lowered and '
|
||||
'compiled for a particular signature. Detected argument of '
|
||||
f'Tracer type {type(arg)}.')
|
||||
else:
|
||||
raise
|
||||
return tree_unflatten(self.out_tree, out_flat)
|
||||
|
||||
|
||||
def _jit_lower(fun, static_argnums, static_argnames, device, backend,
|
||||
donate_argnums, inline):
|
||||
"""Make a ``lower`` method for jitted functions."""
|
||||
@ -664,7 +510,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
|
||||
return aval, None
|
||||
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs) -> Lowered:
|
||||
def lower(*args, **kwargs) -> stages.Lowered:
|
||||
"""Lower this function for the given arguments.
|
||||
|
||||
A lowered function is staged out of Python and translated to a
|
||||
@ -687,8 +533,8 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
|
||||
computation = dispatch.lower_xla_callable(flat_fun, device, backend, name,
|
||||
donated_invars,
|
||||
*arg_specs_and_device)
|
||||
return Lowered(computation, in_tree, in_tree.unflatten(arg_specs),
|
||||
out_tree(), donate_argnums)
|
||||
return stages.Lowered(computation, in_tree, in_tree.unflatten(arg_specs),
|
||||
out_tree(), donate_argnums)
|
||||
|
||||
return lower
|
||||
|
||||
@ -2182,7 +2028,7 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
# this might naturally be a method, with ``fun`` as a ``self`` and
|
||||
# all the other arguments stored as attributes.
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs) -> Lowered:
|
||||
def lower(*args, **kwargs) -> stages.Lowered:
|
||||
"""Lower a parallel-mapped form of this function for the given arguments.
|
||||
|
||||
A parallel-mapped and lowered function is staged out of Python and
|
||||
@ -2208,8 +2054,9 @@ def _pmap_lower(fun, axis_name, in_axes, out_axes, static_broadcasted_tuple,
|
||||
donated_invars=p.donated_invars,
|
||||
global_arg_shapes=p.global_arg_shapes_flat,
|
||||
avals=abstract_args)
|
||||
return Lowered(computation, p.in_tree, p.in_tree.unflatten(abstract_args),
|
||||
p.out_tree(), donate_tuple)
|
||||
return stages.Lowered(
|
||||
computation, p.in_tree, p.in_tree.unflatten(abstract_args),
|
||||
p.out_tree(), donate_tuple)
|
||||
|
||||
return lower
|
||||
|
||||
|
187
jax/_src/stages.py
Normal file
187
jax/_src/stages.py
Normal file
@ -0,0 +1,187 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
|
||||
from jax import core
|
||||
from jax.interpreters import pxla
|
||||
from jax.tree_util import PyTreeDef, tree_flatten, tree_unflatten
|
||||
|
||||
from jax._src import dispatch
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
|
||||
class Lowered:
|
||||
"""Lowering of a function specialized to argument types and values.
|
||||
|
||||
A lowering is a computation ready for compilation. This class
|
||||
carries a lowering together with the remaining information needed to
|
||||
later compile and execute it. It also provides a common API for
|
||||
querying properties of lowered computations across JAX's various
|
||||
lowering paths (``jit``, ``pmap``, etc.).
|
||||
"""
|
||||
__slots__ = [
|
||||
"in_tree", "in_avals", "out_tree", "donate_argnums", "_lowering",
|
||||
"_no_kwargs"
|
||||
]
|
||||
|
||||
# The PyTreeDef of the (positional arguments, keyword arguments).
|
||||
#
|
||||
# To get the individual PyTreeDef for the positional an keyword arguments,
|
||||
# use `in_tree.children() which will return you a sequence of 2 PyTreeDef.
|
||||
in_tree: PyTreeDef
|
||||
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
|
||||
in_avals: Any
|
||||
out_tree: PyTreeDef
|
||||
donate_argnums: Tuple[int]
|
||||
_lowering: Union[dispatch.XlaComputation,
|
||||
pxla.MeshComputation,
|
||||
pxla.PmapComputation]
|
||||
_no_kwargs: bool
|
||||
|
||||
def __init__(self,
|
||||
lowering,
|
||||
in_tree: PyTreeDef,
|
||||
in_avals,
|
||||
out_tree: PyTreeDef,
|
||||
donate_argnums: Tuple[int],
|
||||
no_kwargs: bool = False):
|
||||
"""Initializer.
|
||||
|
||||
Args:
|
||||
in_tree: The `PyTreeDef` of (args, kwargs).
|
||||
out_tree: The `PyTreeDef` of the outputs.
|
||||
no_kwargs: If `True` the transformation, and the `Compiled` returned from
|
||||
this object will not support keyword arguments (an error will be raised
|
||||
if some are provided).
|
||||
"""
|
||||
self._lowering = lowering
|
||||
self.in_tree = in_tree
|
||||
self.in_avals = in_avals
|
||||
self.out_tree = out_tree
|
||||
self.donate_argnums = donate_argnums
|
||||
self._no_kwargs = no_kwargs
|
||||
|
||||
def compile(self) -> 'Compiled':
|
||||
return Compiled(
|
||||
self._lowering.compile(), self.in_tree, self.in_avals,
|
||||
self.out_tree, self.donate_argnums, self._no_kwargs)
|
||||
|
||||
def compiler_ir(self, dialect: Optional[str] = None):
|
||||
if dialect is None or dialect == "mhlo":
|
||||
return self._lowering.mhlo()
|
||||
elif dialect == "hlo":
|
||||
return self._lowering.hlo()
|
||||
else:
|
||||
raise ValueError(f"Unknown dialect {dialect}")
|
||||
|
||||
# TODO(frostig): remove this in favor of `compiler_ir`
|
||||
def _xla_computation(self):
|
||||
return self._lowering.hlo()
|
||||
|
||||
|
||||
class Compiled:
|
||||
"""Compiled representation of a function specialized to types/values.
|
||||
|
||||
A compiled computation is associated with an executable and the
|
||||
remaining information needed to execute it. It also provides a
|
||||
common API for querying properties of compiled computations across
|
||||
JAX's various compilation paths and backends.
|
||||
"""
|
||||
__slots__ = [
|
||||
"in_tree", "in_avals", "out_tree", "donate_argnums", "_executable",
|
||||
"_no_kwargs"
|
||||
]
|
||||
|
||||
|
||||
# The PyTreeDef of the (positional arguments, keyword arguments).
|
||||
in_tree: PyTreeDef
|
||||
# The nested input tree of `ShapedArray` abstract values of (args, kwargs).
|
||||
in_avals: Any
|
||||
out_tree: PyTreeDef
|
||||
donate_argnums: Tuple[int]
|
||||
_executable: Union[dispatch.XlaCompiledComputation,
|
||||
pxla.MeshExecutable,
|
||||
pxla.PmapExecutable]
|
||||
_no_kwargs: bool
|
||||
|
||||
def __init__(self, executable, in_tree, in_avals, out_tree, donate_argnums,
|
||||
no_kwargs=False):
|
||||
self._executable = executable
|
||||
self.in_tree = in_tree
|
||||
self.in_avals = in_avals
|
||||
self.out_tree = out_tree
|
||||
self.donate_argnums = donate_argnums
|
||||
self._no_kwargs = no_kwargs
|
||||
|
||||
def compiler_ir(self):
|
||||
"""Post-compilation IR.
|
||||
|
||||
Compilation typically involves code transformation and
|
||||
optimization. This method exists to reflect the compiler's
|
||||
representation of the program after such passes, whenever
|
||||
possible.
|
||||
"""
|
||||
return self._executable.xla_executable.hlo_modules()
|
||||
|
||||
def runtime_executable(self):
|
||||
return self._executable.xla_executable
|
||||
|
||||
def _xla_executable(self):
|
||||
# TODO(frostig): finalize API. For now, return the underlying
|
||||
# executable directly via this method.
|
||||
return self._executable.xla_executable
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self._no_kwargs and kwargs:
|
||||
kws = ', '.join(kwargs.keys())
|
||||
raise NotImplementedError(
|
||||
'function was compiled by a transformation that does not support '
|
||||
f"keyword arguments, but called with keyword arguments: {kws}")
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
if in_tree != self.in_tree:
|
||||
# TODO(frostig): provide more info about the source function
|
||||
# and transformation
|
||||
raise TypeError(
|
||||
f'function compiled for {self.in_tree}, called with {in_tree}')
|
||||
try:
|
||||
out_flat = self._executable.call(*args_flat)
|
||||
except TypeError:
|
||||
# We can't transform ahead-of-time compiled calls, since we've
|
||||
# lowered and compiled for a fixed function signature, and JAX
|
||||
# transformations change signatures. We interpret a Tracer
|
||||
# argument as an indication of a transformation attempt. We
|
||||
# could check this before the executable call, but we'd rather
|
||||
# avoid isinstance checks on the call path. Seeing a TypeError
|
||||
# might mean that arguments have JAX-invalid types, which in
|
||||
# turn might mean some are Tracers.
|
||||
for arg in args_flat:
|
||||
if isinstance(arg, core.Tracer):
|
||||
raise TypeError(
|
||||
'Cannot apply JAX transformations to a function lowered and '
|
||||
'compiled for a particular signature. Detected argument of '
|
||||
f'Tracer type {type(arg)}.')
|
||||
else:
|
||||
raise
|
||||
return tree_unflatten(self.out_tree, out_flat)
|
@ -26,7 +26,8 @@ from enum import Enum
|
||||
from jax import numpy as jnp
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src.api import Lowered, _check_callable, _check_arg
|
||||
from jax import stages
|
||||
from jax._src.api import _check_callable, _check_arg
|
||||
from jax._src import dispatch
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map,
|
||||
tree_leaves, treedef_tuple)
|
||||
@ -662,7 +663,7 @@ def xmap(fun: Callable,
|
||||
|
||||
in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
|
||||
in_avals = in_tree.unflatten(avals_flat)
|
||||
return Lowered(
|
||||
return stages.Lowered(
|
||||
computation, in_tree, in_avals, out_tree(), donate_argnums,
|
||||
no_kwargs=True)
|
||||
|
||||
|
@ -23,7 +23,8 @@ from jax.experimental import maps
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src.api import _check_callable, _check_arg, Lowered
|
||||
from jax import stages
|
||||
from jax._src.api import _check_callable, _check_arg
|
||||
from jax._src import dispatch
|
||||
from jax._src import source_info_util
|
||||
from jax._src.lib import xla_extension_version
|
||||
@ -282,8 +283,8 @@ def pjit(fun: Callable,
|
||||
|
||||
args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
|
||||
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)
|
||||
return Lowered(lowering, args_kwargs_in_tree, local_in_avals, out_tree,
|
||||
donate_argnums, no_kwargs=True)
|
||||
return stages.Lowered(lowering, args_kwargs_in_tree, local_in_avals,
|
||||
out_tree, donate_argnums, no_kwargs=True)
|
||||
|
||||
wrapped.lower = lower
|
||||
return wrapped
|
||||
|
19
jax/stages.py
Normal file
19
jax/stages.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright 2022 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from jax._src.stages import (
|
||||
Compiled as Compiled,
|
||||
Lowered as Lowered,
|
||||
)
|
Loading…
x
Reference in New Issue
Block a user