mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add logging if we get a C++ cache miss
PiperOrigin-RevId: 531555996
This commit is contained in:
parent
0bc3136fbc
commit
559b837ba5
@ -14,6 +14,7 @@
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import (Callable, Sequence, Tuple, Union, cast, List, Optional,
|
||||
Iterable, NamedTuple, Any)
|
||||
@ -22,6 +23,7 @@ from functools import partial, lru_cache
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import stages
|
||||
from jax._src import dispatch
|
||||
@ -76,6 +78,9 @@ PjitShardingMinusUnspecified = Union[GSPMDSharding, AUTOAxisResource]
|
||||
MeshSharding = Union[NamedSharding, UnspecifiedValue, AUTOAxisResource]
|
||||
MeshShardingMinusUnspecified = Union[NamedSharding, AUTOAxisResource]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _try_infer_args(f, tree):
|
||||
dummy_args = tree_unflatten(tree, [False] * tree.num_leaves)
|
||||
try:
|
||||
@ -205,6 +210,10 @@ def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
|
||||
|
||||
@api_boundary
|
||||
def cache_miss(*args, **kwargs):
|
||||
log_priority = logging.WARNING if jax.config.jax_log_compiles else logging.DEBUG
|
||||
if logger.isEnabledFor(log_priority):
|
||||
logger.log(log_priority, "C++ fastpath cache miss")
|
||||
|
||||
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
|
||||
fun, infer_params_fn, *args, **kwargs)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user