From 1c84e4a753209e48db34b07eb762c0ef08553ce5 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Thu, 9 Feb 2023 15:11:20 -0800 Subject: [PATCH] migrate internal dependencies from `jax.interpreters.batching` to `jax._src.interpreters.batching` ... in preparation for paring down `jax.interpreters.batching`'s exported symbols. PiperOrigin-RevId: 508487887 --- jax/_src/ad_checkpoint.py | 2 +- jax/_src/api.py | 2 +- jax/_src/callback.py | 2 +- jax/_src/checkify.py | 2 +- jax/_src/custom_batching.py | 4 ++-- jax/_src/custom_derivatives.py | 6 +++--- jax/_src/debugging.py | 4 ++-- jax/_src/dispatch.py | 4 ++-- jax/_src/interpreters/pxla.py | 4 ++-- jax/_src/lax/ann.py | 2 +- jax/_src/lax/convolution.py | 2 +- jax/_src/lax/fft.py | 2 +- jax/_src/lax/lax.py | 6 +++--- jax/_src/lax/linalg.py | 4 ++-- jax/_src/lax/parallel.py | 2 +- jax/_src/lax/slicing.py | 4 ++-- jax/_src/lax/windowed_reductions.py | 4 ++-- jax/_src/pjit.py | 4 ++-- jax/_src/prng.py | 6 +++--- jax/_src/random.py | 2 +- jax/experimental/sparse/api.py | 2 +- jax/experimental/sparse/bcoo.py | 2 +- jax/experimental/sparse/bcsr.py | 12 +++++++----- 23 files changed, 43 insertions(+), 41 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 230fe01be..fd09e6b45 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -18,7 +18,6 @@ from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any, import types import jax -from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import partial_eval as pe from jax.interpreters import xla @@ -31,6 +30,7 @@ from jax._src import traceback_util from jax._src import util from jax._src.api_util import flatten_fun, shaped_abstractify from jax._src.interpreters import ad +from jax._src.interpreters import batching from jax._src.lax import lax as lax_internal from jax._src.lax import convolution as lax_convolution from jax._src.lib.mlir.dialects import hlo diff --git a/jax/_src/api.py b/jax/_src/api.py index d06634832..3625f13ac 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -78,7 +78,6 @@ from jax.custom_transpose import custom_transpose from jax.interpreters import partial_eval as pe from jax.interpreters import mlir from jax.interpreters import xla -from jax.interpreters import batching from jax._src.config import ( config, @@ -90,6 +89,7 @@ from jax._src.config import ( explicit_device_get_scope as config_explicit_device_get_scope) from jax._src.core import ShapedArray, raise_to_shaped from jax._src.interpreters import ad +from jax._src.interpreters import batching from jax._src.interpreters import pxla from jax._src.lib.xla_bridge import (device_count, local_device_count, devices, local_devices, process_index, diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 290a2950d..f8511ef40 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -20,7 +20,6 @@ from typing import Any, Callable, Sequence import numpy as np from jax import tree_util -from jax.interpreters import batching from jax.interpreters import mlir from jax._src import core @@ -28,6 +27,7 @@ from jax._src import dtypes from jax._src import util from jax._src import dispatch from jax._src.interpreters import ad +from jax._src.interpreters import batching from jax._src.lib import xla_client as xc # `pure_callback_p` is the main primitive for staging out Python pure callbacks. diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 0717561aa..1abffca88 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -27,7 +27,6 @@ from jax import lax from jax.api_util import flatten_fun from jax.experimental import maps from jax.experimental import pjit -from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import partial_eval as pe from jax.tree_util import tree_flatten @@ -42,6 +41,7 @@ from jax._src import source_info_util from jax._src import traceback_util from jax._src.config import config from jax._src.interpreters import ad +from jax._src.interpreters import batching from jax._src.lax import control_flow as cf from jax._src.sharding import OpShardingSharding from jax._src.typing import Array diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index ed4810f68..7baf7aef4 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -18,8 +18,6 @@ from typing import Callable, Optional import jax from jax import tree_util -from jax.interpreters import batching -from jax.interpreters.batching import not_mapped from jax.interpreters import mlir from jax.interpreters import partial_eval as pe from jax.interpreters import xla @@ -33,6 +31,8 @@ from jax._src import traceback_util from jax._src import util from jax._src.api_util import flatten_fun_nokwargs from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters.batching import not_mapped source_info_util.register_exclusion(__file__) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 2d2313c01..c9534d0ac 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -23,10 +23,7 @@ from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, register_pytree_node_class, tree_leaves) from jax.errors import UnexpectedTracerError from jax.interpreters import partial_eval as pe -from jax.interpreters import batching -from jax._src.interpreters import mlir from jax.interpreters import xla -from jax.interpreters.batching import not_mapped from jax.config import config from jax._src import core @@ -38,6 +35,9 @@ from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p from jax._src.api_util import argnums_partial, flatten_fun_nokwargs from jax._src.core import raise_to_shaped from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.interpreters.batching import not_mapped from jax._src.lax import lax from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index d742a9006..df9ec4109 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -27,8 +27,6 @@ from jax import tree_util from jax import lax from jax.config import config from jax.experimental import pjit -from jax.interpreters import batching -from jax._src.interpreters import mlir from jax.interpreters import partial_eval as pe from jax.interpreters import pxla @@ -38,6 +36,8 @@ from jax._src import custom_derivatives from jax._src import linear_util as lu from jax._src import util from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir from jax._src.lax import control_flow as lcf from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 83b9b8821..4692cb54b 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -34,9 +34,7 @@ import numpy as np import jax from jax.errors import UnexpectedTracerError from jax.monitoring import record_event_duration_secs -import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir -import jax._src.interpreters.xla as xla from jax.interpreters import pxla import jax.interpreters.partial_eval as pe @@ -53,6 +51,8 @@ from jax._src import util from jax._src.abstract_arrays import array_types from jax._src.config import config, flags from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import xla from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import use_stablehlo from jax._src.lib import pmap_lib diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 63d60286f..61a6a95a1 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -47,8 +47,6 @@ import numpy as np import jax from jax.errors import JAXTypeError -from jax.interpreters import batching -from jax._src.interpreters import mlir from jax.interpreters import partial_eval as pe from jax.tree_util import tree_flatten, tree_map @@ -70,6 +68,8 @@ from jax._src.config import config from jax._src.config import flags from jax._src.core import ConcreteArray, ShapedArray from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir from jax._src.interpreters import xla from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index 0ce798f29..00d66550a 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -75,12 +75,12 @@ from typing import (Any, Tuple) import numpy as np from jax.interpreters import xla -from jax.interpreters import batching from jax._src import ad_util from jax._src import core from jax._src import dtypes from jax._src.interpreters import ad +from jax._src.interpreters import batching from jax._src.lax import lax from jax._src.lib import xla_client as xc diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 578c4bf04..227551450 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -23,10 +23,10 @@ from jax._src import core from jax._src import dtypes from jax._src import util from jax._src.interpreters import ad +from jax._src.interpreters import batching from jax._src.lax import lax from jax._src.lib.mlir.dialects import hlo -from jax.interpreters import batching from jax.interpreters import mlir _max = builtins.max diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index cc8ac1774..b199adacb 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -19,13 +19,13 @@ from typing import Union, Sequence import numpy as np from jax import lax -from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import xla from jax._src.api import jit, linear_transpose, ShapeDtypeStruct from jax._src.core import Primitive, is_constant_shape from jax._src.interpreters import ad +from jax._src.interpreters import batching from jax._src.lib.mlir.dialects import hlo from jax._src.lib import xla_client from jax._src.lib import ducc_fft diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index a1788fa29..7b4c5cdfb 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -26,12 +26,9 @@ import numpy as np import jax from jax import tree_util -from jax.interpreters import batching -from jax._src.interpreters import mlir from jax.interpreters import partial_eval as pe from jax.interpreters import pxla from jax.interpreters import xla -from jax.interpreters.batching import ConcatAxis from jax.tree_util import tree_map from jax._src import ad_util @@ -51,6 +48,9 @@ from jax._src.config import config from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray, raise_to_shaped, abstract_token, canonicalize_shape) from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.interpreters.batching import ConcatAxis from jax._src.lax import slicing from jax._src.lax.utils import ( _input_dtype, diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 9e7411e2d..6139d26d1 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -22,8 +22,6 @@ import numpy as np import jax from jax import lax -from jax.interpreters import batching -from jax._src.interpreters import mlir from jax.interpreters import xla from jax._src import ad_util @@ -32,6 +30,8 @@ from jax._src import dtypes from jax._src.core import ( Primitive, ShapedArray, raise_to_shaped, is_constant_shape) from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir from jax._src.lax import control_flow from jax._src.lax import eigh as lax_eigh from jax._src.lax import lax as lax_internal diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index bc663cee7..ab6f01d8a 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -24,7 +24,6 @@ import warnings import numpy as np from jax import tree_util -from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import pxla from jax.interpreters import xla @@ -34,6 +33,7 @@ from jax._src import dtypes from jax._src import util from jax._src.core import ShapedArray, AxisName, raise_to_shaped from jax._src.interpreters import ad +from jax._src.interpreters import batching from jax._src.lax import lax from jax._src.lax import slicing from jax._src.lib.mlir import ir diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index de892919e..91755d557 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -20,8 +20,6 @@ import weakref import numpy as np import jax -from jax.interpreters import batching -from jax._src.interpreters import mlir from jax.interpreters import partial_eval as pe from jax._src import ad_util @@ -29,6 +27,8 @@ from jax._src import core from jax._src import dtypes from jax._src import util from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir from jax._src.lax import lax from jax._src.lax.utils import ( _argnum_weak_type, diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 1dc3bc706..f7049d600 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -19,8 +19,6 @@ import warnings import numpy as np from jax import tree_util -from jax.interpreters import batching -from jax._src.interpreters import mlir from jax.interpreters import xla from jax._src import ad_util @@ -29,6 +27,8 @@ from jax._src import dtypes from jax._src import util from jax._src.core import ShapedArray, ConcreteArray from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir from jax._src.lax import lax from jax._src.lax import convolution from jax._src.lax import slicing diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index c01b2066a..d1300452d 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -28,8 +28,6 @@ from jax._src import core from jax import stages from jax.errors import JAXTypeError from jax.experimental.global_device_array import GlobalDeviceArray as GDA -from jax.interpreters import batching -from jax._src.interpreters import mlir from jax.interpreters import partial_eval as pe from jax.interpreters import xla from jax._src.interpreters.pxla import PartitionSpec @@ -52,6 +50,8 @@ from jax._src.api_util import ( argnames_partial_except, resolve_argnums, FLAGS) from jax._src.config import config from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 78e13ded3..188e1c4cc 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -25,9 +25,6 @@ from jax import lax from jax import numpy as jnp from jax.config import config from jax.dtypes import float0 -from jax.interpreters import batching -from jax._src.interpreters import pxla -from jax._src.interpreters import mlir from jax.interpreters import xla from jax._src import basearray @@ -37,6 +34,9 @@ from jax._src import dtypes from jax._src import pretty_printer as pp from jax._src.api import jit, vmap from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir +from jax._src.interpreters import pxla from jax._src.lax import lax as lax_internal from jax._src.lax import utils as lax_utils from jax._src.lib import gpu_prng diff --git a/jax/_src/random.py b/jax/_src/random.py index d8401854d..72167c7a9 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -24,7 +24,6 @@ import jax import jax.numpy as jnp from jax import lax from jax.config import config -from jax.interpreters import batching from jax.interpreters import mlir from jax.numpy.linalg import cholesky, svd, eigh @@ -34,6 +33,7 @@ from jax._src import prng from jax._src.api import jit, vmap from jax._src.core import NamedShape from jax._src.interpreters import ad +from jax._src.interpreters import batching from jax._src.lax import lax as lax_internal from jax._src.lib import xla_bridge from jax._src.numpy.lax_numpy import ( diff --git a/jax/experimental/sparse/api.py b/jax/experimental/sparse/api.py index 76f7fc4ba..96594e6b2 100644 --- a/jax/experimental/sparse/api.py +++ b/jax/experimental/sparse/api.py @@ -42,11 +42,11 @@ from jax.experimental.sparse.bcsr import BCSR from jax.experimental.sparse.coo import COO from jax.experimental.sparse.csr import CSR, CSC from jax.experimental.sparse.util import _coo_extract -from jax.interpreters import batching from jax.interpreters import mlir from jax._src import dtypes from jax._src.interpreters import ad +from jax._src.interpreters import batching from jax._src.typing import Array, DTypeLike, Shape diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 7b235fa04..c195245d7 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -35,13 +35,13 @@ from jax.experimental.sparse.util import ( _dot_general_validated_shape, CuSparseEfficiencyWarning, SparseEfficiencyError, SparseEfficiencyWarning, Shape, SparseInfo) -from jax.interpreters import batching from jax.interpreters import partial_eval as pe from jax._src.interpreters import mlir import jax.numpy as jnp from jax.util import safe_zip, unzip2, split_list from jax._src import api_util from jax._src.interpreters import ad +from jax._src.interpreters import batching from jax._src.lax.lax import ( _const, ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule, DotDimensionNumbers) diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 557b09ae3..621deadb6 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -23,6 +23,7 @@ from typing import NamedTuple, Optional, Sequence, Tuple, Union import numpy as np +import jax.numpy as jnp from jax import config from jax import core from jax import lax @@ -33,18 +34,19 @@ from jax.experimental.sparse.util import ( nfold_vmap, _count_stored_elements, _csr_to_coo, _dot_general_validated_shape, CuSparseEfficiencyWarning, SparseInfo, Shape) -import jax.numpy as jnp +from jax.util import split_list, safe_zip + from jax._src import api_util from jax._src.lax.lax import DotDimensionNumbers from jax._src.lib import gpu_sparse -from jax.util import split_list, safe_zip -from jax.interpreters import ad -from jax.interpreters import batching -from jax.interpreters import mlir from jax._src.lib.mlir.dialects import hlo +from jax._src.interpreters import ad +from jax._src.interpreters import batching +from jax._src.interpreters import mlir from jax._src.typing import Array, ArrayLike, DTypeLike + def bcsr_eliminate_zeros(mat: BCSR, nse: Optional[int] = None) -> BCSR: """Eliminate zeros in BCSR representation.""" return BCSR.from_bcoo(bcoo.bcoo_eliminate_zeros(mat.to_bcoo(), nse=nse))