Use imports relative to the jax package consistently, rather than .-relative imports.

This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types.

PiperOrigin-RevId: 412057514
This commit is contained in:
Peter Hawkins 2021-11-24 07:47:48 -08:00 committed by jax authors
parent b10a306266
commit 4e21922055
47 changed files with 287 additions and 281 deletions

View File

@ -18,7 +18,7 @@ _os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
del _os
# Set Cloud TPU env vars if necessary before transitively loading C++ backend
from .cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init
from jax._src.cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init
try:
_cloud_tpu_init()
except Exception as exc:
@ -34,10 +34,10 @@ del _cloud_tpu_init
# Confusingly there are two things named "config": the module and the class.
# We want the exported object to be the class, so we first import the module
# to make sure a later import doesn't overwrite the class.
from . import config as _config_module
from jax import config as _config_module
del _config_module
from ._src.config import (
from jax._src.config import (
config as config,
enable_checks as enable_checks,
check_tracer_leaks as check_tracer_leaks,
@ -50,7 +50,7 @@ from ._src.config import (
default_prng_impl as default_prng_impl,
numpy_rank_promotion as numpy_rank_promotion,
)
from ._src.api import (
from jax._src.api import (
ad, # TODO(phawkins): update users to avoid this.
checkpoint as checkpoint,
checkpoint_policies as checkpoint_policies,
@ -115,24 +115,24 @@ from ._src.api import (
xla, # TODO(phawkins): update users to avoid this.
xla_computation as xla_computation,
)
from .experimental.maps import soft_pmap as soft_pmap
from .version import __version__ as __version__
from jax.experimental.maps import soft_pmap as soft_pmap
from jax.version import __version__ as __version__
# These submodules are separate because they are in an import cycle with
# jax and rely on the names imported above.
from . import abstract_arrays as abstract_arrays
from . import api_util as api_util
from . import distributed as distributed
from . import dtypes as dtypes
from . import errors as errors
from . import image as image
from . import lax as lax
from . import nn as nn
from . import numpy as numpy
from . import ops as ops
from . import profiler as profiler
from . import random as random
from . import tree_util as tree_util
from . import util as util
from jax import abstract_arrays as abstract_arrays
from jax import api_util as api_util
from jax import distributed as distributed
from jax import dtypes as dtypes
from jax import errors as errors
from jax import image as image
from jax import lax as lax
from jax import nn as nn
from jax import numpy as numpy
from jax import ops as ops
from jax import profiler as profiler
from jax import random as random
from jax import tree_util as tree_util
from jax import util as util
import jax.lib # TODO(phawkins): remove this export.

View File

@ -39,23 +39,24 @@ import numpy as np
from contextlib import contextmanager, ExitStack
import jax
from .. import core
from .. import linear_util as lu
from . import dtypes
from ..core import eval_jaxpr
from .api_util import (flatten_fun, apply_flat_fun, flatten_fun_nokwargs,
flatten_fun_nokwargs2, argnums_partial,
argnums_partial_except, flatten_axes, donation_vector,
from jax import core
from jax import linear_util as lu
from jax._src import dtypes
from jax.core import eval_jaxpr
from jax._src.api_util import (
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
shaped_abstractify, _ensure_str_tuple,
argnames_partial_except)
from . import traceback_util
from .traceback_util import api_boundary
from ..tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
tree_transpose, tree_leaves, tree_multimap,
treedef_is_leaf, treedef_children, Partial, PyTreeDef)
from .util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
extend_name_stack, wrap_name, cache, wraps, HashableFunction)
shaped_abstractify, _ensure_str_tuple, argnames_partial_except)
from jax._src import traceback_util
from jax._src.traceback_util import api_boundary
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_structure, tree_transpose, tree_leaves,
tree_multimap, treedef_is_leaf, treedef_children,
Partial, PyTreeDef)
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
extend_name_stack, wrap_name, cache, wraps,
HashableFunction)
from jax._src import device_array
from jax._src import dispatch
from jax._src.lib import jax_jit
@ -65,22 +66,24 @@ from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
# Unused imports to be exported
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
local_devices, process_index, process_count,
host_id, host_ids, host_count, default_backend)
from ..core import ConcreteArray, ShapedArray, raise_to_shaped
from ..interpreters import partial_eval as pe
from ..interpreters import xla
from ..interpreters import pxla
from ..interpreters import ad
from ..interpreters import batching
from ..interpreters import masking
from ..interpreters import invertible_ad as iad
from ..interpreters.invertible_ad import custom_ivjp
from ..custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
local_devices, process_index,
process_count, host_id, host_ids,
host_count, default_backend)
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import masking
from jax.interpreters import invertible_ad as iad
from jax.interpreters.invertible_ad import custom_ivjp
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
custom_vjp, linear_call)
from ..ad_checkpoint import checkpoint_policies
from jax.ad_checkpoint import checkpoint_policies
from .._src.config import (flags, config, bool_env, disable_jit as _disable_jit,
from jax._src.config import (flags, config, bool_env,
disable_jit as _disable_jit,
debug_nans as config_debug_nans,
debug_infs as config_debug_infs,
_thread_local_state as config_thread_local_state)

View File

@ -18,16 +18,17 @@ from typing import Any, Dict, Iterable, Tuple, Union, Optional
import numpy as np
from .. import core
from . import dtypes
from .tree_util import (PyTreeDef, tree_flatten, tree_unflatten, tree_multimap,
tree_structure, treedef_children, treedef_is_leaf)
from .tree_util import _replace_nones
from .. import linear_util as lu
from .util import safe_map, WrapKwArgs, Hashable, Unhashable
from ..core import unit
from jax import core
from jax._src import dtypes
from jax._src.tree_util import (
PyTreeDef, tree_flatten, tree_unflatten, tree_multimap, tree_structure,
treedef_children, treedef_is_leaf)
from jax._src.tree_util import _replace_nones
from jax import linear_util as lu
from jax._src.util import safe_map, WrapKwArgs, Hashable, Unhashable
from jax.core import unit
from . import traceback_util
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
map = safe_map

View File

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax import core
from jax import core
class _JAXErrorMixin:
"""Mixin for JAX-specific errors"""

View File

@ -32,8 +32,8 @@ logging._warn_preinit_stderr = 0
import jax._src.lib
from jax._src.config import flags, bool_env
from . import tpu_driver_client
from . import xla_client
from jax._src.lib import tpu_driver_client
from jax._src.lib import xla_client
from jax._src import util, traceback_util
import numpy as np

View File

@ -24,7 +24,7 @@ from jax._src import dtypes
from jax import lax
from jax import core
from jax.core import AxisName
from .. import util
from jax._src import util
from jax.scipy.special import expit
from jax.scipy.special import logsumexp as _logsumexp
import jax.numpy as jnp

View File

@ -19,8 +19,8 @@ import numpy as np
from jax import lax
from jax._src.lib import xla_client
from jax._src.util import safe_zip
from .util import _wraps
from . import lax_numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy import lax_numpy as jnp
def _fft_norm(s, func_name, norm):

View File

@ -39,8 +39,8 @@ import opt_einsum
import jax
from jax import jit, custom_jvp
from .vectorize import vectorize
from .util import _wraps
from jax._src.numpy.vectorize import vectorize
from jax._src.numpy.util import _wraps
from jax import core
from jax._src import dtypes
from jax._src.api_util import _ensure_index_tuple
@ -1698,7 +1698,7 @@ def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
# scale lhs to improve condition number and solve
scale = sqrt((lhs*lhs).sum(axis=0))
lhs /= scale[newaxis,:]
from . import linalg
from jax._src.numpy import linalg
c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond)
c = (c.T/scale).T # broadcast scale coefficients
@ -4547,7 +4547,7 @@ def poly(seq_of_zeros):
sh = seq_of_zeros.shape
if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
# import at runtime to avoid circular import
from . import linalg
from jax._src.numpy import linalg
seq_of_zeros = linalg.eigvals(seq_of_zeros)
if seq_of_zeros.ndim != 1:

