mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
migrate internal dependencies from jax.interpreters.ad
to jax._src.interpreters.ad
... in preparation for paring down `jax.interpreters.ad`'s exported symbols. Includes some import fixups along the way. PiperOrigin-RevId: 507684262
This commit is contained in:
parent
c252162821
commit
219723c738
@ -18,8 +18,6 @@ from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any,
|
||||
import types
|
||||
|
||||
import jax
|
||||
from jax._src import linear_util as lu
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
@ -27,10 +25,12 @@ from jax.interpreters import xla
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import util
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.api_util import flatten_fun, shaped_abstractify
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import convolution as lax_convolution
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
@ -68,12 +68,8 @@ from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list,
|
||||
extend_name_stack, new_name_stack, wrap_name, cache,
|
||||
wraps, HashableFunction, weakref_lru_cache)
|
||||
|
||||
|
||||
# Unused imports to be exported
|
||||
from jax._src.core import ShapedArray, raise_to_shaped
|
||||
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
|
||||
local_devices, process_index,
|
||||
process_count, host_id, host_ids,
|
||||
host_count, default_backend)
|
||||
from jax.ad_checkpoint import checkpoint_policies, checkpoint as new_checkpoint
|
||||
from jax.custom_batching import custom_vmap
|
||||
from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,
|
||||
@ -82,8 +78,6 @@ from jax.custom_transpose import custom_transpose
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
|
||||
from jax._src.config import (
|
||||
@ -94,6 +88,13 @@ from jax._src.config import (
|
||||
_thread_local_state as config_thread_local_state,
|
||||
explicit_device_put_scope as config_explicit_device_put_scope,
|
||||
explicit_device_get_scope as config_explicit_device_get_scope)
|
||||
from jax._src.core import ShapedArray, raise_to_shaped
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
|
||||
local_devices, process_index,
|
||||
process_count, host_id, host_ids,
|
||||
host_count, default_backend)
|
||||
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
@ -15,19 +15,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
|
||||
from typing import Any, Callable, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import tree_util
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src import dispatch
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
import numpy as np
|
||||
|
||||
# `pure_callback_p` is the main primitive for staging out Python pure callbacks.
|
||||
pure_callback_p = core.Primitive("pure_callback")
|
||||
|
@ -15,10 +15,25 @@
|
||||
import dataclasses
|
||||
import functools
|
||||
import itertools as it
|
||||
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable, Type, Set, List, Sequence, Any
|
||||
from typing import (Union, Optional, Callable, Dict, Tuple, TypeVar,
|
||||
FrozenSet, Iterable, Type, Set, List, Sequence, Any)
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax.tree_util as jtu
|
||||
from jax import lax
|
||||
from jax.api_util import flatten_fun
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_flatten
|
||||
from jax.tree_util import tree_map
|
||||
from jax.tree_util import tree_unflatten
|
||||
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
@ -26,24 +41,12 @@ from jax._src import prng
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax import control_flow as cf
|
||||
from jax._src.sharding import OpShardingSharding
|
||||
from jax._src.typing import Array
|
||||
from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip,
|
||||
unzip3, weakref_lru_cache)
|
||||
from jax.api_util import flatten_fun
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_flatten
|
||||
from jax.tree_util import tree_map
|
||||
from jax.tree_util import tree_unflatten
|
||||
import jax.numpy as jnp
|
||||
import jax.tree_util as jtu
|
||||
import numpy as np
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
@ -17,9 +17,7 @@ import operator
|
||||
from typing import Callable, Optional
|
||||
|
||||
import jax
|
||||
from jax._src import linear_util as lu
|
||||
from jax import tree_util
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters.batching import not_mapped
|
||||
from jax.interpreters import mlir
|
||||
@ -29,10 +27,12 @@ from jax.tree_util import (tree_flatten, tree_map, tree_structure,
|
||||
tree_unflatten, treedef_tuple)
|
||||
from jax._src import core
|
||||
from jax._src import custom_api_util
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.api_util import flatten_fun_nokwargs
|
||||
from jax._src.interpreters import ad
|
||||
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
|
@ -17,29 +17,31 @@ import inspect
|
||||
from typing import (Callable, Generic, Optional, Sequence, Tuple, TypeVar, Set,
|
||||
Any)
|
||||
|
||||
from jax._src import linear_util as lu
|
||||
from jax.custom_transpose import custom_transpose
|
||||
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
|
||||
treedef_is_leaf, treedef_tuple,
|
||||
register_pytree_node_class, tree_leaves)
|
||||
from jax._src import core
|
||||
from jax._src import custom_api_util
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax
|
||||
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
|
||||
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
|
||||
from jax._src.core import raise_to_shaped
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters.batching import not_mapped
|
||||
from jax.config import config
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import custom_api_util
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import traceback_util
|
||||
from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p
|
||||
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
|
||||
from jax._src.core import raise_to_shaped
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax import lax
|
||||
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
|
||||
|
||||
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
map = safe_map
|
||||
|
@ -15,8 +15,6 @@
|
||||
import functools
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
|
||||
from jax._src import linear_util as lu
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
@ -26,9 +24,11 @@ from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import core
|
||||
from jax._src import custom_api_util
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
|
@ -12,36 +12,38 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Module for JAX debugging primitives and related functionality."""
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import string
|
||||
import sys
|
||||
from typing import Any, Dict, Callable, Optional, Sequence, Set, Tuple, Union
|
||||
import weakref
|
||||
|
||||
from typing import Any, Dict, Callable, Optional, Sequence, Set, Tuple, Union
|
||||
import numpy as np
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax import lax
|
||||
from jax._src import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import pxla
|
||||
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax import control_flow as lcf
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.sharding import Sharding, OpShardingSharding
|
||||
import jax.numpy as jnp
|
||||
|
||||
import numpy as np
|
||||
# pytype: disable=import-error
|
||||
try:
|
||||
import rich
|
||||
|
@ -32,34 +32,36 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src import linear_util as lu
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.monitoring import record_event_duration_secs
|
||||
import jax.interpreters.ad as ad
|
||||
import jax.interpreters.batching as batching
|
||||
import jax.interpreters.mlir as mlir
|
||||
import jax.interpreters.xla as xla
|
||||
from jax.interpreters import pxla
|
||||
import jax.interpreters.partial_eval as pe
|
||||
|
||||
from jax._src import array
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import path
|
||||
from jax._src import profiler
|
||||
from jax._src import stages
|
||||
from jax._src import traceback_util
|
||||
from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
|
||||
OpShardingSharding, NamedSharding, PartitionSpec,
|
||||
Sharding)
|
||||
from jax._src import util
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.config import config, flags
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import use_stablehlo
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
import jax._src.util as util
|
||||
from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
|
||||
OpShardingSharding, NamedSharding, PartitionSpec,
|
||||
Sharding)
|
||||
from jax._src.util import flatten, unflatten
|
||||
from jax._src import path
|
||||
|
||||
|
||||
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
|
||||
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
|
||||
|
@ -42,12 +42,11 @@ import sys
|
||||
import threading
|
||||
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
|
||||
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast)
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src import linear_util as lu
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
@ -59,17 +58,19 @@ from jax._src import api_util
|
||||
from jax._src import basearray
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dtypes
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import profiler
|
||||
from jax._src import stages
|
||||
from jax._src import sharding as sharding_internal
|
||||
from jax._src import source_info_util
|
||||
from jax._src import stages
|
||||
from jax._src import util
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.config import config
|
||||
from jax._src.config import flags
|
||||
from jax._src.core import ConcreteArray, ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
@ -73,13 +73,17 @@ from functools import partial
|
||||
from typing import (Any, Tuple)
|
||||
|
||||
import numpy as np
|
||||
from jax._src import core
|
||||
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import batching
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
from jax.interpreters import ad, xla, batching
|
||||
|
||||
Array = Any
|
||||
|
||||
|
@ -19,15 +19,16 @@ from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
|
||||
_max = builtins.max
|
||||
|
||||
Array = Any
|
||||
|
@ -18,13 +18,14 @@ from typing import Union, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax import lax
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
|
||||
from jax._src.api import jit, linear_transpose, ShapeDtypeStruct
|
||||
from jax._src.core import Primitive, is_constant_shape
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import ducc_fft
|
||||
|
@ -25,40 +25,33 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax import tree_util
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters.batching import ConcatAxis
|
||||
from jax.tree_util import tree_map
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax._src import api_util
|
||||
from jax._src import array
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
from jax._src import dispatch
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import dtypes
|
||||
from jax import tree_util
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src import source_info_util
|
||||
from jax._src.sharding import PmapSharding
|
||||
from jax._src import util
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.config import config
|
||||
from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
|
||||
raise_to_shaped, abstract_token, canonicalize_shape)
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters.batching import ConcatAxis
|
||||
import jax._src.pretty_printer as pp
|
||||
from jax._src import util
|
||||
from jax._src.util import (cache, prod, safe_zip, safe_map, canonicalize_axis,
|
||||
split_list)
|
||||
from jax.tree_util import tree_map
|
||||
from jax._src.lib import pytree
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.lax.utils import (
|
||||
_input_dtype,
|
||||
standard_abstract_eval,
|
||||
@ -67,8 +60,16 @@ from jax._src.lax.utils import (
|
||||
standard_primitive,
|
||||
standard_translate,
|
||||
)
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.lib import pytree
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.sharding import PmapSharding
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike, Shape
|
||||
from jax._src.util import (cache, prod, safe_zip, safe_map, canonicalize_axis,
|
||||
split_list)
|
||||
|
||||
xb = xla_bridge
|
||||
xc = xla_client
|
||||
|
@ -21,38 +21,36 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax import lax
|
||||
from jax._src import dtypes
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax._src.util import prod
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import api
|
||||
from jax._src import dtypes
|
||||
from jax._src.core import (
|
||||
Primitive, ShapedArray, raise_to_shaped, is_constant_shape)
|
||||
from jax._src.lax.lax import (
|
||||
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
|
||||
_input_dtype)
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax import control_flow
|
||||
from jax._src.lax import eigh as lax_eigh
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import svd as lax_svd
|
||||
from jax._src.lib import lapack
|
||||
|
||||
from jax._src.lax.lax import (
|
||||
standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex,
|
||||
_input_dtype)
|
||||
from jax._src.lib import gpu_linalg
|
||||
from jax._src.lib import gpu_solver
|
||||
from jax._src.lib import gpu_sparse
|
||||
|
||||
from jax._src.lib import lapack
|
||||
from jax._src.lib import xla_client
|
||||
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import chlo
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy import lax_numpy as jnp
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax._src.util import prod
|
||||
|
||||
xops = xla_client.ops
|
||||
|
||||
|
@ -24,21 +24,23 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
from jax import tree_util
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.core import ShapedArray, AxisName, raise_to_shaped
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.numpy import lax_numpy
|
||||
import jax._src.util as util
|
||||
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo, use_stablehlo
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.util import (
|
||||
unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis)
|
||||
|
||||
unsafe_map, map = map, safe_map # type: ignore
|
||||
|
||||
|
@ -20,24 +20,25 @@ import weakref
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax.utils import (
|
||||
_argnum_weak_type,
|
||||
_input_dtype,
|
||||
standard_primitive,
|
||||
)
|
||||
from jax._src.lax import lax
|
||||
from jax._src import util
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.typing import Array, ArrayLike, Shape
|
||||
from jax._src.util import safe_map, safe_zip
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
@ -18,18 +18,17 @@ import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from jax.interpreters import ad
|
||||
from jax import tree_util
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax import tree_util
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.core import ShapedArray, ConcreteArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax import convolution
|
||||
from jax._src.lax import slicing
|
||||
|
@ -24,41 +24,42 @@ import threading
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax import stages
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters.pxla import PartitionSpec
|
||||
from jax.tree_util import (
|
||||
tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure,
|
||||
treedef_tuple)
|
||||
|
||||
from jax._src.sharding import (
|
||||
NamedSharding, Sharding, XLACompatibleSharding, OpShardingSharding,
|
||||
XLADeviceAssignment, SingleDeviceSharding, PmapSharding)
|
||||
from jax import core
|
||||
from jax._src import linear_util as lu
|
||||
from jax import stages
|
||||
from jax._src import array
|
||||
from jax._src.config import config
|
||||
from jax._src import dispatch
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.api_util import (argnums_partial_except, flatten_axes,
|
||||
flatten_fun, flatten_fun_nokwargs,
|
||||
donation_vector, shaped_abstractify,
|
||||
check_callable, argnames_partial_except,
|
||||
resolve_argnums, FLAGS)
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters.pxla import PartitionSpec
|
||||
from jax._src import util
|
||||
from jax._src.api_util import (
|
||||
argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs,
|
||||
donation_vector, shaped_abstractify, check_callable,
|
||||
argnames_partial_except, resolve_argnums, FLAGS)
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
|
||||
treedef_is_leaf, tree_structure, treedef_tuple)
|
||||
from jax._src.traceback_util import api_boundary
|
||||
from jax._src.tree_util import prefix_errors
|
||||
from jax._src import util
|
||||
from jax._src.util import (
|
||||
HashableFunction, safe_map, safe_zip, wraps,
|
||||
distributed_debug_log, split_list, tuple_insert, weakref_lru_cache,
|
||||
|
@ -25,26 +25,26 @@ from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax.config import config
|
||||
from jax.dtypes import float0
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax._src import basearray
|
||||
from jax._src.sharding import (
|
||||
NamedSharding, PmapSharding, OpShardingSharding)
|
||||
|
||||
from jax._src import basearray
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import utils as lax_utils
|
||||
from jax._src.lib import gpu_prng
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.numpy import lax_numpy
|
||||
from jax._src.sharding import (
|
||||
NamedSharding, PmapSharding, OpShardingSharding)
|
||||
from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip
|
||||
from jax._src.lib import gpu_prng
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
@ -24,18 +24,21 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax.config import config
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.numpy.linalg import cholesky, svd, eigh
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import prng
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.core import NamedShape
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.numpy.lax_numpy import _arraylike, _check_arraylike, _convert_and_clip_integer, _promote_dtypes_inexact
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
_arraylike, _check_arraylike, _convert_and_clip_integer,
|
||||
_promote_dtypes_inexact)
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
from jax._src.util import prod, canonicalize_axis
|
||||
|
||||
|
@ -18,21 +18,26 @@ import operator
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union, cast
|
||||
from typing import (
|
||||
Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union,
|
||||
cast)
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax import config
|
||||
from jax import core, custom_derivatives
|
||||
from jax._src import linear_util as lu
|
||||
from jax import random, tree_util
|
||||
from jax import core
|
||||
from jax import custom_derivatives
|
||||
from jax import random
|
||||
from jax import numpy as jnp
|
||||
from jax import tree_util
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax._src import sharding
|
||||
from jax.interpreters import ad
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.jax2tf import shape_poly
|
||||
from jax.experimental.jax2tf import impl_no_xla
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
@ -43,10 +48,13 @@ from jax._src import api
|
||||
from jax._src import api_util
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import prng
|
||||
from jax._src import random as random_internal
|
||||
from jax._src import sharding
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax import control_flow as lax_control_flow
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import linalg as lax_linalg
|
||||
@ -55,12 +63,6 @@ from jax._src.lax import windowed_reductions as lax_windowed_reductions
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.numpy.ufuncs import logaddexp
|
||||
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.jax2tf import shape_poly
|
||||
from jax.experimental.jax2tf import impl_no_xla
|
||||
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
|
||||
# These don't have public equivalents.
|
||||
|
@ -42,10 +42,11 @@ from jax.experimental.sparse.bcsr import BCSR
|
||||
from jax.experimental.sparse.coo import COO
|
||||
from jax.experimental.sparse.csr import CSR, CSC
|
||||
from jax.experimental.sparse.util import _coo_extract
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.typing import Array, DTypeLike, Shape
|
||||
|
||||
|
||||
|
@ -39,9 +39,9 @@ from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
import jax.numpy as jnp
|
||||
from jax.interpreters import ad
|
||||
from jax.util import safe_zip, unzip2, split_list
|
||||
from jax._src import api_util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax.lax import (
|
||||
_const, ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
|
||||
DotDimensionNumbers)
|
||||
|
@ -24,11 +24,11 @@ import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax import lax
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.experimental.sparse._base import JAXSparse
|
||||
from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning
|
||||
from jax import tree_util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax.lax import _const
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib import gpu_sparse
|
||||
|
@ -22,13 +22,13 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
from jax import core
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.experimental.sparse._base import JAXSparse
|
||||
from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo
|
||||
from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, CuSparseEfficiencyWarning
|
||||
from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.lax.lax import _const
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax._src.numpy.lax_numpy import _promote_dtypes
|
||||
|
@ -28,7 +28,6 @@ import numpy as np
|
||||
|
||||
from jax.config import config
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import ad
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import device_array
|
||||
@ -37,6 +36,7 @@ from jax._src import pretty_printer as pp
|
||||
from jax._src import source_info_util
|
||||
from jax._src.abstract_arrays import numpy_scalar_types
|
||||
from jax._src.core import ConcreteArray, ShapedArray, str_eqn_compact
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.util import (prod, new_name_stack, safe_zip, safe_map,
|
||||
partition_list)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user