mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Remove _python_pjit
and make _cpp_pjit
the only function wrapper.
PiperOrigin-RevId: 617846352
This commit is contained in:
parent
74f4846d14
commit
dd574cbc74
@ -190,25 +190,9 @@ def _get_states(attrs_tracked):
|
||||
return [jax_getattr(obj, attr) for (obj, attr) in attrs_tracked]
|
||||
|
||||
|
||||
def _python_pjit(jit_info: PjitInfo):
|
||||
|
||||
fun = jit_info.fun
|
||||
|
||||
@wraps(fun)
|
||||
@api_boundary
|
||||
def wrapped(*args, **kwargs):
|
||||
if config.disable_jit.value:
|
||||
return fun(*args, **kwargs)
|
||||
return _python_pjit_helper(jit_info, *args, **kwargs)[0]
|
||||
|
||||
def _python_pjit_evict_fn():
|
||||
_create_pjit_jaxpr.evict_function(fun) # type: ignore
|
||||
wrapped.clear_cache = _python_pjit_evict_fn
|
||||
return wrapped
|
||||
|
||||
|
||||
def _get_fastpath_data(
|
||||
executable, out_tree, args_flat, out_flat, attrs_tracked, effects
|
||||
executable, out_tree, args_flat, out_flat, attrs_tracked, effects,
|
||||
abstracted_axes
|
||||
) -> Optional[pxla.MeshExecutableFastpathData]:
|
||||
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
|
||||
|
||||
@ -221,6 +205,7 @@ def _get_fastpath_data(
|
||||
and not executable.unsafe_call.has_unordered_effects
|
||||
and not executable.unsafe_call.has_host_callbacks
|
||||
and all(isinstance(x, xc.ArrayImpl) for x in out_reflattened)
|
||||
and abstracted_axes is None
|
||||
# no attr state effects
|
||||
and not attrs_tracked
|
||||
# no ref state effects
|
||||
@ -289,7 +274,8 @@ def _cpp_pjit(jit_info: PjitInfo):
|
||||
jit_info, *args, **kwargs)
|
||||
executable = _read_most_recent_pjit_call_executable(jaxpr)
|
||||
maybe_fastpath_data = _get_fastpath_data(
|
||||
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects)
|
||||
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects,
|
||||
jit_info.abstracted_axes)
|
||||
return outs, maybe_fastpath_data
|
||||
|
||||
fun = jit_info.fun
|
||||
@ -394,10 +380,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
|
||||
|
||||
|
||||
def _make_jit_wrapper(jit_info: PjitInfo):
|
||||
if jit_info.abstracted_axes is None:
|
||||
wrapped = _cpp_pjit(jit_info)
|
||||
else:
|
||||
wrapped = _python_pjit(jit_info)
|
||||
wrapped = _cpp_pjit(jit_info)
|
||||
|
||||
@api_boundary
|
||||
def lower(*args, **kwargs):
|
||||
@ -1432,7 +1415,8 @@ def _pjit_call_impl(*args, jaxpr,
|
||||
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
|
||||
inline=inline)
|
||||
fastpath_data = _get_fastpath_data(
|
||||
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects)
|
||||
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
|
||||
None)
|
||||
return out_flat, fastpath_data
|
||||
|
||||
f = _get_jaxpr_as_fun(
|
||||
|
Loading…
x
Reference in New Issue
Block a user