Remove _python_pjit and make _cpp_pjit the only function wrapper.

PiperOrigin-RevId: 617846352
This commit is contained in:
Yash Katariya 2024-03-21 08:09:37 -07:00 committed by jax authors
parent 74f4846d14
commit dd574cbc74

View File

@ -190,25 +190,9 @@ def _get_states(attrs_tracked):
return [jax_getattr(obj, attr) for (obj, attr) in attrs_tracked]
def _python_pjit(jit_info: PjitInfo):
fun = jit_info.fun
@wraps(fun)
@api_boundary
def wrapped(*args, **kwargs):
if config.disable_jit.value:
return fun(*args, **kwargs)
return _python_pjit_helper(jit_info, *args, **kwargs)[0]
def _python_pjit_evict_fn():
_create_pjit_jaxpr.evict_function(fun) # type: ignore
wrapped.clear_cache = _python_pjit_evict_fn
return wrapped
def _get_fastpath_data(
executable, out_tree, args_flat, out_flat, attrs_tracked, effects
executable, out_tree, args_flat, out_flat, attrs_tracked, effects,
abstracted_axes
) -> Optional[pxla.MeshExecutableFastpathData]:
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)
@ -221,6 +205,7 @@ def _get_fastpath_data(
and not executable.unsafe_call.has_unordered_effects
and not executable.unsafe_call.has_host_callbacks
and all(isinstance(x, xc.ArrayImpl) for x in out_reflattened)
and abstracted_axes is None
# no attr state effects
and not attrs_tracked
# no ref state effects
@ -289,7 +274,8 @@ def _cpp_pjit(jit_info: PjitInfo):
jit_info, *args, **kwargs)
executable = _read_most_recent_pjit_call_executable(jaxpr)
maybe_fastpath_data = _get_fastpath_data(
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects)
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects,
jit_info.abstracted_axes)
return outs, maybe_fastpath_data
fun = jit_info.fun
@ -394,10 +380,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
def _make_jit_wrapper(jit_info: PjitInfo):
if jit_info.abstracted_axes is None:
wrapped = _cpp_pjit(jit_info)
else:
wrapped = _python_pjit(jit_info)
wrapped = _cpp_pjit(jit_info)
@api_boundary
def lower(*args, **kwargs):
@ -1432,7 +1415,8 @@ def _pjit_call_impl(*args, jaxpr,
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
inline=inline)
fastpath_data = _get_fastpath_data(
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects)
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects,
None)
return out_flat, fastpath_data
f = _get_jaxpr_as_fun(