migrate internal dependencies from jax.interpreters.ad to jax._src.interpreters.ad

... in preparation for paring down `jax.interpreters.ad`'s exported symbols.

Includes some import fixups along the way.

PiperOrigin-RevId: 507684262
This commit is contained in:
Roy Frostig 2023-02-06 22:51:50 -08:00 committed by jax authors
parent c252162821
commit 219723c738
27 changed files with 196 additions and 170 deletions

View File

@ -18,8 +18,6 @@ from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any,
import types
import jax
from jax._src import linear_util as lu
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
@ -27,10 +25,12 @@ from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
from jax._src import core
from jax._src import util
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.interpreters import ad
from jax._src.lax import lax as lax_internal
from jax._src.lax import convolution as lax_convolution
from jax._src.lib.mlir.dialects import hlo

View File

@ -68,12 +68,8 @@ from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
extend_name_stack, new_name_stack, wrap_name, cache,
wraps, HashableFunction, weakref_lru_cache)
# Unused imports to be exported
from jax._src.core import ShapedArray, raise_to_shaped
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
local_devices, process_index,
process_count, host_id, host_ids,
host_count, default_backend)
from jax.ad_checkpoint import checkpoint_policies, checkpoint as new_checkpoint
from jax.custom_batching import custom_vmap
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
@ -82,8 +78,6 @@ from jax.custom_transpose import custom_transpose
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.interpreters import pxla
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.config import (
@ -94,6 +88,13 @@ from jax._src.config import (
_thread_local_state as config_thread_local_state,
explicit_device_put_scope as config_explicit_device_put_scope,
explicit_device_get_scope as config_explicit_device_get_scope)
from jax._src.core import ShapedArray, raise_to_shaped
from jax._src.interpreters import ad
from jax._src.interpreters import pxla
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
local_devices, process_index,
process_count, host_id, host_ids,
host_count, default_backend)
traceback_util.register_exclusion(__file__)

View File

@ -15,19 +15,20 @@
from __future__ import annotations
import functools
from typing import Any, Callable, Sequence
import numpy as np
from jax import tree_util
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src import dispatch
from jax._src.interpreters import ad
from jax._src.lib import xla_client as xc
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
import numpy as np
# `pure_callback_p` is the main primitive for staging out Python pure callbacks.
pure_callback_p = core.Primitive("pure_callback")

View File

@ -15,10 +15,25 @@
import dataclasses
import functools
import itertools as it
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable, Type, Set, List, Sequence, Any
from typing import (Union, Optional, Callable, Dict, Tuple, TypeVar,
FrozenSet, Iterable, Type, Set, List, Sequence, Any)
import numpy as np
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jax import lax
from jax.api_util import flatten_fun
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten
from jax.tree_util import tree_map
from jax.tree_util import tree_unflatten
from jax._src import linear_util as lu
from jax._src import core
from jax._src import custom_derivatives
@ -26,24 +41,12 @@ from jax._src import prng
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.lax import control_flow as cf
from jax._src.sharding import OpShardingSharding
from jax._src.typing import Array
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
unzip3, weakref_lru_cache)
from jax.api_util import flatten_fun
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten
from jax.tree_util import tree_map
from jax.tree_util import tree_unflatten
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)

View File

@ -17,9 +17,7 @@ import operator
from typing import Callable, Optional
import jax
from jax._src import linear_util as lu
from jax import tree_util
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters.batching import not_mapped
from jax.interpreters import mlir
@ -29,10 +27,12 @@ from jax.tree_util import (tree_flatten, tree_map, tree_structure,
tree_unflatten, treedef_tuple)
from jax._src import core
from jax._src import custom_api_util
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.api_util import flatten_fun_nokwargs
from jax._src.interpreters import ad
source_info_util.register_exclusion(__file__)

View File

@ -17,29 +17,31 @@ import inspect
from typing import (Callable, Generic, Optional, Sequence, Tuple, TypeVar, Set,
Any)
from jax._src import linear_util as lu
from jax.custom_transpose import custom_transpose
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
treedef_is_leaf, treedef_tuple,
register_pytree_node_class, tree_leaves)
from jax._src import core
from jax._src import custom_api_util
from jax._src import dtypes
from jax._src.lax import lax
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
from jax._src.core import raise_to_shaped
from jax.errors import UnexpectedTracerError
from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters.batching import not_mapped
from jax.config import config
from jax._src import core
from jax._src import custom_api_util
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import traceback_util
from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
from jax._src.core import raise_to_shaped
from jax._src.interpreters import ad
from jax._src.lax import lax
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
traceback_util.register_exclusion(__file__)
map = safe_map

View File

