migrate internal dependencies from jax.core to jax._src.core

... in preparation for paring down `jax.core`'s exported symbols.

Also includes a few import fixups along the way, and a TODO comment to avoid an
import cycle in `_src/dtypes.py`.

PiperOrigin-RevId: 496024782
This commit is contained in:
Roy Frostig 2022-12-16 20:59:41 -08:00 committed by jax authors
parent 6f770c936e
commit d927a5dbf3
43 changed files with 139 additions and 127 deletions

View File

@ -62,7 +62,7 @@ from jax._src.config import (
transfer_guard_device_to_host as transfer_guard_device_to_host,
spmd_mode as spmd_mode,
)
from .core import eval_context as ensure_compile_time_eval
from jax._src.core import eval_context as ensure_compile_time_eval
from jax._src.environment_info import print_environment_info as print_environment_info
from jax._src.api import (
ad, # TODO(phawkins): update users to avoid this.

View File

@ -17,7 +17,7 @@ from functools import partial
import numpy as np
from jax._src import ad_util
from jax import core
from jax._src import core
from jax._src import dtypes
from jax._src import traceback_util

View File

@ -18,7 +18,6 @@ from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any,
import types
import jax
from jax import core
from jax import linear_util as lu
from jax.interpreters import ad
from jax.interpreters import batching
@ -27,6 +26,7 @@ from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
from jax._src import core
from jax._src import util
from jax._src import source_info_util
from jax._src import traceback_util

View File

@ -15,10 +15,11 @@ from __future__ import annotations
from typing import Any, Callable, Dict, Type, Union
from jax import core
from jax.core import (lattice_join, Primitive, valid_jaxtype, raise_to_shaped,
get_aval)
from jax.tree_util import register_pytree_node
from jax._src import core
from jax._src.core import (lattice_join, Primitive, valid_jaxtype,
raise_to_shaped, get_aval)
from jax._src.util import safe_map
from jax._src import traceback_util

View File

@ -34,22 +34,21 @@ import numpy as np
from contextlib import contextmanager, ExitStack
import jax
from jax import core
from jax import linear_util as lu
from jax import stages
from jax.core import eval_jaxpr
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
tree_structure, tree_transpose, tree_leaves,
treedef_is_leaf, treedef_children,
Partial, PyTreeDef, all_leaves, treedef_tuple)
from jax._src import callback as jcb
from jax._src import core
from jax._src import device_array
from jax._src import dispatch
from jax._src import array
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.sharding import PmapSharding
from jax._src.core import eval_jaxpr
from jax._src.api_util import (
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
@ -61,6 +60,7 @@ from jax._src.lib import jax_jit
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import pmap_lib
from jax._src.sharding import PmapSharding
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import broadcast_prefix, _generate_key_paths
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
@ -68,12 +68,12 @@ from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
wraps, HashableFunction, weakref_lru_cache)
# Unused imports to be exported
from jax._src.core import ShapedArray, raise_to_shaped
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
local_devices, process_index,
process_count, host_id, host_ids,
host_count, default_backend)
from jax.ad_checkpoint import checkpoint_policies, checkpoint as new_checkpoint
from jax.core import ShapedArray, raise_to_shaped
from jax.custom_batching import custom_vmap
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
custom_vjp, linear_call)

View File

