rocm_jax/jax/__init__.py

136 lines
4.3 KiB
Python
Raw Normal View History

2018-11-17 18:03:33 -08:00
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Set default logging level before any logging happens.
import os as _os
_os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
del _os
# Set Cloud TPU env vars if necessary before transitively loading C++ backend
from .cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init
try:
_cloud_tpu_init()
except Exception as exc:
# Defensively swallow any exceptions to avoid making jax unimportable
from warnings import warn as _warn
_warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report "
f"an issue at https://github.com/google/jax/issues")
del _warn
del _cloud_tpu_init
2020-06-06 10:51:34 -07:00
# flake8: noqa: F401
# Confusingly there are two things named "config": the module and the class.
# We want the exported object to be the class, so we first import the module
# to make sure a later import doesn't overwrite the class.
from . import config as _config_module
del _config_module
from ._src.config import (
config as config,
enable_checks as enable_checks,
check_tracer_leaks as check_tracer_leaks,
checking_leaks as checking_leaks,
enable_custom_prng as enable_custom_prng,
debug_nans as debug_nans,
debug_infs as debug_infs,
log_compiles as log_compiles,
default_matmul_precision as default_matmul_precision,
default_prng_impl as default_prng_impl,
numpy_rank_promotion as numpy_rank_promotion,
)
from ._src.api import (
ad, # TODO(phawkins): update users to avoid this.
checkpoint as checkpoint,
checkpoint_policies as checkpoint_policies,
closure_convert as closure_convert,
curry, # TODO(phawkins): update users to avoid this.
custom_ivjp as custom_ivjp,
custom_gradient as custom_gradient,
custom_jvp as custom_jvp,
custom_vjp as custom_vjp,
default_backend as default_backend,
device_count as device_count,
device_get as device_get,
device_put as device_put,
device_put_sharded as device_put_sharded,
device_put_replicated as device_put_replicated,
devices as devices,
disable_jit as disable_jit,
eval_shape as eval_shape,
flatten_fun_nokwargs, # TODO(phawkins): update users to avoid this.
float0 as float0,
grad as grad,
hessian as hessian,
host_count as host_count,
host_id as host_id,
host_ids as host_ids,
invertible as invertible,
jacobian as jacobian,
jacfwd as jacfwd,
jacrev as jacrev,
jit as jit,
jvp as jvp,
local_device_count as local_device_count,
local_devices as local_devices,
linearize as linearize,
linear_transpose as linear_transpose,
make_jaxpr as make_jaxpr,
mask as mask,
named_call as named_call,
pmap as pmap,
process_count as process_count,
process_index as process_index,
pxla, # TODO(phawkins): update users to avoid this.
remat as remat,
shapecheck as shapecheck,
ShapedArray as ShapedArray,
ShapeDtypeStruct as ShapeDtypeStruct,
# TODO(phawkins): hide tree* functions from jax, update callers to use
# jax.tree_util.
treedef_is_leaf,
tree_flatten,
tree_leaves,
tree_map,
tree_multimap,
tree_structure,
tree_transpose,
tree_unflatten,
2021-09-14 13:55:55 -07:00
value_and_grad as value_and_grad,
vjp as vjp,
vmap as vmap,
xla, # TODO(phawkins): update users to avoid this.
2021-09-14 13:55:55 -07:00
xla_computation as xla_computation,
)
from .experimental.maps import soft_pmap as soft_pmap
from .version import __version__ as __version__
# These submodules are separate because they are in an import cycle with
# jax and rely on the names imported above.
from . import abstract_arrays as abstract_arrays
from . import api_util as api_util
from . import dtypes as dtypes
from . import errors as errors
from . import image as image
from . import lax as lax
from . import nn as nn
from . import numpy as numpy
from . import ops as ops
from . import profiler as profiler
from . import random as random
from . import tree_util as tree_util
from . import util as util
import jax.lib # TODO(phawkins): remove this export.