Catch ModuleNotFoundError instead of ImportError.

We frequently use the pattern
try:
  import m
except ImportError:
  # do something else.

This suppresses errors when the module can be found but does not import
successfully for any reason. Instead, catch only ModuleNotFoundError so
missing modules are allowed but buggy modules still report errors.
This commit is contained in:
Peter Hawkins 2022-08-18 12:55:49 +00:00
parent fe665b3a64
commit 1e241dcf16
20 changed files with 52 additions and 52 deletions

View File

@ -33,7 +33,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -60,7 +60,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -235,7 +235,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -265,7 +265,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -288,7 +288,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -303,7 +303,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -319,7 +319,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -342,7 +342,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -358,7 +358,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -371,7 +371,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -384,7 +384,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -399,7 +399,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -416,7 +416,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -451,7 +451,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -466,7 +466,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -482,7 +482,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -493,7 +493,7 @@
"# Download animation.\n",
"try:\n",
" from google.colab import files\n",
"except ImportError:\n",
"except ModuleNotFoundError:\n",
" pass\n",
"else:\n",
" files.download('wave_movie.gif')"

View File

@ -39,7 +39,7 @@ def cloud_tpu_init():
import libtpu
# pytype: enable=import-error
# pylint: enable=import-outside-toplevel
except ImportError:
except ModuleNotFoundError:
# We assume libtpu is installed iff we're in a correctly-configured Cloud
# TPU environment. Exit early if we're not running on Cloud TPU.
return

View File

@ -33,7 +33,7 @@ if colab_lib.IS_COLAB_ENABLED:
try:
import pygments
IS_PYGMENTS_ENABLED = True
except ImportError:
except ModuleNotFoundError:
IS_PYGMENTS_ENABLED = False
# pytype: enable=import-error
# pylint: enable=g-import-not-at-top

View File

@ -34,7 +34,7 @@ from jax.config import config
try:
import colorama # pytype: disable=import-error
except ImportError:
except ModuleNotFoundError:
colorama = None
def _can_use_color() -> bool:

View File