@ -21,7 +21,7 @@ import warnings
import numpy as np
from jax import core
from jax._src import core
from jax._src import dtypes
from jax._src.tree_util import (
PyTreeDef, tree_flatten, tree_unflatten, tree_map, tree_structure,

View File

@ -19,10 +19,10 @@ import numpy as np
import functools
from typing import Sequence, Tuple, Callable, Union, Optional, cast, List
from jax import core
from jax._src import abstract_arrays
from jax._src import api_util
from jax._src import basearray
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src.config import config

View File

@ -33,14 +33,14 @@ class Array(abc.ABC):
{class}`DeviceArray`
{class}`ShardedDeviceArray`
{class}`GlobalDeviceArray`
{class}`~jax.core.Tracer`
{class}`~jax._src.core.Tracer`
{class}`~jax.experimental.Array`
================================ ====================== =========================
In other words, ``isinstance(x, jax.Array)`` will return True for any of these types,
whereas annotations such as ``x : jax.Array`` will only type-check correctly for
instances of {class}`~jax.core.Tracer` and {class}`jax.experimental.Array`, and not
for the other soon-to-be-deprecated array types.
instances of {class}`~jax._src.core.Tracer` and {class}`jax.experimental.Array`, and
not for the other soon-to-be-deprecated array types.
"""
# Note: no abstract methods are defined in this base class; the associated pyi
# file contains the type signature for static type checking.

View File

@ -18,8 +18,8 @@ import functools
from typing import Any, Callable, Sequence
from jax import core
from jax import tree_util
from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src import dispatch

View File

@ -19,9 +19,9 @@ import types
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable, Type, Set, List
import jax
from jax import core
from jax import lax
from jax import linear_util as lu
from jax._src import core
from jax._src import prng
from jax._src import source_info_util
from jax._src import traceback_util

View File

@ -17,7 +17,6 @@ import operator
from typing import Callable, Optional
import jax
from jax import core
from jax import linear_util as lu
from jax import tree_util
from jax.interpreters import ad
@ -28,6 +27,7 @@ from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import (tree_flatten, tree_map, tree_structure,
tree_unflatten, treedef_tuple)
from jax._src import core
from jax._src import custom_api_util
from jax._src import source_info_util
from jax._src import traceback_util

View File

@ -17,18 +17,18 @@ import inspect
from typing import (Callable, Generic, Optional, Sequence, Tuple, TypeVar, Set,
Any)
from jax import core
from jax import linear_util as lu
from jax.custom_transpose import custom_transpose
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
treedef_is_leaf, treedef_tuple,
register_pytree_node_class, tree_leaves)
from jax._src import core
from jax._src import custom_api_util
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
from jax.core import raise_to_shaped
from jax._src.core import raise_to_shaped
from jax.errors import UnexpectedTracerError
from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p
from jax.interpreters import partial_eval as pe

View File

@ -15,7 +15,6 @@
import functools
from typing import Any, Callable, Optional, Tuple
from jax import core
from jax import linear_util as lu
from jax.interpreters import ad
from jax.interpreters import mlir
@ -25,6 +24,7 @@ from jax.tree_util import (tree_flatten, tree_leaves, tree_map,
tree_structure, treedef_tuple, tree_unflatten)
from jax._src import ad_util
from jax._src import api_util
from jax._src import core
from jax._src import custom_api_util
from jax._src import source_info_util
from jax._src import traceback_util

View File

@ -20,12 +20,10 @@ import weakref
from typing import Any, Dict, Callable, Optional, Sequence, Set, Tuple, Union
from jax import core
from jax import tree_util
from jax import lax
from jax import linear_util as lu
from jax.config import config
from jax._src.sharding import Sharding, OpShardingSharding
from jax.experimental import pjit
from jax.interpreters import ad
from jax.interpreters import batching
@ -33,12 +31,14 @@ from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax._src import ad_checkpoint
from jax._src import core
from jax._src import custom_derivatives
from jax._src import util
from jax._src.lax import control_flow as lcf
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import Sharding, OpShardingSharding
import jax.numpy as jnp
import numpy as np

View File

@ -23,13 +23,13 @@ import warnings
import numpy as np
import jax
from jax import core
from jax._src.config import config
from jax._src import core
from jax._src import abstract_arrays
from jax._src import dtypes
from jax._src import profiler
from jax._src import util
from jax._src.config import config
from jax._src.lib import xla_client as xc
import jax._src.util as util
from jax._src.typing import Array
### device-persistent data

View File

@ -32,7 +32,6 @@ import warnings
import numpy as np
import jax
from jax import core
from jax import linear_util as lu
from jax.errors import UnexpectedTracerError
from jax.monitoring import record_event_duration_secs
@ -43,6 +42,7 @@ import jax.interpreters.xla as xla
from jax.interpreters import pxla
import jax.interpreters.partial_eval as pe
from jax._src import array
from jax._src import core
from jax._src import device_array
from jax._src import dtypes
from jax._src import profiler

View File

@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from jax import core
from jax import numpy as jnp
from jax._src import core
from jax._src import device_array
from jax._src import dispatch
from jax._src import array

View File

@ -25,7 +25,6 @@ from typing import cast, overload, Any, Dict, List, Literal, Optional, Set, Tupl
import numpy as np
import jax
from jax._src.config import flags, config
from jax._src.lib import xla_client
from jax._src.typing import DType, DTypeLike, OpaqueDType
@ -92,7 +91,8 @@ def to_complex_dtype(dtype: DTypeLike) -> DType:
@functools.lru_cache(maxsize=None)
def _canonicalize_dtype(x64_enabled: bool, allow_opaque_dtype: bool, dtype: Any) -> Union[DType, OpaqueDType]:
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
if jax.core.is_opaque_dtype(dtype):
from jax._src import core # TODO(frostig): break this cycle
if core.is_opaque_dtype(dtype):
if not allow_opaque_dtype:
raise ValueError(f"Internal: canonicalize_dtype called onopaque dtype {dtype} "
"with allow_opaque_dtype=False")
@ -450,13 +450,14 @@ def check_valid_dtype(dtype: DType) -> None:
def dtype(x: Any, *, canonicalize: bool = False) -> DType:
"""Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
from jax._src import core # TODO(frostig): break this cycle
if x is None:
raise ValueError(f"Invalid argument to dtype: {x}.")
elif isinstance(x, type) and x in python_scalar_dtypes:
dt = python_scalar_dtypes[x]
elif type(x) in python_scalar_dtypes:
dt = python_scalar_dtypes[type(x)]
elif jax.core.is_opaque_dtype(getattr(x, 'dtype', None)):
elif core.is_opaque_dtype(getattr(x, 'dtype', None)):
dt = x.dtype
else:
dt = np.result_type(x)

View File

@ -13,7 +13,7 @@
# limitations under the License.
from __future__ import annotations
from jax import core
from jax._src import core
class _JAXErrorMixin:
"""Mixin for JAX-specific errors"""

View File

@ -73,10 +73,11 @@ from functools import partial
from typing import (Any, Tuple)
import numpy as np
from jax import core
from jax._src import core
from jax._src import ad_util
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.lib import xla_client as xc
from jax._src import ad_util, dtypes
from jax.interpreters import ad, xla, batching

View File

@ -19,15 +19,15 @@ from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
from jax import core
from jax._src import dtypes
from jax._src.lax import lax
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src.lib.mlir.dialects import hlo
from jax._src.lax import lax
from jax._src.lib import xla_client
from jax._src.lib.mlir.dialects import hlo
_max = builtins.max

View File

@ -18,18 +18,18 @@ from typing import Union, Sequence
import numpy as np
from jax._src.api import jit, linear_transpose, ShapeDtypeStruct
from jax.core import Primitive, is_constant_shape
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.util import prod
from jax import lax
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.api import jit, linear_transpose, ShapeDtypeStruct
from jax._src.core import Primitive, is_constant_shape
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client
from jax._src.lib import ducc_fft
from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact
from jax._src.util import prod
__all__ = [
"fft",

View File

@ -25,7 +25,7 @@ import warnings
import numpy as np
import jax
from jax import core
from jax._src import core
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
@ -38,8 +38,8 @@ from jax import tree_util
from jax._src import source_info_util
from jax._src.sharding import PmapSharding
from jax._src.config import config
from jax.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
raise_to_shaped, abstract_token, canonicalize_shape)
from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
raise_to_shaped, abstract_token, canonicalize_shape)
from jax._src.abstract_arrays import array_types
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir

