From 219723c73817ffb884f4b87d7ed81be5a4354254 Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Mon, 6 Feb 2023 22:51:50 -0800 Subject: [PATCH] migrate internal dependencies from `jax.interpreters.ad` to `jax._src.interpreters.ad` ... in preparation for paring down `jax.interpreters.ad`'s exported symbols. Includes some import fixups along the way. PiperOrigin-RevId: 507684262 --- jax/_src/ad_checkpoint.py | 6 ++-- jax/_src/api.py | 15 ++++----- jax/_src/callback.py | 11 ++++--- jax/_src/checkify.py | 31 +++++++++--------- jax/_src/custom_batching.py | 4 +-- jax/_src/custom_derivatives.py | 22 +++++++------ jax/_src/custom_transpose.py | 4 +-- jax/_src/debugging.py | 12 ++++--- jax/_src/dispatch.py | 16 +++++----- jax/_src/interpreters/pxla.py | 13 ++++---- jax/_src/lax/ann.py | 8 +++-- jax/_src/lax/convolution.py | 7 +++-- jax/_src/lax/fft.py | 7 +++-- jax/_src/lax/lax.py | 49 +++++++++++++++-------------- jax/_src/lax/linalg.py | 28 ++++++++--------- jax/_src/lax/parallel.py | 16 +++++----- jax/_src/lax/slicing.py | 15 ++++----- jax/_src/lax/windowed_reductions.py | 5 ++- jax/_src/pjit.py | 43 ++++++++++++------------- jax/_src/prng.py | 10 +++--- jax/_src/random.py | 7 +++-- jax/experimental/jax2tf/jax2tf.py | 26 ++++++++------- jax/experimental/sparse/api.py | 3 +- jax/experimental/sparse/bcoo.py | 2 +- jax/experimental/sparse/coo.py | 2 +- jax/experimental/sparse/csr.py | 2 +- jax/interpreters/xla.py | 2 +- 27 files changed, 196 insertions(+), 170 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 588ea5c19..230fe01be 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -18,8 +18,6 @@ from typing import (Callable, Optional, List, Tuple, Sequence, Set, Union, Any, import types import jax -from jax._src import linear_util as lu -from jax.interpreters import ad from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import partial_eval as pe @@ -27,10 +25,12 @@ from jax.interpreters import xla from jax.tree_util import tree_flatten, tree_unflatten from jax._src import ad_util from jax._src import core -from jax._src import util +from jax._src import linear_util as lu from jax._src import source_info_util from jax._src import traceback_util +from jax._src import util from jax._src.api_util import flatten_fun, shaped_abstractify +from jax._src.interpreters import ad from jax._src.lax import lax as lax_internal from jax._src.lax import convolution as lax_convolution from jax._src.lib.mlir.dialects import hlo diff --git a/jax/_src/api.py b/jax/_src/api.py index d68928d7f..247d77d21 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -68,12 +68,8 @@ from jax._src.util import (unzip2, curry, safe_map, safe_zip, prod, split_list, extend_name_stack, new_name_stack, wrap_name, cache, wraps, HashableFunction, weakref_lru_cache) + # Unused imports to be exported -from jax._src.core import ShapedArray, raise_to_shaped -from jax._src.lib.xla_bridge import (device_count, local_device_count, devices, - local_devices, process_index, - process_count, host_id, host_ids, - host_count, default_backend) from jax.ad_checkpoint import checkpoint_policies, checkpoint as new_checkpoint from jax.custom_batching import custom_vmap from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp, @@ -82,8 +78,6 @@ from jax.custom_transpose import custom_transpose from jax.interpreters import partial_eval as pe from jax.interpreters import mlir from jax.interpreters import xla -from jax._src.interpreters import pxla -from jax.interpreters import ad from jax.interpreters import batching from jax._src.config import ( @@ -94,6 +88,13 @@ from jax._src.config import ( _thread_local_state as config_thread_local_state, explicit_device_put_scope as config_explicit_device_put_scope, explicit_device_get_scope as config_explicit_device_get_scope) +from jax._src.core import ShapedArray, raise_to_shaped +from jax._src.interpreters import ad +from jax._src.interpreters import pxla +from jax._src.lib.xla_bridge import (device_count, local_device_count, devices, + local_devices, process_index, + process_count, host_id, host_ids, + host_count, default_backend) traceback_util.register_exclusion(__file__) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 7f551045a..290a2950d 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -15,19 +15,20 @@ from __future__ import annotations import functools - from typing import Any, Callable, Sequence +import numpy as np + from jax import tree_util +from jax.interpreters import batching +from jax.interpreters import mlir + from jax._src import core from jax._src import dtypes from jax._src import util from jax._src import dispatch +from jax._src.interpreters import ad from jax._src.lib import xla_client as xc -from jax.interpreters import ad -from jax.interpreters import batching -from jax.interpreters import mlir -import numpy as np # `pure_callback_p` is the main primitive for staging out Python pure callbacks. pure_callback_p = core.Primitive("pure_callback") diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 4ed67669c..0717561aa 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -15,10 +15,25 @@ import dataclasses import functools import itertools as it -from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, FrozenSet, Iterable, Type, Set, List, Sequence, Any +from typing import (Union, Optional, Callable, Dict, Tuple, TypeVar, + FrozenSet, Iterable, Type, Set, List, Sequence, Any) + +import numpy as np import jax +import jax.numpy as jnp +import jax.tree_util as jtu from jax import lax +from jax.api_util import flatten_fun +from jax.experimental import maps +from jax.experimental import pjit +from jax.interpreters import batching +from jax.interpreters import mlir +from jax.interpreters import partial_eval as pe +from jax.tree_util import tree_flatten +from jax.tree_util import tree_map +from jax.tree_util import tree_unflatten + from jax._src import linear_util as lu from jax._src import core from jax._src import custom_derivatives @@ -26,24 +41,12 @@ from jax._src import prng from jax._src import source_info_util from jax._src import traceback_util from jax._src.config import config +from jax._src.interpreters import ad from jax._src.lax import control_flow as cf from jax._src.sharding import OpShardingSharding from jax._src.typing import Array from jax._src.util import (as_hashable_function, split_list, safe_map, safe_zip, unzip3, weakref_lru_cache) -from jax.api_util import flatten_fun -from jax.experimental import maps -from jax.experimental import pjit -from jax.interpreters import ad -from jax.interpreters import batching -from jax.interpreters import mlir -from jax.interpreters import partial_eval as pe -from jax.tree_util import tree_flatten -from jax.tree_util import tree_map -from jax.tree_util import tree_unflatten -import jax.numpy as jnp -import jax.tree_util as jtu -import numpy as np source_info_util.register_exclusion(__file__) traceback_util.register_exclusion(__file__) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 5a75cbffe..ed4810f68 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -17,9 +17,7 @@ import operator from typing import Callable, Optional import jax -from jax._src import linear_util as lu from jax import tree_util -from jax.interpreters import ad from jax.interpreters import batching from jax.interpreters.batching import not_mapped from jax.interpreters import mlir @@ -29,10 +27,12 @@ from jax.tree_util import (tree_flatten, tree_map, tree_structure, tree_unflatten, treedef_tuple) from jax._src import core from jax._src import custom_api_util +from jax._src import linear_util as lu from jax._src import source_info_util from jax._src import traceback_util from jax._src import util from jax._src.api_util import flatten_fun_nokwargs +from jax._src.interpreters import ad source_info_util.register_exclusion(__file__) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 5036e731c..99720c42a 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -17,29 +17,31 @@ import inspect from typing import (Callable, Generic, Optional, Sequence, Tuple, TypeVar, Set, Any) -from jax._src import linear_util as lu from jax.custom_transpose import custom_transpose from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple, register_pytree_node_class, tree_leaves) -from jax._src import core -from jax._src import custom_api_util -from jax._src import dtypes -from jax._src.lax import lax -from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable -from jax._src.api_util import argnums_partial, flatten_fun_nokwargs -from jax._src.core import raise_to_shaped from jax.errors import UnexpectedTracerError -from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p from jax.interpreters import partial_eval as pe -from jax._src.interpreters import ad from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import xla from jax.interpreters.batching import not_mapped from jax.config import config +from jax._src import core +from jax._src import custom_api_util +from jax._src import dtypes +from jax._src import linear_util as lu from jax._src import traceback_util +from jax._src.ad_util import Zero, zeros_like_aval, stop_gradient_p +from jax._src.api_util import argnums_partial, flatten_fun_nokwargs +from jax._src.core import raise_to_shaped +from jax._src.interpreters import ad +from jax._src.lax import lax +from jax._src.util import cache, safe_zip, safe_map, split_list, Unhashable + + traceback_util.register_exclusion(__file__) map = safe_map diff --git a/jax/_src/custom_transpose.py b/jax/_src/custom_transpose.py index 5f18dd5c4..c70034b55 100644 --- a/jax/_src/custom_transpose.py +++ b/jax/_src/custom_transpose.py @@ -15,8 +15,6 @@ import functools from typing import Any, Callable, Optional, Tuple -from jax._src import linear_util as lu -from jax.interpreters import ad from jax.interpreters import mlir from jax.interpreters import partial_eval as pe from jax.interpreters import xla @@ -26,9 +24,11 @@ from jax._src import ad_util from jax._src import api_util from jax._src import core from jax._src import custom_api_util +from jax._src import linear_util as lu from jax._src import source_info_util from jax._src import traceback_util from jax._src import util +from jax._src.interpreters import ad source_info_util.register_exclusion(__file__) diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 87cfc8860..f253ed9d5 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -12,36 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for JAX debugging primitives and related functionality.""" + import enum import functools import string import sys +from typing import Any, Dict, Callable, Optional, Sequence, Set, Tuple, Union import weakref -from typing import Any, Dict, Callable, Optional, Sequence, Set, Tuple, Union +import numpy as np +import jax.numpy as jnp from jax import tree_util from jax import lax -from jax._src import linear_util as lu from jax.config import config from jax.experimental import pjit -from jax.interpreters import ad from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import partial_eval as pe from jax.interpreters import pxla + from jax._src import ad_checkpoint from jax._src import core from jax._src import custom_derivatives +from jax._src import linear_util as lu from jax._src import util +from jax._src.interpreters import ad from jax._src.lax import control_flow as lcf from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.sharding import Sharding, OpShardingSharding -import jax.numpy as jnp -import numpy as np # pytype: disable=import-error try: import rich diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 64951b894..5018eeb66 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -32,34 +32,36 @@ import warnings import numpy as np import jax -from jax._src import linear_util as lu from jax.errors import UnexpectedTracerError from jax.monitoring import record_event_duration_secs -import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir import jax.interpreters.xla as xla from jax.interpreters import pxla import jax.interpreters.partial_eval as pe + from jax._src import array from jax._src import core from jax._src import device_array from jax._src import dtypes +from jax._src import linear_util as lu +from jax._src import path from jax._src import profiler from jax._src import stages from jax._src import traceback_util -from jax._src.sharding import (PmapSharding, SingleDeviceSharding, - OpShardingSharding, NamedSharding, PartitionSpec, - Sharding) +from jax._src import util from jax._src.abstract_arrays import array_types from jax._src.config import config, flags +from jax._src.interpreters import ad from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import use_stablehlo from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc -import jax._src.util as util +from jax._src.sharding import (PmapSharding, SingleDeviceSharding, + OpShardingSharding, NamedSharding, PartitionSpec, + Sharding) from jax._src.util import flatten, unflatten -from jax._src import path + JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration" BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration" diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 20a3eba42..39fe2a2a5 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -42,12 +42,11 @@ import sys import threading from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet, Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast) + import numpy as np import jax -from jax._src import linear_util as lu from jax.errors import JAXTypeError -from jax.interpreters import ad from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import partial_eval as pe @@ -59,17 +58,19 @@ from jax._src import api_util from jax._src import basearray from jax._src import core from jax._src import device_array -from jax._src import dtypes -from jax._src import source_info_util -from jax._src import util from jax._src import dispatch +from jax._src import dtypes +from jax._src import linear_util as lu from jax._src import profiler -from jax._src import stages from jax._src import sharding as sharding_internal +from jax._src import source_info_util +from jax._src import stages +from jax._src import util from jax._src.abstract_arrays import array_types from jax._src.config import config from jax._src.config import flags from jax._src.core import ConcreteArray, ShapedArray +from jax._src.interpreters import ad from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index b1db8c0da..0ce798f29 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -73,13 +73,17 @@ from functools import partial from typing import (Any, Tuple) import numpy as np -from jax._src import core + +from jax.interpreters import xla +from jax.interpreters import batching + from jax._src import ad_util +from jax._src import core from jax._src import dtypes +from jax._src.interpreters import ad from jax._src.lax import lax from jax._src.lib import xla_client as xc -from jax.interpreters import ad, xla, batching Array = Any diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 536c98a05..578c4bf04 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -19,15 +19,16 @@ from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union import numpy as np -from jax.interpreters import ad -from jax.interpreters import batching -from jax.interpreters import mlir from jax._src import core from jax._src import dtypes from jax._src import util +from jax._src.interpreters import ad from jax._src.lax import lax from jax._src.lib.mlir.dialects import hlo +from jax.interpreters import batching +from jax.interpreters import mlir + _max = builtins.max Array = Any diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 0e1248fad..cc8ac1774 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -18,13 +18,14 @@ from typing import Union, Sequence import numpy as np +from jax import lax +from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import xla -from jax import lax -from jax.interpreters import ad -from jax.interpreters import batching + from jax._src.api import jit, linear_transpose, ShapeDtypeStruct from jax._src.core import Primitive, is_constant_shape +from jax._src.interpreters import ad from jax._src.lib.mlir.dialects import hlo from jax._src.lib import xla_client from jax._src.lib import ducc_fft diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index ee54c0532..0fac2d83f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -25,40 +25,33 @@ import warnings import numpy as np import jax -from jax._src import core +from jax import tree_util +from jax.interpreters import batching +from jax.interpreters import mlir +from jax.interpreters import partial_eval as pe +from jax.interpreters import pxla +from jax.interpreters import xla +from jax.interpreters.batching import ConcatAxis +from jax.tree_util import tree_map + from jax._src import ad_util from jax._src import api from jax._src import api_util from jax._src import array +from jax._src import core from jax._src import device_array from jax._src import dispatch -from jax._src import linear_util as lu from jax._src import dtypes -from jax import tree_util +from jax._src import linear_util as lu +from jax._src import pretty_printer as pp from jax._src import source_info_util -from jax._src.sharding import PmapSharding +from jax._src import util +from jax._src.abstract_arrays import array_types from jax._src.config import config from jax._src.core import (Primitive, UnshapedArray, ShapedArray, ConcreteArray, raise_to_shaped, abstract_token, canonicalize_shape) -from jax._src.abstract_arrays import array_types -from jax.interpreters import partial_eval as pe -from jax.interpreters import mlir -from jax.interpreters import xla -from jax.interpreters import pxla -from jax.interpreters import ad -from jax.interpreters import batching -from jax.interpreters.batching import ConcatAxis -import jax._src.pretty_printer as pp -from jax._src import util -from jax._src.util import (cache, prod, safe_zip, safe_map, canonicalize_axis, - split_list) -from jax.tree_util import tree_map -from jax._src.lib import pytree -from jax._src.lib import xla_bridge -from jax._src.lib import xla_client -from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import chlo -from jax._src.lib.mlir.dialects import hlo +from jax._src.interpreters import ad +from jax._src.lax import slicing from jax._src.lax.utils import ( _input_dtype, standard_abstract_eval, @@ -67,8 +60,16 @@ from jax._src.lax.utils import ( standard_primitive, standard_translate, ) -from jax._src.lax import slicing +from jax._src.lib import pytree +from jax._src.lib import xla_bridge +from jax._src.lib import xla_client +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import chlo +from jax._src.lib.mlir.dialects import hlo +from jax._src.sharding import PmapSharding from jax._src.typing import Array, ArrayLike, DTypeLike, Shape +from jax._src.util import (cache, prod, safe_zip, safe_map, canonicalize_axis, + split_list) xb = xla_bridge xc = xla_client diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index b6165651a..c0f78af28 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -21,38 +21,36 @@ import warnings import numpy as np import jax -from jax._src.numpy import lax_numpy as jnp -from jax._src.numpy.vectorize import vectorize -from jax._src import ad_util -from jax._src import api from jax import lax -from jax._src import dtypes +from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import xla -from jax.interpreters import ad -from jax.interpreters import batching -from jax._src.util import prod + +from jax._src import ad_util +from jax._src import api +from jax._src import dtypes from jax._src.core import ( Primitive, ShapedArray, raise_to_shaped, is_constant_shape) -from jax._src.lax.lax import ( - standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex, - _input_dtype) +from jax._src.interpreters import ad from jax._src.lax import control_flow from jax._src.lax import eigh as lax_eigh from jax._src.lax import lax as lax_internal from jax._src.lax import svd as lax_svd -from jax._src.lib import lapack - +from jax._src.lax.lax import ( + standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex, + _input_dtype) from jax._src.lib import gpu_linalg from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse - +from jax._src.lib import lapack from jax._src.lib import xla_client - from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo +from jax._src.numpy import lax_numpy as jnp +from jax._src.numpy.vectorize import vectorize from jax._src.typing import Array, ArrayLike +from jax._src.util import prod xops = xla_client.ops diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index fedaa7cea..bc663cee7 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -24,21 +24,23 @@ import warnings import numpy as np from jax import tree_util -from jax.interpreters import ad -from jax.interpreters import mlir -from jax.interpreters import xla -from jax.interpreters import pxla from jax.interpreters import batching +from jax.interpreters import mlir +from jax.interpreters import pxla +from jax.interpreters import xla + from jax._src import core from jax._src import dtypes +from jax._src import util from jax._src.core import ShapedArray, AxisName, raise_to_shaped +from jax._src.interpreters import ad from jax._src.lax import lax from jax._src.lax import slicing -from jax._src.numpy import lax_numpy -import jax._src.util as util -from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo, use_stablehlo +from jax._src.numpy import lax_numpy +from jax._src.util import ( + unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis) unsafe_map, map = map, safe_map # type: ignore diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index fbfe931f1..1461ad9dd 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -20,24 +20,25 @@ import weakref import numpy as np import jax -from jax._src import ad_util -from jax._src import core -from jax._src import dtypes -from jax.interpreters import ad from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import partial_eval as pe + +from jax._src import ad_util +from jax._src import core +from jax._src import dtypes +from jax._src import util +from jax._src.interpreters import ad +from jax._src.lax import lax from jax._src.lax.utils import ( _argnum_weak_type, _input_dtype, standard_primitive, ) -from jax._src.lax import lax -from jax._src import util -from jax._src.util import safe_map, safe_zip from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.typing import Array, ArrayLike, Shape +from jax._src.util import safe_map, safe_zip map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 221ca51ca..37a1a9f6f 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -18,18 +18,17 @@ import warnings import numpy as np -from jax.interpreters import ad +from jax import tree_util from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import xla -from jax import tree_util - from jax._src import ad_util from jax._src import core from jax._src import dtypes from jax._src import util from jax._src.core import ShapedArray, ConcreteArray +from jax._src.interpreters import ad from jax._src.lax import lax from jax._src.lax import convolution from jax._src.lax import slicing diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 483d19582..cc6a51ca9 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -24,41 +24,42 @@ import threading import warnings import jax +from jax import core +from jax import stages +from jax.errors import JAXTypeError from jax.experimental.global_device_array import GlobalDeviceArray as GDA +from jax.interpreters import batching +from jax.interpreters import mlir +from jax.interpreters import partial_eval as pe +from jax.interpreters import pxla +from jax.interpreters import xla +from jax.interpreters.pxla import PartitionSpec +from jax.tree_util import ( + tree_map, tree_flatten, tree_unflatten, treedef_is_leaf, tree_structure, + treedef_tuple) + from jax._src.sharding import ( NamedSharding, Sharding, XLACompatibleSharding, OpShardingSharding, XLADeviceAssignment, SingleDeviceSharding, PmapSharding) -from jax import core -from jax._src import linear_util as lu -from jax import stages from jax._src import array -from jax._src.config import config from jax._src import dispatch +from jax._src import linear_util as lu from jax._src import source_info_util from jax._src import traceback_util -from jax._src.traceback_util import api_boundary -from jax._src.api_util import (argnums_partial_except, flatten_axes, - flatten_fun, flatten_fun_nokwargs, - donation_vector, shaped_abstractify, - check_callable, argnames_partial_except, - resolve_argnums, FLAGS) -from jax.errors import JAXTypeError -from jax.interpreters import ad -from jax.interpreters import mlir -from jax.interpreters import pxla -from jax.interpreters import xla -from jax.interpreters import batching -from jax.interpreters import partial_eval as pe -from jax.interpreters.pxla import PartitionSpec +from jax._src import util +from jax._src.api_util import ( + argnums_partial_except, flatten_axes, flatten_fun, flatten_fun_nokwargs, + donation_vector, shaped_abstractify, check_callable, + argnames_partial_except, resolve_argnums, FLAGS) +from jax._src.config import config +from jax._src.interpreters import ad from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version -from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, - treedef_is_leaf, tree_structure, treedef_tuple) +from jax._src.traceback_util import api_boundary from jax._src.tree_util import prefix_errors -from jax._src import util from jax._src.util import ( HashableFunction, safe_map, safe_zip, wraps, distributed_debug_log, split_list, tuple_insert, weakref_lru_cache, diff --git a/jax/_src/prng.py b/jax/_src/prng.py index ff1df29fc..b385193f2 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -25,26 +25,26 @@ from jax import lax from jax import numpy as jnp from jax.config import config from jax.dtypes import float0 -from jax.interpreters import ad from jax.interpreters import batching from jax.interpreters import mlir from jax.interpreters import pxla from jax.interpreters import xla -from jax._src import basearray -from jax._src.sharding import ( - NamedSharding, PmapSharding, OpShardingSharding) +from jax._src import basearray from jax._src import core from jax._src import dispatch from jax._src import dtypes from jax._src import pretty_printer as pp from jax._src.api import jit, vmap +from jax._src.interpreters import ad from jax._src.lax import lax as lax_internal from jax._src.lax import utils as lax_utils +from jax._src.lib import gpu_prng from jax._src.lib.mlir.dialects import hlo from jax._src.numpy import lax_numpy +from jax._src.sharding import ( + NamedSharding, PmapSharding, OpShardingSharding) from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip -from jax._src.lib import gpu_prng map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip diff --git a/jax/_src/random.py b/jax/_src/random.py index d57ba788c..d8401854d 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -24,18 +24,21 @@ import jax import jax.numpy as jnp from jax import lax from jax.config import config -from jax.interpreters import ad from jax.interpreters import batching from jax.interpreters import mlir from jax.numpy.linalg import cholesky, svd, eigh + from jax._src import core from jax._src import dtypes from jax._src import prng from jax._src.api import jit, vmap from jax._src.core import NamedShape +from jax._src.interpreters import ad from jax._src.lax import lax as lax_internal from jax._src.lib import xla_bridge -from jax._src.numpy.lax_numpy import _arraylike, _check_arraylike, _convert_and_clip_integer, _promote_dtypes_inexact +from jax._src.numpy.lax_numpy import ( + _arraylike, _check_arraylike, _convert_and_clip_integer, + _promote_dtypes_inexact) from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.util import prod, canonicalize_axis diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 969a32f1a..9c05d038b 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -18,21 +18,26 @@ import operator import os import re import threading -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union, cast +from typing import ( + Any, Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union, + cast) from absl import logging +import numpy as np import jax from jax import lax from jax import config -from jax import core, custom_derivatives -from jax._src import linear_util as lu -from jax import random, tree_util +from jax import core +from jax import custom_derivatives +from jax import random from jax import numpy as jnp +from jax import tree_util from jax.experimental import maps from jax.experimental import pjit -from jax._src import sharding -from jax.interpreters import ad +from jax.experimental.global_device_array import GlobalDeviceArray +from jax.experimental.jax2tf import shape_poly +from jax.experimental.jax2tf import impl_no_xla from jax.interpreters import mlir from jax.interpreters import pxla from jax.interpreters import xla @@ -43,10 +48,13 @@ from jax._src import api from jax._src import api_util from jax._src import dispatch from jax._src import dtypes +from jax._src import linear_util as lu from jax._src import prng from jax._src import random as random_internal +from jax._src import sharding from jax._src import source_info_util from jax._src import util +from jax._src.interpreters import ad from jax._src.lax import control_flow as lax_control_flow from jax._src.lax import lax as lax_internal from jax._src.lax import linalg as lax_linalg @@ -55,12 +63,6 @@ from jax._src.lax import windowed_reductions as lax_windowed_reductions from jax._src.lib import xla_client from jax._src.numpy.ufuncs import logaddexp -from jax.experimental.global_device_array import GlobalDeviceArray -from jax.experimental.jax2tf import shape_poly -from jax.experimental.jax2tf import impl_no_xla - - -import numpy as np import tensorflow as tf # type: ignore[import] # These don't have public equivalents. diff --git a/jax/experimental/sparse/api.py b/jax/experimental/sparse/api.py index 51ea16abe..76f7fc4ba 100644 --- a/jax/experimental/sparse/api.py +++ b/jax/experimental/sparse/api.py @@ -42,10 +42,11 @@ from jax.experimental.sparse.bcsr import BCSR from jax.experimental.sparse.coo import COO from jax.experimental.sparse.csr import CSR, CSC from jax.experimental.sparse.util import _coo_extract -from jax.interpreters import ad from jax.interpreters import batching from jax.interpreters import mlir + from jax._src import dtypes +from jax._src.interpreters import ad from jax._src.typing import Array, DTypeLike, Shape diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index edd02cb66..a8bef431d 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -39,9 +39,9 @@ from jax.interpreters import batching from jax.interpreters import partial_eval as pe from jax.interpreters import mlir import jax.numpy as jnp -from jax.interpreters import ad from jax.util import safe_zip, unzip2, split_list from jax._src import api_util +from jax._src.interpreters import ad from jax._src.lax.lax import ( _const, ranges_like, remaining, _dot_general_batch_dim_nums, _dot_general_shape_rule, DotDimensionNumbers) diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index 70d04d583..822bb29c3 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -24,11 +24,11 @@ import numpy as np from jax import core from jax import lax -from jax.interpreters import ad from jax.interpreters import mlir from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning from jax import tree_util +from jax._src.interpreters import ad from jax._src.lax.lax import _const from jax._src.lib.mlir.dialects import hlo from jax._src.lib import gpu_sparse diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py index 3200c62e8..c73fac469 100644 --- a/jax/experimental/sparse/csr.py +++ b/jax/experimental/sparse/csr.py @@ -22,13 +22,13 @@ import warnings import numpy as np from jax import core -from jax.interpreters import ad from jax.interpreters import mlir from jax.experimental.sparse._base import JAXSparse from jax.experimental.sparse.coo import _coo_matmat, _coo_matvec, _coo_todense, COOInfo from jax.experimental.sparse.util import _csr_to_coo, _csr_extract, CuSparseEfficiencyWarning from jax import lax from jax import tree_util +from jax._src.interpreters import ad from jax._src.lax.lax import _const from jax._src.lib import gpu_sparse from jax._src.numpy.lax_numpy import _promote_dtypes diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index abe69e59c..b68631ba6 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -28,7 +28,6 @@ import numpy as np from jax.config import config from jax.interpreters import partial_eval as pe -from jax.interpreters import ad from jax._src import core from jax._src import device_array @@ -37,6 +36,7 @@ from jax._src import pretty_printer as pp from jax._src import source_info_util from jax._src.abstract_arrays import numpy_scalar_types from jax._src.core import ConcreteArray, ShapedArray, str_eqn_compact +from jax._src.interpreters import ad from jax._src.util import (prod, new_name_stack, safe_zip, safe_map, partition_list)