migrate internal dependencies from jax.interpreters.batching to jax._src.interpreters.batching

... in preparation for paring down `jax.interpreters.batching`'s exported symbols.

PiperOrigin-RevId: 508487887
This commit is contained in:
Roy Frostig 2023-02-09 15:11:20 -08:00 committed by jax authors
parent 12dc73dc6e
commit 1c84e4a753
23 changed files with 43 additions and 41 deletions

View File

@ -18,7 +18,6 @@ from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any,
import types
import jax
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
@ -31,6 +30,7 @@ from jax._src import traceback_util
from jax._src import util
from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lax import lax as lax_internal
from jax._src.lax import convolution as lax_convolution
from jax._src.lib.mlir.dialects import hlo

View File

@ -78,7 +78,6 @@ from jax.custom_transpose import custom_transpose
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import batching
from jax._src.config import (
config,
@ -90,6 +89,7 @@ from jax._src.config import (
explicit_device_get_scope as config_explicit_device_get_scope)
from jax._src.core import ShapedArray, raise_to_shaped
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import pxla
from jax._src.lib.xla_bridge import (device_count, local_device_count, devices,
local_devices, process_index,

View File

@ -20,7 +20,6 @@ from typing import Any, Callable, Sequence
import numpy as np
from jax import tree_util
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src import core
@ -28,6 +27,7 @@ from jax._src import dtypes
from jax._src import util
from jax._src import dispatch
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lib import xla_client as xc
# `pure_callback_p` is the main primitive for staging out Python pure callbacks.

View File

@ -27,7 +27,6 @@ from jax import lax
from jax.api_util import flatten_fun
from jax.experimental import maps
from jax.experimental import pjit
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten
@ -42,6 +41,7 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lax import control_flow as cf
from jax._src.sharding import OpShardingSharding
from jax._src.typing import Array

View File

@ -18,8 +18,6 @@ from typing import Callable, Optional
import jax
from jax import tree_util
from jax.interpreters import batching
from jax.interpreters.batching import not_mapped
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
@ -33,6 +31,8 @@ from jax._src import traceback_util
from jax._src import util
from jax._src.api_util import flatten_fun_nokwargs
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters.batching import not_mapped
source_info_util.register_exclusion(__file__)

View File

@ -23,10 +23,7 @@ from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
register_pytree_node_class, tree_leaves)
from jax.errors import UnexpectedTracerError
from jax.interpreters import partial_eval as pe
from jax.interpreters import batching
from jax._src.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters.batching import not_mapped
from jax.config import config
from jax._src import core
@ -38,6 +35,9 @@ from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p
from jax._src.api_util import argnums_partial, flatten_fun_nokwargs
from jax._src.core import raise_to_shaped
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters.batching import not_mapped
from jax._src.lax import lax
from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable

View File

@ -27,8 +27,6 @@ from jax import tree_util
from jax import lax
from jax.config import config
from jax.experimental import pjit
from jax.interpreters import batching
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
@ -38,6 +36,8 @@ from jax._src import custom_derivatives
from jax._src import linear_util as lu
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import control_flow as lcf
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir

View File

@ -34,9 +34,7 @@ import numpy as np
import jax
from jax.errors import UnexpectedTracerError
from jax.monitoring import record_event_duration_secs
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
import jax._src.interpreters.xla as xla
from jax.interpreters import pxla
import jax.interpreters.partial_eval as pe
@ -53,6 +51,8 @@ from jax._src import util
from jax._src.abstract_arrays import array_types
from jax._src.config import config, flags
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import xla
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import use_stablehlo
from jax._src.lib import pmap_lib

View File

@ -47,8 +47,6 @@ import numpy as np
import jax
from jax.errors import JAXTypeError
from jax.interpreters import batching
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten, tree_map
@ -70,6 +68,8 @@ from jax._src.config import config
from jax._src.config import flags
from jax._src.core import ConcreteArray, ShapedArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc

View File

@ -75,12 +75,12 @@ from typing import (Any, Tuple)
import numpy as np
from jax.interpreters import xla
from jax.interpreters import batching
from jax._src import ad_util
from jax._src import core
from jax._src import dtypes
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lax import lax
from jax._src.lib import xla_client as xc

View File

@ -23,10 +23,10 @@ from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lax import lax
from jax._src.lib.mlir.dialects import hlo
from jax.interpreters import batching
from jax.interpreters import mlir
_max = builtins.max

View File

@ -19,13 +19,13 @@ from typing import Union, Sequence
import numpy as np
from jax import lax
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.api import jit, linear_transpose, ShapeDtypeStruct
from jax._src.core import Primitive, is_constant_shape
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client
from jax._src.lib import ducc_fft

View File

@ -26,12 +26,9 @@ import numpy as np
import jax
from jax import tree_util
from jax.interpreters import batching
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax.interpreters import xla
from jax.interpreters.batching import ConcatAxis
from jax.tree_util import tree_map
from jax._src import ad_util
@ -51,6 +48,9 @@ from jax._src.config import config
from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray,
raise_to_shaped, abstract_token, canonicalize_shape)
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters.batching import ConcatAxis
from jax._src.lax import slicing
from jax._src.lax.utils import (
_input_dtype,

View File

@ -22,8 +22,6 @@ import numpy as np
import jax
from jax import lax
from jax.interpreters import batching
from jax._src.interpreters import mlir
from jax.interpreters import xla
from jax._src import ad_util
@ -32,6 +30,8 @@ from jax._src import dtypes
from jax._src.core import (
Primitive, ShapedArray, raise_to_shaped, is_constant_shape)
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import control_flow
from jax._src.lax import eigh as lax_eigh
from jax._src.lax import lax as lax_internal

View File

@ -24,7 +24,6 @@ import warnings
import numpy as np
from jax import tree_util
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import pxla
from jax.interpreters import xla
@ -34,6 +33,7 @@ from jax._src import dtypes
from jax._src import util
from jax._src.core import ShapedArray, AxisName, raise_to_shaped
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lax import lax
from jax._src.lax import slicing
from jax._src.lib.mlir import ir

View File

@ -20,8 +20,6 @@ import weakref
import numpy as np
import jax
from jax.interpreters import batching
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax._src import ad_util
@ -29,6 +27,8 @@ from jax._src import core
from jax._src import dtypes
from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import lax
from jax._src.lax.utils import (
_argnum_weak_type,

View File

@ -19,8 +19,6 @@ import warnings
import numpy as np
from jax import tree_util
from jax.interpreters import batching
from jax._src.interpreters import mlir
from jax.interpreters import xla
from jax._src import ad_util
@ -29,6 +27,8 @@ from jax._src import dtypes
from jax._src import util
from jax._src.core import ShapedArray, ConcreteArray
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import lax
from jax._src.lax import convolution
from jax._src.lax import slicing

View File

@ -28,8 +28,6 @@ from jax._src import core
from jax import stages
from jax.errors import JAXTypeError
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
from jax.interpreters import batching
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src.interpreters.pxla import PartitionSpec
@ -52,6 +50,8 @@ from jax._src.api_util import (
argnames_partial_except, resolve_argnums, FLAGS)
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func as func_dialect

View File

@ -25,9 +25,6 @@ from jax import lax
from jax import numpy as jnp
from jax.config import config
from jax.dtypes import float0
from jax.interpreters import batching
from jax._src.interpreters import pxla
from jax._src.interpreters import mlir
from jax.interpreters import xla
from jax._src import basearray
@ -37,6 +34,9 @@ from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src.api import jit, vmap
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
from jax._src.lax import lax as lax_internal
from jax._src.lax import utils as lax_utils
from jax._src.lib import gpu_prng

View File

@ -24,7 +24,6 @@ import jax
import jax.numpy as jnp
from jax import lax
from jax.config import config
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.numpy.linalg import cholesky, svd, eigh
@ -34,6 +33,7 @@ from jax._src import prng
from jax._src.api import jit, vmap
from jax._src.core import NamedShape
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lax import lax as lax_internal
from jax._src.lib import xla_bridge
from jax._src.numpy.lax_numpy import (

View File

@ -42,11 +42,11 @@ from jax.experimental.sparse.bcsr import BCSR
from jax.experimental.sparse.coo import COO
from jax.experimental.sparse.csr import CSR, CSC
from jax.experimental.sparse.util import _coo_extract
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src import dtypes
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.typing import Array, DTypeLike, Shape

View File

@ -35,13 +35,13 @@ from jax.experimental.sparse.util import (
_dot_general_validated_shape, CuSparseEfficiencyWarning,
SparseEfficiencyError, SparseEfficiencyWarning, Shape,
SparseInfo)
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import mlir
import jax.numpy as jnp
from jax.util import safe_zip, unzip2, split_list
from jax._src import api_util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.lax.lax import (
_const, ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule,
DotDimensionNumbers)

View File

@ -23,6 +23,7 @@ from typing import NamedTuple, Optional, Sequence, Tuple, Union
import numpy as np
import jax.numpy as jnp
from jax import config
from jax import core
from jax import lax
@ -33,18 +34,19 @@ from jax.experimental.sparse.util import (
nfold_vmap, _count_stored_elements,
_csr_to_coo, _dot_general_validated_shape,
CuSparseEfficiencyWarning, SparseInfo, Shape)
import jax.numpy as jnp
from jax.util import split_list, safe_zip
from jax._src import api_util
from jax._src.lax.lax import DotDimensionNumbers
from jax._src.lib import gpu_sparse
from jax.util import split_list, safe_zip
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.lib.mlir.dialects import hlo
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.typing import Array, ArrayLike, DTypeLike
def bcsr_eliminate_zeros(mat: BCSR, nse: Optional[int] = None) -> BCSR:
"""Eliminate zeros in BCSR representation."""
return BCSR.from_bcoo(bcoo.bcoo_eliminate_zeros(mat.to_bcoo(), nse=nse))