View File

@ -24,8 +24,8 @@ from jax import jit, custom_jvp
from jax import lax
from jax._src.lax import linalg as lax_linalg
from jax._src import dtypes
from .util import _wraps
from . import lax_numpy as jnp
from jax._src.numpy.util import _wraps
from jax._src.numpy import lax_numpy as jnp
from jax._src.util import canonicalize_axis
_T = lambda x: jnp.swapaxes(x, -1, -2)

View File

@ -15,11 +15,11 @@
import numpy as np
from jax import lax
from . import lax_numpy as jnp
from jax._src.numpy import lax_numpy as jnp
from jax import jit
from .util import _wraps
from .linalg import eigvals as _eigvals
from jax._src.numpy.util import _wraps
from jax._src.numpy.linalg import eigvals as _eigvals
def _to_inexact_type(type):

View File

@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, List, Tuple
from jax._src import api
from jax import lax
from . import lax_numpy as jnp
from jax._src.numpy import lax_numpy as jnp
from jax._src.util import safe_map as map, safe_zip as zip

View File

@ -18,7 +18,7 @@ from functools import partial
import jax
import jax.numpy as jnp
from jax import lax
from .line_search import line_search
from jax._src.scipy.optimize.line_search import line_search
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)

View File

@ -18,7 +18,7 @@ from typing import Callable, NamedTuple, Optional, Union
import jax
import jax.numpy as jnp
from jax import lax
from .line_search import line_search
from jax._src.scipy.optimize.line_search import line_search
class _BFGSResults(NamedTuple):

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Optional, Tuple, Union
from .bfgs import minimize_bfgs
from ._lbfgs import _minimize_lbfgs
from jax._src.scipy.optimize.bfgs import minimize_bfgs
from jax._src.scipy.optimize._lbfgs import _minimize_lbfgs
from typing import NamedTuple
import jax.numpy as jnp

