roll forward with the fix: Make params arg in Compiled.call() position-only so that it does not conflict with the keyword args.

PiperOrigin-RevId: 481666211
This commit is contained in:
jax authors 2022-10-17 09:50:27 -07:00
parent 4cfa01f1cf
commit 504b3c1b25
3 changed files with 14 additions and 143 deletions

View File

@ -629,7 +629,7 @@ def bench_slicing_compilation2(state):
jax.jit(lambda x: (x[:1], x[1:2], x[2:3])).lower(x).compile()
def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit):
spec = pjit_lib.PartitionSpec('x')
mesh = jtu.create_global_mesh((num_devices,), ('x',))
s = sharding.MeshPspecSharding(mesh, spec)
@ -649,9 +649,6 @@ def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources)
if use_aot:
f = f.lower(x).compile()
x = f(x)
while state:
@ -700,59 +697,5 @@ def pjit_simple_4000_device(state):
state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1))
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_aot_1_device(state):
pjit_simple_benchmark(
state,
num_devices=1,
num_args=state.range(0),
cpp_jit=state.range(1),
use_aot=True)
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_aot_4_device(state):
pjit_simple_benchmark(
state,
num_devices=4,
num_args=state.range(0),
cpp_jit=state.range(1),
use_aot=True)
@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_aot_4000_device(state):
pjit_simple_benchmark(
state,
num_devices=4000,
num_args=state.range(0),
cpp_jit=state.range(1),
use_aot=True)
if __name__ == "__main__":
google_benchmark.main()

View File

