mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
require keyword arguments to Traced.lower
Minor: also move the `Traced` definition to between `Wrapped` and `Lowered`, since it is the stage between these two. PiperOrigin-RevId: 651211125
This commit is contained in:
parent
23922ce4bc
commit
21fd50749a
@ -421,6 +421,7 @@ def make_args_info(in_tree, in_avals, donate_argnums):
|
||||
ArgInfo(aval, i in donate_argnums)
|
||||
for i, aval in enumerate(flat_avals)])
|
||||
|
||||
|
||||
class CompiledCallParams(NamedTuple):
|
||||
executable: Executable
|
||||
no_kwargs: bool
|
||||
@ -428,37 +429,6 @@ class CompiledCallParams(NamedTuple):
|
||||
out_tree: tree_util.PyTreeDef
|
||||
|
||||
|
||||
class Traced(Stage):
|
||||
__slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable",
|
||||
"_args_flat", "_arg_names", "_num_consts"]
|
||||
|
||||
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
|
||||
lower_callable, args_flat=None, arg_names=None,
|
||||
num_consts: int = 0):
|
||||
self.jaxpr = jaxpr
|
||||
self.args_info = args_info
|
||||
self.fun_name = fun_name
|
||||
self._out_tree = out_tree
|
||||
self._lower_callable = lower_callable
|
||||
self._args_flat = args_flat
|
||||
self._arg_names = arg_names
|
||||
self._num_consts = num_consts
|
||||
|
||||
@property
|
||||
def out_info(self):
|
||||
return self._out_tree.unflatten(
|
||||
[OutInfo(o.shape, o.dtype) for o in self.jaxpr.out_avals])
|
||||
|
||||
def lower(self, lowering_platforms: tuple[str, ...] | None = None,
|
||||
_private_parameters: mlir.LoweringParameters | None = None):
|
||||
if _private_parameters is None:
|
||||
_private_parameters = mlir.LoweringParameters()
|
||||
new_callable = functools.partial(
|
||||
self._lower_callable, lowering_platforms=lowering_platforms,
|
||||
lowering_parameters=_private_parameters)
|
||||
return Lowered(new_callable(), self.args_info, self._out_tree)
|
||||
|
||||
|
||||
class Compiled(Stage):
|
||||
"""Compiled representation of a function specialized to types/values.
|
||||
|
||||
@ -756,6 +726,37 @@ class Lowered(Stage):
|
||||
return None
|
||||
|
||||
|
||||
class Traced(Stage):
|
||||
__slots__ = ["jaxpr", "args_info", "fun_name", "_out_tree", "_lower_callable",
|
||||
"_args_flat", "_arg_names", "_num_consts"]
|
||||
|
||||
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
|
||||
lower_callable, args_flat=None, arg_names=None,
|
||||
num_consts: int = 0):
|
||||
self.jaxpr = jaxpr
|
||||
self.args_info = args_info
|
||||
self.fun_name = fun_name
|
||||
self._out_tree = out_tree
|
||||
self._lower_callable = lower_callable
|
||||
self._args_flat = args_flat
|
||||
self._arg_names = arg_names
|
||||
self._num_consts = num_consts
|
||||
|
||||
@property
|
||||
def out_info(self):
|
||||
return self._out_tree.unflatten(
|
||||
[OutInfo(o.shape, o.dtype) for o in self.jaxpr.out_avals])
|
||||
|
||||
def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
|
||||
_private_parameters: mlir.LoweringParameters | None = None):
|
||||
if _private_parameters is None:
|
||||
_private_parameters = mlir.LoweringParameters()
|
||||
new_callable = functools.partial(
|
||||
self._lower_callable, lowering_platforms=lowering_platforms,
|
||||
lowering_parameters=_private_parameters)
|
||||
return Lowered(new_callable(), self.args_info, self._out_tree)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Wrapped(Protocol):
|
||||
"""A function ready to be traced, lowered, and compiled.
|
||||
|
Loading…
x
Reference in New Issue
Block a user