require keyword arguments to Traced.lower

Minor: also move the `Traced` definition to between `Wrapped` and `Lowered`, since it is the stage between these two.
PiperOrigin-RevId: 651211125
This commit is contained in:
Roy Frostig 2024-07-10 17:58:44 -07:00 committed by jax authors
parent 23922ce4bc
commit 21fd50749a

View File

@ -421,6 +421,7 @@ def make_args_info(in_tree, in_avals, donate_argnums):
ArgInfo(aval, i in donate_argnums)
for i, aval in enumerate(flat_avals)])
class CompiledCallParams(NamedTuple):
executable: Executable
no_kwargs: bool
@ -428,37 +429,6 @@ class CompiledCallParams(NamedTuple):
out_tree: tree_util.PyTreeDef
class Traced(Stage):
__slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable",
"_args_flat", "_arg_names", "_num_consts"]
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
lower_callable, args_flat=None, arg_names=None,
num_consts: int = 0):
self.jaxpr = jaxpr
self.args_info = args_info
self.fun_name = fun_name
self._out_tree = out_tree
self._lower_callable = lower_callable
self._args_flat = args_flat
self._arg_names = arg_names
self._num_consts = num_consts
@property
def out_info(self):
return self._out_tree.unflatten(
[OutInfo(o.shape, o.dtype) for o in self.jaxpr.out_avals])
def lower(self, lowering_platforms: tuple[str, ...] | None = None,
_private_parameters: mlir.LoweringParameters | None = None):
if _private_parameters is None:
_private_parameters = mlir.LoweringParameters()
new_callable = functools.partial(
self._lower_callable, lowering_platforms=lowering_platforms,
lowering_parameters=_private_parameters)
return Lowered(new_callable(), self.args_info, self._out_tree)
class Compiled(Stage):
"""Compiled representation of a function specialized to types/values.
@ -756,6 +726,37 @@ class Lowered(Stage):
return None
class Traced(Stage):
__slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable",
"_args_flat", "_arg_names", "_num_consts"]
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
lower_callable, args_flat=None, arg_names=None,
num_consts: int = 0):
self.jaxpr = jaxpr
self.args_info = args_info
self.fun_name = fun_name
self._out_tree = out_tree
self._lower_callable = lower_callable
self._args_flat = args_flat
self._arg_names = arg_names
self._num_consts = num_consts
@property
def out_info(self):
return self._out_tree.unflatten(
[OutInfo(o.shape, o.dtype) for o in self.jaxpr.out_avals])
def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
_private_parameters: mlir.LoweringParameters | None = None):
if _private_parameters is None:
_private_parameters = mlir.LoweringParameters()
new_callable = functools.partial(
self._lower_callable, lowering_platforms=lowering_platforms,
lowering_parameters=_private_parameters)
return Lowered(new_callable(), self.args_info, self._out_tree)
@runtime_checkable
class Wrapped(Protocol):
"""A function ready to be traced, lowered, and compiled.