mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Use imports relative to the jax
package consistently, rather than .
-relative imports.
This is more consistent, since currently we use a mix of both styles. It may also help pytype yield more accurate types. PiperOrigin-RevId: 412057514
This commit is contained in:
parent
b10a306266
commit
4e21922055
@ -18,7 +18,7 @@ _os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1')
|
||||
del _os
|
||||
|
||||
# Set Cloud TPU env vars if necessary before transitively loading C++ backend
|
||||
from .cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init
|
||||
from jax._src.cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init
|
||||
try:
|
||||
_cloud_tpu_init()
|
||||
except Exception as exc:
|
||||
@ -34,10 +34,10 @@ del _cloud_tpu_init
|
||||
# Confusingly there are two things named "config": the module and the class.
|
||||
# We want the exported object to be the class, so we first import the module
|
||||
# to make sure a later import doesn't overwrite the class.
|
||||
from . import config as _config_module
|
||||
from jax import config as _config_module
|
||||
del _config_module
|
||||
|
||||
from ._src.config import (
|
||||
from jax._src.config import (
|
||||
config as config,
|
||||
enable_checks as enable_checks,
|
||||
check_tracer_leaks as check_tracer_leaks,
|
||||
@ -50,7 +50,7 @@ from ._src.config import (
|
||||
default_prng_impl as default_prng_impl,
|
||||
numpy_rank_promotion as numpy_rank_promotion,
|
||||
)
|
||||
from ._src.api import (
|
||||
from jax._src.api import (
|
||||
ad, # TODO(phawkins): update users to avoid this.
|
||||
checkpoint as checkpoint,
|
||||
checkpoint_policies as checkpoint_policies,
|
||||
@ -115,24 +115,24 @@ from ._src.api import (
|
||||
xla, # TODO(phawkins): update users to avoid this.
|
||||
xla_computation as xla_computation,
|
||||
)
|
||||
from .experimental.maps import soft_pmap as soft_pmap
|
||||
from .version import __version__ as __version__
|
||||
from jax.experimental.maps import soft_pmap as soft_pmap
|
||||
from jax.version import __version__ as __version__
|
||||
|
||||
# These submodules are separate because they are in an import cycle with
|
||||
# jax and rely on the names imported above.
|
||||
from . import abstract_arrays as abstract_arrays
|
||||
from . import api_util as api_util
|
||||
from . import distributed as distributed
|
||||
from . import dtypes as dtypes
|
||||
from . import errors as errors
|
||||
from . import image as image
|
||||
from . import lax as lax
|
||||
from . import nn as nn
|
||||
from . import numpy as numpy
|
||||
from . import ops as ops
|
||||
from . import profiler as profiler
|
||||
from . import random as random
|
||||
from . import tree_util as tree_util
|
||||
from . import util as util
|
||||
from jax import abstract_arrays as abstract_arrays
|
||||
from jax import api_util as api_util
|
||||
from jax import distributed as distributed
|
||||
from jax import dtypes as dtypes
|
||||
from jax import errors as errors
|
||||
from jax import image as image
|
||||
from jax import lax as lax
|
||||
from jax import nn as nn
|
||||
from jax import numpy as numpy
|
||||
from jax import ops as ops
|
||||
from jax import profiler as profiler
|
||||
from jax import random as random
|
||||
from jax import tree_util as tree_util
|
||||
from jax import util as util
|
||||
|
||||
import jax.lib # TODO(phawkins): remove this export.
|
||||
|
@ -39,23 +39,24 @@ import numpy as np
|
||||
from contextlib import contextmanager, ExitStack
|
||||
|
||||
import jax
|
||||
from .. import core
|
||||
from .. import linear_util as lu
|
||||
from . import dtypes
|
||||
from ..core import eval_jaxpr
|
||||
from .api_util import (flatten_fun, apply_flat_fun, flatten_fun_nokwargs,
|
||||
flatten_fun_nokwargs2, argnums_partial,
|
||||
argnums_partial_except, flatten_axes, donation_vector,
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src import dtypes
|
||||
from jax.core import eval_jaxpr
|
||||
from jax._src.api_util import (
|
||||
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
|
||||
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
|
||||
rebase_donate_argnums, _ensure_index, _ensure_index_tuple,
|
||||
shaped_abstractify, _ensure_str_tuple,
|
||||
argnames_partial_except)
|
||||
from . import traceback_util
|
||||
from .traceback_util import api_boundary
|
||||
from ..tree_util import (tree_map, tree_flatten, tree_unflatten, tree_structure,
|
||||
tree_transpose, tree_leaves, tree_multimap,
|
||||
treedef_is_leaf, treedef_children, Partial, PyTreeDef)
|
||||
from .util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
|
||||
extend_name_stack, wrap_name, cache, wraps, HashableFunction)
|
||||
shaped_abstractify, _ensure_str_tuple, argnames_partial_except)
|
||||
from jax._src import traceback_util
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
tree_structure, tree_transpose, tree_leaves,
|
||||
tree_multimap, treedef_is_leaf, treedef_children,
|
||||
Partial, PyTreeDef)
|
||||
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
|
||||
extend_name_stack, wrap_name, cache, wraps,
|
||||
HashableFunction)
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src.lib import jax_jit
|
||||
@ -65,22 +66,24 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
# Unused imports to be exported
|
||||
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
|
||||
local_devices, process_index, process_count,
|
||||
host_id, host_ids, host_count, default_backend)
|
||||
from ..core import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from ..interpreters import partial_eval as pe
|
||||
from ..interpreters import xla
|
||||
from ..interpreters import pxla
|
||||
from ..interpreters import ad
|
||||
from ..interpreters import batching
|
||||
from ..interpreters import masking
|
||||
from ..interpreters import invertible_ad as iad
|
||||
from ..interpreters.invertible_ad import custom_ivjp
|
||||
from ..custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
|
||||
local_devices, process_index,
|
||||
process_count, host_id, host_ids,
|
||||
host_count, default_backend)
|
||||
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import masking
|
||||
from jax.interpreters import invertible_ad as iad
|
||||
from jax.interpreters.invertible_ad import custom_ivjp
|
||||
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
|
||||
custom_vjp, linear_call)
|
||||
from ..ad_checkpoint import checkpoint_policies
|
||||
from jax.ad_checkpoint import checkpoint_policies
|
||||
|
||||
from .._src.config import (flags, config, bool_env, disable_jit as _disable_jit,
|
||||
from jax._src.config import (flags, config, bool_env,
|
||||
disable_jit as _disable_jit,
|
||||
debug_nans as config_debug_nans,
|
||||
debug_infs as config_debug_infs,
|
||||
_thread_local_state as config_thread_local_state)
|
||||
|
@ -18,16 +18,17 @@ from typing import Any, Dict, Iterable, Tuple, Union, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import core
|
||||
from . import dtypes
|
||||
from .tree_util import (PyTreeDef, tree_flatten, tree_unflatten, tree_multimap,
|
||||
tree_structure, treedef_children, treedef_is_leaf)
|
||||
from .tree_util import _replace_nones
|
||||
from .. import linear_util as lu
|
||||
from .util import safe_map, WrapKwArgs, Hashable, Unhashable
|
||||
from ..core import unit
|
||||
from jax import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.tree_util import (
|
||||
PyTreeDef, tree_flatten, tree_unflatten, tree_multimap, tree_structure,
|
||||
treedef_children, treedef_is_leaf)
|
||||
from jax._src.tree_util import _replace_nones
|
||||
from jax import linear_util as lu
|
||||
from jax._src.util import safe_map, WrapKwArgs, Hashable, Unhashable
|
||||
from jax.core import unit
|
||||
|
||||
from . import traceback_util
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
map = safe_map
|
||||
|
@ -11,8 +11,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from jax import core
|
||||
|
||||
from jax import core
|
||||
|
||||
class _JAXErrorMixin:
|
||||
"""Mixin for JAX-specific errors"""
|
||||
|
@ -32,8 +32,8 @@ logging._warn_preinit_stderr = 0
|
||||
|
||||
import jax._src.lib
|
||||
from jax._src.config import flags, bool_env
|
||||
from . import tpu_driver_client
|
||||
from . import xla_client
|
||||
from jax._src.lib import tpu_driver_client
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import util, traceback_util
|
||||
import numpy as np
|
||||
|
||||
|
@ -24,7 +24,7 @@ from jax._src import dtypes
|
||||
from jax import lax
|
||||
from jax import core
|
||||
from jax.core import AxisName
|
||||
from .. import util
|
||||
from jax._src import util
|
||||
from jax.scipy.special import expit
|
||||
from jax.scipy.special import logsumexp as _logsumexp
|
||||
import jax.numpy as jnp
|
||||
|
@ -19,8 +19,8 @@ import numpy as np
|
||||
from jax import lax
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.util import safe_zip
|
||||
from .util import _wraps
|
||||
from . import lax_numpy as jnp
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
|
||||
|
||||
def _fft_norm(s, func_name, norm):
|
||||
|
@ -39,8 +39,8 @@ import opt_einsum
|
||||
|
||||
import jax
|
||||
from jax import jit, custom_jvp
|
||||
from .vectorize import vectorize
|
||||
from .util import _wraps
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.api_util import _ensure_index_tuple
|
||||
@ -1698,7 +1698,7 @@ def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
|
||||
# scale lhs to improve condition number and solve
|
||||
scale = sqrt((lhs*lhs).sum(axis=0))
|
||||
lhs /= scale[newaxis,:]
|
||||
from . import linalg
|
||||
from jax._src.numpy import linalg
|
||||
c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond)
|
||||
c = (c.T/scale).T # broadcast scale coefficients
|
||||
|
||||
@ -4547,7 +4547,7 @@ def poly(seq_of_zeros):
|
||||
sh = seq_of_zeros.shape
|
||||
if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
|
||||
# import at runtime to avoid circular import
|
||||
from . import linalg
|
||||
from jax._src.numpy import linalg
|
||||
seq_of_zeros = linalg.eigvals(seq_of_zeros)
|
||||
|
||||
if seq_of_zeros.ndim != 1:
|
||||
|
@ -24,8 +24,8 @@ from jax import jit, custom_jvp
|
||||
from jax import lax
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
from jax._src import dtypes
|
||||
from .util import _wraps
|
||||
from . import lax_numpy as jnp
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.util import canonicalize_axis
|
||||
|
||||
_T = lambda x: jnp.swapaxes(x, -1, -2)
|
||||
|
@ -15,11 +15,11 @@
|
||||
|
||||
import numpy as np
|
||||
from jax import lax
|
||||
from . import lax_numpy as jnp
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
|
||||
from jax import jit
|
||||
from .util import _wraps
|
||||
from .linalg import eigvals as _eigvals
|
||||
from jax._src.numpy.util import _wraps
|
||||
from jax._src.numpy.linalg import eigvals as _eigvals
|
||||
|
||||
|
||||
def _to_inexact_type(type):
|
||||
|
@ -18,7 +18,7 @@ from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
from jax._src import api
|
||||
from jax import lax
|
||||
from . import lax_numpy as jnp
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.util import safe_map as map, safe_zip as zip
|
||||
|
||||
|
||||
|
@ -18,7 +18,7 @@ from functools import partial
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from .line_search import line_search
|
||||
from jax._src.scipy.optimize.line_search import line_search
|
||||
|
||||
_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)
|
||||
|
||||
|
@ -18,7 +18,7 @@ from typing import Callable, NamedTuple, Optional, Union
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from .line_search import line_search
|
||||
from jax._src.scipy.optimize.line_search import line_search
|
||||
|
||||
|
||||
class _BFGSResults(NamedTuple):
|
||||
|
@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, Mapping, Optional, Tuple, Union
|
||||
from .bfgs import minimize_bfgs
|
||||
from ._lbfgs import _minimize_lbfgs
|
||||
from jax._src.scipy.optimize.bfgs import minimize_bfgs
|
||||
from jax._src.scipy.optimize._lbfgs import _minimize_lbfgs
|
||||
from typing import NamedTuple
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
@ -21,9 +21,9 @@ from typing import (Any, Callable, Hashable, Iterable, Optional, Tuple, Type,
|
||||
|
||||
from jax._src.lib import pytree
|
||||
|
||||
from .._src.util import safe_zip, unzip2
|
||||
from jax._src.util import safe_zip, unzip2
|
||||
|
||||
from .._src import traceback_util
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
14
jax/core.py
14
jax/core.py
@ -31,20 +31,20 @@ from weakref import ref
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ._src import dtypes
|
||||
from ._src import config as jax_config
|
||||
from ._src.config import FLAGS, config
|
||||
from .errors import (ConcretizationTypeError, TracerArrayConversionError,
|
||||
from jax._src import dtypes
|
||||
from jax._src import config as jax_config
|
||||
from jax._src.config import FLAGS, config
|
||||
from jax.errors import (ConcretizationTypeError, TracerArrayConversionError,
|
||||
TracerIntegerConversionError, UnexpectedTracerError)
|
||||
from . import linear_util as lu
|
||||
from jax import linear_util as lu
|
||||
|
||||
from jax._src import source_info_util
|
||||
from ._src.util import (safe_zip, safe_map, curry, prod, tuple_insert,
|
||||
from jax._src.util import (safe_zip, safe_map, curry, prod, tuple_insert,
|
||||
tuple_delete, cache, as_hashable_function,
|
||||
HashableFunction)
|
||||
import jax._src.pretty_printer as pp
|
||||
|
||||
from ._src import traceback_util
|
||||
from jax._src import traceback_util
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from ._src.errors import (
|
||||
from jax._src.errors import (
|
||||
JAXTypeError as JAXTypeError,
|
||||
JAXIndexError as JAXIndexError,
|
||||
ConcretizationTypeError as ConcretizationTypeError,
|
||||
|
@ -13,12 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from ..interpreters.sharded_jit import (
|
||||
from jax.interpreters.sharded_jit import (
|
||||
sharded_jit as sharded_jit,
|
||||
PartitionSpec as PartitionSpec,
|
||||
with_sharding_constraint as with_sharding_constraint,
|
||||
)
|
||||
from .x64_context import (
|
||||
from jax.experimental.x64_context import (
|
||||
enable_x64 as enable_x64,
|
||||
disable_x64 as disable_x64,
|
||||
)
|
||||
|
@ -18,14 +18,14 @@ import dataclasses
|
||||
import numpy as np
|
||||
from typing import Callable, Sequence, Tuple, Union, Mapping, Optional, List, Dict
|
||||
|
||||
from . import maps
|
||||
from .. import core
|
||||
from jax.experimental import maps
|
||||
from jax import core
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from ..interpreters import pxla, xla
|
||||
from .._src.util import prod, safe_zip
|
||||
from .._src.api import device_put
|
||||
from ..interpreters.sharded_jit import PartitionSpec
|
||||
from jax.interpreters import pxla, xla
|
||||
from jax._src.util import prod, safe_zip
|
||||
from jax._src.api import device_put
|
||||
from jax.interpreters.sharded_jit import PartitionSpec
|
||||
|
||||
Shape = Tuple[int, ...]
|
||||
MeshAxes = Sequence[Union[str, Tuple[str], None]]
|
||||
@ -49,7 +49,7 @@ class _HashableIndex:
|
||||
def get_shard_indices(global_shape: Shape, global_mesh: pxla.Mesh,
|
||||
mesh_axes: MeshAxes) -> Mapping[Device, Index]:
|
||||
# Import here to avoid cyclic import error when importing gsda in pjit.py.
|
||||
from .pjit import get_array_mapping, _prepare_axis_resources
|
||||
from jax.experimental.pjit import get_array_mapping, _prepare_axis_resources
|
||||
|
||||
if not isinstance(mesh_axes, PartitionSpec):
|
||||
pspec = PartitionSpec(*mesh_axes)
|
||||
|
@ -13,5 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from .jax2tf import convert, dtype_of_val, split_to_logical_devices, PolyShape
|
||||
from .call_tf import call_tf
|
||||
from jax.experimental.jax2tf.jax2tf import (convert, dtype_of_val,
|
||||
split_to_logical_devices, PolyShape)
|
||||
from jax.experimental.jax2tf.call_tf import call_tf
|
||||
|
@ -36,7 +36,7 @@ from jax._src import ad_util
|
||||
from jax._src.lax.lax import _device_put_raw
|
||||
from jax.interpreters import xla
|
||||
from jax._src.lib import xla_client
|
||||
from . import jax2tf as jax2tf_internal
|
||||
from jax.experimental.jax2tf import jax2tf as jax2tf_internal
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
@ -24,7 +24,7 @@ from jax._src.lax import slicing as lax_slicing
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
|
||||
from . import jax2tf
|
||||
from jax.experimental.jax2tf import jax2tf
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
@ -48,9 +48,9 @@ from jax.interpreters import sharded_jit
|
||||
from jax.interpreters import xla
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
from . import shape_poly
|
||||
from . import shape_poly_tf
|
||||
from . import impl_no_xla
|
||||
from jax.experimental.jax2tf import shape_poly
|
||||
from jax.experimental.jax2tf import shape_poly_tf
|
||||
from jax.experimental.jax2tf import impl_no_xla
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
@ -21,7 +21,7 @@ from typing import Any, Optional, Sequence, Tuple, Union
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
from . import shape_poly
|
||||
from jax.experimental.jax2tf import shape_poly
|
||||
|
||||
TfVal = Any
|
||||
|
||||
|
@ -23,32 +23,32 @@ from warnings import warn
|
||||
from functools import wraps, partial, partialmethod
|
||||
from enum import Enum
|
||||
|
||||
from .. import numpy as jnp
|
||||
from .. import core
|
||||
from .. import linear_util as lu
|
||||
from .._src.api import _check_callable, _check_arg
|
||||
from jax import numpy as jnp
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src.api import _check_callable, _check_arg
|
||||
from jax._src import dispatch
|
||||
from ..tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map,
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, all_leaves, tree_map,
|
||||
tree_leaves)
|
||||
from .._src.tree_util import _replace_nones
|
||||
from .._src.api_util import (flatten_fun_nokwargs, flatten_axes,
|
||||
from jax._src.tree_util import _replace_nones
|
||||
from jax._src.api_util import (flatten_fun_nokwargs, flatten_axes,
|
||||
_ensure_index_tuple, donation_vector,
|
||||
shaped_abstractify)
|
||||
from .._src import source_info_util
|
||||
from .._src.config import config
|
||||
from ..errors import JAXTypeError
|
||||
from ..interpreters import partial_eval as pe
|
||||
from ..interpreters import pxla
|
||||
from ..interpreters import xla
|
||||
from ..interpreters import batching
|
||||
from ..interpreters import ad
|
||||
from jax._src import source_info_util
|
||||
from jax._src.config import config
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import ad
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from .._src.util import (safe_map, safe_zip, HashableFunction,
|
||||
from jax._src.util import (safe_map, safe_zip, HashableFunction,
|
||||
as_hashable_function, unzip2, distributed_debug_log,
|
||||
tuple_insert, moveaxis, split_list, wrap_name)
|
||||
from .._src.lax.parallel import _build_axis_index_lowering
|
||||
from .. import lax
|
||||
from jax._src.lax.parallel import _build_axis_index_lowering
|
||||
from jax import lax
|
||||
|
||||
class _PositionalSemantics(Enum):
|
||||
"""Indicates whether the positional shapes of inputs should be interpreted as
|
||||
|
@ -20,27 +20,27 @@ from warnings import warn
|
||||
import itertools as it
|
||||
from functools import partial
|
||||
|
||||
from . import maps
|
||||
from .gsda import GlobalShardedDeviceArray as GSDA
|
||||
from .. import core
|
||||
from .. import linear_util as lu
|
||||
from .._src.api import _check_callable, _check_arg, Lowered
|
||||
from .._src import source_info_util
|
||||
from .._src.api_util import (argnums_partial_except, flatten_axes,
|
||||
from jax.experimental import maps
|
||||
from jax.experimental.gsda import GlobalShardedDeviceArray as GSDA
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src.api import _check_callable, _check_arg, Lowered
|
||||
from jax._src import source_info_util
|
||||
from jax._src.api_util import (argnums_partial_except, flatten_axes,
|
||||
flatten_fun_nokwargs, _ensure_index_tuple,
|
||||
donation_vector, rebase_donate_argnums,
|
||||
shaped_abstractify)
|
||||
from ..errors import JAXTypeError
|
||||
from ..interpreters import ad
|
||||
from ..interpreters import pxla
|
||||
from ..interpreters import xla
|
||||
from ..interpreters import batching
|
||||
from ..interpreters import partial_eval as pe
|
||||
from ..interpreters.sharded_jit import PartitionSpec
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters.sharded_jit import PartitionSpec
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from ..tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves
|
||||
from .._src.util import (extend_name_stack, HashableFunction, safe_zip,
|
||||
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_leaves
|
||||
from jax._src.util import (extend_name_stack, HashableFunction, safe_zip,
|
||||
wrap_name, wraps, distributed_debug_log,
|
||||
split_list, cache, tuple_insert)
|
||||
xops = xc._xla.ops
|
||||
|
@ -183,11 +183,11 @@ To fit the same model on sparse data, we can apply the :func:`sparsify` transfor
|
||||
"""
|
||||
|
||||
# flake8: noqa: F401
|
||||
from .ad import (
|
||||
from jax.experimental.sparse.ad import (
|
||||
grad as grad,
|
||||
value_and_grad as value_and_grad,
|
||||
)
|
||||
from .bcoo import (
|
||||
from jax.experimental.sparse.bcoo import (
|
||||
bcoo_dot_general as bcoo_dot_general,
|
||||
bcoo_dot_general_p as bcoo_dot_general_p,
|
||||
bcoo_dot_general_sampled as bcoo_dot_general_sampled,
|
||||
@ -207,7 +207,7 @@ from .bcoo import (
|
||||
BCOO as BCOO,
|
||||
)
|
||||
|
||||
from .ops import (
|
||||
from jax.experimental.sparse.ops import (
|
||||
coo_fromdense as coo_fromdense,
|
||||
coo_fromdense_p as coo_fromdense_p,
|
||||
coo_matmat as coo_matmat,
|
||||
@ -232,8 +232,8 @@ from .ops import (
|
||||
CSR as CSR,
|
||||
)
|
||||
|
||||
from .random import random_bcoo as random_bcoo
|
||||
from .transform import (
|
||||
from jax.experimental.sparse.random import random_bcoo as random_bcoo
|
||||
from jax.experimental.sparse.transform import (
|
||||
sparsify as sparsify,
|
||||
SparseTracer as SparseTracer,
|
||||
)
|
||||
|
@ -23,7 +23,7 @@ from jax._src.api_util import _ensure_index, _ensure_index_tuple
|
||||
from jax.util import safe_zip
|
||||
from jax._src.util import wraps
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from .bcoo import BCOO
|
||||
from jax.experimental.sparse.bcoo import BCOO
|
||||
|
||||
|
||||
def value_and_grad(fun: Callable,
|
||||
|
@ -35,7 +35,7 @@ from jax._src.lax.lax import (
|
||||
ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
|
||||
DotDimensionNumbers)
|
||||
from jax._src.numpy.lax_numpy import _unique
|
||||
from . import ops
|
||||
from jax.experimental.sparse import ops
|
||||
|
||||
Dtype = Any
|
||||
Shape = Tuple[int, ...]
|
||||
|
@ -757,7 +757,7 @@ def _todense_transpose(ct, *bufs, tree):
|
||||
|
||||
standin = object()
|
||||
obj = tree_util.tree_unflatten(tree, [standin] * len(bufs))
|
||||
from . import BCOO, bcoo_extract
|
||||
from jax.experimental.sparse import BCOO, bcoo_extract
|
||||
if obj is standin:
|
||||
return (ct,)
|
||||
elif isinstance(obj, BCOO):
|
||||
|
@ -24,7 +24,7 @@
|
||||
# uniformity
|
||||
|
||||
from contextlib import contextmanager
|
||||
from .._src.config import enable_x64 as _jax_enable_x64
|
||||
from jax._src.config import enable_x64 as _jax_enable_x64
|
||||
|
||||
@contextmanager
|
||||
def enable_x64(new_val: bool = True):
|
||||
|
@ -19,21 +19,21 @@ import itertools as it
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
import jax
|
||||
from . import partial_eval as pe
|
||||
from ..config import config
|
||||
from .. import core
|
||||
from .._src.dtypes import dtype, float0
|
||||
from ..core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.config import config
|
||||
from jax import core
|
||||
from jax._src.dtypes import dtype, float0
|
||||
from jax.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
|
||||
raise_to_shaped)
|
||||
from .._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_aval, zeros_like_p, Zero)
|
||||
from .._src.util import (unzip2, safe_map, safe_zip, split_list,
|
||||
from jax._src.util import (unzip2, safe_map, safe_zip, split_list,
|
||||
wrap_name, as_hashable_function)
|
||||
from ..tree_util import register_pytree_node
|
||||
from .. import linear_util as lu
|
||||
from .._src.api_util import flatten_fun, flatten_fun_nokwargs
|
||||
from ..tree_util import tree_flatten, tree_unflatten, Partial
|
||||
from .._src import source_info_util
|
||||
from jax.tree_util import register_pytree_node
|
||||
from jax import linear_util as lu
|
||||
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, Partial
|
||||
from jax._src import source_info_util
|
||||
|
||||
zip = safe_zip
|
||||
map = safe_map
|
||||
|
@ -19,17 +19,17 @@ from typing import (Any, Callable, Dict, Set, Optional, Tuple, Union, Iterable,
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from ..config import config
|
||||
from .. import core
|
||||
from ..core import raise_to_shaped, Trace, Tracer
|
||||
from jax.config import config
|
||||
from jax import core
|
||||
from jax.core import raise_to_shaped, Trace, Tracer
|
||||
from jax._src.tree_util import tree_unflatten, tree_flatten
|
||||
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_p, Zero)
|
||||
from .. import linear_util as lu
|
||||
from .._src.util import (unzip2, safe_map, safe_zip, wrap_name, split_list,
|
||||
from jax import linear_util as lu
|
||||
from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, split_list,
|
||||
canonicalize_axis, moveaxis, as_hashable_function,
|
||||
curry, memoize)
|
||||
from . import partial_eval as pe
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
map = safe_map
|
||||
|
||||
|
@ -19,13 +19,13 @@ from typing import Dict, Any, Callable
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from . import ad
|
||||
from . import partial_eval as pe
|
||||
from ..core import raise_to_shaped, get_aval, Literal, Jaxpr
|
||||
from .._src.api_util import flatten_fun_nokwargs
|
||||
from ..tree_util import tree_flatten, tree_unflatten, register_pytree_node
|
||||
from .._src.util import safe_map, safe_zip, split_list
|
||||
from .._src import custom_derivatives
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.core import raise_to_shaped, get_aval, Literal, Jaxpr
|
||||
from jax._src.api_util import flatten_fun_nokwargs
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
|
||||
from jax._src.util import safe_map, safe_zip, split_list
|
||||
from jax._src import custom_derivatives
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
@ -22,12 +22,12 @@ from typing import Callable, Dict, Optional, Sequence, Union, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import core
|
||||
from .._src import dtypes
|
||||
from ..tree_util import tree_unflatten
|
||||
from ..core import ShapedArray, Trace, Tracer
|
||||
from .._src.util import safe_map, safe_zip, unzip2, prod, wrap_name
|
||||
from .. import linear_util as lu
|
||||
from jax import core
|
||||
from jax._src import dtypes
|
||||
from jax.tree_util import tree_unflatten
|
||||
from jax.core import ShapedArray, Trace, Tracer
|
||||
from jax._src.util import safe_map, safe_zip, unzip2, prod, wrap_name
|
||||
from jax import linear_util as lu
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
@ -25,21 +25,21 @@ from weakref import ref
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .. import core
|
||||
from .._src import dtypes
|
||||
from .. import linear_util as lu
|
||||
from jax import core
|
||||
from jax._src import dtypes
|
||||
from jax import linear_util as lu
|
||||
from jax._src.ad_util import Zero
|
||||
from .._src.api_util import flattened_fun_in_tree
|
||||
from .._src.tree_util import PyTreeDef, tree_unflatten, tree_leaves
|
||||
from .._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
from jax._src.api_util import flattened_fun_in_tree
|
||||
from jax._src.tree_util import PyTreeDef, tree_unflatten, tree_leaves
|
||||
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
partition_list, cache, OrderedSet,
|
||||
as_hashable_function)
|
||||
from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
|
||||
from jax.core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
|
||||
unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
|
||||
ConcreteArray, raise_to_shaped, Var, Atom,
|
||||
JaxprEqn, Primitive)
|
||||
from jax._src import source_info_util
|
||||
from ..config import config
|
||||
from jax.config import config
|
||||
|
||||
map = safe_map
|
||||
zip = safe_zip
|
||||
|
@ -42,26 +42,26 @@ import sys
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
from .._src.config import config
|
||||
from .. import core
|
||||
from .. import linear_util as lu
|
||||
from jax._src.config import config
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from ..core import ConcreteArray, ShapedArray
|
||||
from jax.core import ConcreteArray, ShapedArray
|
||||
from jax._src import device_array
|
||||
from .._src import source_info_util
|
||||
from .._src.util import (unzip3, prod, safe_map, safe_zip,
|
||||
from jax._src import source_info_util
|
||||
from jax._src.util import (unzip3, prod, safe_map, safe_zip,
|
||||
extend_name_stack, wrap_name, assert_unreachable,
|
||||
tuple_insert, tuple_delete, distributed_debug_log)
|
||||
from ..errors import JAXTypeError
|
||||
from jax.errors import JAXTypeError
|
||||
from jax._src import dispatch
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from ..tree_util import tree_flatten, tree_map
|
||||
from . import batching
|
||||
from . import partial_eval as pe
|
||||
from . import xla
|
||||
from . import ad
|
||||
from jax.tree_util import tree_flatten, tree_map
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import ad
|
||||
|
||||
# Built in Python lists don't support weak refs but subclasses of lists do.
|
||||
class WeakRefList(list):
|
||||
@ -458,7 +458,7 @@ def sda_array_result_handler(sharding_spec, indices, aval: ShapedArray):
|
||||
indices)
|
||||
|
||||
def gsda_array_result_handler(global_aval, global_mesh, out_axis_resources):
|
||||
from ..experimental.gsda import GlobalShardedDeviceArray
|
||||
from jax.experimental.gsda import GlobalShardedDeviceArray
|
||||
|
||||
return lambda bufs: GlobalShardedDeviceArray(
|
||||
global_aval.shape, global_mesh, out_axis_resources, bufs)
|
||||
|
@ -18,21 +18,22 @@ from typing import Callable, Iterable, Optional, Tuple, Union
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
from .. import core
|
||||
from . import ad
|
||||
from . import partial_eval as pe
|
||||
from jax import core
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import partial_eval as pe
|
||||
# TODO(skye): separate pmap into it's own module?
|
||||
from . import pxla
|
||||
from . import xla
|
||||
from .. import linear_util as lu
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax import linear_util as lu
|
||||
from jax._src import dispatch
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from .._src.api_util import argnums_partial, flatten_axes, flatten_fun, _ensure_index_tuple
|
||||
from ..tree_util import tree_flatten, tree_unflatten
|
||||
from .._src.util import (extend_name_stack, wrap_name, wraps, safe_zip,
|
||||
from jax._src.api_util import (argnums_partial, flatten_axes, flatten_fun,
|
||||
_ensure_index_tuple)
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src.util import (extend_name_stack, wrap_name, wraps, safe_zip,
|
||||
HashableFunction)
|
||||
from .._src.config import config
|
||||
from jax._src.config import config
|
||||
|
||||
xops = xc._xla.ops
|
||||
|
||||
|
@ -28,24 +28,24 @@ from typing_extensions import Protocol
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..config import config
|
||||
from .. import core
|
||||
from jax.config import config
|
||||
from jax import core
|
||||
from jax._src import ad_util
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from .. import linear_util as lu
|
||||
from jax import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src.abstract_arrays import (make_shaped_array, array_types)
|
||||
from ..core import (ConcreteArray, ShapedArray,
|
||||
from jax.core import (ConcreteArray, ShapedArray,
|
||||
Literal, pp_eqn_compact, JaxprPpContext,
|
||||
abstract_token)
|
||||
import jax._src.pretty_printer as pp
|
||||
from .._src.util import (prod, extend_name_stack, wrap_name,
|
||||
from jax._src.util import (prod, extend_name_stack, wrap_name,
|
||||
safe_zip, safe_map, partition_list)
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from . import partial_eval as pe
|
||||
from . import ad
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import ad
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
@ -367,4 +367,4 @@ from jax._src.lax.other import (
|
||||
conv_general_dilated_patches as conv_general_dilated_patches
|
||||
)
|
||||
from jax._src.ad_util import stop_gradient_p as stop_gradient_p
|
||||
from . import linalg as linalg
|
||||
from jax.lax import linalg as linalg
|
||||
|
@ -17,4 +17,4 @@ from jax._src.lib import (
|
||||
xla_client as xla_client,
|
||||
xla_extension as xla_extension,
|
||||
)
|
||||
from . import xla_bridge as xla_bridge
|
||||
from jax.lib import xla_bridge as xla_bridge
|
||||
|
@ -67,13 +67,13 @@ from functools import partial
|
||||
from typing import Any, Tuple, Callable
|
||||
import weakref
|
||||
|
||||
from . import core
|
||||
from ._src.util import curry
|
||||
from .tree_util import tree_map
|
||||
from jax import core
|
||||
from jax._src.util import curry
|
||||
from jax.tree_util import tree_map
|
||||
|
||||
from ._src import traceback_util
|
||||
from jax._src import traceback_util
|
||||
|
||||
from .config import config
|
||||
from jax.config import config
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
|
@ -17,7 +17,7 @@
|
||||
# flake8: noqa: F401
|
||||
|
||||
from jax.numpy import tanh as tanh
|
||||
from . import initializers as initializers
|
||||
from jax.nn import initializers as initializers
|
||||
from jax._src.nn.functions import (
|
||||
celu as celu,
|
||||
elu as elu,
|
||||
|
@ -16,8 +16,8 @@
|
||||
# See PEP 484 & https://github.com/google/jax/issues/7570
|
||||
|
||||
# flake8: noqa: F401
|
||||
from . import fft as fft
|
||||
from . import linalg as linalg
|
||||
from jax.numpy import fft as fft
|
||||
from jax.numpy import linalg as linalg
|
||||
|
||||
from jax._src.device_array import DeviceArray as DeviceArray
|
||||
|
||||
|
@ -13,10 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from . import linalg as linalg
|
||||
from . import ndimage as ndimage
|
||||
from . import signal as signal
|
||||
from . import sparse as sparse
|
||||
from . import special as special
|
||||
from . import stats as stats
|
||||
from . import fft as fft
|
||||
from jax.scipy import linalg as linalg
|
||||
from jax.scipy import ndimage as ndimage
|
||||
from jax.scipy import signal as signal
|
||||
from jax.scipy import sparse as sparse
|
||||
from jax.scipy import special as special
|
||||
from jax.scipy import stats as stats
|
||||
from jax.scipy import fft as fft
|
||||
|
@ -13,4 +13,4 @@
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from . import linalg as linalg
|
||||
from jax.scipy.sparse import linalg as linalg
|
||||
|
@ -13,21 +13,21 @@
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from . import bernoulli as bernoulli
|
||||
from . import beta as beta
|
||||
from . import cauchy as cauchy
|
||||
from . import dirichlet as dirichlet
|
||||
from . import expon as expon
|
||||
from . import gamma as gamma
|
||||
from . import geom as geom
|
||||
from . import laplace as laplace
|
||||
from . import logistic as logistic
|
||||
from . import multivariate_normal as multivariate_normal
|
||||
from . import nbinom as nbinom
|
||||
from . import norm as norm
|
||||
from . import pareto as pareto
|
||||
from . import poisson as poisson
|
||||
from . import t as t
|
||||
from . import uniform as uniform
|
||||
from . import chi2 as chi2
|
||||
from . import betabinom as betabinom
|
||||
from jax.scipy.stats import bernoulli as bernoulli
|
||||
from jax.scipy.stats import beta as beta
|
||||
from jax.scipy.stats import cauchy as cauchy
|
||||
from jax.scipy.stats import dirichlet as dirichlet
|
||||
from jax.scipy.stats import expon as expon
|
||||
from jax.scipy.stats import gamma as gamma
|
||||
from jax.scipy.stats import geom as geom
|
||||
from jax.scipy.stats import laplace as laplace
|
||||
from jax.scipy.stats import logistic as logistic
|
||||
from jax.scipy.stats import multivariate_normal as multivariate_normal
|
||||
from jax.scipy.stats import nbinom as nbinom
|
||||
from jax.scipy.stats import norm as norm
|
||||
from jax.scipy.stats import pareto as pareto
|
||||
from jax.scipy.stats import poisson as poisson
|
||||
from jax.scipy.stats import t as t
|
||||
from jax.scipy.stats import uniform as uniform
|
||||
from jax.scipy.stats import chi2 as chi2
|
||||
from jax.scipy.stats import betabinom as betabinom
|
||||
|
Loading…
x
Reference in New Issue
Block a user