mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Enable colors when we are using a terminal or IPython
This commit is contained in:
parent
41e1635ac1
commit
b8a523f977
@ -476,7 +476,7 @@ flags.DEFINE_integer(
|
||||
|
||||
flags.DEFINE_bool(
|
||||
'jax_pprint_use_color',
|
||||
bool_env('JAX_PPRINT_USE_COLOR', False),
|
||||
bool_env('JAX_PPRINT_USE_COLOR', True),
|
||||
help='Enable jaxpr pretty-printing with colorful syntax highlighting.'
|
||||
)
|
||||
|
||||
|
@ -27,6 +27,7 @@
|
||||
|
||||
import abc
|
||||
import enum
|
||||
import sys
|
||||
from functools import partial
|
||||
from typing import List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
from jax.config import config
|
||||
@ -36,13 +37,31 @@ try:
|
||||
except ImportError:
|
||||
colorama = None
|
||||
|
||||
def _can_use_color() -> bool:
|
||||
try:
|
||||
# Check if we're in IPython or Colab
|
||||
ipython = get_ipython() # type: ignore[name-defined]
|
||||
shell = ipython.__class__.__name__
|
||||
if shell == "ZMQInteractiveShell":
|
||||
# Jupyter Notebook
|
||||
return True
|
||||
elif "colab" in str(ipython.__class__):
|
||||
# Google Colab (external or internal)
|
||||
return True
|
||||
except NameError:
|
||||
pass
|
||||
# Otherwise check if we're in a terminal
|
||||
return sys.stdout.isatty()
|
||||
|
||||
CAN_USE_COLOR = _can_use_color()
|
||||
|
||||
class Doc(abc.ABC):
|
||||
__slots__ = ()
|
||||
|
||||
def format(self, width: int = 80, use_color: bool = False,
|
||||
def format(self, width: int = 80, use_color: Optional[bool] = None,
|
||||
annotation_prefix=" # ") -> str:
|
||||
use_color = use_color or config.FLAGS.jax_pprint_use_color
|
||||
if use_color is None:
|
||||
use_color = CAN_USE_COLOR and config.FLAGS.jax_pprint_use_color
|
||||
return _format(self, width, use_color=use_color,
|
||||
annotation_prefix=annotation_prefix)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user