View File

@ -21,9 +21,9 @@ from typing import (Any, Callable, Hashable, Iterable, Optional, Tuple, Type,
from jax._src.lib import pytree
from .._src.util import safe_zip, unzip2
from jax._src.util import safe_zip, unzip2
from .._src import traceback_util
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
T = TypeVar("T")

View File

@ -31,20 +31,20 @@ from weakref import ref
import numpy as np
from ._src import dtypes
from ._src import config as jax_config
from ._src.config import FLAGS, config
from .errors import (ConcretizationTypeError, TracerArrayConversionError,
from jax._src import dtypes
from jax._src import config as jax_config
from jax._src.config import FLAGS, config
from jax.errors import (ConcretizationTypeError, TracerArrayConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
from . import linear_util as lu
from jax import linear_util as lu
from jax._src import source_info_util
from ._src.util import (safe_zip, safe_map, curry, prod, tuple_insert,
from jax._src.util import (safe_zip, safe_map, curry, prod, tuple_insert,
tuple_delete, cache, as_hashable_function,
HashableFunction)
import jax._src.pretty_printer as pp
from ._src import traceback_util
from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
zip, unsafe_zip = safe_zip, zip

View File

@ -13,7 +13,7 @@
# limitations under the License.
# flake8: noqa: F401
from ._src.errors import (
from jax._src.errors import (
JAXTypeError as JAXTypeError,
JAXIndexError as JAXIndexError,
ConcretizationTypeError as ConcretizationTypeError,

View File

@ -13,12 +13,12 @@
# limitations under the License.
# flake8: noqa: F401
from ..interpreters.sharded_jit import (
from jax.interpreters.sharded_jit import (
sharded_jit as sharded_jit,
PartitionSpec as PartitionSpec,
with_sharding_constraint as with_sharding_constraint,
)
from .x64_context import (
from jax.experimental.x64_context import (
enable_x64 as enable_x64,
disable_x64 as disable_x64,
)

View File

@ -18,14 +18,14 @@ import dataclasses
import numpy as np
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict
from . import maps
from .. import core
from jax.experimental import maps
from jax import core
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from ..interpreters import pxla, xla
from .._src.util import prod, safe_zip
from .._src.api import device_put
from ..interpreters.sharded_jit import PartitionSpec
from jax.interpreters import pxla, xla
from jax._src.util import prod, safe_zip
from jax._src.api import device_put
from jax.interpreters.sharded_jit import PartitionSpec
Shape = Tuple[int, ...]
MeshAxes = Sequence[Union[str, Tuple[str], None]]
@ -49,7 +49,7 @@ class _HashableIndex:
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
mesh_axes: MeshAxes) -> Mapping[Device, Index]:
# Import here to avoid cyclic import error when importing gsda in pjit.py.
from .pjit import get_array_mapping, _prepare_axis_resources
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources
if not isinstance(mesh_axes, PartitionSpec):
pspec = PartitionSpec(*mesh_axes)

View File

@ -13,5 +13,6 @@
# limitations under the License.
# flake8: noqa: F401
from .jax2tf import convert, dtype_of_val, split_to_logical_devices, PolyShape
from .call_tf import call_tf
from jax.experimental.jax2tf.jax2tf import (convert, dtype_of_val,
split_to_logical_devices, PolyShape)
from jax.experimental.jax2tf.call_tf import call_tf

View File

@ -36,7 +36,7 @@ from jax._src import ad_util
from jax._src.lax.lax import _device_put_raw
from jax.interpreters import xla
from jax._src.lib import xla_client
from . import jax2tf as jax2tf_internal
from jax.experimental.jax2tf import jax2tf as jax2tf_internal
import numpy as np
import tensorflow as tf # type: ignore[import]

View File

@ -24,7 +24,7 @@ from jax._src.lax import slicing as lax_slicing
from jax._src import dtypes
from jax._src import util
from . import jax2tf
from jax.experimental.jax2tf import jax2tf
import numpy as np
import tensorflow as tf # type: ignore[import]

View File

@ -48,9 +48,9 @@ from jax.interpreters import sharded_jit
from jax.interpreters import xla
from jax._src.lib import xla_client
from . import shape_poly
from . import shape_poly_tf
from . import impl_no_xla
from jax.experimental.jax2tf import shape_poly
from jax.experimental.jax2tf import shape_poly_tf
from jax.experimental.jax2tf import impl_no_xla
import numpy as np
import tensorflow as tf # type: ignore[import]

View File

@ -21,7 +21,7 @@ from typing import Any, Optional, Sequence, Tuple, Union
import numpy as np
import tensorflow as tf # type: ignore[import]
from . import shape_poly
from jax.experimental.jax2tf import shape_poly
TfVal = Any

View File

@ -23,32 +23,32 @@ from warnings import warn
from functools import wraps, partial, partialmethod
from enum import Enum
from .. import numpy as jnp
from .. import core
from .. import linear_util as lu
from .._src.api import _check_callable, _check_arg
from jax import numpy as jnp
from jax import core
from jax import linear_util as lu
from jax._src.api import _check_callable, _check_arg
from jax._src import dispatch
from ..tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map,
from jax.tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map,
tree_leaves)
from .._src.tree_util import _replace_nones
from .._src.api_util import (flatten_fun_nokwargs, flatten_axes,
from jax._src.tree_util import _replace_nones
from jax._src.api_util import (flatten_fun_nokwargs, flatten_axes,
_ensure_index_tuple, donation_vector,
shaped_abstractify)
from .._src import source_info_util
from .._src.config import config
from ..errors import JAXTypeError
from ..interpreters import partial_eval as pe
from ..interpreters import pxla
from ..interpreters import xla
from ..interpreters import batching
from ..interpreters import ad
from jax._src import source_info_util
from jax._src.config import config
from jax.errors import JAXTypeError
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax.interpreters import xla
from jax.interpreters import batching
from jax.interpreters import ad
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from .._src.util import (safe_map, safe_zip, HashableFunction,
from jax._src.util import (safe_map, safe_zip, HashableFunction,
as_hashable_function, unzip2, distributed_debug_log,
tuple_insert, moveaxis, split_list, wrap_name)
from .._src.lax.parallel import _build_axis_index_lowering
from .. import lax
from jax._src.lax.parallel import _build_axis_index_lowering
from jax import lax
class _PositionalSemantics(Enum):
"""Indicates whether the positional shapes of inputs should be interpreted as

View File

@ -20,27 +20,27 @@ from warnings import warn
import itertools as it
from functools import partial
from . import maps
from .gsda import GlobalShardedDeviceArray as GSDA
from .. import core
from .. import linear_util as lu
from .._src.api import _check_callable, _check_arg, Lowered
from .._src import source_info_util
from .._src.api_util import (argnums_partial_except, flatten_axes,
from jax.experimental import maps
from jax.experimental.gsda import GlobalShardedDeviceArray as GSDA
from jax import core
from jax import linear_util as lu
from jax._src.api import _check_callable, _check_arg, Lowered
from jax._src import source_info_util
from jax._src.api_util import (argnums_partial_except, flatten_axes,
flatten_fun_nokwargs, _ensure_index_tuple,
donation_vector, rebase_donate_argnums,
shaped_abstractify)
from ..errors import JAXTypeError
from ..interpreters import ad
from ..interpreters import pxla
from ..interpreters import xla
from ..interpreters import batching
from ..interpreters import partial_eval as pe
from ..interpreters.sharded_jit import PartitionSpec
from jax.errors import JAXTypeError
from jax.interpreters import ad
from jax.interpreters import pxla
from jax.interpreters import xla
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters.sharded_jit import PartitionSpec
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from ..tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves
from .._src.util import (extend_name_stack, HashableFunction, safe_zip,
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves
from jax._src.util import (extend_name_stack, HashableFunction, safe_zip,
wrap_name, wraps, distributed_debug_log,
split_list, cache, tuple_insert)
xops = xc._xla.ops

View File

@ -183,11 +183,11 @@ To fit the same model on sparse data, we can apply the :func:`sparsify` transfor
"""
# flake8: noqa: F401
from .ad import (
from jax.experimental.sparse.ad import (
grad as grad,
value_and_grad as value_and_grad,
)
from .bcoo import (
from jax.experimental.sparse.bcoo import (
bcoo_dot_general as bcoo_dot_general,
bcoo_dot_general_p as bcoo_dot_general_p,
bcoo_dot_general_sampled as bcoo_dot_general_sampled,
@ -207,7 +207,7 @@ from .bcoo import (
BCOO as BCOO,
)
from .ops import (
from jax.experimental.sparse.ops import (
coo_fromdense as coo_fromdense,
coo_fromdense_p as coo_fromdense_p,
coo_matmat as coo_matmat,
@ -232,8 +232,8 @@ from .ops import (
CSR as CSR,
)
from .random import random_bcoo as random_bcoo
from .transform import (
from jax.experimental.sparse.random import random_bcoo as random_bcoo
from jax.experimental.sparse.transform import (
sparsify as sparsify,
SparseTracer as SparseTracer,
)

View File

@ -23,7 +23,7 @@ from jax._src.api_util import _ensure_index, _ensure_index_tuple
from jax.util import safe_zip
from jax._src.util import wraps
from jax._src.traceback_util import api_boundary
from .bcoo import BCOO
from jax.experimental.sparse.bcoo import BCOO
def value_and_grad(fun: Callable,

View File

@ -35,7 +35,7 @@ from jax._src.lax.lax import (
ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
DotDimensionNumbers)
from jax._src.numpy.lax_numpy import _unique
from . import ops
from jax.experimental.sparse import ops
Dtype = Any
Shape = Tuple[int, ...]

View File

@ -757,7 +757,7 @@ def _todense_transpose(ct, *bufs, tree):
standin = object()
obj = tree_util.tree_unflatten(tree, [standin] * len(bufs))
from . import BCOO, bcoo_extract
from jax.experimental.sparse import BCOO, bcoo_extract
if obj is standin:
return (ct,)
elif isinstance(obj, BCOO):

View File

@ -24,7 +24,7 @@
# uniformity
from contextlib import contextmanager
from .._src.config import enable_x64 as _jax_enable_x64
from jax._src.config import enable_x64 as _jax_enable_x64
@contextmanager
def enable_x64(new_val: bool = True):

View File

@ -19,21 +19,21 @@ import itertools as it
from typing import Any, Callable, Dict
import jax
from . import partial_eval as pe
from ..config import config
from .. import core
from .._src.dtypes import dtype, float0
from ..core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
from jax.interpreters import partial_eval as pe
from jax.config import config
from jax import core
from jax._src.dtypes import dtype, float0
from jax.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
raise_to_shaped)
from .._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_aval, zeros_like_p, Zero)
from .._src.util import (unzip2, safe_map, safe_zip, split_list,
from jax._src.util import (unzip2, safe_map, safe_zip, split_list,
wrap_name, as_hashable_function)
from ..tree_util import register_pytree_node
from .. import linear_util as lu
from .._src.api_util import flatten_fun, flatten_fun_nokwargs
from ..tree_util import tree_flatten, tree_unflatten, Partial
from .._src import source_info_util
from jax.tree_util import register_pytree_node
from jax import linear_util as lu
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
from jax.tree_util import tree_flatten, tree_unflatten, Partial
from jax._src import source_info_util
zip = safe_zip
map = safe_map

View File

@ -19,17 +19,17 @@ from typing import (Any, Callable, Dict, Set, Optional, Tuple, Union, Iterable,
import numpy as np
import jax
from ..config import config
from .. import core
from ..core import raise_to_shaped, Trace, Tracer
from jax.config import config
from jax import core
from jax.core import raise_to_shaped, Trace, Tracer
from jax._src.tree_util import tree_unflatten, tree_flatten
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_p, Zero)
from .. import linear_util as lu
from .._src.util import (unzip2, safe_map, safe_zip, wrap_name, split_list,
from jax import linear_util as lu
from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, split_list,
canonicalize_axis, moveaxis, as_hashable_function,
curry, memoize)
from . import partial_eval as pe
from jax.interpreters import partial_eval as pe
map = safe_map

View File

@ -19,13 +19,13 @@ from typing import Dict, Any, Callable
import jax
from jax import core
from jax import linear_util as lu
from . import ad
from . import partial_eval as pe
from ..core import raise_to_shaped, get_aval, Literal, Jaxpr
from .._src.api_util import flatten_fun_nokwargs
from ..tree_util import tree_flatten, tree_unflatten, register_pytree_node
from .._src.util import safe_map, safe_zip, split_list
from .._src import custom_derivatives
from jax.interpreters import ad
from jax.interpreters import partial_eval as pe
from jax.core import raise_to_shaped, get_aval, Literal, Jaxpr
from jax._src.api_util import flatten_fun_nokwargs
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
from jax._src.util import safe_map, safe_zip, split_list
from jax._src import custom_derivatives
map = safe_map
zip = safe_zip

View File

@ -22,12 +22,12 @@ from typing import Callable, Dict, Optional, Sequence, Union, Tuple
import numpy as np
from .. import core
from .._src import dtypes
from ..tree_util import tree_unflatten
from ..core import ShapedArray, Trace, Tracer
from .._src.util import safe_map, safe_zip, unzip2, prod, wrap_name
from .. import linear_util as lu
from jax import core
from jax._src import dtypes
from jax.tree_util import tree_unflatten
from jax.core import ShapedArray, Trace, Tracer
from jax._src.util import safe_map, safe_zip, unzip2, prod, wrap_name
from jax import linear_util as lu
map = safe_map
zip = safe_zip

View File

@ -25,21 +25,21 @@ from weakref import ref
import numpy as np
from .. import core
from .._src import dtypes
from .. import linear_util as lu
from jax import core
from jax._src import dtypes
from jax import linear_util as lu
from jax._src.ad_util import Zero
from .._src.api_util import flattened_fun_in_tree
from .._src.tree_util import PyTreeDef, tree_unflatten, tree_leaves
from .._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
from jax._src.api_util import flattened_fun_in_tree
from jax._src.tree_util import PyTreeDef, tree_unflatten, tree_leaves
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
partition_list, cache, OrderedSet,
as_hashable_function)
from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
from jax.core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
ConcreteArray, raise_to_shaped, Var, Atom,
JaxprEqn, Primitive)
from jax._src import source_info_util
from ..config import config
from jax.config import config
map = safe_map
zip = safe_zip

View File

@ -42,26 +42,26 @@ import sys
from absl import logging
import numpy as np
from .._src.config import config
from .. import core
from .. import linear_util as lu
from jax._src.config import config
from jax import core
from jax import linear_util as lu
from jax._src.abstract_arrays import array_types
from ..core import ConcreteArray, ShapedArray
from jax.core import ConcreteArray, ShapedArray
from jax._src import device_array
from .._src import source_info_util
from .._src.util import (unzip3, prod, safe_map, safe_zip,
from jax._src import source_info_util
from jax._src.util import (unzip3, prod, safe_map, safe_zip,
extend_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log)
from ..errors import JAXTypeError
from jax.errors import JAXTypeError
from jax._src import dispatch
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from ..tree_util import tree_flatten, tree_map
from . import batching
from . import partial_eval as pe
from . import xla
from . import ad
from jax.tree_util import tree_flatten, tree_map
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters import ad
# Built in Python lists don't support weak refs but subclasses of lists do.
class WeakRefList(list):
@ -458,7 +458,7 @@ def sda_array_result_handler(sharding_spec, indices, aval: ShapedArray):
indices)
def gsda_array_result_handler(global_aval, global_mesh, out_axis_resources):
from ..experimental.gsda import GlobalShardedDeviceArray
from jax.experimental.gsda import GlobalShardedDeviceArray
return lambda bufs: GlobalShardedDeviceArray(
global_aval.shape, global_mesh, out_axis_resources, bufs)

View File

@ -18,21 +18,22 @@ from typing import Callable, Iterable, Optional, Tuple, Union
from absl import logging
import numpy as np
from .. import core
from . import ad
from . import partial_eval as pe
from jax import core
from jax.interpreters import ad
from jax.interpreters import partial_eval as pe
# TODO(skye): separate pmap into it's own module?
from . import pxla
from . import xla
from .. import linear_util as lu
from jax.interpreters import pxla
from jax.interpreters import xla
from jax import linear_util as lu
from jax._src import dispatch
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from .._src.api_util import argnums_partial, flatten_axes, flatten_fun, _ensure_index_tuple
from ..tree_util import tree_flatten, tree_unflatten
from .._src.util import (extend_name_stack, wrap_name, wraps, safe_zip,
from jax._src.api_util import (argnums_partial, flatten_axes, flatten_fun,
_ensure_index_tuple)
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src.util import (extend_name_stack, wrap_name, wraps, safe_zip,
HashableFunction)
from .._src.config import config
from jax._src.config import config
xops = xc._xla.ops

View File

@ -28,24 +28,24 @@ from typing_extensions import Protocol
import numpy as np
from ..config import config
from .. import core
from jax.config import config
from jax import core
from jax._src import ad_util
from jax._src import device_array
from jax._src import dtypes
from .. import linear_util as lu
from jax import linear_util as lu
from jax._src import source_info_util
from jax._src.abstract_arrays import (make_shaped_array, array_types)
from ..core import (ConcreteArray, ShapedArray,
from jax.core import (ConcreteArray, ShapedArray,
Literal, pp_eqn_compact, JaxprPpContext,
abstract_token)
import jax._src.pretty_printer as pp
from .._src.util import (prod, extend_name_stack, wrap_name,
from jax._src.util import (prod, extend_name_stack, wrap_name,
safe_zip, safe_map, partition_list)
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from . import partial_eval as pe
from . import ad
from jax.interpreters import partial_eval as pe
from jax.interpreters import ad
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

View File

@ -367,4 +367,4 @@ from jax._src.lax.other import (
conv_general_dilated_patches as conv_general_dilated_patches
)
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
from . import linalg as linalg
from jax.lax import linalg as linalg

View File

@ -17,4 +17,4 @@ from jax._src.lib import (
xla_client as xla_client,
xla_extension as xla_extension,
)
from . import xla_bridge as xla_bridge
from jax.lib import xla_bridge as xla_bridge

View File

@ -67,13 +67,13 @@ from functools import partial
from typing import Any, Tuple, Callable
import weakref
from . import core
from ._src.util import curry
from .tree_util import tree_map
from jax import core
from jax._src.util import curry
from jax.tree_util import tree_map
from ._src import traceback_util
from jax._src import traceback_util
from .config import config
from jax.config import config
traceback_util.register_exclusion(__file__)

View File

@ -17,7 +17,7 @@
# flake8: noqa: F401
from jax.numpy import tanh as tanh
from . import initializers as initializers
from jax.nn import initializers as initializers
from jax._src.nn.functions import (
celu as celu,
elu as elu,

View File

@ -16,8 +16,8 @@
# See PEP 484 & https://github.com/google/jax/issues/7570
# flake8: noqa: F401
from . import fft as fft
from . import linalg as linalg
from jax.numpy import fft as fft
from jax.numpy import linalg as linalg
from jax._src.device_array import DeviceArray as DeviceArray

View File

@ -13,10 +13,10 @@
# limitations under the License.
# flake8: noqa: F401
from . import linalg as linalg
from . import ndimage as ndimage
from . import signal as signal
from . import sparse as sparse
from . import special as special
from . import stats as stats
from . import fft as fft
from jax.scipy import linalg as linalg
from jax.scipy import ndimage as ndimage
from jax.scipy import signal as signal
from jax.scipy import sparse as sparse
from jax.scipy import special as special
from jax.scipy import stats as stats
from jax.scipy import fft as fft

View File

@ -13,4 +13,4 @@
# limitations under the License.
# flake8: noqa: F401
from . import linalg as linalg
from jax.scipy.sparse import linalg as linalg

View File

@ -13,21 +13,21 @@
# limitations under the License.
# flake8: noqa: F401
from . import bernoulli as bernoulli
from . import beta as beta
from . import cauchy as cauchy
from . import dirichlet as dirichlet
from . import expon as expon
from . import gamma as gamma
from . import geom as geom
from . import laplace as laplace
from . import logistic as logistic
from . import multivariate_normal as multivariate_normal
from . import nbinom as nbinom
from . import norm as norm
from . import pareto as pareto
from . import poisson as poisson
from . import t as t
from . import uniform as uniform
from . import chi2 as chi2
from . import betabinom as betabinom
from jax.scipy.stats import bernoulli as bernoulli
from jax.scipy.stats import beta as beta
from jax.scipy.stats import cauchy as cauchy
from jax.scipy.stats import dirichlet as dirichlet
from jax.scipy.stats import expon as expon
from jax.scipy.stats import gamma as gamma
from jax.scipy.stats import geom as geom
from jax.scipy.stats import laplace as laplace
from jax.scipy.stats import logistic as logistic
from jax.scipy.stats import multivariate_normal as multivariate_normal
from jax.scipy.stats import nbinom as nbinom
from jax.scipy.stats import norm as norm
from jax.scipy.stats import pareto as pareto
from jax.scipy.stats import poisson as poisson
from jax.scipy.stats import t as t
from jax.scipy.stats import uniform as uniform
from jax.scipy.stats import chi2 as chi2
from jax.scipy.stats import betabinom as betabinom