Move jax.interpreters.mlir to jax._src.interpreters.mlir.

Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally.

PiperOrigin-RevId: 508187063
This commit is contained in:
Peter Hawkins 2023-02-08 14:38:22 -08:00 committed by jax authors
parent 3e349c7bed
commit cc8d7fae32
23 changed files with 1964 additions and 1897 deletions

View File

@ -24,7 +24,7 @@ from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
from jax.errors import UnexpectedTracerError
from jax.interpreters import partial_eval as pe
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters.batching import not_mapped
from jax.config import config

View File

@ -28,7 +28,7 @@ from jax import lax
from jax.config import config
from jax.experimental import pjit
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla

File diff suppressed because it is too large Load Diff

View File

@ -48,7 +48,7 @@ import numpy as np
import jax
from jax.errors import JAXTypeError
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten, tree_map

View File

@ -27,7 +27,7 @@ from jax.config import config
from jax.core import ConcreteArray, raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten

View File

@ -26,7 +26,7 @@ from jax.config import config
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
import jax._src.pretty_printer as pp

View File

@ -27,7 +27,7 @@ import numpy as np
import jax
from jax import tree_util
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax.interpreters import xla

View File

@ -23,7 +23,7 @@ import numpy as np
import jax
from jax import lax
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import xla
from jax._src import ad_util

View File

@ -21,7 +21,7 @@ import numpy as np
import jax
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax._src import ad_util

View File

@ -20,7 +20,7 @@ import numpy as np
from jax import tree_util
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import xla
from jax._src import ad_util

View File

@ -29,7 +29,7 @@ from jax import stages
from jax.errors import JAXTypeError
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.interpreters.pxla import PartitionSpec

View File

@ -26,7 +26,7 @@ from jax import numpy as jnp
from jax.config import config
from jax.dtypes import float0
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import pxla
from jax.interpreters import xla

View File

@ -35,7 +35,7 @@ import numpy.random as npr
import jax
from jax import lax
from jax.experimental.compilation_cache import compilation_cache
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten
from jax._src import api
from jax._src import pjit as pjit_lib

View File

@ -22,7 +22,7 @@ from jax.errors import UnexpectedTracerError
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir import ir
import jax.interpreters.pxla as pxla
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax._src import custom_api_util
from jax._src.lib import xla_client as xc

View File

@ -513,7 +513,7 @@ from jax import lax
from jax.experimental import pjit
from jax.interpreters import ad, batching, pxla
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src import ad_checkpoint
from jax._src import dispatch

View File

@ -37,7 +37,7 @@ from jax.errors import JAXTypeError
from jax._src.array import ArrayImpl
from jax.experimental.global_device_array import GlobalDeviceArray
from jax._src.sharding import NamedSharding
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax._src.interpreters import pxla
from jax.interpreters import xla

View File

@ -40,7 +40,7 @@ from jax._src.util import (prod, HashableFunction, unzip2, as_hashable_function,
memoize, partition_list, merge_lists)
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax._src.interpreters import pxla

View File

@ -37,7 +37,7 @@ from jax.experimental.sparse.util import (
SparseInfo)
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax._src.interpreters import mlir
import jax.numpy as jnp
from jax.util import safe_zip, unzip2, split_list
from jax._src import api_util

File diff suppressed because it is too large Load Diff

View File

@ -37,6 +37,7 @@ per-file-ignores =
jax/flatten_util.py:F401
jax/interpreters/ad.py:F401
jax/interpreters/batching.py:F401
jax/interpreters/mlir.py:F401
jax/interpreters/pxla.py:F401
jax/interpreters/xla.py:F401
jax/linear_util.py:F401

View File

@ -52,7 +52,7 @@ from jax._src import api, dtypes, dispatch, lib, api_util
from jax.core import Primitive
from jax.errors import UnexpectedTracerError
from jax.interpreters import ad
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import batching

View File

@ -27,7 +27,7 @@ from jax.interpreters import ad
from jax.experimental import maps
from jax.experimental import pjit
from jax._src import sharding
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax._src import ad_checkpoint
from jax._src import dispatch
from jax._src import test_util as jtu

View File

@ -36,7 +36,7 @@ from jax import tree_util
import jax.util
from jax.interpreters import xla
from jax.interpreters import mlir
from jax._src.interpreters import mlir
from jax.interpreters import batching
from jax.interpreters import pxla
from jax._src import array