@ -24,12 +24,12 @@ import jax
try:
from tensorflow.python.profiler import profiler_v2 as profiler
from tensorflow.python.profiler import profiler_client
except ImportError:
raise ImportError("This script requires `tensorflow` to be installed.")
except ModuleNotFoundError:
raise ModuleNotFoundError("This script requires `tensorflow` to be installed.")
try:
from tensorboard_plugin_profile.convert import raw_to_tool_data as convert
except ImportError:
raise ImportError(
except ModuleNotFoundError:
raise ModuleNotFoundError(
"This script requires `tensorboard_plugin_profile` to be installed.")
# pytype: enable=import-error

View File

@ -32,7 +32,7 @@ import numpy as np
try:
import tensorflow as tf # type: ignore[import]
except ImportError:
except ModuleNotFoundError:
tf = None
config.parse_flags_with_absl()

View File

@ -25,7 +25,7 @@ from jax._src.public_test_util import (
# pytype: disable=import-error
try:
import jax._src.test_util as _private_test_util
except ImportError:
except ModuleNotFoundError:
pass
else:
del _private_test_util
@ -35,7 +35,7 @@ else:
def __getattr__(attr):
try:
from jax._src import test_util
except ImportError:
except ModuleNotFoundError:
raise AttributeError(f"module {__name__} has no attribute {attr}")
if attr in ['cases_from_list', 'check_close', 'check_eq', 'device_under_test',
'format_shape_dtype_string', 'rand_uniform', 'skip_on_devices',

View File

@ -78,13 +78,13 @@ import jax
import jax.numpy as jnp
try:
from jax.experimental import jax2tf
except ImportError:
import jax.experimental.jax2tf as jax2tf
except ModuleNotFoundError:
jax2tf = None # type: ignore[assignment]
try:
import tensorflow as tf
except ImportError:
except ModuleNotFoundError:
tf = None # type: ignore
FLAGS = flags.FLAGS

View File

@ -26,14 +26,14 @@ try:
from .cuda import _cuda_linalg
for _name, _value in _cuda_linalg.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
except ModuleNotFoundError:
_cuda_linalg = None
try:
from .rocm import _hip_linalg
for _name, _value in _hip_linalg.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
except ModuleNotFoundError:
_hip_linalg = None
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)

View File

@ -28,14 +28,14 @@ try:
from .cuda import _cuda_prng
for _name, _value in _cuda_prng.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
except ModuleNotFoundError:
_cuda_prng = None
try:
from .rocm import _hip_prng
for _name, _value in _hip_prng.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
except ModuleNotFoundError:
_hip_prng = None
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)

View File

@ -30,14 +30,14 @@ try:
from .cuda import _cublas
for _name, _value in _cublas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
except ModuleNotFoundError:
_cublas = None
try:
from .cuda import _cusolver
for _name, _value in _cusolver.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
except ModuleNotFoundError:
_cusolver = None
@ -45,14 +45,14 @@ try:
from .rocm import _hipblas
for _name, _value in _hipblas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
except ModuleNotFoundError:
_hipblas = None
try:
from .rocm import _hipsolver
for _name, _value in _hipsolver.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
except ModuleNotFoundError:
_hipsolver = None

View File

@ -27,7 +27,7 @@ from .mhlo_helpers import custom_call
try:
from .cuda import _cusparse
except ImportError:
except ModuleNotFoundError:
_cusparse = None
else:
for _name, _value in _cusparse.registrations().items():
@ -35,7 +35,7 @@ else:
try:
from .rocm import _hipsparse
except ImportError:
except ModuleNotFoundError:
_hipsparse = None
else:
for _name, _value in _hipsparse.registrations().items():

View File

@ -32,12 +32,12 @@ config.parse_flags_with_absl()
try:
import torch
import torch.utils.dlpack
except ImportError:
except ModuleNotFoundError:
torch = None
try:
import cupy
except ImportError:
except ModuleNotFoundError:
cupy = None
try:

View File

@ -29,7 +29,7 @@ from jax._src import test_util as jtu
try:
import portpicker
except ImportError:
except ModuleNotFoundError:
portpicker = None
config.parse_flags_with_absl()

View File

@ -34,7 +34,7 @@ import numpy as np
try:
import tensorflow as tf
except ImportError:
except ModuleNotFoundError:
tf = None
config.parse_flags_with_absl()

View File

@ -37,12 +37,12 @@ try:
warnings.filterwarnings('ignore', category=DeprecationWarning,
message=".*is deprecated and will be removed in Pillow 10.*")
import tensorflow as tf
except ImportError:
except ModuleNotFoundError:
tf = None
try:
from PIL import Image as PIL_Image
except ImportError:
except ModuleNotFoundError:
PIL_Image = None
config.parse_flags_with_absl()

View File

@ -22,7 +22,7 @@ from jax.config import config
try:
import tensorflow as tf
except ImportError:
except ModuleNotFoundError:
tf = None # type: ignore

View File

@ -32,7 +32,7 @@ from absl.testing import parameterized
import numpy as np
try:
import numpy_dispatch
except ImportError:
except ModuleNotFoundError:
numpy_dispatch = None
import jax

View File

@ -21,7 +21,7 @@ from absl.testing import parameterized
try:
import cloudpickle
except ImportError:
except ModuleNotFoundError:
cloudpickle = None
import jax

View File

@ -30,13 +30,13 @@ import jax._src.test_util as jtu
try:
import portpicker
except ImportError:
except ModuleNotFoundError:
portpicker = None
try:
from tensorflow.python.profiler import profiler_client
from tensorflow.python.profiler import profiler_v2 as tf_profiler
except ImportError:
except ModuleNotFoundError:
profiler_client = None
tf_profiler = None
@ -45,7 +45,7 @@ try:
import tensorboard_plugin_profile
del tensorboard_plugin_profile
TBP_ENABLED = True
except ImportError:
except ModuleNotFoundError:
pass
config.parse_flags_with_absl()