mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Catch ModuleNotFoundError instead of ImportError.
We frequently use the pattern try: import m except ImportError: # do something else. This suppresses errors when the module can be found but does not import successfully for any reason. Instead, catch only ModuleNotFoundError so missing modules are allowed but buggy modules still report errors.
This commit is contained in:
parent
fe665b3a64
commit
1e241dcf16
@ -33,7 +33,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -60,7 +60,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -235,7 +235,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -265,7 +265,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -288,7 +288,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -303,7 +303,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -319,7 +319,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -342,7 +342,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -358,7 +358,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -371,7 +371,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -384,7 +384,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -399,7 +399,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -416,7 +416,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -451,7 +451,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -466,7 +466,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -482,7 +482,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -493,7 +493,7 @@
|
||||
"# Download animation.\n",
|
||||
"try:\n",
|
||||
" from google.colab import files\n",
|
||||
"except ImportError:\n",
|
||||
"except ModuleNotFoundError:\n",
|
||||
" pass\n",
|
||||
"else:\n",
|
||||
" files.download('wave_movie.gif')"
|
||||
|
@ -39,7 +39,7 @@ def cloud_tpu_init():
|
||||
import libtpu
|
||||
# pytype: enable=import-error
|
||||
# pylint: enable=import-outside-toplevel
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
# We assume libtpu is installed iff we're in a correctly-configured Cloud
|
||||
# TPU environment. Exit early if we're not running on Cloud TPU.
|
||||
return
|
||||
|
@ -33,7 +33,7 @@ if colab_lib.IS_COLAB_ENABLED:
|
||||
try:
|
||||
import pygments
|
||||
IS_PYGMENTS_ENABLED = True
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
IS_PYGMENTS_ENABLED = False
|
||||
# pytype: enable=import-error
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
@ -34,7 +34,7 @@ from jax.config import config
|
||||
|
||||
try:
|
||||
import colorama # pytype: disable=import-error
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
colorama = None
|
||||
|
||||
def _can_use_color() -> bool:
|
||||
|
@ -24,12 +24,12 @@ import jax
|
||||
try:
|
||||
from tensorflow.python.profiler import profiler_v2 as profiler
|
||||
from tensorflow.python.profiler import profiler_client
|
||||
except ImportError:
|
||||
raise ImportError("This script requires `tensorflow` to be installed.")
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("This script requires `tensorflow` to be installed.")
|
||||
try:
|
||||
from tensorboard_plugin_profile.convert import raw_to_tool_data as convert
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"This script requires `tensorboard_plugin_profile` to be installed.")
|
||||
# pytype: enable=import-error
|
||||
|
||||
|
@ -32,7 +32,7 @@ import numpy as np
|
||||
|
||||
try:
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
tf = None
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -25,7 +25,7 @@ from jax._src.public_test_util import (
|
||||
# pytype: disable=import-error
|
||||
try:
|
||||
import jax._src.test_util as _private_test_util
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
else:
|
||||
del _private_test_util
|
||||
@ -35,7 +35,7 @@ else:
|
||||
def __getattr__(attr):
|
||||
try:
|
||||
from jax._src import test_util
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
raise AttributeError(f"module {__name__} has no attribute {attr}")
|
||||
if attr in ['cases_from_list', 'check_close', 'check_eq', 'device_under_test',
|
||||
'format_shape_dtype_string', 'rand_uniform', 'skip_on_devices',
|
||||
|
@ -78,13 +78,13 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
try:
|
||||
from jax.experimental import jax2tf
|
||||
except ImportError:
|
||||
import jax.experimental.jax2tf as jax2tf
|
||||
except ModuleNotFoundError:
|
||||
jax2tf = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
tf = None # type: ignore
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
@ -26,14 +26,14 @@ try:
|
||||
from .cuda import _cuda_linalg
|
||||
for _name, _value in _cuda_linalg.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_cuda_linalg = None
|
||||
|
||||
try:
|
||||
from .rocm import _hip_linalg
|
||||
for _name, _value in _hip_linalg.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_hip_linalg = None
|
||||
|
||||
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
||||
|
@ -28,14 +28,14 @@ try:
|
||||
from .cuda import _cuda_prng
|
||||
for _name, _value in _cuda_prng.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_cuda_prng = None
|
||||
|
||||
try:
|
||||
from .rocm import _hip_prng
|
||||
for _name, _value in _hip_prng.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_hip_prng = None
|
||||
|
||||
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
||||
|
@ -30,14 +30,14 @@ try:
|
||||
from .cuda import _cublas
|
||||
for _name, _value in _cublas.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_cublas = None
|
||||
|
||||
try:
|
||||
from .cuda import _cusolver
|
||||
for _name, _value in _cusolver.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_cusolver = None
|
||||
|
||||
|
||||
@ -45,14 +45,14 @@ try:
|
||||
from .rocm import _hipblas
|
||||
for _name, _value in _hipblas.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_hipblas = None
|
||||
|
||||
try:
|
||||
from .rocm import _hipsolver
|
||||
for _name, _value in _hipsolver.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_hipsolver = None
|
||||
|
||||
|
||||
|
@ -27,7 +27,7 @@ from .mhlo_helpers import custom_call
|
||||
|
||||
try:
|
||||
from .cuda import _cusparse
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_cusparse = None
|
||||
else:
|
||||
for _name, _value in _cusparse.registrations().items():
|
||||
@ -35,7 +35,7 @@ else:
|
||||
|
||||
try:
|
||||
from .rocm import _hipsparse
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_hipsparse = None
|
||||
else:
|
||||
for _name, _value in _hipsparse.registrations().items():
|
||||
|
@ -32,12 +32,12 @@ config.parse_flags_with_absl()
|
||||
try:
|
||||
import torch
|
||||
import torch.utils.dlpack
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
torch = None
|
||||
|
||||
try:
|
||||
import cupy
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
cupy = None
|
||||
|
||||
try:
|
||||
|
@ -29,7 +29,7 @@ from jax._src import test_util as jtu
|
||||
|
||||
try:
|
||||
import portpicker
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
portpicker = None
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -34,7 +34,7 @@ import numpy as np
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
tf = None
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -37,12 +37,12 @@ try:
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning,
|
||||
message=".*is deprecated and will be removed in Pillow 10.*")
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
tf = None
|
||||
|
||||
try:
|
||||
from PIL import Image as PIL_Image
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
PIL_Image = None
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -22,7 +22,7 @@ from jax.config import config
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
tf = None # type: ignore
|
||||
|
||||
|
||||
|
@ -32,7 +32,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
try:
|
||||
import numpy_dispatch
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
numpy_dispatch = None
|
||||
|
||||
import jax
|
||||
|
@ -21,7 +21,7 @@ from absl.testing import parameterized
|
||||
|
||||
try:
|
||||
import cloudpickle
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
cloudpickle = None
|
||||
|
||||
import jax
|
||||
|
@ -30,13 +30,13 @@ import jax._src.test_util as jtu
|
||||
|
||||
try:
|
||||
import portpicker
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
portpicker = None
|
||||
|
||||
try:
|
||||
from tensorflow.python.profiler import profiler_client
|
||||
from tensorflow.python.profiler import profiler_v2 as tf_profiler
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
profiler_client = None
|
||||
tf_profiler = None
|
||||
|
||||
@ -45,7 +45,7 @@ try:
|
||||
import tensorboard_plugin_profile
|
||||
del tensorboard_plugin_profile
|
||||
TBP_ENABLED = True
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
Loading…
x
Reference in New Issue
Block a user