@ -33,7 +33,7 @@ from __future__ import annotations
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple
from typing import Any, Dict, List, Optional, Sequence, Tuple
from typing_extensions import Protocol
import jax
@ -306,12 +306,6 @@ def make_args_info(in_tree, in_avals, donate_argnums):
ArgInfo(aval, i in donate_argnums)
for i, aval in enumerate(flat_avals)])
class CompiledCallParams(NamedTuple):
executable: Executable
no_kwargs: bool
in_tree: tree_util.PyTreeDef
out_tree: tree_util.PyTreeDef
class Compiled(Stage):
"""Compiled representation of a function specialized to types/values.
@ -328,19 +322,11 @@ class Compiled(Stage):
_executable: Executable
_no_kwargs: bool
def __init__(self, executable, args_info, out_tree, no_kwargs=False, create_cpp_call=None):
def __init__(self, executable, args_info, out_tree, no_kwargs=False):
self._executable = executable
self._no_kwargs = no_kwargs
self.args_info = args_info
self.out_tree = out_tree
self._params = CompiledCallParams(self._executable, self._no_kwargs,
self.in_tree, self.out_tree)
# TODO(chky): Remove this conditional statement once we implement the fast
# path in C++ for all AOT paths.
if create_cpp_call is not None:
self._cpp_call = create_cpp_call(self._params)
else:
self._cpp_call = None
def compiler_ir(self):
"""Post-compilation IR.
@ -429,25 +415,22 @@ class Compiled(Stage):
shardings_flat = self._executable.output_shardings()
return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error
@staticmethod
def call(*args, **kwargs):
params = args[0]
args = args[1:]
def __call__(self, *args, **kwargs):
if jax.config.jax_dynamic_shapes:
raise NotImplementedError
if params.no_kwargs and kwargs:
if self._no_kwargs and kwargs:
kws = ', '.join(kwargs.keys())
raise NotImplementedError(
"function was compiled by a transformation that does not support "
f"keyword arguments, but called with keyword arguments: {kws}")
args_flat, in_tree = tree_util.tree_flatten((args, kwargs))
if in_tree != params.in_tree:
if in_tree != self.in_tree:
# TODO(frostig): provide more info about the source function
# and transformation
raise TypeError(
f"function compiled for {params.in_tree}, called with {in_tree}")
f"function compiled for {self.in_tree}, called with {in_tree}")
try:
out_flat = params.executable.call(*args_flat)
out_flat = self._executable.call(*args_flat)
except TypeError as e:
# We can't transform ahead-of-time compiled calls, since we've
# lowered and compiled for a fixed function signature, and JAX
@ -465,15 +448,7 @@ class Compiled(Stage):
f"Tracer type {type(arg)}.") from e
else:
raise
outs = tree_util.tree_unflatten(params.out_tree, out_flat)
return outs, out_flat
def __call__(self, *args, **kwargs):
if self._cpp_call is not None:
return self._cpp_call(*args, **kwargs)
outs, _ = Compiled.call(self._params, *args, **kwargs)
return outs
return tree_util.tree_unflatten(self.out_tree, out_flat)
class Lowered(Stage):
@ -491,18 +466,16 @@ class Lowered(Stage):
out_tree: tree_util.PyTreeDef
_lowering: XlaLowering
_no_kwargs: bool
_create_cpp_call: Optional[Callable]
def __init__(self,
lowering: XlaLowering,
args_info, # PyTreee of ArgInfo
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False, create_cpp_call: Optional[Callable] = None):
no_kwargs: bool = False):
self._lowering = lowering
self._no_kwargs = no_kwargs
self.args_info = args_info
self.out_tree = out_tree
self._create_cpp_call = create_cpp_call
@classmethod
def from_flat_info(cls,
@ -511,7 +484,7 @@ class Lowered(Stage):
in_avals,
donate_argnums: Tuple[int, ...],
out_tree: tree_util.PyTreeDef,
no_kwargs: bool = False, create_cpp_call: Optional[Callable] = None):
no_kwargs: bool = False):
"""Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef.
Args:
@ -522,7 +495,7 @@ class Lowered(Stage):
arguments (an error will be raised if some are provided).
"""
return cls(lowering, make_args_info(in_tree, in_avals, donate_argnums),
out_tree, no_kwargs=no_kwargs, create_cpp_call=create_cpp_call)
out_tree, no_kwargs=no_kwargs)
def compile(self) -> Compiled:
"""Compile, returning a corresponding ``Compiled`` instance."""
@ -536,7 +509,7 @@ class Lowered(Stage):
kw = {}
return Compiled(self._lowering.compile(**kw), self.args_info, self.out_tree,
no_kwargs=self._no_kwargs, create_cpp_call=self._create_cpp_call)
no_kwargs=self._no_kwargs)
def as_text(self, dialect: Optional[str] = None) -> str:
"""A human-readable text representation of this lowering.

View File

@ -30,7 +30,6 @@ from jax import core
from jax import linear_util as lu
from jax import stages
from jax._src import array
from jax._src.stages import CompiledCallParams
from jax._src.api import _check_callable, _check_arg, FLAGS, device_put
from jax._src.config import config
from jax._src import dispatch
@ -171,45 +170,6 @@ def _cpp_pjit(fun: Callable, infer_params, static_argnums):
return wraps(fun)(cpp_pjit_f)
class _CppPjitAotCall:
def __init__(self, fun: Callable, static_argnums: Any):
self._fun = fun
self._static_argnums = static_argnums
def __call__(self, params: CompiledCallParams):
def aot_cache_miss(*args, **kwargs):
# The first invocation for a signature will use the python path. If it is
# a correct signature, the later invocations will use the C++ path.
# Otherwise, if the signature is wrong, it will be caught by the python
# path during the first invocation.
outs, out_flat = stages.Compiled.call(params, *args, **kwargs)
executable = params.executable
use_fastpath = (
isinstance(executable, pxla.MeshExecutable) and
isinstance(executable.unsafe_call, pxla.ExecuteReplicated) and
not executable.unsafe_call.has_unordered_effects and
not executable.unsafe_call.has_host_callbacks and
all(isinstance(x, xc.ArrayImpl) for x in out_flat))
if use_fastpath:
out_avals = [o.aval for o in out_flat]
out_committed = [o._committed for o in out_flat]
fastpath_data = _PjitFastpathData(executable.xla_executable, params.out_tree,
executable._in_shardings,
executable._out_shardings,
out_avals, out_committed)
else:
fastpath_data = None
return outs, fastpath_data
self._cpp_aot_pjit_f = xc._xla.pjit(self._fun, aot_cache_miss,
self._static_argnums)
return self._cpp_aot_pjit_f
# TODO(yashkatariya): Add pjit microbenchmarks.
# in_axis_resources and out_axis_resources can't be None as the default value
@ -470,16 +430,11 @@ def pjit(fun: Callable,
params['resource_env'], params['donated_invars'], params['name'],
in_is_global, always_lower=True)
if FLAGS.experimental_cpp_pjit and xc._version >= 96:
create_cpp_call = _CppPjitAotCall(fun, static_argnums)
else:
create_cpp_call = None
args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)
return stages.Lowered.from_flat_info(
lowering, args_kwargs_in_tree, local_in_avals, donate_argnums, out_tree,
no_kwargs=True, create_cpp_call=create_cpp_call)
no_kwargs=True)
wrapped.lower = lower
return wrapped