@ -15,8 +15,6 @@
import functools
from typing import Any, Callable, Optional, Tuple
from jax._src import linear_util as lu
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
@ -26,9 +24,11 @@ from jax._src import ad_util
from jax._src import api_util
from jax._src import core
from jax._src import custom_api_util
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.interpreters import ad
source_info_util.register_exclusion(__file__)

View File

@ -12,36 +12,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module for JAX debugging primitives and related functionality."""
import enum
import functools
import string
import sys
from typing import Any, Dict, Callable, Optional, Sequence, Set, Tuple, Union
import weakref
from typing import Any, Dict, Callable, Optional, Sequence, Set, Tuple, Union
import numpy as np
import jax.numpy as jnp
from jax import tree_util
from jax import lax
from jax._src import linear_util as lu
from jax.config import config
from jax.experimental import pjit
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax._src import ad_checkpoint
from jax._src import core
from jax._src import custom_derivatives
from jax._src import linear_util as lu
from jax._src import util
from jax._src.interpreters import ad
from jax._src.lax import control_flow as lcf
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import Sharding, OpShardingSharding
import jax.numpy as jnp
import numpy as np
# pytype: disable=import-error
try:
import rich

View File

@ -32,34 +32,36 @@ import warnings
import numpy as np
import jax
from jax._src import linear_util as lu
from jax.errors import UnexpectedTracerError
from jax.monitoring import record_event_duration_secs
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
import jax.interpreters.xla as xla
from jax.interpreters import pxla
import jax.interpreters.partial_eval as pe
from jax._src import array
from jax._src import core
from jax._src import device_array
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import path
from jax._src import profiler
from jax._src import stages
from jax._src import traceback_util
from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
OpShardingSharding, NamedSharding, PartitionSpec,
Sharding)
from jax._src import util
from jax._src.abstract_arrays import array_types
from jax._src.config import config, flags
from jax._src.interpreters import ad
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import use_stablehlo
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
import jax._src.util as util
from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
OpShardingSharding, NamedSharding, PartitionSpec,
Sharding)
from jax._src.util import flatten, unflatten
from jax._src import path
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"

View File

@ -42,12 +42,11 @@ import sys
import threading
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast)
import numpy as np
import jax
from jax._src import linear_util as lu
from jax.errors import JAXTypeError
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
@ -59,17 +58,19 @@ from jax._src import api_util
from jax._src import basearray
from jax._src import core
from jax._src import device_array
from jax._src import dtypes
from jax._src import source_info_util
from jax._src import util
from jax._src import dispatch
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import profiler
from jax._src import stages
from jax._src import sharding as sharding_internal
from jax._src import source_info_util
from jax._src import stages
from jax._src import util
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.config import flags
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.interpreters import ad
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version

View File

@ -73,13 +73,17 @@ from functools import partial
from typing import (Any, Tuple)
import numpy as np
from jax._src import core
from jax.interpreters import xla
from jax.interpreters import batching
from jax._src import ad_util
from jax._src import core
from jax._src import dtypes
from jax._src.interpreters import ad
from jax._src.lax import lax
from jax._src.lib import xla_client as xc
from jax.interpreters import ad, xla, batching
Array = Any

View File

@ -19,15 +19,16 @@ from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src.interpreters import ad
from jax._src.lax import lax
from jax._src.lib.mlir.dialects import hlo
from jax.interpreters import batching
from jax.interpreters import mlir
_max = builtins.max
Array = Any

View File

@ -18,13 +18,14 @@ from typing import Union, Sequence
import numpy as np
from jax import lax
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import xla
from jax import lax
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.api import jit, linear_transpose, ShapeDtypeStruct
from jax._src.core import Primitive, is_constant_shape
from jax._src.interpreters import ad
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client
from jax._src.lib import ducc_fft

View File

@ -25,40 +25,33 @@ import warnings
import numpy as np
import jax
from jax._src import core
from jax import tree_util
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax.interpreters import xla
from jax.interpreters.batching import ConcatAxis
from jax.tree_util import tree_map
from jax._src import ad_util
from jax._src import api
from jax._src import api_util
from jax._src import array
from jax._src import core
from jax._src import device_array
from jax._src import dispatch
from jax._src import linear_util as lu
from jax._src import dtypes
from jax import tree_util
from jax._src import linear_util as lu
from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src.sharding import PmapSharding
from jax._src import util
from jax._src.abstract_arrays import array_types
from jax._src.config import config
from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
raise_to_shaped, abstract_token, canonicalize_shape)
from jax._src.abstract_arrays import array_types
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters.batching import ConcatAxis
import jax._src.pretty_printer as pp
from jax._src import util
from jax._src.util import (cache, prod, safe_zip, safe_map, canonicalize_axis,
split_list)
from jax.tree_util import tree_map
from jax._src.lib import pytree
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.interpreters import ad
from jax._src.lax import slicing
from jax._src.lax.utils import (
_input_dtype,
standard_abstract_eval,
@ -67,8 +60,16 @@ from jax._src.lax.utils import (
standard_primitive,
standard_translate,
)
from jax._src.lax import slicing
from jax._src.lib import pytree
from jax._src.lib import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding import PmapSharding
from jax._src.typing import Array, ArrayLike, DTypeLike, Shape
from jax._src.util import (cache, prod, safe_zip, safe_map, canonicalize_axis,
split_list)
xb = xla_bridge
xc = xla_client

View File

@ -21,38 +21,36 @@ import warnings
import numpy as np
import jax
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.vectorize import vectorize
from jax._src import ad_util
from jax._src import api
from jax import lax
from jax._src import dtypes
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.util import prod
from jax._src import ad_util
from jax._src import api
from jax._src import dtypes
from jax._src.core import (
Primitive, ShapedArray, raise_to_shaped, is_constant_shape)
from jax._src.lax.lax import (
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
_input_dtype)
from jax._src.interpreters import ad
from jax._src.lax import control_flow
from jax._src.lax import eigh as lax_eigh
from jax._src.lax import lax as lax_internal
from jax._src.lax import svd as lax_svd
from jax._src.lib import lapack
from jax._src.lax.lax import (
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
_input_dtype)
from jax._src.lib import gpu_linalg
from jax._src.lib import gpu_solver
from jax._src.lib import gpu_sparse
from jax._src.lib import lapack
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.vectorize import vectorize
from jax._src.typing import Array, ArrayLike
from jax._src.util import prod
xops = xla_client.ops

View File

@ -24,21 +24,23 @@ import warnings
import numpy as np
from jax import tree_util
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import pxla
from jax.interpreters import xla
from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src.core import ShapedArray, AxisName, raise_to_shaped
from jax._src.interpreters import ad
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.numpy import lax_numpy
import jax._src.util as util
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo, use_stablehlo
from jax._src.numpy import lax_numpy
from jax._src.util import (
unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis)
unsafe_map, map = map, safe_map # type: ignore

View File

@ -20,24 +20,25 @@ import weakref
import numpy as np
import jax
from jax._src import ad_util
from jax._src import core
from jax._src import dtypes
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax._src import ad_util
from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src.interpreters import ad
from jax._src.lax import lax
from jax._src.lax.utils import (
_argnum_weak_type,
_input_dtype,
standard_primitive,
)
from jax._src.lax import lax
from jax._src import util
from jax._src.util import safe_map, safe_zip
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.typing import Array, ArrayLike, Shape
from jax._src.util import safe_map, safe_zip
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

View File

@ -18,18 +18,17 @@ import warnings
import numpy as np
from jax.interpreters import ad
from jax import tree_util
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import xla
from jax import tree_util
from jax._src import ad_util
from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src.core import ShapedArray, ConcreteArray
from jax._src.interpreters import ad
from jax._src.lax import lax
from jax._src.lax import convolution
from jax._src.lax import slicing

View File

@ -24,41 +24,42 @@ import threading
import warnings
import jax
from jax import core
from jax import stages
from jax.errors import JAXTypeError
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax.interpreters import xla
from jax.interpreters.pxla import PartitionSpec
from jax.tree_util import (
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
treedef_tuple)
from jax._src.sharding import (
NamedSharding, Sharding, XLACompatibleSharding, OpShardingSharding,
XLADeviceAssignment, SingleDeviceSharding, PmapSharding)
from jax import core
from jax._src import linear_util as lu
from jax import stages
from jax._src import array
from jax._src.config import config
from jax._src import dispatch
from jax._src import linear_util as lu
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.traceback_util import api_boundary
from jax._src.api_util import (argnums_partial_except, flatten_axes,
flatten_fun, flatten_fun_nokwargs,
donation_vector, shaped_abstractify,
check_callable, argnames_partial_except,
resolve_argnums, FLAGS)
from jax.errors import JAXTypeError
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import pxla
from jax.interpreters import xla
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters.pxla import PartitionSpec
from jax._src import util
from jax._src.api_util import (
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
donation_vector, shaped_abstractify, check_callable,
argnames_partial_except, resolve_argnums, FLAGS)
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
treedef_is_leaf, tree_structure, treedef_tuple)
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import prefix_errors
from jax._src import util
from jax._src.util import (
HashableFunction, safe_map, safe_zip, wraps,
distributed_debug_log, split_list, tuple_insert, weakref_lru_cache,

View File

@ -25,26 +25,26 @@ from jax import lax
from jax import numpy as jnp
from jax.config import config
from jax.dtypes import float0
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import pxla
from jax.interpreters import xla
from jax._src import basearray
from jax._src.sharding import (
NamedSharding, PmapSharding, OpShardingSharding)
from jax._src import basearray
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src.api import jit, vmap
from jax._src.interpreters import ad
from jax._src.lax import lax as lax_internal
from jax._src.lax import utils as lax_utils
from jax._src.lib import gpu_prng
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy import lax_numpy
from jax._src.sharding import (
NamedSharding, PmapSharding, OpShardingSharding)
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
from jax._src.lib import gpu_prng
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

View File

@ -24,18 +24,21 @@ import jax
import jax.numpy as jnp
from jax import lax
from jax.config import config
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.numpy.linalg import cholesky, svd, eigh
from jax._src import core
from jax._src import dtypes
from jax._src import prng
from jax._src.api import jit, vmap
from jax._src.core import NamedShape
from jax._src.interpreters import ad
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_bridge
from jax._src.numpy.lax_numpy import _arraylike, _check_arraylike, _convert_and_clip_integer, _promote_dtypes_inexact
from jax._src.numpy.lax_numpy import (
_arraylike, _check_arraylike, _convert_and_clip_integer,
_promote_dtypes_inexact)
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.util import prod, canonicalize_axis

View File

@ -18,21 +18,26 @@ import operator
import os
import re
import threading
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union, cast
from typing import (
Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union,
cast)
from absl import logging
import numpy as np
import jax
from jax import lax
from jax import config
from jax import core, custom_derivatives
from jax._src import linear_util as lu
from jax import random, tree_util
from jax import core
from jax import custom_derivatives
from jax import random
from jax import numpy as jnp
from jax import tree_util
from jax.experimental import maps
from jax.experimental import pjit
from jax._src import sharding
from jax.interpreters import ad
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.jax2tf import shape_poly
from jax.experimental.jax2tf import impl_no_xla
from jax.interpreters import mlir
from jax.interpreters import pxla
from jax.interpreters import xla
@ -43,10 +48,13 @@ from jax._src import api
from jax._src import api_util
from jax._src import dispatch
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import prng
from jax._src import random as random_internal
from jax._src import sharding
from jax._src import source_info_util
from jax._src import util
from jax._src.interpreters import ad
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax import lax as lax_internal
from jax._src.lax import linalg as lax_linalg
@ -55,12 +63,6 @@ from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src.lib import xla_client
from jax._src.numpy.ufuncs import logaddexp
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.jax2tf import shape_poly
from jax.experimental.jax2tf import impl_no_xla
import numpy as np
import tensorflow as tf # type: ignore[import]
# These don't have public equivalents.

View File

@ -42,10 +42,11 @@ from jax.experimental.sparse.bcsr import BCSR
from jax.experimental.sparse.coo import COO
from jax.experimental.sparse.csr import CSR, CSC
from jax.experimental.sparse.util import _coo_extract
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src import dtypes
from jax._src.interpreters import ad
from jax._src.typing import Array, DTypeLike, Shape

View File

@ -39,9 +39,9 @@ from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
import jax.numpy as jnp
from jax.interpreters import ad
from jax.util import safe_zip, unzip2, split_list
from jax._src import api_util
from jax._src.interpreters import ad
from jax._src.lax.lax import (
_const, ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
DotDimensionNumbers)

View File

@ -24,11 +24,11 @@ import numpy as np
from jax import core
from jax import lax
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning
from jax import tree_util
from jax._src.interpreters import ad
from jax._src.lax.lax import _const
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import gpu_sparse

View File

@ -22,13 +22,13 @@ import warnings
import numpy as np
from jax import core
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.experimental.sparse._base import JAXSparse
from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo
from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, CuSparseEfficiencyWarning
from jax import lax
from jax import tree_util
from jax._src.interpreters import ad
from jax._src.lax.lax import _const
from jax._src.lib import gpu_sparse
from jax._src.numpy.lax_numpy import _promote_dtypes

View File

@ -28,7 +28,6 @@ import numpy as np
from jax.config import config
from jax.interpreters import partial_eval as pe
from jax.interpreters import ad
from jax._src import core
from jax._src import device_array
@ -37,6 +36,7 @@ from jax._src import pretty_printer as pp
from jax._src import source_info_util
from jax._src.abstract_arrays import numpy_scalar_types
from jax._src.core import ConcreteArray, ShapedArray, str_eqn_compact
from jax._src.interpreters import ad
from jax._src.util import (prod, new_name_stack, safe_zip, safe_map,
partition_list)