mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
migrate internal dependencies from jax.interpreters.batching
to jax._src.interpreters.batching
... in preparation for paring down `jax.interpreters.batching`'s exported symbols. PiperOrigin-RevId: 508487887
This commit is contained in:
parent
12dc73dc6e
commit
1c84e4a753
@ -18,7 +18,6 @@ from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any,
|
||||
import types
|
||||
|
||||
import jax
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
@ -31,6 +30,7 @@ from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.api_util import flatten_fun, shaped_abstractify
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import convolution as lax_convolution
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
@ -78,7 +78,6 @@ from jax.custom_transpose import custom_transpose
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import batching
|
||||
|
||||
from jax._src.config import (
|
||||
config,
|
||||
@ -90,6 +89,7 @@ from jax._src.config import (
|
||||
explicit_device_get_scope as config_explicit_device_get_scope)
|
||||
from jax._src.core import ShapedArray, raise_to_shaped
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
|
||||
local_devices, process_index,
|
||||
|
@ -20,7 +20,6 @@ from typing import Any, Callable, Sequence
|
||||
import numpy as np
|
||||
|
||||
from jax import tree_util
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
|
||||
from jax._src import core
|
||||
@ -28,6 +27,7 @@ from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src import dispatch
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
# `pure_callback_p` is the main primitive for staging out Python pure callbacks.
|
||||
|
@ -27,7 +27,6 @@ from jax import lax
|
||||
from jax.api_util import flatten_fun
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_flatten
|
||||
@ -42,6 +41,7 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.lax import control_flow as cf
|
||||
from jax._src.sharding import OpShardingSharding
|
||||
from jax._src.typing import Array
|
||||
|
@ -18,8 +18,6 @@ from typing import Callable, Optional
|
||||
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters.batching import not_mapped
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
@ -33,6 +31,8 @@ from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.api_util import flatten_fun_nokwargs
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters.batching import not_mapped
|
||||
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
|
@ -23,10 +23,7 @@ from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
|
||||
register_pytree_node_class, tree_leaves)
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters.batching import not_mapped
|
||||
from jax.config import config
|
||||
|
||||
from jax._src import core
|
||||
@ -38,6 +35,9 @@ from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p
|
||||
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
|
||||
from jax._src.core import raise_to_shaped
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters.batching import not_mapped
|
||||
from jax._src.lax import lax
|
||||
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable
|
||||
|
||||
|
@ -27,8 +27,6 @@ from jax import tree_util
|
||||
from jax import lax
|
||||
from jax.config import config
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import pxla
|
||||
|
||||
@ -38,6 +36,8 @@ from jax._src import custom_derivatives
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lax import control_flow as lcf
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib.mlir import ir
|
||||
|
@ -34,9 +34,7 @@ import numpy as np
|
||||
import jax
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.monitoring import record_event_duration_secs
|
||||
import jax.interpreters.batching as batching
|
||||
import jax.interpreters.mlir as mlir
|
||||
import jax._src.interpreters.xla as xla
|
||||
from jax.interpreters import pxla
|
||||
import jax.interpreters.partial_eval as pe
|
||||
|
||||
@ -53,6 +51,8 @@ from jax._src import util
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.config import config, flags
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import use_stablehlo
|
||||
from jax._src.lib import pmap_lib
|
||||
|
@ -47,8 +47,6 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_flatten, tree_map
|
||||
|
||||
@ -70,6 +68,8 @@ from jax._src.config import config
|
||||
from jax._src.config import flags
|
||||
from jax._src.core import ConcreteArray, ShapedArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
@ -75,12 +75,12 @@ from typing import (Any, Tuple)
|
||||
import numpy as np
|
||||
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import batching
|
||||
|
||||
from jax._src import ad_util
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
||||
|
@ -23,10 +23,10 @@ from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
|
||||
_max = builtins.max
|
||||
|
@ -19,13 +19,13 @@ from typing import Union, Sequence
|
||||
import numpy as np
|
||||
|
||||
from jax import lax
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src.api import jit, linear_transpose, ShapeDtypeStruct
|
||||
from jax._src.core import Primitive, is_constant_shape
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import ducc_fft
|
||||
|
@ -26,12 +26,9 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters.batching import ConcatAxis
|
||||
from jax.tree_util import tree_map
|
||||
|
||||
from jax._src import ad_util
|
||||
@ -51,6 +48,9 @@ from jax._src.config import config
|
||||
from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
|
||||
raise_to_shaped, abstract_token, canonicalize_shape)
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters.batching import ConcatAxis
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.lax.utils import (
|
||||
_input_dtype,
|
||||
|
@ -22,8 +22,6 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import ad_util
|
||||
@ -32,6 +30,8 @@ from jax._src import dtypes
|
||||
from jax._src.core import (
|
||||
Primitive, ShapedArray, raise_to_shaped, is_constant_shape)
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lax import control_flow
|
||||
from jax._src.lax import eigh as lax_eigh
|
||||
from jax._src.lax import lax as lax_internal
|
||||
|
@ -24,7 +24,6 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
from jax import tree_util
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
@ -34,6 +33,7 @@ from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.core import ShapedArray, AxisName, raise_to_shaped
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax import slicing
|
||||
from jax._src.lib.mlir import ir
|
||||
|
@ -20,8 +20,6 @@ import weakref
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
from jax._src import ad_util
|
||||
@ -29,6 +27,8 @@ from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax.utils import (
|
||||
_argnum_weak_type,
|
||||
|
@ -19,8 +19,6 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
from jax import tree_util
|
||||
from jax.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import ad_util
|
||||
@ -29,6 +27,8 @@ from jax._src import dtypes
|
||||
from jax._src import util
|
||||
from jax._src.core import ShapedArray, ConcreteArray
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax import convolution
|
||||
from jax._src.lax import slicing
|
||||
|
@ -28,8 +28,6 @@ from jax._src import core
|
||||
from jax import stages
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
|
||||
from jax.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax._src.interpreters.pxla import PartitionSpec
|
||||
@ -52,6 +50,8 @@ from jax._src.api_util import (
|
||||
argnames_partial_except, resolve_argnums, FLAGS)
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import func as func_dialect
|
||||
|
@ -25,9 +25,6 @@ from jax import lax
|
||||
from jax import numpy as jnp
|
||||
from jax.config import config
|
||||
from jax.dtypes import float0
|
||||
from jax.interpreters import batching
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import basearray
|
||||
@ -37,6 +34,9 @@ from jax._src import dtypes
|
||||
from jax._src import pretty_printer as pp
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lax import utils as lax_utils
|
||||
from jax._src.lib import gpu_prng
|
||||
|
@ -24,7 +24,6 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax.config import config
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.numpy.linalg import cholesky, svd, eigh
|
||||
|
||||
@ -34,6 +33,7 @@ from jax._src import prng
|
||||
from jax._src.api import jit, vmap
|
||||
from jax._src.core import NamedShape
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.lib import xla_bridge
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
|
@ -42,11 +42,11 @@ from jax.experimental.sparse.bcsr import BCSR
|
||||
from jax.experimental.sparse.coo import COO
|
||||
from jax.experimental.sparse.csr import CSR, CSC
|
||||
from jax.experimental.sparse.util import _coo_extract
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
|
||||
from jax._src import dtypes
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.typing import Array, DTypeLike, Shape
|
||||
|
||||
|
||||
|
@ -35,13 +35,13 @@ from jax.experimental.sparse.util import (
|
||||
_dot_general_validated_shape, CuSparseEfficiencyWarning,
|
||||
SparseEfficiencyError, SparseEfficiencyWarning, Shape,
|
||||
SparseInfo)
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import mlir
|
||||
import jax.numpy as jnp
|
||||
from jax.util import safe_zip, unzip2, split_list
|
||||
from jax._src import api_util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.lax.lax import (
|
||||
_const, ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
|
||||
DotDimensionNumbers)
|
||||
|
@ -23,6 +23,7 @@ from typing import NamedTuple, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax.numpy as jnp
|
||||
from jax import config
|
||||
from jax import core
|
||||
from jax import lax
|
||||
@ -33,18 +34,19 @@ from jax.experimental.sparse.util import (
|
||||
nfold_vmap, _count_stored_elements,
|
||||
_csr_to_coo, _dot_general_validated_shape,
|
||||
CuSparseEfficiencyWarning, SparseInfo, Shape)
|
||||
import jax.numpy as jnp
|
||||
from jax.util import split_list, safe_zip
|
||||
|
||||
from jax._src import api_util
|
||||
from jax._src.lax.lax import DotDimensionNumbers
|
||||
from jax._src.lib import gpu_sparse
|
||||
from jax.util import split_list, safe_zip
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.typing import Array, ArrayLike, DTypeLike
|
||||
|
||||
|
||||
|
||||
def bcsr_eliminate_zeros(mat: BCSR, nse: Optional[int] = None) -> BCSR:
|
||||
"""Eliminate zeros in BCSR representation."""
|
||||
return BCSR.from_bcoo(bcoo.bcoo_eliminate_zeros(mat.to_bcoo(), nse=nse))
|
||||
|
Loading…
x
Reference in New Issue
Block a user