mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
roll forward with the fix: Make params
arg in Compiled.call() position-only so that it does not conflict with the keyword args.
PiperOrigin-RevId: 481666211
This commit is contained in:
parent
4cfa01f1cf
commit
504b3c1b25
@ -629,7 +629,7 @@ def bench_slicing_compilation2(state):
|
||||
jax.jit(lambda x: (x[:1], x[1:2], x[2:3])).lower(x).compile()
|
||||
|
||||
|
||||
def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
|
||||
def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit):
|
||||
spec = pjit_lib.PartitionSpec('x')
|
||||
mesh = jtu.create_global_mesh((num_devices,), ('x',))
|
||||
s = sharding.MeshPspecSharding(mesh, spec)
|
||||
@ -649,9 +649,6 @@ def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False):
|
||||
in_axis_resources=in_axis_resources,
|
||||
out_axis_resources=out_axis_resources)
|
||||
|
||||
if use_aot:
|
||||
f = f.lower(x).compile()
|
||||
|
||||
x = f(x)
|
||||
|
||||
while state:
|
||||
@ -700,59 +697,5 @@ def pjit_simple_4000_device(state):
|
||||
state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1))
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
|
||||
@google_benchmark.option.args([1, False])
|
||||
@google_benchmark.option.args([1, True])
|
||||
@google_benchmark.option.args([10, False])
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@jax_config.jax_array(True)
|
||||
def pjit_aot_1_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state,
|
||||
num_devices=1,
|
||||
num_args=state.range(0),
|
||||
cpp_jit=state.range(1),
|
||||
use_aot=True)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
|
||||
@google_benchmark.option.args([1, False])
|
||||
@google_benchmark.option.args([1, True])
|
||||
@google_benchmark.option.args([10, False])
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@jax_config.jax_array(True)
|
||||
def pjit_aot_4_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state,
|
||||
num_devices=4,
|
||||
num_args=state.range(0),
|
||||
cpp_jit=state.range(1),
|
||||
use_aot=True)
|
||||
|
||||
|
||||
@google_benchmark.register
|
||||
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
|
||||
@google_benchmark.option.args([1, False])
|
||||
@google_benchmark.option.args([1, True])
|
||||
@google_benchmark.option.args([10, False])
|
||||
@google_benchmark.option.args([10, True])
|
||||
@google_benchmark.option.args([100, False])
|
||||
@google_benchmark.option.args([100, True])
|
||||
@jax_config.jax_array(True)
|
||||
def pjit_aot_4000_device(state):
|
||||
pjit_simple_benchmark(
|
||||
state,
|
||||
num_devices=4000,
|
||||
num_args=state.range(0),
|
||||
cpp_jit=state.range(1),
|
||||
use_aot=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
google_benchmark.main()
|
||||
|
@ -33,7 +33,7 @@ from __future__ import annotations
|
||||
import warnings
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
from typing_extensions import Protocol
|
||||
|
||||
import jax
|
||||
@ -306,12 +306,6 @@ def make_args_info(in_tree, in_avals, donate_argnums):
|
||||
ArgInfo(aval, i in donate_argnums)
|
||||
for i, aval in enumerate(flat_avals)])
|
||||
|
||||
class CompiledCallParams(NamedTuple):
|
||||
executable: Executable
|
||||
no_kwargs: bool
|
||||
in_tree: tree_util.PyTreeDef
|
||||
out_tree: tree_util.PyTreeDef
|
||||
|
||||
|
||||
class Compiled(Stage):
|
||||
"""Compiled representation of a function specialized to types/values.
|
||||
@ -328,19 +322,11 @@ class Compiled(Stage):
|
||||
_executable: Executable
|
||||
_no_kwargs: bool
|
||||
|
||||
def __init__(self, executable, args_info, out_tree, no_kwargs=False, create_cpp_call=None):
|
||||
def __init__(self, executable, args_info, out_tree, no_kwargs=False):
|
||||
self._executable = executable
|
||||
self._no_kwargs = no_kwargs
|
||||
self.args_info = args_info
|
||||
self.out_tree = out_tree
|
||||
self._params = CompiledCallParams(self._executable, self._no_kwargs,
|
||||
self.in_tree, self.out_tree)
|
||||
# TODO(chky): Remove this conditional statement once we implement the fast
|
||||
# path in C++ for all AOT paths.
|
||||
if create_cpp_call is not None:
|
||||
self._cpp_call = create_cpp_call(self._params)
|
||||
else:
|
||||
self._cpp_call = None
|
||||
|
||||
def compiler_ir(self):
|
||||
"""Post-compilation IR.
|
||||
@ -429,25 +415,22 @@ class Compiled(Stage):
|
||||
shardings_flat = self._executable.output_shardings()
|
||||
return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error
|
||||
|
||||
@staticmethod
|
||||
def call(*args, **kwargs):
|
||||
params = args[0]
|
||||
args = args[1:]
|
||||
def __call__(self, *args, **kwargs):
|
||||
if jax.config.jax_dynamic_shapes:
|
||||
raise NotImplementedError
|
||||
if params.no_kwargs and kwargs:
|
||||
if self._no_kwargs and kwargs:
|
||||
kws = ', '.join(kwargs.keys())
|
||||
raise NotImplementedError(
|
||||
"function was compiled by a transformation that does not support "
|
||||
f"keyword arguments, but called with keyword arguments: {kws}")
|
||||
args_flat, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
if in_tree != params.in_tree:
|
||||
if in_tree != self.in_tree:
|
||||
# TODO(frostig): provide more info about the source function
|
||||
# and transformation
|
||||
raise TypeError(
|
||||
f"function compiled for {params.in_tree}, called with {in_tree}")
|
||||
f"function compiled for {self.in_tree}, called with {in_tree}")
|
||||
try:
|
||||
out_flat = params.executable.call(*args_flat)
|
||||
out_flat = self._executable.call(*args_flat)
|
||||
except TypeError as e:
|
||||
# We can't transform ahead-of-time compiled calls, since we've
|
||||
# lowered and compiled for a fixed function signature, and JAX
|
||||
@ -465,15 +448,7 @@ class Compiled(Stage):
|
||||
f"Tracer type {type(arg)}.") from e
|
||||
else:
|
||||
raise
|
||||
outs = tree_util.tree_unflatten(params.out_tree, out_flat)
|
||||
return outs, out_flat
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self._cpp_call is not None:
|
||||
return self._cpp_call(*args, **kwargs)
|
||||
|
||||
outs, _ = Compiled.call(self._params, *args, **kwargs)
|
||||
return outs
|
||||
return tree_util.tree_unflatten(self.out_tree, out_flat)
|
||||
|
||||
|
||||
class Lowered(Stage):
|
||||
@ -491,18 +466,16 @@ class Lowered(Stage):
|
||||
out_tree: tree_util.PyTreeDef
|
||||
_lowering: XlaLowering
|
||||
_no_kwargs: bool
|
||||
_create_cpp_call: Optional[Callable]
|
||||
|
||||
def __init__(self,
|
||||
lowering: XlaLowering,
|
||||
args_info, # PyTreee of ArgInfo
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
no_kwargs: bool = False, create_cpp_call: Optional[Callable] = None):
|
||||
no_kwargs: bool = False):
|
||||
self._lowering = lowering
|
||||
self._no_kwargs = no_kwargs
|
||||
self.args_info = args_info
|
||||
self.out_tree = out_tree
|
||||
self._create_cpp_call = create_cpp_call
|
||||
|
||||
@classmethod
|
||||
def from_flat_info(cls,
|
||||
@ -511,7 +484,7 @@ class Lowered(Stage):
|
||||
in_avals,
|
||||
donate_argnums: Tuple[int, ...],
|
||||
out_tree: tree_util.PyTreeDef,
|
||||
no_kwargs: bool = False, create_cpp_call: Optional[Callable] = None):
|
||||
no_kwargs: bool = False):
|
||||
"""Initialize from flat info (``in_avals`` etc.) and an input PyTreeDef.
|
||||
|
||||
Args:
|
||||
@ -522,7 +495,7 @@ class Lowered(Stage):
|
||||
arguments (an error will be raised if some are provided).
|
||||
"""
|
||||
return cls(lowering, make_args_info(in_tree, in_avals, donate_argnums),
|
||||
out_tree, no_kwargs=no_kwargs, create_cpp_call=create_cpp_call)
|
||||
out_tree, no_kwargs=no_kwargs)
|
||||
|
||||
def compile(self) -> Compiled:
|
||||
"""Compile, returning a corresponding ``Compiled`` instance."""
|
||||
@ -536,7 +509,7 @@ class Lowered(Stage):
|
||||
kw = {}
|
||||
|
||||
return Compiled(self._lowering.compile(**kw), self.args_info, self.out_tree,
|
||||
no_kwargs=self._no_kwargs, create_cpp_call=self._create_cpp_call)
|
||||
no_kwargs=self._no_kwargs)
|
||||
|
||||
def as_text(self, dialect: Optional[str] = None) -> str:
|
||||
"""A human-readable text representation of this lowering.
|
||||
|
@ -30,7 +30,6 @@ from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax import stages
|
||||
from jax._src import array
|
||||
from jax._src.stages import CompiledCallParams
|
||||
from jax._src.api import _check_callable, _check_arg, FLAGS, device_put
|
||||
from jax._src.config import config
|
||||
from jax._src import dispatch
|
||||
@ -171,45 +170,6 @@ def _cpp_pjit(fun: Callable, infer_params, static_argnums):
|
||||
|
||||
return wraps(fun)(cpp_pjit_f)
|
||||
|
||||
class _CppPjitAotCall:
|
||||
def __init__(self, fun: Callable, static_argnums: Any):
|
||||
self._fun = fun
|
||||
self._static_argnums = static_argnums
|
||||
|
||||
def __call__(self, params: CompiledCallParams):
|
||||
|
||||
def aot_cache_miss(*args, **kwargs):
|
||||
# The first invocation for a signature will use the python path. If it is
|
||||
# a correct signature, the later invocations will use the C++ path.
|
||||
# Otherwise, if the signature is wrong, it will be caught by the python
|
||||
# path during the first invocation.
|
||||
outs, out_flat = stages.Compiled.call(params, *args, **kwargs)
|
||||
|
||||
executable = params.executable
|
||||
|
||||
use_fastpath = (
|
||||
isinstance(executable, pxla.MeshExecutable) and
|
||||
isinstance(executable.unsafe_call, pxla.ExecuteReplicated) and
|
||||
not executable.unsafe_call.has_unordered_effects and
|
||||
not executable.unsafe_call.has_host_callbacks and
|
||||
all(isinstance(x, xc.ArrayImpl) for x in out_flat))
|
||||
|
||||
if use_fastpath:
|
||||
out_avals = [o.aval for o in out_flat]
|
||||
out_committed = [o._committed for o in out_flat]
|
||||
fastpath_data = _PjitFastpathData(executable.xla_executable, params.out_tree,
|
||||
executable._in_shardings,
|
||||
executable._out_shardings,
|
||||
out_avals, out_committed)
|
||||
else:
|
||||
fastpath_data = None
|
||||
|
||||
return outs, fastpath_data
|
||||
|
||||
self._cpp_aot_pjit_f = xc._xla.pjit(self._fun, aot_cache_miss,
|
||||
self._static_argnums)
|
||||
return self._cpp_aot_pjit_f
|
||||
|
||||
|
||||
# TODO(yashkatariya): Add pjit microbenchmarks.
|
||||
# in_axis_resources and out_axis_resources can't be None as the default value
|
||||
@ -470,16 +430,11 @@ def pjit(fun: Callable,
|
||||
params['resource_env'], params['donated_invars'], params['name'],
|
||||
in_is_global, always_lower=True)
|
||||
|
||||
if FLAGS.experimental_cpp_pjit and xc._version >= 96:
|
||||
create_cpp_call = _CppPjitAotCall(fun, static_argnums)
|
||||
else:
|
||||
create_cpp_call = None
|
||||
|
||||
args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
|
||||
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)
|
||||
return stages.Lowered.from_flat_info(
|
||||
lowering, args_kwargs_in_tree, local_in_avals, donate_argnums, out_tree,
|
||||
no_kwargs=True, create_cpp_call=create_cpp_call)
|
||||
no_kwargs=True)
|
||||
|
||||
wrapped.lower = lower
|
||||
return wrapped
|
||||
|
Loading…
x
Reference in New Issue
Block a user