Merge pull request #11978 from hawkinsp:import

PiperOrigin-RevId: 468478116
This commit is contained in:
jax authors 2022-08-18 09:34:57 -07:00
commit d933c8c427
20 changed files with 52 additions and 52 deletions

View File

@ -33,7 +33,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -60,7 +60,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -235,7 +235,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -265,7 +265,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -288,7 +288,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -303,7 +303,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -319,7 +319,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -342,7 +342,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -358,7 +358,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -371,7 +371,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -384,7 +384,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -399,7 +399,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -416,7 +416,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -451,7 +451,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -466,7 +466,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -482,7 +482,7 @@
},
{
"cell_type": "code",
"execution_count": 0,
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
@ -493,7 +493,7 @@
"# Download animation.\n",
"try:\n",
" from google.colab import files\n",
"except ImportError:\n",
"except ModuleNotFoundError:\n",
" pass\n",
"else:\n",
" files.download('wave_movie.gif')"

View File

@ -39,7 +39,7 @@ def cloud_tpu_init():
import libtpu
# pytype: enable=import-error
# pylint: enable=import-outside-toplevel
except ImportError:
except ModuleNotFoundError:
# We assume libtpu is installed iff we're in a correctly-configured Cloud
# TPU environment. Exit early if we're not running on Cloud TPU.
return

View File

@ -33,7 +33,7 @@ if colab_lib.IS_COLAB_ENABLED:
try:
import pygments
IS_PYGMENTS_ENABLED = True
except ImportError:
except ModuleNotFoundError:
IS_PYGMENTS_ENABLED = False
# pytype: enable=import-error
# pylint: enable=g-import-not-at-top

View File

@ -34,7 +34,7 @@ from jax.config import config
try:
import colorama # pytype: disable=import-error
except ImportError:
except ModuleNotFoundError:
colorama = None
def _can_use_color() -> bool:

View File

@ -24,12 +24,12 @@ import jax
try:
from tensorflow.python.profiler import profiler_v2 as profiler
from tensorflow.python.profiler import profiler_client
except ImportError:
raise ImportError("This script requires `tensorflow` to be installed.")
except ModuleNotFoundError:
raise ModuleNotFoundError("This script requires `tensorflow` to be installed.")
try:
from tensorboard_plugin_profile.convert import raw_to_tool_data as convert
except ImportError:
raise ImportError(
except ModuleNotFoundError:
raise ModuleNotFoundError(
"This script requires `tensorboard_plugin_profile` to be installed.")
# pytype: enable=import-error

View File

@ -32,7 +32,7 @@ import numpy as np
try:
import tensorflow as tf # type: ignore[import]
except ImportError:
except ModuleNotFoundError:
tf = None
config.parse_flags_with_absl()

View File

@ -25,7 +25,7 @@ from jax._src.public_test_util import (
# pytype: disable=import-error
try:
import jax._src.test_util as _private_test_util
except ImportError:
except ModuleNotFoundError:
pass
else:
del _private_test_util
@ -35,7 +35,7 @@ else:
def __getattr__(attr):
try:
from jax._src import test_util
except ImportError:
except ModuleNotFoundError:
raise AttributeError(f"module {__name__} has no attribute {attr}")
if attr in ['cases_from_list', 'check_close', 'check_eq', 'device_under_test',
'format_shape_dtype_string', 'rand_uniform', 'skip_on_devices',

View File

@ -78,13 +78,13 @@ import jax
import jax.numpy as jnp
try:
from jax.experimental import jax2tf
except ImportError:
import jax.experimental.jax2tf as jax2tf
except ModuleNotFoundError:
jax2tf = None # type: ignore[assignment]
try:
import tensorflow as tf
except ImportError:
except ModuleNotFoundError:
tf = None # type: ignore
FLAGS = flags.FLAGS

View File

@ -26,14 +26,14 @@ try:
from .cuda import _cuda_linalg
for _name, _value in _cuda_linalg.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
except ModuleNotFoundError:
_cuda_linalg = None
try:
from .rocm import _hip_linalg
for _name, _value in _hip_linalg.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
except ModuleNotFoundError:
_hip_linalg = None
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)

View File

@ -28,14 +28,14 @@ try:
from .cuda import _cuda_prng
for _name, _value in _cuda_prng.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
except ModuleNotFoundError:
_cuda_prng = None
try:
from .rocm import _hip_prng
for _name, _value in _hip_prng.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
except ModuleNotFoundError:
_hip_prng = None
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)

View File

@ -30,14 +30,14 @@ try:
from .cuda import _cublas
for _name, _value in _cublas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
except ModuleNotFoundError:
_cublas = None
try:
from .cuda import _cusolver
for _name, _value in _cusolver.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
except ModuleNotFoundError:
_cusolver = None
@ -45,14 +45,14 @@ try:
from .rocm import _hipblas
for _name, _value in _hipblas.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
except ModuleNotFoundError:
_hipblas = None
try:
from .rocm import _hipsolver
for _name, _value in _hipsolver.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
except ImportError:
except ModuleNotFoundError:
_hipsolver = None

View File

@ -27,7 +27,7 @@ from .mhlo_helpers import custom_call
try:
from .cuda import _cusparse
except ImportError:
except ModuleNotFoundError:
_cusparse = None
else:
for _name, _value in _cusparse.registrations().items():
@ -35,7 +35,7 @@ else:
try:
from .rocm import _hipsparse
except ImportError:
except ModuleNotFoundError:
_hipsparse = None
else:
for _name, _value in _hipsparse.registrations().items():

View File

@ -32,12 +32,12 @@ config.parse_flags_with_absl()
try:
import torch
import torch.utils.dlpack
except ImportError:
except ModuleNotFoundError:
torch = None
try:
import cupy
except ImportError:
except ModuleNotFoundError:
cupy = None
try:

View File

@ -29,7 +29,7 @@ from jax._src import test_util as jtu
try:
import portpicker
except ImportError:
except ModuleNotFoundError:
portpicker = None
config.parse_flags_with_absl()

View File

@ -34,7 +34,7 @@ import numpy as np
try:
import tensorflow as tf
except ImportError:
except ModuleNotFoundError:
tf = None
config.parse_flags_with_absl()

View File

@ -37,12 +37,12 @@ try:
warnings.filterwarnings('ignore', category=DeprecationWarning,
message=".*is deprecated and will be removed in Pillow 10.*")
import tensorflow as tf
except ImportError:
except ModuleNotFoundError:
tf = None
try:
from PIL import Image as PIL_Image
except ImportError:
except ModuleNotFoundError:
PIL_Image = None
config.parse_flags_with_absl()

View File

@ -22,7 +22,7 @@ from jax.config import config
try:
import tensorflow as tf
except ImportError:
except ModuleNotFoundError:
tf = None # type: ignore

View File

@ -32,7 +32,7 @@ from absl.testing import parameterized
import numpy as np
try:
import numpy_dispatch
except ImportError:
except ModuleNotFoundError:
numpy_dispatch = None
import jax

View File

@ -21,7 +21,7 @@ from absl.testing import parameterized
try:
import cloudpickle
except ImportError:
except ModuleNotFoundError:
cloudpickle = None
import jax

View File

@ -30,13 +30,13 @@ import jax._src.test_util as jtu
try:
import portpicker
except ImportError:
except ModuleNotFoundError:
portpicker = None
try:
from tensorflow.python.profiler import profiler_client
from tensorflow.python.profiler import profiler_v2 as tf_profiler
except ImportError:
except ModuleNotFoundError:
profiler_client = None
tf_profiler = None
@ -45,7 +45,7 @@ try:
import tensorboard_plugin_profile
del tensorboard_plugin_profile
TBP_ENABLED = True
except ImportError:
except ModuleNotFoundError:
pass
config.parse_flags_with_absl()