mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Make the _pjit_jaxpr cache more by not depending on the out_shardings. So if out_shardings argument of pjit changes, it should affect the jaxpr created because jaxpr creation is not dependent on out_shardings.
PiperOrigin-RevId: 510488544
This commit is contained in:
parent
f888e4814c
commit
031d15ed2d
@ -3570,7 +3570,7 @@ def clear_backends():
|
||||
_cpp_jit_cache.clear()
|
||||
jax_jit.CompiledFunctionCache.clear_all()
|
||||
pjit._pjit_lower_cached.cache_clear()
|
||||
pjit._pjit_jaxpr.cache_clear()
|
||||
pjit._create_pjit_jaxpr.cache_clear()
|
||||
if xla_extension_version >= 124:
|
||||
pjit._cpp_pjit_cache.clear()
|
||||
xc._xla.PjitFunctionCache.clear_all()
|
||||
|
@ -203,7 +203,7 @@ def _python_pjit(fun: Callable, infer_params_fn):
|
||||
return _python_pjit_helper(fun, infer_params_fn, *args, **kwargs)[0]
|
||||
|
||||
def _python_pjit_evict_fn():
|
||||
_pjit_jaxpr.evict_function(fun) # type: ignore
|
||||
_create_pjit_jaxpr.evict_function(fun) # type: ignore
|
||||
wrapped.clear_cache = _python_pjit_evict_fn
|
||||
return wrapped
|
||||
|
||||
@ -222,7 +222,7 @@ def _read_most_recent_pjit_call_executable():
|
||||
|
||||
def _cpp_pjit_evict_fn(self):
|
||||
self._clear_cache()
|
||||
_pjit_jaxpr.evict_function(self._fun) # type: ignore
|
||||
_create_pjit_jaxpr.evict_function(self._fun) # type: ignore
|
||||
|
||||
|
||||
if xla_extension_version >= 124:
|
||||
@ -922,7 +922,7 @@ def _process_in_axis_resources(in_shardings_thunk, local_in_avals,
|
||||
|
||||
|
||||
@lu.cache
|
||||
def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree, api_name):
|
||||
def _create_pjit_jaxpr(fun, global_in_avals, api_name):
|
||||
prev_positional_val = pxla.positional_semantics.val
|
||||
try:
|
||||
pxla.positional_semantics.val = pxla._PositionalSemantics.GLOBAL
|
||||
@ -941,7 +941,12 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree, api_name):
|
||||
else:
|
||||
jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
||||
final_consts = []
|
||||
return _ListWithW([jaxpr, final_consts, global_out_avals])
|
||||
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def _check_and_canonicalize_out_shardings(
|
||||
out_shardings_thunk, out_tree, global_out_avals):
|
||||
orig_out_shardings = out_shardings_thunk()
|
||||
# TODO(yashkatariya): Remove the if branch and fix flatten_axis_resources
|
||||
# instead. This condition exists because flatten_axis_resources passes in an
|
||||
@ -962,9 +967,16 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree, api_name):
|
||||
o if _is_unspecified(o) or is_auto(o) else to_op_sharding_sharding(o, aval.ndim)
|
||||
for o, aval in safe_zip(out_shardings_flat, global_out_avals)
|
||||
)
|
||||
return canonicalized_out_shardings_flat
|
||||
|
||||
|
||||
def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree, api_name):
|
||||
jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
|
||||
fun, global_in_avals, api_name)
|
||||
canonicalized_out_shardings_flat = _check_and_canonicalize_out_shardings(
|
||||
out_shardings_thunk, out_tree, tuple(global_out_avals))
|
||||
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
|
||||
return _ListWithW([jaxpr, final_consts, canonicalized_out_shardings_flat])
|
||||
return jaxpr, final_consts, canonicalized_out_shardings_flat
|
||||
|
||||
|
||||
def pjit_check_aval_sharding(
|
||||
|
Loading…
x
Reference in New Issue
Block a user