mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
migrate internal dependencies from jax.core
to jax._src.core
... in preparation for paring down `jax.core`'s exported symbols. Also includes a few import fixups along the way, and a TODO comment to avoid an import cycle in `_src/dtypes.py`. PiperOrigin-RevId: 496024782
This commit is contained in:
parent
6f770c936e
commit
d927a5dbf3
@ -62,7 +62,7 @@ from jax._src.config import (
|
||||
transfer_guard_device_to_host as transfer_guard_device_to_host,
|
||||
spmd_mode as spmd_mode,
|
||||
)
|
||||
from .core import eval_context as ensure_compile_time_eval
|
||||
from jax._src.core import eval_context as ensure_compile_time_eval
|
||||
from jax._src.environment_info import print_environment_info as print_environment_info
|
||||
from jax._src.api import (
|
||||
ad, # TODO(phawkins): update users to avoid this.
|
||||
|
@ -17,7 +17,7 @@ from functools import partial
|
||||
import numpy as np
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
|
||||
from jax._src import traceback_util
|
||||
|
@ -18,7 +18,6 @@ from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any,
|
||||
import types
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
@ -27,6 +26,7 @@ from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import util
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
|
@ -15,10 +15,11 @@ from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, Type, Union
|
||||
|
||||
from jax import core
|
||||
from jax.core import (lattice_join, Primitive, valid_jaxtype, raise_to_shaped,
|
||||
get_aval)
|
||||
from jax.tree_util import register_pytree_node
|
||||
|
||||
from jax._src import core
|
||||
from jax._src.core import (lattice_join, Primitive, valid_jaxtype,
|
||||
raise_to_shaped, get_aval)
|
||||
from jax._src.util import safe_map
|
||||
|
||||
from jax._src import traceback_util
|
||||
|
@ -34,22 +34,21 @@ import numpy as np
|
||||
from contextlib import contextmanager, ExitStack
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax import stages
|
||||
from jax.core import eval_jaxpr
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
tree_structure, tree_transpose, tree_leaves,
|
||||
treedef_is_leaf, treedef_children,
|
||||
Partial, PyTreeDef, all_leaves, treedef_tuple)
|
||||
from jax._src import callback as jcb
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src import array
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src.sharding import PmapSharding
|
||||
from jax._src.core import eval_jaxpr
|
||||
from jax._src.api_util import (
|
||||
flatten_fun, apply_flat_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2,
|
||||
argnums_partial, argnums_partial_except, flatten_axes, donation_vector,
|
||||
@ -61,6 +60,7 @@ from jax._src.lib import jax_jit
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.sharding import PmapSharding
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import broadcast_prefix, _generate_key_paths
|
||||
from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
|
||||
@ -68,12 +68,12 @@ from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
|
||||
wraps, HashableFunction, weakref_lru_cache)
|
||||
|
||||
# Unused imports to be exported
|
||||
from jax._src.core import ShapedArray, raise_to_shaped
|
||||
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
|
||||
local_devices, process_index,
|
||||
process_count, host_id, host_ids,
|
||||
host_count, default_backend)
|
||||
from jax.ad_checkpoint import checkpoint_policies, checkpoint as new_checkpoint
|
||||
from jax.core import ShapedArray, raise_to_shaped
|
||||
from jax.custom_batching import custom_vmap
|
||||
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
|
||||
custom_vjp, linear_call)
|
||||
|
@ -21,7 +21,7 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.tree_util import (
|
||||
PyTreeDef, tree_flatten, tree_unflatten, tree_map, tree_structure,
|
||||
|
@ -19,10 +19,10 @@ import numpy as np
|
||||
import functools
|
||||
from typing import Sequence, Tuple, Callable, Union, Optional, cast, List
|
||||
|
||||
from jax import core
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import api_util
|
||||
from jax._src import basearray
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src.config import config
|
||||
|
@ -33,14 +33,14 @@ class Array(abc.ABC):
|
||||
{class}`DeviceArray` ✅ ❌
|
||||
{class}`ShardedDeviceArray` ✅ ❌
|
||||
{class}`GlobalDeviceArray` ✅ ❌
|
||||
{class}`~jax.core.Tracer` ✅ ✅
|
||||
{class}`~jax._src.core.Tracer` ✅ ✅
|
||||
{class}`~jax.experimental.Array` ✅ ✅
|
||||
================================ ====================== =========================
|
||||
|
||||
In other words, ``isinstance(x, jax.Array)`` will return True for any of these types,
|
||||
whereas annotations such as ``x : jax.Array`` will only type-check correctly for
|
||||
instances of {class}`~jax.core.Tracer` and {class}`jax.experimental.Array`, and not
|
||||
for the other soon-to-be-deprecated array types.
|
||||
instances of {class}`~jax._src.core.Tracer` and {class}`jax.experimental.Array`, and
|
||||
not for the other soon-to-be-deprecated array types.
|
||||
"""
|
||||
# Note: no abstract methods are defined in this base class; the associated pyi
|
||||
# file contains the type signature for static type checking.
|
||||
|
@ -18,8 +18,8 @@ import functools
|
||||
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
from jax import core
|
||||
from jax import tree_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src import dispatch
|
||||
|
@ -19,9 +19,9 @@ import types
|
||||
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable, Type, Set, List
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax._src import core
|
||||
from jax._src import prng
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
|
@ -17,7 +17,6 @@ import operator
|
||||
from typing import Callable, Optional
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax import tree_util
|
||||
from jax.interpreters import ad
|
||||
@ -28,6 +27,7 @@ from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import (tree_flatten, tree_map, tree_structure,
|
||||
tree_unflatten, treedef_tuple)
|
||||
from jax._src import core
|
||||
from jax._src import custom_api_util
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
|
@ -17,18 +17,18 @@ import inspect
|
||||
from typing import (Callable, Generic, Optional, Sequence, Tuple, TypeVar, Set,
|
||||
Any)
|
||||
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.custom_transpose import custom_transpose
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
|
||||
treedef_is_leaf, treedef_tuple,
|
||||
register_pytree_node_class, tree_leaves)
|
||||
from jax._src import core
|
||||
from jax._src import custom_api_util
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax
|
||||
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
|
||||
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
|
||||
from jax.core import raise_to_shaped
|
||||
from jax._src.core import raise_to_shaped
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
@ -15,7 +15,6 @@
|
||||
import functools
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
@ -25,6 +24,7 @@ from jax.tree_util import (tree_flatten, tree_leaves, tree_map,
|
||||
tree_structure, treedef_tuple, tree_unflatten)
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import custom_api_util
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
|
@ -20,12 +20,10 @@ import weakref
|
||||
|
||||
from typing import Any, Dict, Callable, Optional, Sequence, Set, Tuple, Union
|
||||
|
||||
from jax import core
|
||||
from jax import tree_util
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax._src.sharding import Sharding, OpShardingSharding
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
@ -33,12 +31,14 @@ from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import pxla
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import util
|
||||
from jax._src.lax import control_flow as lcf
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.sharding import Sharding, OpShardingSharding
|
||||
import jax.numpy as jnp
|
||||
|
||||
import numpy as np
|
||||
|
@ -23,13 +23,13 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src.config import config
|
||||
from jax._src import core
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import dtypes
|
||||
from jax._src import profiler
|
||||
from jax._src import util
|
||||
from jax._src.config import config
|
||||
from jax._src.lib import xla_client as xc
|
||||
import jax._src.util as util
|
||||
from jax._src.typing import Array
|
||||
|
||||
### device-persistent data
|
||||
|
@ -32,7 +32,6 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.monitoring import record_event_duration_secs
|
||||
@ -43,6 +42,7 @@ import jax.interpreters.xla as xla
|
||||
from jax.interpreters import pxla
|
||||
import jax.interpreters.partial_eval as pe
|
||||
from jax._src import array
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src import profiler
|
||||
|
@ -12,8 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax import core
|
||||
from jax import numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src import array
|
||||
|
@ -25,7 +25,6 @@ from typing import cast, overload, Any, Dict, List, Literal, Optional, Set, Tupl
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src.config import flags, config
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.typing import DType, DTypeLike, OpaqueDType
|
||||
@ -92,7 +91,8 @@ def to_complex_dtype(dtype: DTypeLike) -> DType:
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _canonicalize_dtype(x64_enabled: bool, allow_opaque_dtype: bool, dtype: Any) -> Union[DType, OpaqueDType]:
|
||||
"""Convert from a dtype to a canonical dtype based on config.x64_enabled."""
|
||||
if jax.core.is_opaque_dtype(dtype):
|
||||
from jax._src import core # TODO(frostig): break this cycle
|
||||
if core.is_opaque_dtype(dtype):
|
||||
if not allow_opaque_dtype:
|
||||
raise ValueError(f"Internal: canonicalize_dtype called onopaque dtype {dtype} "
|
||||
"with allow_opaque_dtype=False")
|
||||
@ -450,13 +450,14 @@ def check_valid_dtype(dtype: DType) -> None:
|
||||
|
||||
def dtype(x: Any, *, canonicalize: bool = False) -> DType:
|
||||
"""Return the dtype object for a value or type, optionally canonicalized based on X64 mode."""
|
||||
from jax._src import core # TODO(frostig): break this cycle
|
||||
if x is None:
|
||||
raise ValueError(f"Invalid argument to dtype: {x}.")
|
||||
elif isinstance(x, type) and x in python_scalar_dtypes:
|
||||
dt = python_scalar_dtypes[x]
|
||||
elif type(x) in python_scalar_dtypes:
|
||||
dt = python_scalar_dtypes[type(x)]
|
||||
elif jax.core.is_opaque_dtype(getattr(x, 'dtype', None)):
|
||||
elif core.is_opaque_dtype(getattr(x, 'dtype', None)):
|
||||
dt = x.dtype
|
||||
else:
|
||||
dt = np.result_type(x)
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
|
||||
class _JAXErrorMixin:
|
||||
"""Mixin for JAX-specific errors"""
|
||||
|
@ -73,10 +73,11 @@ from functools import partial
|
||||
from typing import (Any, Tuple)
|
||||
|
||||
import numpy as np
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax._src import ad_util
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src import ad_util, dtypes
|
||||
|
||||
from jax.interpreters import ad, xla, batching
|
||||
|
||||
|
@ -19,15 +19,15 @@ from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
||||
_max = builtins.max
|
||||
|
||||
|
@ -18,18 +18,18 @@ from typing import Union, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax._src.api import jit, linear_transpose, ShapeDtypeStruct
|
||||
from jax.core import Primitive, is_constant_shape
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src.util import prod
|
||||
from jax import lax
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax._src.api import jit, linear_transpose, ShapeDtypeStruct
|
||||
from jax._src.core import Primitive, is_constant_shape
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import ducc_fft
|
||||
from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact
|
||||
from jax._src.util import prod
|
||||
|
||||
__all__ = [
|
||||
"fft",
|
||||
|
@ -25,7 +25,7 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src import core
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax._src import api_util
|
||||
@ -38,8 +38,8 @@ from jax import tree_util
|
||||
from jax._src import source_info_util
|
||||
from jax._src.sharding import PmapSharding
|
||||
from jax._src.config import config
|
||||
from jax.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
|
||||
raise_to_shaped, abstract_token, canonicalize_shape)
|
||||
from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
|
||||
raise_to_shaped, abstract_token, canonicalize_shape)
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
|
@ -32,7 +32,8 @@ from jax.interpreters import xla
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax._src.util import prod
|
||||
from jax.core import Primitive, ShapedArray, raise_to_shaped, is_constant_shape
|
||||
from jax._src.core import (
|
||||
Primitive, ShapedArray, raise_to_shaped, is_constant_shape)
|
||||
from jax._src.lax.lax import (
|
||||
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
|
||||
_input_dtype)
|
||||
|
@ -23,15 +23,15 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax import tree_util
|
||||
from jax.core import ShapedArray, AxisName, raise_to_shaped
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import batching
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.core import ShapedArray, AxisName, raise_to_shaped
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.numpy import lax_numpy
|
||||
|
@ -28,9 +28,9 @@ import functools
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
|
||||
|
||||
|
@ -20,8 +20,8 @@ import weakref
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
|
@ -34,12 +34,12 @@ https://epubs.siam.org/doi/abs/10.1137/090774999
|
||||
"""
|
||||
|
||||
import functools
|
||||
|
||||
from typing import Any, Sequence, Union
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax._src import core
|
||||
|
||||
|
||||
@functools.partial(jax.jit, static_argnums=(1, 2))
|
||||
|
@ -21,9 +21,9 @@ from functools import partial
|
||||
import operator
|
||||
from typing import Callable
|
||||
|
||||
from jax import core
|
||||
from jax._src import dtypes
|
||||
from jax.interpreters import xla
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
|
@ -23,19 +23,19 @@ from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax import core
|
||||
from jax.core import (ShapedArray, ConcreteArray)
|
||||
from jax import tree_util
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
import jax._src.lax.lax as lax
|
||||
import jax._src.lax.convolution as convolution
|
||||
import jax._src.lax.slicing as slicing
|
||||
from jax._src import util
|
||||
from jax._src.core import ShapedArray, ConcreteArray
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax import convolution
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy.ufuncs import logaddexp
|
||||
import jax._src.util as util
|
||||
|
||||
map = util.safe_map
|
||||
zip = util.safe_zip
|
||||
|
@ -22,7 +22,6 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import core
|
||||
from jax import numpy as jnp
|
||||
from jax.config import config
|
||||
from jax.dtypes import float0
|
||||
@ -35,14 +34,15 @@ from jax._src import basearray
|
||||
from jax._src.sharding import (
|
||||
NamedSharding, PmapSharding, OpShardingSharding)
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import utils as lax_utils
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy import lax_numpy
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
|
||||
from jax._src.lib import gpu_prng
|
||||
|
||||
|
@ -21,21 +21,21 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax import core
|
||||
from jax import numpy as jnp
|
||||
from jax._src import dtypes
|
||||
from jax._src import prng
|
||||
from jax.config import config
|
||||
from jax.core import NamedShape
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.numpy.lax_numpy import _arraylike, _check_arraylike, _convert_and_clip_integer, _promote_dtypes_inexact
|
||||
from jax.numpy.linalg import cholesky, svd, eigh
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.numpy.linalg import cholesky, svd, eigh
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import prng
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.core import NamedShape
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.numpy.lax_numpy import _arraylike, _check_arraylike, _convert_and_clip_integer, _promote_dtypes_inexact
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
from jax._src.util import prod, canonicalize_axis
|
||||
|
||||
|
@ -36,10 +36,10 @@ from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, NamedTuple, Optional, Protocol, Sequence, Tuple
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import tree_util
|
||||
from jax.lib import xla_client as xc
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
|
@ -32,20 +32,21 @@ import numpy as np
|
||||
import numpy.random as npr
|
||||
|
||||
import jax
|
||||
from jax._src import api
|
||||
from jax import core
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax import lax
|
||||
from jax.interpreters import mlir
|
||||
from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax._src.config import flags, bool_env, config
|
||||
from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact
|
||||
from jax._src.util import prod, unzip2
|
||||
from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src import dispatch
|
||||
from jax._src.public_test_util import ( # noqa: F401
|
||||
_assert_numpy_allclose, _check_dtypes_match, _default_tolerance, _dtype, check_close, check_grads,
|
||||
check_jvp, check_vjp, default_gradient_tolerance, default_tolerance, device_under_test, tolerance)
|
||||
from jax.interpreters import mlir
|
||||
|
||||
|
||||
# This submodule includes private test utilities that are not exported to
|
||||
# jax.test_util. Functionality appearing here is for internal use only, and
|
||||
|
@ -15,7 +15,7 @@
|
||||
# TODO(phawkins): fix users of these aliases and delete this file.
|
||||
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax.core import (
|
||||
from jax._src.core import (
|
||||
ShapedArray,
|
||||
raise_to_shaped,
|
||||
)
|
||||
|
@ -14,26 +14,28 @@
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Tuple, Sequence, Optional, Union
|
||||
|
||||
import jax
|
||||
from jax import linear_util as lu
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.config import config
|
||||
from jax import core
|
||||
from jax._src.dtypes import dtype, float0
|
||||
from jax.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
|
||||
raise_to_shaped)
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten,
|
||||
register_pytree_node, Partial)
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
zeros_like_aval, zeros_like_p, Zero)
|
||||
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
|
||||
from jax._src.core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
|
||||
raise_to_shaped)
|
||||
from jax._src.dtypes import dtype, float0
|
||||
from jax._src.util import (unzip2, safe_map, safe_zip, split_list, wrap_name,
|
||||
as_hashable_function, weakref_lru_cache,
|
||||
partition_list)
|
||||
from jax.tree_util import register_pytree_node
|
||||
from jax import linear_util as lu
|
||||
from jax._src.api_util import flatten_fun, flatten_fun_nokwargs
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, Partial
|
||||
from jax._src import source_info_util
|
||||
|
||||
|
||||
zip = safe_zip
|
||||
map = safe_map
|
||||
|
@ -23,9 +23,9 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.config import config
|
||||
from jax import core
|
||||
from jax.core import raise_to_shaped, Trace, Tracer
|
||||
from jax._src import core
|
||||
from jax._src import source_info_util
|
||||
from jax._src.core import raise_to_shaped, Trace, Tracer
|
||||
from jax._src.tree_util import (tree_unflatten, tree_flatten,
|
||||
register_pytree_node)
|
||||
from jax._src.ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval,
|
||||
|
@ -27,25 +27,27 @@ from typing import (Any, Callable, Dict, Iterator, List, NamedTuple, Optional,
|
||||
Protocol, Sequence, Set, Tuple, Type, Union, FrozenSet)
|
||||
import warnings
|
||||
|
||||
from jax import core
|
||||
import numpy as np
|
||||
|
||||
from jax import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src.lib import mlir_api_version, xla_extension_version
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.lib import (can_execute_with_token, mlir_api_version,
|
||||
xla_extension_version)
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import can_execute_with_token
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src import source_info_util
|
||||
import jax._src.util as util
|
||||
from jax.config import config
|
||||
import jax.interpreters.ad as ad
|
||||
import jax.interpreters.partial_eval as pe
|
||||
import jax.interpreters.xla as xla
|
||||
import numpy as np
|
||||
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
@ -26,24 +26,27 @@ from weakref import ref
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import profiler
|
||||
from jax._src import source_info_util
|
||||
from jax._src.api_util import flattened_fun_in_tree, flatten_fun_nokwargs
|
||||
from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval,
|
||||
AbstractValue, ClosedJaxpr, new_jaxpr_eqn,
|
||||
ConcreteArray, Var, DropVar,
|
||||
raise_to_shaped, Atom, JaxprEqn, Primitive,
|
||||
ShapedArray, DShapedArray, mapped_aval,
|
||||
unmapped_aval, DBIdx, InDBIdx, OutDBIdx,
|
||||
InputType, OutputType, get_referent)
|
||||
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
|
||||
tree_leaves)
|
||||
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
|
||||
merge_lists, partition_list, OrderedSet,
|
||||
as_hashable_function, weakref_lru_cache)
|
||||
from jax.core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
|
||||
ClosedJaxpr, new_jaxpr_eqn, ConcreteArray, Var, DropVar,
|
||||
raise_to_shaped, Atom, JaxprEqn, Primitive, ShapedArray,
|
||||
DShapedArray, mapped_aval, unmapped_aval, DBIdx, InDBIdx,
|
||||
OutDBIdx, InputType, OutputType, get_referent)
|
||||
from jax._src import source_info_util
|
||||
from jax.config import config
|
||||
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
@ -47,9 +47,7 @@ from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.core import ConcreteArray, ShapedArray
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
@ -61,6 +59,7 @@ from jax.tree_util import tree_flatten, tree_map
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import api_util
|
||||
from jax._src import basearray
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
@ -72,6 +71,7 @@ from jax._src import sharding as sharding_internal
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.config import config
|
||||
from jax._src.config import flags
|
||||
from jax._src.core import ConcreteArray, ShapedArray
|
||||
from jax._src.lib import can_execute_with_token
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
@ -83,6 +83,7 @@ from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
|
||||
tuple_insert, tuple_delete, distributed_debug_log,
|
||||
unzip2, HashableFunction)
|
||||
|
||||
|
||||
# Built in Python lists don't support weak refs but subclasses of lists do.
|
||||
class WeakRefList(list):
|
||||
pass
|
||||
|
@ -27,13 +27,16 @@ from typing import (Any, Callable, Dict, List, NamedTuple, Optional,
|
||||
import numpy as np
|
||||
|
||||
from jax.config import config
|
||||
from jax import core
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import ad
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import source_info_util
|
||||
from jax._src.abstract_arrays import numpy_scalar_types
|
||||
from jax.core import (ConcreteArray, ShapedArray, str_eqn_compact)
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src.core import ConcreteArray, ShapedArray, str_eqn_compact
|
||||
from jax._src.util import (prod, new_name_stack, safe_zip, safe_map,
|
||||
partition_list)
|
||||
|
||||
@ -45,8 +48,6 @@ from jax._src.typing import Shape
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import ad
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
@ -21,10 +21,10 @@ import json
|
||||
import types
|
||||
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Tuple
|
||||
|
||||
from jax import core
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src import core
|
||||
from jax._src import util
|
||||
from jax._src import source_info_util
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
@ -67,13 +67,11 @@ from functools import partial
|
||||
from typing import Any, Tuple, Callable
|
||||
import weakref
|
||||
|
||||
from jax import core
|
||||
from jax._src.util import curry
|
||||
from jax.tree_util import tree_map
|
||||
|
||||
from jax._src import traceback_util
|
||||
|
||||
from jax.config import config
|
||||
from jax._src import core
|
||||
from jax._src import traceback_util
|
||||
from jax._src.util import curry
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user