Make the _pjit_jaxpr cache more by not depending on the out_shardings. So if out_shardings argument of pjit changes, it should affect the jaxpr created because jaxpr creation is not dependent on out_shardings.

PiperOrigin-RevId: 510488544
This commit is contained in:
Yash Katariya 2023-02-17 12:01:50 -08:00 committed by jax authors
parent f888e4814c
commit 031d15ed2d
2 changed files with 17 additions and 5 deletions

View File

@ -3570,7 +3570,7 @@ def clear_backends():
_cpp_jit_cache.clear()
jax_jit.CompiledFunctionCache.clear_all()
pjit._pjit_lower_cached.cache_clear()
pjit._pjit_jaxpr.cache_clear()
pjit._create_pjit_jaxpr.cache_clear()
if xla_extension_version >= 124:
pjit._cpp_pjit_cache.clear()
xc._xla.PjitFunctionCache.clear_all()

View File

@ -203,7 +203,7 @@ def _python_pjit(fun: Callable, infer_params_fn):
return _python_pjit_helper(fun, infer_params_fn, *args, **kwargs)[0]
def _python_pjit_evict_fn():
_pjit_jaxpr.evict_function(fun) # type: ignore
_create_pjit_jaxpr.evict_function(fun) # type: ignore
wrapped.clear_cache = _python_pjit_evict_fn
return wrapped
@ -222,7 +222,7 @@ def _read_most_recent_pjit_call_executable():
def _cpp_pjit_evict_fn(self):
self._clear_cache()
_pjit_jaxpr.evict_function(self._fun) # type: ignore
_create_pjit_jaxpr.evict_function(self._fun) # type: ignore
if xla_extension_version >= 124:
@ -922,7 +922,7 @@ def _process_in_axis_resources(in_shardings_thunk, local_in_avals,
@lu.cache
def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree, api_name):
def _create_pjit_jaxpr(fun, global_in_avals, api_name):
prev_positional_val = pxla.positional_semantics.val
try:
pxla.positional_semantics.val = pxla._PositionalSemantics.GLOBAL
@ -941,7 +941,12 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree, api_name):
else:
jaxpr = core.ClosedJaxpr(jaxpr, consts)
final_consts = []
return _ListWithW([jaxpr, final_consts, global_out_avals])
@lru_cache(maxsize=4096)
def _check_and_canonicalize_out_shardings(
out_shardings_thunk, out_tree, global_out_avals):
orig_out_shardings = out_shardings_thunk()
# TODO(yashkatariya): Remove the if branch and fix flatten_axis_resources
# instead. This condition exists because flatten_axis_resources passes in an
@ -962,9 +967,16 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree, api_name):
o if _is_unspecified(o) or is_auto(o) else to_op_sharding_sharding(o, aval.ndim)
for o, aval in safe_zip(out_shardings_flat, global_out_avals)
)
return canonicalized_out_shardings_flat
def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree, api_name):
jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
fun, global_in_avals, api_name)
canonicalized_out_shardings_flat = _check_and_canonicalize_out_shardings(
out_shardings_thunk, out_tree, tuple(global_out_avals))
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
return _ListWithW([jaxpr, final_consts, canonicalized_out_shardings_flat])
return jaxpr, final_consts, canonicalized_out_shardings_flat
def pjit_check_aval_sharding(