Add logging if we get a C++ cache miss

PiperOrigin-RevId: 531555996
This commit is contained in:
Yash Katariya 2023-05-12 11:14:53 -07:00 committed by jax authors
parent 0bc3136fbc
commit 559b837ba5

View File

@ -14,6 +14,7 @@
import dataclasses
import inspect
import logging
import numpy as np
from typing import (Callable, Sequence, Tuple, Union, cast, List, Optional,
Iterable, NamedTuple, Any)
@ -22,6 +23,7 @@ from functools import partial, lru_cache
import threading
import warnings
import jax
from jax._src import core
from jax._src import stages
from jax._src import dispatch
@ -76,6 +78,9 @@ PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTOAxisResource]
MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTOAxisResource]
MeshShardingMinusUnspecified = Union[NamedSharding, AUTOAxisResource]
logger = logging.getLogger(__name__)
def _try_infer_args(f, tree):
dummy_args = tree_unflatten(tree, [False] * tree.num_leaves)
try:
@ -205,6 +210,10 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
@api_boundary
def cache_miss(*args, **kwargs):
log_priority = logging.WARNING if jax.config.jax_log_compiles else logging.DEBUG
if logger.isEnabledFor(log_priority):
logger.log(log_priority, "C++ fastpath cache miss")
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
fun, infer_params_fn, *args, **kwargs)