mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #11978 from hawkinsp:import
PiperOrigin-RevId: 468478116
This commit is contained in:
commit
d933c8c427
@ -33,7 +33,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -60,7 +60,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -235,7 +235,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -265,7 +265,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -288,7 +288,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -303,7 +303,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -319,7 +319,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -342,7 +342,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -358,7 +358,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -371,7 +371,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -384,7 +384,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -399,7 +399,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -416,7 +416,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -451,7 +451,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -466,7 +466,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -482,7 +482,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -493,7 +493,7 @@
|
||||
"# Download animation.\n",
|
||||
"try:\n",
|
||||
" from google.colab import files\n",
|
||||
"except ImportError:\n",
|
||||
"except ModuleNotFoundError:\n",
|
||||
" pass\n",
|
||||
"else:\n",
|
||||
" files.download('wave_movie.gif')"
|
||||
|
@ -39,7 +39,7 @@ def cloud_tpu_init():
|
||||
import libtpu
|
||||
# pytype: enable=import-error
|
||||
# pylint: enable=import-outside-toplevel
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
# We assume libtpu is installed iff we're in a correctly-configured Cloud
|
||||
# TPU environment. Exit early if we're not running on Cloud TPU.
|
||||
return
|
||||
|
@ -33,7 +33,7 @@ if colab_lib.IS_COLAB_ENABLED:
|
||||
try:
|
||||
import pygments
|
||||
IS_PYGMENTS_ENABLED = True
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
IS_PYGMENTS_ENABLED = False
|
||||
# pytype: enable=import-error
|
||||
# pylint: enable=g-import-not-at-top
|
||||
|
@ -34,7 +34,7 @@ from jax.config import config
|
||||
|
||||
try:
|
||||
import colorama # pytype: disable=import-error
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
colorama = None
|
||||
|
||||
def _can_use_color() -> bool:
|
||||
|
@ -24,12 +24,12 @@ import jax
|
||||
try:
|
||||
from tensorflow.python.profiler import profiler_v2 as profiler
|
||||
from tensorflow.python.profiler import profiler_client
|
||||
except ImportError:
|
||||
raise ImportError("This script requires `tensorflow` to be installed.")
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError("This script requires `tensorflow` to be installed.")
|
||||
try:
|
||||
from tensorboard_plugin_profile.convert import raw_to_tool_data as convert
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(
|
||||
"This script requires `tensorboard_plugin_profile` to be installed.")
|
||||
# pytype: enable=import-error
|
||||
|
||||
|
@ -32,7 +32,7 @@ import numpy as np
|
||||
|
||||
try:
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
tf = None
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -25,7 +25,7 @@ from jax._src.public_test_util import (
|
||||
# pytype: disable=import-error
|
||||
try:
|
||||
import jax._src.test_util as _private_test_util
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
else:
|
||||
del _private_test_util
|
||||
@ -35,7 +35,7 @@ else:
|
||||
def __getattr__(attr):
|
||||
try:
|
||||
from jax._src import test_util
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
raise AttributeError(f"module {__name__} has no attribute {attr}")
|
||||
if attr in ['cases_from_list', 'check_close', 'check_eq', 'device_under_test',
|
||||
'format_shape_dtype_string', 'rand_uniform', 'skip_on_devices',
|
||||
|
@ -78,13 +78,13 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
try:
|
||||
from jax.experimental import jax2tf
|
||||
except ImportError:
|
||||
import jax.experimental.jax2tf as jax2tf
|
||||
except ModuleNotFoundError:
|
||||
jax2tf = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
tf = None # type: ignore
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
@ -26,14 +26,14 @@ try:
|
||||
from .cuda import _cuda_linalg
|
||||
for _name, _value in _cuda_linalg.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_cuda_linalg = None
|
||||
|
||||
try:
|
||||
from .rocm import _hip_linalg
|
||||
for _name, _value in _hip_linalg.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_hip_linalg = None
|
||||
|
||||
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
||||
|
@ -28,14 +28,14 @@ try:
|
||||
from .cuda import _cuda_prng
|
||||
for _name, _value in _cuda_prng.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_cuda_prng = None
|
||||
|
||||
try:
|
||||
from .rocm import _hip_prng
|
||||
for _name, _value in _hip_prng.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_hip_prng = None
|
||||
|
||||
_prod = lambda xs: functools.reduce(operator.mul, xs, 1)
|
||||
|
@ -30,14 +30,14 @@ try:
|
||||
from .cuda import _cublas
|
||||
for _name, _value in _cublas.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_cublas = None
|
||||
|
||||
try:
|
||||
from .cuda import _cusolver
|
||||
for _name, _value in _cusolver.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_cusolver = None
|
||||
|
||||
|
||||
@ -45,14 +45,14 @@ try:
|
||||
from .rocm import _hipblas
|
||||
for _name, _value in _hipblas.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_hipblas = None
|
||||
|
||||
try:
|
||||
from .rocm import _hipsolver
|
||||
for _name, _value in _hipsolver.registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_hipsolver = None
|
||||
|
||||
|
||||
|
@ -27,7 +27,7 @@ from .mhlo_helpers import custom_call
|
||||
|
||||
try:
|
||||
from .cuda import _cusparse
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_cusparse = None
|
||||
else:
|
||||
for _name, _value in _cusparse.registrations().items():
|
||||
@ -35,7 +35,7 @@ else:
|
||||
|
||||
try:
|
||||
from .rocm import _hipsparse
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
_hipsparse = None
|
||||
else:
|
||||
for _name, _value in _hipsparse.registrations().items():
|
||||
|
@ -32,12 +32,12 @@ config.parse_flags_with_absl()
|
||||
try:
|
||||
import torch
|
||||
import torch.utils.dlpack
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
torch = None
|
||||
|
||||
try:
|
||||
import cupy
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
cupy = None
|
||||
|
||||
try:
|
||||
|
@ -29,7 +29,7 @@ from jax._src import test_util as jtu
|
||||
|
||||
try:
|
||||
import portpicker
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
portpicker = None
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -34,7 +34,7 @@ import numpy as np
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
tf = None
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -37,12 +37,12 @@ try:
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning,
|
||||
message=".*is deprecated and will be removed in Pillow 10.*")
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
tf = None
|
||||
|
||||
try:
|
||||
from PIL import Image as PIL_Image
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
PIL_Image = None
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -22,7 +22,7 @@ from jax.config import config
|
||||
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
tf = None # type: ignore
|
||||
|
||||
|
||||
|
@ -32,7 +32,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
try:
|
||||
import numpy_dispatch
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
numpy_dispatch = None
|
||||
|
||||
import jax
|
||||
|
@ -21,7 +21,7 @@ from absl.testing import parameterized
|
||||
|
||||
try:
|
||||
import cloudpickle
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
cloudpickle = None
|
||||
|
||||
import jax
|
||||
|
@ -30,13 +30,13 @@ import jax._src.test_util as jtu
|
||||
|
||||
try:
|
||||
import portpicker
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
portpicker = None
|
||||
|
||||
try:
|
||||
from tensorflow.python.profiler import profiler_client
|
||||
from tensorflow.python.profiler import profiler_v2 as tf_profiler
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
profiler_client = None
|
||||
tf_profiler = None
|
||||
|
||||
@ -45,7 +45,7 @@ try:
|
||||
import tensorboard_plugin_profile
|
||||
del tensorboard_plugin_profile
|
||||
TBP_ENABLED = True
|
||||
except ImportError:
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
Loading…
x
Reference in New Issue
Block a user