View File

@ -32,7 +32,8 @@ from jax.interpreters import xla
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.util import prod
from jax.core import Primitive, ShapedArray, raise_to_shaped, is_constant_shape
from jax._src.core import (
Primitive, ShapedArray, raise_to_shaped, is_constant_shape)
from jax._src.lax.lax import (
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
_input_dtype)

View File

@ -23,15 +23,15 @@ import warnings
import numpy as np
from jax import core
from jax import tree_util
from jax.core import ShapedArray, AxisName, raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import batching
from jax._src import core
from jax._src import dtypes
from jax._src.core import ShapedArray, AxisName, raise_to_shaped
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.numpy import lax_numpy

View File

@ -28,9 +28,9 @@ import functools
from typing import Optional, Tuple
import jax
from jax import core
from jax import lax
import jax.numpy as jnp
from jax import lax
from jax._src import core
from jax._src.lax import linalg as lax_linalg

View File

@ -20,8 +20,8 @@ import weakref
import numpy as np
import jax
from jax import core
from jax._src import ad_util
from jax._src import core
from jax._src import dtypes
from jax.interpreters import ad
from jax.interpreters import batching

View File

@ -34,12 +34,12 @@ https://epubs.siam.org/doi/abs/10.1137/090774999
"""
import functools
from typing import Any, Sequence, Union
import jax
from jax import core
from jax import lax
import jax.numpy as jnp
from jax import lax
from jax._src import core
@functools.partial(jax.jit, static_argnums=(1, 2))

View File

@ -21,9 +21,9 @@ from functools import partial
import operator
from typing import Callable
from jax import core
from jax._src import dtypes
from jax.interpreters import xla
from jax._src import core
from jax._src import dtypes
from jax._src.util import safe_zip
from jax._src.lib import xla_client

View File

@ -23,19 +23,19 @@ from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import xla
from jax import core
from jax.core import (ShapedArray, ConcreteArray)
from jax import tree_util
from jax._src import ad_util
from jax._src import core
from jax._src import dtypes
import jax._src.lax.lax as lax
import jax._src.lax.convolution as convolution
import jax._src.lax.slicing as slicing
from jax._src import util
from jax._src.core import ShapedArray, ConcreteArray
from jax._src.lax import lax
from jax._src.lax import convolution
from jax._src.lax import slicing
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
import jax._src.util as util
map = util.safe_map
zip = util.safe_zip

View File

@ -22,7 +22,6 @@ import numpy as np
import jax
from jax import lax
from jax import core
from jax import numpy as jnp
from jax.config import config
from jax.dtypes import float0
@ -35,14 +34,15 @@ from jax._src import basearray
from jax._src.sharding import (
NamedSharding, PmapSharding, OpShardingSharding)
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src.api import jit, vmap
from jax._src.lax import lax as lax_internal
from jax._src.lax import utils as lax_utils
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy import lax_numpy
import jax._src.pretty_printer as pp
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
from jax._src.lib import gpu_prng

View File

@ -21,21 +21,21 @@ import warnings
import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
from jax import core
from jax import numpy as jnp
from jax._src import dtypes
from jax._src import prng
from jax.config import config
from jax.core import NamedShape
from jax._src.api import jit, vmap
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_bridge
from jax._src.numpy.lax_numpy import _arraylike, _check_arraylike, _convert_and_clip_integer, _promote_dtypes_inexact
from jax.numpy.linalg import cholesky, svd, eigh
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.numpy.linalg import cholesky, svd, eigh
from jax._src import core
from jax._src import dtypes
from jax._src import prng
from jax._src.api import jit, vmap
from jax._src.core import NamedShape
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_bridge
from jax._src.numpy.lax_numpy import _arraylike, _check_arraylike, _convert_and_clip_integer, _promote_dtypes_inexact
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.util import prod, canonicalize_axis

View File

@ -36,10 +36,10 @@ from dataclasses import dataclass
from typing import Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tuple
import jax
from jax import core
from jax import tree_util
from jax.lib import xla_client as xc
from jax._src import core
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util

View File

@ -32,20 +32,21 @@ import numpy as np
import numpy.random as npr
import jax
from jax._src import api
from jax import core
from jax._src import dtypes as _dtypes
from jax import lax
from jax.interpreters import mlir
from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten
from jax._src import api
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes as _dtypes
from jax._src.config import flags, bool_env, config
from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact
from jax._src.util import prod, unzip2
from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten
from jax._src.lib import xla_bridge
from jax._src import dispatch
from jax._src.public_test_util import ( # noqa: F401
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, device_under_test, tolerance)
from jax.interpreters import mlir
# This submodule includes private test utilities that are not exported to
# jax.test_util. Functionality appearing here is for internal use only, and

View File

@ -15,7 +15,7 @@
# TODO(phawkins): fix users of these aliases and delete this file.
from jax._src.abstract_arrays import array_types
from jax.core import (
from jax._src.core import (
ShapedArray,
raise_to_shaped,
)

View File

@ -14,26 +14,28 @@
import contextlib
import functools
from functools import partial
import itertools as it
from functools import partial
from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union
import jax
from jax import linear_util as lu
from jax.interpreters import partial_eval as pe
from jax.config import config
from jax import core
from jax._src.dtypes import dtype, float0
from jax.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
raise_to_shaped)
from jax.tree_util import (tree_flatten, tree_unflatten,
register_pytree_node, Partial)
from jax._src import core
from jax._src import source_info_util
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
zeros_like_aval, zeros_like_p, Zero)
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
raise_to_shaped)
from jax._src.dtypes import dtype, float0
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
as_hashable_function, weakref_lru_cache,
partition_list)
from jax.tree_util import register_pytree_node
from jax import linear_util as lu
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
from jax.tree_util import tree_flatten, tree_unflatten, Partial
from jax._src import source_info_util
zip = safe_zip
map = safe_map

View File

@ -23,9 +23,9 @@ import numpy as np
import jax
from jax.config import config
from jax import core
from jax.core import raise_to_shaped, Trace, Tracer
from jax._src import core
from jax._src import source_info_util
from jax._src.core import raise_to_shaped, Trace, Tracer
from jax._src.tree_util import (tree_unflatten, tree_flatten,
register_pytree_node)
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,

View File

@ -27,25 +27,27 @@ from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional,
Protocol, Sequence, Set, Tuple, Type, Union, FrozenSet)
import warnings
from jax import core
import numpy as np
from jax import linear_util as lu
from jax.config import config
from jax.interpreters import ad
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src import ad_util
from jax._src import core
from jax._src import device_array
from jax._src import dtypes
from jax._src.lib import mlir_api_version, xla_extension_version
from jax._src import source_info_util
from jax._src import util
from jax._src.lib import (can_execute_with_token, mlir_api_version,
xla_extension_version)
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import can_execute_with_token
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src import source_info_util
import jax._src.util as util
from jax.config import config
import jax.interpreters.ad as ad
import jax.interpreters.partial_eval as pe
import jax.interpreters.xla as xla
import numpy as np
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip

View File

@ -26,24 +26,27 @@ from weakref import ref
import numpy as np
from jax import core
from jax import linear_util as lu
from jax.config import config
from jax._src import api_util
from jax._src import core
from jax._src import dtypes
from jax._src import profiler
from jax._src import source_info_util
from jax._src.api_util import flattened_fun_in_tree, flatten_fun_nokwargs
from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval,
AbstractValue, ClosedJaxpr, new_jaxpr_eqn,
ConcreteArray, Var, DropVar,
raise_to_shaped, Atom, JaxprEqn, Primitive,
ShapedArray, DShapedArray, mapped_aval,
unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
InputType, OutputType, get_referent)
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
tree_leaves)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, OrderedSet,
as_hashable_function, weakref_lru_cache)
from jax.core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
ClosedJaxpr, new_jaxpr_eqn, ConcreteArray, Var, DropVar,
raise_to_shaped, Atom, JaxprEqn, Primitive, ShapedArray,
DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx,
OutDBIdx, InputType, OutputType, get_referent)
from jax._src import source_info_util
from jax.config import config
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

View File

@ -47,9 +47,7 @@ from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
import numpy as np
import jax
from jax import core
from jax import linear_util as lu
from jax.core import ConcreteArray, ShapedArray
from jax.errors import JAXTypeError
from jax.interpreters import ad
from jax.interpreters import batching
@ -61,6 +59,7 @@ from jax.tree_util import tree_flatten, tree_map
from jax._src import abstract_arrays
from jax._src import api_util
from jax._src import basearray
from jax._src import core
from jax._src import device_array
from jax._src import dtypes
from jax._src import source_info_util
@ -72,6 +71,7 @@ from jax._src import sharding as sharding_internal
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.config import flags
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.lib import can_execute_with_token
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
@ -83,6 +83,7 @@ from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
tuple_insert, tuple_delete, distributed_debug_log,
unzip2, HashableFunction)
# Built in Python lists don't support weak refs but subclasses of lists do.
class WeakRefList(list):
pass

View File

@ -27,13 +27,16 @@ from typing import (Any, Callable, Dict, List, NamedTuple, Optional,
import numpy as np
from jax.config import config
from jax import core
from jax.interpreters import partial_eval as pe
from jax.interpreters import ad
from jax._src import core
from jax._src import device_array
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src.abstract_arrays import numpy_scalar_types
from jax.core import (ConcreteArray, ShapedArray, str_eqn_compact)
import jax._src.pretty_printer as pp
from jax._src.core import ConcreteArray, ShapedArray, str_eqn_compact
from jax._src.util import (prod, new_name_stack, safe_zip, safe_map,
partition_list)
@ -45,8 +48,6 @@ from jax._src.typing import Shape
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax.interpreters import partial_eval as pe
from jax.interpreters import ad
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

View File

@ -21,10 +21,10 @@ import json
import types
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Tuple
from jax import core
from jax._src.lib import xla_client
from jax._src import core
from jax._src import util
from jax._src import source_info_util
from jax._src.lib import xla_client
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip

View File

@ -67,13 +67,11 @@ from functools import partial
from typing import Any, Tuple, Callable
import weakref
from jax import core
from jax._src.util import curry
from jax.tree_util import tree_map
from jax._src import traceback_util
from jax.config import config
from jax._src import core
from jax._src import traceback_util
from jax._src.util import curry
traceback_util.register_exclusion(__file__)