2018-11-17 18:03:33 -08:00
|
|
|
# Copyright 2018 Google LLC
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2020-07-10 08:11:48 -07:00
|
|
|
# Set default logging level before any logging happens.
|
|
|
|
import os as _os
|
|
|
|
_os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
|
|
|
|
del _os
|
|
|
|
|
2021-03-05 14:57:36 -08:00
|
|
|
# Set Cloud TPU env vars if necessary before transitively loading C++ backend
|
|
|
|
from .cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init
|
|
|
|
try:
|
|
|
|
_cloud_tpu_init()
|
|
|
|
except Exception as exc:
|
|
|
|
# Defensively swallow any exceptions to avoid making jax unimportable
|
|
|
|
from warnings import warn as _warn
|
|
|
|
_warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report "
|
|
|
|
f"an issue at https://github.com/google/jax/issues")
|
|
|
|
del _warn
|
|
|
|
del _cloud_tpu_init
|
|
|
|
|
2020-06-06 10:51:34 -07:00
|
|
|
# flake8: noqa: F401
|
2021-04-19 08:52:48 -07:00
|
|
|
|
|
|
|
# Confusingly there are two things named "config": the module and the class.
|
|
|
|
# We want the exported object to be the class, so we first import the module
|
|
|
|
# to make sure a later import doesn't overwrite the class.
|
|
|
|
from . import config as _config_module
|
|
|
|
del _config_module
|
|
|
|
|
|
|
|
from ._src.config import (
|
|
|
|
config, enable_checks, check_tracer_leaks, checking_leaks,
|
|
|
|
debug_nans, debug_infs, log_compiles, default_matmul_precision,
|
|
|
|
numpy_rank_promotion
|
|
|
|
)
|
2021-04-13 09:42:54 -07:00
|
|
|
from ._src.api import (
|
2020-05-07 17:24:19 -04:00
|
|
|
ad, # TODO(phawkins): update users to avoid this.
|
|
|
|
checkpoint,
|
2021-01-26 09:44:48 -08:00
|
|
|
closure_convert,
|
2020-05-07 17:24:19 -04:00
|
|
|
curry, # TODO(phawkins): update users to avoid this.
|
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.
The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
that can be inverted correctly if only we iterate over its equations
in reverse. This is a bit strict, because users generally don't have
too much control over that, and there are functions that produce
jaxprs which will be treated as invertible when one topological
ordering of equations is used, while they will be considered
non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
that zero argument pruning in JVPTrace makes it pretty much impossible
to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
that all the VJPs of primal primitives are jittable (not sure if
that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
`custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
|
|
|
custom_ivjp,
|
2020-05-07 17:24:19 -04:00
|
|
|
custom_gradient,
|
|
|
|
custom_jvp,
|
|
|
|
custom_vjp,
|
2021-02-04 11:56:41 +00:00
|
|
|
default_backend,
|
2020-05-07 17:24:19 -04:00
|
|
|
device_count,
|
|
|
|
device_get,
|
|
|
|
device_put,
|
2020-12-04 12:53:36 -08:00
|
|
|
device_put_sharded,
|
|
|
|
device_put_replicated,
|
2020-05-07 17:24:19 -04:00
|
|
|
devices,
|
|
|
|
disable_jit,
|
|
|
|
eval_shape,
|
|
|
|
flatten_fun_nokwargs, # TODO(phawkins): update users to avoid this.
|
2020-09-24 16:29:57 +01:00
|
|
|
float0,
|
2020-05-07 17:24:19 -04:00
|
|
|
grad,
|
|
|
|
hessian,
|
|
|
|
host_count,
|
|
|
|
host_id,
|
|
|
|
host_ids,
|
Initial version of invertible AD implementation (#3232)
This is a prototype implementation of the memory-efficient VJP method
for invertible function. The general idea is that thanks to
invertibility, we don't have to memoize any intermediate primal values,
but can simply reconstruct them in lock-step with gradient computation.
The API is such that the only thing a user has to do, is decorate a
function with `@invertible`, which will make AD apply the more efficient
transpose than usual.
The current version is expressive enough to support e.g. the Reversible
ResNet, but there are still some caveats:
- The definition of "invertible" function is a one that produces a jaxpr
that can be inverted correctly if only we iterate over its equations
in reverse. This is a bit strict, because users generally don't have
too much control over that, and there are functions that produce
jaxprs which will be treated as invertible when one topological
ordering of equations is used, while they will be considered
non-invertible for other valid orderings.
- It doesn't follow the usual jvp + transpose path, and it turns out
that zero argument pruning in JVPTrace makes it pretty much impossible
to implement correctly.
- `custom_ivjp` is an initial-style primitive.
- Invertible reverse-mode implementation (`rev_backward_pass`) assumes
that all the VJPs of primal primitives are jittable (not sure if
that's a problem, but worth pointing out).
- Not having a dedicated linearization pass makes the JVP of
`custom_ivjp` inefficient if it is being staged out.
2020-06-15 12:35:06 +02:00
|
|
|
invertible,
|
2020-05-07 17:24:19 -04:00
|
|
|
jacobian,
|
|
|
|
jacfwd,
|
|
|
|
jacrev,
|
|
|
|
jit,
|
|
|
|
jvp,
|
|
|
|
local_device_count,
|
|
|
|
local_devices,
|
|
|
|
linearize,
|
2020-09-16 20:29:19 -07:00
|
|
|
linear_transpose,
|
2020-05-07 17:24:19 -04:00
|
|
|
make_jaxpr,
|
|
|
|
mask,
|
2020-11-04 21:01:42 -08:00
|
|
|
named_call,
|
2020-05-07 17:24:19 -04:00
|
|
|
partial, # TODO(phawkins): update callers to use functools.partial.
|
|
|
|
pmap,
|
2021-04-20 17:56:41 -07:00
|
|
|
process_count,
|
|
|
|
process_index,
|
2020-05-07 17:24:19 -04:00
|
|
|
pxla, # TODO(phawkins): update users to avoid this.
|
|
|
|
remat,
|
|
|
|
shapecheck,
|
|
|
|
ShapedArray,
|
|
|
|
ShapeDtypeStruct,
|
|
|
|
# TODO(phawkins): hide tree* functions from jax, update callers to use
|
|
|
|
# jax.tree_util.
|
|
|
|
treedef_is_leaf,
|
|
|
|
tree_flatten,
|
|
|
|
tree_leaves,
|
|
|
|
tree_map,
|
|
|
|
tree_multimap,
|
|
|
|
tree_structure,
|
|
|
|
tree_transpose,
|
|
|
|
tree_unflatten,
|
|
|
|
value_and_grad,
|
|
|
|
vjp,
|
|
|
|
vmap,
|
|
|
|
xla, # TODO(phawkins): update users to avoid this.
|
|
|
|
xla_computation,
|
|
|
|
)
|
2021-01-26 19:38:40 -08:00
|
|
|
from .experimental.maps import soft_pmap
|
2020-05-19 20:40:03 +01:00
|
|
|
from .version import __version__
|
2020-05-07 17:24:19 -04:00
|
|
|
|
2020-05-19 20:40:03 +01:00
|
|
|
# These submodules are separate because they are in an import cycle with
|
|
|
|
# jax and rely on the names imported above.
|
2021-04-13 09:42:54 -07:00
|
|
|
from . import api
|
2021-04-07 19:35:17 -07:00
|
|
|
from . import dtypes
|
2021-03-02 09:29:59 -08:00
|
|
|
from . import errors
|
2020-07-10 09:57:59 -04:00
|
|
|
from . import image
|
2020-05-19 20:40:03 +01:00
|
|
|
from . import lax
|
|
|
|
from . import nn
|
2020-07-12 04:44:16 +01:00
|
|
|
from . import profiler
|
2020-05-19 20:40:03 +01:00
|
|
|
from . import random
|
2021-01-11 14:20:32 -08:00
|
|
|
from . import util
|
2020-05-07 17:24:19 -04:00
|
|
|
|
|
|
|
def _init():
|
2020-05-19 20:40:03 +01:00
|
|
|
from . import numpy # side-effecting import sets up operator overloads
|
2020-05-08 10:04:19 -04:00
|
|
|
|
2020-05-07 17:24:19 -04:00
|
|
|
_init()
|
|
|
|
del _init
|