2022-09-22 12:26:48 -07:00
|
|
|
# Copyright 2018 The JAX Authors.
|
2018-11-17 18:03:33 -08:00
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
Add `jax_debug_log_modules` config option.
This can be used to enable debug logging for specific files
(e.g. `JAX_DEBUG_LOG_MODULES="jax._src.xla_bridge,jax._src.dispatch"`)
or all jax (`JAX_DEBUG_LOG_MODULES="jax"`).
Example output:
```
$ JAX_DEBUG_LOG_MODULES=jax python3 -c "import jax; jax.numpy.add(1,1)"
DEBUG:2023-06-07 00:27:57,399:jax._src.xla_bridge:352: No jax_plugins namespace packages available
DEBUG:2023-06-07 00:27:57,488:jax._src.path:29: etils.epath found. Using etils.epath for file I/O.
DEBUG:2023-06-07 00:27:57,663:jax._src.dispatch:272: Finished tracing + transforming fn for pjit in 0.0005719661712646484 sec
DEBUG:2023-06-07 00:27:57,664:jax._src.xla_bridge:590: Initializing backend 'tpu'
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:602: Backend 'tpu' initialized
DEBUG:2023-06-07 00:28:00,502:jax._src.xla_bridge:590: Initializing backend 'cpu'
DEBUG:2023-06-07 00:28:00,542:jax._src.xla_bridge:602: Backend 'cpu' initialized
DEBUG:2023-06-07 00:28:00,544:jax._src.interpreters.pxla:1890: Compiling fn for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:2023-06-07 00:28:00,547:jax._src.dispatch:272: Finished jaxpr to MLIR module conversion jit(fn) in 0.0023522377014160156 sec
DEBUG:2023-06-07 00:28:00,547:jax._src.xla_bridge:140: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]]
DEBUG:2023-06-07 00:28:00,571:jax._src.dispatch:272: Finished XLA compilation of jit(fn) in 0.023587703704833984 sec
```
2023-06-07 00:20:32 +00:00
|
|
|
# Set default C++ logging level before any logging happens.
|
2020-07-10 08:11:48 -07:00
|
|
|
import os as _os
|
|
|
|
_os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
|
|
|
|
del _os
|
|
|
|
|
2023-09-12 12:27:20 -07:00
|
|
|
# Import version first, because other submodules may reference it.
|
|
|
|
from jax.version import __version__ as __version__
|
|
|
|
from jax.version import __version_info__ as __version_info__
|
|
|
|
|
2021-03-05 14:57:36 -08:00
|
|
|
# Set Cloud TPU env vars if necessary before transitively loading C++ backend
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax._src.cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init
|
2021-03-05 14:57:36 -08:00
|
|
|
try:
|
|
|
|
_cloud_tpu_init()
|
|
|
|
except Exception as exc:
|
|
|
|
# Defensively swallow any exceptions to avoid making jax unimportable
|
|
|
|
from warnings import warn as _warn
|
2023-10-23 15:11:15 +01:00
|
|
|
_warn(f"cloud_tpu_init failed: {exc!r}\n This a JAX bug; please report "
|
2021-03-05 14:57:36 -08:00
|
|
|
f"an issue at https://github.com/google/jax/issues")
|
|
|
|
del _warn
|
|
|
|
del _cloud_tpu_init
|
|
|
|
|
2023-02-14 23:00:40 -08:00
|
|
|
# Force early import, allowing use of `jax.core` after importing `jax`.
|
|
|
|
import jax.core as _core
|
|
|
|
del _core
|
|
|
|
|
2022-12-14 15:07:04 -08:00
|
|
|
# Note: import <name> as <name> is required for names to be exported.
|
|
|
|
# See PEP 484 & https://github.com/google/jax/issues/7570
|
|
|
|
|
2022-09-23 09:59:46 -07:00
|
|
|
from jax._src.basearray import Array as Array
|
2024-02-12 13:07:59 -08:00
|
|
|
from jax import tree as tree
|
2023-02-15 14:52:31 -08:00
|
|
|
from jax import typing as typing
|
2022-09-23 09:59:46 -07:00
|
|
|
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax._src.config import (
|
2021-08-30 14:35:22 -07:00
|
|
|
config as config,
|
|
|
|
enable_checks as enable_checks,
|
2024-03-21 10:47:16 -07:00
|
|
|
debug_key_reuse as debug_key_reuse,
|
2021-08-30 14:35:22 -07:00
|
|
|
check_tracer_leaks as check_tracer_leaks,
|
|
|
|
checking_leaks as checking_leaks,
|
|
|
|
enable_custom_prng as enable_custom_prng,
|
2023-04-19 18:11:35 -07:00
|
|
|
softmax_custom_jvp as softmax_custom_jvp,
|
2022-03-28 17:17:33 -07:00
|
|
|
enable_custom_vjp_by_custom_transpose as enable_custom_vjp_by_custom_transpose,
|
2021-08-30 14:35:22 -07:00
|
|
|
debug_nans as debug_nans,
|
|
|
|
debug_infs as debug_infs,
|
|
|
|
log_compiles as log_compiles,
|
2024-08-23 21:21:55 +00:00
|
|
|
no_tracing as no_tracing,
|
2024-07-11 01:11:18 +00:00
|
|
|
explain_cache_misses as explain_cache_misses,
|
2022-06-02 10:33:53 -07:00
|
|
|
default_device as default_device,
|
2021-08-30 14:35:22 -07:00
|
|
|
default_matmul_precision as default_matmul_precision,
|
2021-10-07 19:15:43 -07:00
|
|
|
default_prng_impl as default_prng_impl,
|
2022-05-26 10:56:09 -07:00
|
|
|
numpy_dtype_promotion as numpy_dtype_promotion,
|
2021-08-30 14:35:22 -07:00
|
|
|
numpy_rank_promotion as numpy_rank_promotion,
|
2022-02-14 13:11:26 -08:00
|
|
|
jax2tf_associative_scan_reductions as jax2tf_associative_scan_reductions,
|
2023-08-22 15:08:51 -07:00
|
|
|
legacy_prng_key as legacy_prng_key,
|
2023-10-31 13:42:45 -07:00
|
|
|
threefry_partitionable as threefry_partitionable,
|
2022-02-14 13:11:26 -08:00
|
|
|
transfer_guard as transfer_guard,
|
|
|
|
transfer_guard_host_to_device as transfer_guard_host_to_device,
|
|
|
|
transfer_guard_device_to_device as transfer_guard_device_to_device,
|
|
|
|
transfer_guard_device_to_host as transfer_guard_device_to_host,
|
2022-10-31 09:46:46 -07:00
|
|
|
spmd_mode as spmd_mode,
|
2021-04-19 08:52:48 -07:00
|
|
|
)
|
2022-12-21 20:12:08 -05:00
|
|
|
from jax._src.core import ensure_compile_time_eval as ensure_compile_time_eval
|
2022-09-12 15:39:33 -07:00
|
|
|
from jax._src.environment_info import print_environment_info as print_environment_info
|
2023-03-01 09:19:06 -08:00
|
|
|
|
|
|
|
from jax._src.lib import xla_client as _xc
|
|
|
|
Device = _xc.Device
|
|
|
|
del _xc
|
|
|
|
|
2023-03-28 12:40:59 -07:00
|
|
|
from jax._src.api import effects_barrier as effects_barrier
|
2023-03-01 09:19:06 -08:00
|
|
|
from jax._src.api import block_until_ready as block_until_ready
|
2023-04-04 11:41:00 -07:00
|
|
|
from jax._src.ad_checkpoint import checkpoint_wrapper as checkpoint
|
2023-03-01 09:19:06 -08:00
|
|
|
from jax._src.ad_checkpoint import checkpoint_policies as checkpoint_policies
|
2024-03-18 14:22:35 -07:00
|
|
|
from jax._src.api import clear_backends as _deprecated_clear_backends
|
2023-04-07 12:09:26 -07:00
|
|
|
from jax._src.api import clear_caches as clear_caches
|
2023-03-01 09:19:06 -08:00
|
|
|
from jax._src.custom_derivatives import closure_convert as closure_convert
|
|
|
|
from jax._src.custom_derivatives import custom_gradient as custom_gradient
|
|
|
|
from jax._src.custom_derivatives import custom_jvp as custom_jvp
|
|
|
|
from jax._src.custom_derivatives import custom_vjp as custom_vjp
|
|
|
|
from jax._src.xla_bridge import default_backend as default_backend
|
|
|
|
from jax._src.xla_bridge import device_count as device_count
|
|
|
|
from jax._src.api import device_get as device_get
|
|
|
|
from jax._src.api import device_put as device_put
|
|
|
|
from jax._src.api import device_put_sharded as device_put_sharded
|
|
|
|
from jax._src.api import device_put_replicated as device_put_replicated
|
|
|
|
from jax._src.xla_bridge import devices as devices
|
|
|
|
from jax._src.api import disable_jit as disable_jit
|
|
|
|
from jax._src.api import eval_shape as eval_shape
|
|
|
|
from jax._src.dtypes import float0 as float0
|
|
|
|
from jax._src.api import grad as grad
|
|
|
|
from jax._src.api import hessian as hessian
|
|
|
|
from jax._src.xla_bridge import host_count as host_count
|
|
|
|
from jax._src.xla_bridge import host_id as host_id
|
|
|
|
from jax._src.xla_bridge import host_ids as host_ids
|
|
|
|
from jax._src.api import jacobian as jacobian
|
|
|
|
from jax._src.api import jacfwd as jacfwd
|
|
|
|
from jax._src.api import jacrev as jacrev
|
|
|
|
from jax._src.api import jit as jit
|
|
|
|
from jax._src.api import jvp as jvp
|
|
|
|
from jax._src.xla_bridge import local_device_count as local_device_count
|
|
|
|
from jax._src.xla_bridge import local_devices as local_devices
|
|
|
|
from jax._src.api import linearize as linearize
|
|
|
|
from jax._src.api import linear_transpose as linear_transpose
|
|
|
|
from jax._src.api import live_arrays as live_arrays
|
|
|
|
from jax._src.api import make_jaxpr as make_jaxpr
|
|
|
|
from jax._src.api import named_call as named_call
|
|
|
|
from jax._src.api import named_scope as named_scope
|
|
|
|
from jax._src.api import pmap as pmap
|
|
|
|
from jax._src.xla_bridge import process_count as process_count
|
|
|
|
from jax._src.xla_bridge import process_index as process_index
|
2024-08-12 10:29:15 -07:00
|
|
|
from jax._src.xla_bridge import process_indices as process_indices
|
2024-04-25 10:21:49 -07:00
|
|
|
from jax._src.callback import pure_callback as pure_callback
|
2023-04-04 11:41:00 -07:00
|
|
|
from jax._src.ad_checkpoint import checkpoint_wrapper as remat
|
2023-03-01 09:19:06 -08:00
|
|
|
from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct
|
|
|
|
from jax._src.api import value_and_grad as value_and_grad
|
|
|
|
from jax._src.api import vjp as vjp
|
|
|
|
from jax._src.api import vmap as vmap
|
2024-04-11 16:23:59 -07:00
|
|
|
from jax._src.sharding_impls import NamedSharding as NamedSharding
|
2024-09-03 14:30:37 -07:00
|
|
|
from jax._src.sharding_impls import make_mesh as make_mesh
|
2022-09-27 10:06:10 -07:00
|
|
|
|
2023-07-11 12:42:32 -07:00
|
|
|
# Force import, allowing jax.interpreters.* to be used after import jax.
|
|
|
|
from jax.interpreters import ad, batching, mlir, partial_eval, pxla, xla
|
|
|
|
del ad, batching, mlir, partial_eval, pxla, xla
|
2023-02-09 13:31:04 -08:00
|
|
|
|
2022-09-27 10:06:10 -07:00
|
|
|
from jax._src.array import (
|
|
|
|
make_array_from_single_device_arrays as make_array_from_single_device_arrays,
|
|
|
|
make_array_from_callback as make_array_from_callback,
|
2024-05-15 22:06:11 -07:00
|
|
|
make_array_from_process_local_data as make_array_from_process_local_data,
|
2022-09-27 10:06:10 -07:00
|
|
|
)
|
|
|
|
|
2022-07-07 11:41:46 -07:00
|
|
|
from jax._src.tree_util import (
|
2024-02-28 14:28:17 -08:00
|
|
|
tree_map as _deprecated_tree_map,
|
2023-07-20 12:58:17 -07:00
|
|
|
treedef_is_leaf as _deprecated_treedef_is_leaf,
|
|
|
|
tree_flatten as _deprecated_tree_flatten,
|
|
|
|
tree_leaves as _deprecated_tree_leaves,
|
|
|
|
tree_structure as _deprecated_tree_structure,
|
|
|
|
tree_transpose as _deprecated_tree_transpose,
|
|
|
|
tree_unflatten as _deprecated_tree_unflatten,
|
2022-07-07 11:41:46 -07:00
|
|
|
)
|
|
|
|
|
2020-05-19 20:40:03 +01:00
|
|
|
# These submodules are separate because they are in an import cycle with
|
|
|
|
# jax and rely on the names imported above.
|
2023-04-04 16:38:53 -07:00
|
|
|
from jax import custom_derivatives as custom_derivatives
|
|
|
|
from jax import custom_batching as custom_batching
|
|
|
|
from jax import custom_transpose as custom_transpose
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax import api_util as api_util
|
|
|
|
from jax import distributed as distributed
|
2022-07-26 14:47:36 -07:00
|
|
|
from jax import debug as debug
|
2023-08-28 10:13:16 +05:00
|
|
|
from jax import dlpack as dlpack
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax import dtypes as dtypes
|
|
|
|
from jax import errors as errors
|
|
|
|
from jax import image as image
|
|
|
|
from jax import lax as lax
|
2023-04-06 17:03:10 -07:00
|
|
|
from jax import monitoring as monitoring
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax import nn as nn
|
|
|
|
from jax import numpy as numpy
|
|
|
|
from jax import ops as ops
|
|
|
|
from jax import profiler as profiler
|
|
|
|
from jax import random as random
|
2022-10-06 01:43:54 +00:00
|
|
|
from jax import scipy as scipy
|
2022-09-27 10:06:10 -07:00
|
|
|
from jax import sharding as sharding
|
2022-03-14 19:38:23 -07:00
|
|
|
from jax import stages as stages
|
2021-11-24 07:47:48 -08:00
|
|
|
from jax import tree_util as tree_util
|
|
|
|
from jax import util as util
|
2020-05-07 17:24:19 -04:00
|
|
|
|
2023-02-22 15:37:25 -08:00
|
|
|
# Also circular dependency.
|
|
|
|
from jax._src.array import Shard as Shard
|
|
|
|
|
2023-04-18 08:19:14 -07:00
|
|
|
import jax.experimental.compilation_cache.compilation_cache as _ccache
|
2023-04-17 21:28:06 -07:00
|
|
|
del _ccache
|
2023-03-17 12:57:18 -07:00
|
|
|
|
2023-03-28 12:40:59 -07:00
|
|
|
_deprecations = {
|
2023-07-20 12:58:17 -07:00
|
|
|
# Added July 2022
|
|
|
|
"treedef_is_leaf": (
|
|
|
|
"jax.treedef_is_leaf is deprecated: use jax.tree_util.treedef_is_leaf.",
|
|
|
|
_deprecated_treedef_is_leaf
|
|
|
|
),
|
|
|
|
"tree_flatten": (
|
2024-02-22 11:35:39 -08:00
|
|
|
"jax.tree_flatten is deprecated: use jax.tree.flatten (jax v0.4.25 or newer) "
|
|
|
|
"or jax.tree_util.tree_flatten (any JAX version).",
|
2023-07-20 12:58:17 -07:00
|
|
|
_deprecated_tree_flatten
|
|
|
|
),
|
|
|
|
"tree_leaves": (
|
2024-02-22 11:35:39 -08:00
|
|
|
"jax.tree_leaves is deprecated: use jax.tree.leaves (jax v0.4.25 or newer) "
|
|
|
|
"or jax.tree_util.tree_leaves (any JAX version).",
|
2023-07-20 12:58:17 -07:00
|
|
|
_deprecated_tree_leaves
|
|
|
|
),
|
|
|
|
"tree_structure": (
|
2024-02-22 11:35:39 -08:00
|
|
|
"jax.tree_structure is deprecated: use jax.tree.structure (jax v0.4.25 or newer) "
|
|
|
|
"or jax.tree_util.tree_structure (any JAX version).",
|
2023-07-20 12:58:17 -07:00
|
|
|
_deprecated_tree_structure
|
|
|
|
),
|
|
|
|
"tree_transpose": (
|
2024-02-22 11:35:39 -08:00
|
|
|
"jax.tree_transpose is deprecated: use jax.tree.transpose (jax v0.4.25 or newer) "
|
|
|
|
"or jax.tree_util.tree_transpose (any JAX version).",
|
2023-07-20 12:58:17 -07:00
|
|
|
_deprecated_tree_transpose
|
|
|
|
),
|
|
|
|
"tree_unflatten": (
|
2024-02-22 11:35:39 -08:00
|
|
|
"jax.tree_unflatten is deprecated: use jax.tree.unflatten (jax v0.4.25 or newer) "
|
|
|
|
"or jax.tree_util.tree_unflatten (any JAX version).",
|
2023-07-20 12:58:17 -07:00
|
|
|
_deprecated_tree_unflatten
|
2023-08-30 15:14:47 -07:00
|
|
|
),
|
2024-02-28 14:28:17 -08:00
|
|
|
# Added Feb 28, 2024
|
|
|
|
"tree_map": (
|
|
|
|
"jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) "
|
|
|
|
"or jax.tree_util.tree_map (any JAX version).",
|
|
|
|
_deprecated_tree_map
|
|
|
|
),
|
2024-03-18 14:22:35 -07:00
|
|
|
# Added Mar 18, 2024
|
|
|
|
"clear_backends": (
|
|
|
|
"jax.clear_backends is deprecated.",
|
|
|
|
_deprecated_clear_backends
|
|
|
|
),
|
2024-09-12 11:47:03 -07:00
|
|
|
# Remove after jax 0.4.35 release.
|
2024-06-17 13:01:37 -07:00
|
|
|
"xla_computation": (
|
2024-09-12 11:47:03 -07:00
|
|
|
"jax.xla_computation is deleted. Please use the AOT APIs; see "
|
2024-07-01 22:21:10 +00:00
|
|
|
"https://jax.readthedocs.io/en/latest/aot.html. For example, replace "
|
|
|
|
"xla_computation(f)(*xs) with jit(f).lower(*xs).compiler_ir('hlo'). See "
|
2024-09-12 11:47:03 -07:00
|
|
|
"CHANGELOG.md for 0.4.30 for more examples.", None
|
2024-06-17 13:01:37 -07:00
|
|
|
),
|
2023-03-28 12:40:59 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
import typing as _typing
|
|
|
|
if _typing.TYPE_CHECKING:
|
2024-03-18 14:22:35 -07:00
|
|
|
from jax._src.api import clear_backends as clear_backends
|
2023-07-20 12:58:17 -07:00
|
|
|
from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf
|
|
|
|
from jax._src.tree_util import tree_flatten as tree_flatten
|
|
|
|
from jax._src.tree_util import tree_leaves as tree_leaves
|
2024-02-28 14:28:17 -08:00
|
|
|
from jax._src.tree_util import tree_map as tree_map
|
2023-07-20 12:58:17 -07:00
|
|
|
from jax._src.tree_util import tree_structure as tree_structure
|
|
|
|
from jax._src.tree_util import tree_transpose as tree_transpose
|
|
|
|
from jax._src.tree_util import tree_unflatten as tree_unflatten
|
|
|
|
|
2023-03-28 12:40:59 -07:00
|
|
|
else:
|
|
|
|
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
|
|
|
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
|
|
|
del _deprecation_getattr
|
|
|
|
del _typing
|
|
|
|
|
2021-09-23 06:33:25 -07:00
|
|
|
import jax.lib # TODO(phawkins): remove this export.
|
2022-09-09 07:05:30 -07:00
|
|
|
|
2023-02-25 07:17:18 -08:00
|
|
|
# trailer
|