mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Move jax.interpreters.mlir to jax._src.interpreters.mlir.
Replace jax.interpreters.mlir with a shim that re-exports names that are likely to be used externally. PiperOrigin-RevId: 508187063
This commit is contained in:
parent
3e349c7bed
commit
cc8d7fae32
@ -24,7 +24,7 @@ from jax.tree_util import (tree_flatten, tree_unflatten, tree_map,
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters.batching import not_mapped
|
||||
from jax.config import config
|
||||
|
@ -28,7 +28,7 @@ from jax import lax
|
||||
from jax.config import config
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import pxla
|
||||
|
||||
|
1891
jax/_src/interpreters/mlir.py
Normal file
1891
jax/_src/interpreters/mlir.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -48,7 +48,7 @@ import numpy as np
|
||||
import jax
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_flatten, tree_map
|
||||
|
||||
|
@ -27,7 +27,7 @@ from jax.config import config
|
||||
from jax.core import ConcreteArray, raise_to_shaped
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
|
@ -26,7 +26,7 @@ from jax.config import config
|
||||
from jax.core import ConcreteArray, ShapedArray, raise_to_shaped
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
import jax._src.pretty_printer as pp
|
||||
|
@ -27,7 +27,7 @@ import numpy as np
|
||||
import jax
|
||||
from jax import tree_util
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
|
@ -23,7 +23,7 @@ import numpy as np
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import ad_util
|
||||
|
@ -21,7 +21,7 @@ import numpy as np
|
||||
|
||||
import jax
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
|
||||
from jax._src import ad_util
|
||||
|
@ -20,7 +20,7 @@ import numpy as np
|
||||
|
||||
from jax import tree_util
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
from jax._src import ad_util
|
||||
|
@ -29,7 +29,7 @@ from jax import stages
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters.pxla import PartitionSpec
|
||||
|
@ -26,7 +26,7 @@ from jax import numpy as jnp
|
||||
from jax.config import config
|
||||
from jax.dtypes import float0
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
|
||||
|
@ -35,7 +35,7 @@ import numpy.random as npr
|
||||
import jax
|
||||
from jax import lax
|
||||
from jax.experimental.compilation_cache import compilation_cache
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten
|
||||
from jax._src import api
|
||||
from jax._src import pjit as pjit_lib
|
||||
|
@ -22,7 +22,7 @@ from jax.errors import UnexpectedTracerError
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir import ir
|
||||
import jax.interpreters.pxla as pxla
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src import custom_api_util
|
||||
from jax._src.lib import xla_client as xc
|
||||
|
@ -513,7 +513,7 @@ from jax import lax
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import ad, batching, pxla
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import dispatch
|
||||
|
@ -37,7 +37,7 @@ from jax.errors import JAXTypeError
|
||||
from jax._src.array import ArrayImpl
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax._src.sharding import NamedSharding
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
|
@ -40,7 +40,7 @@ from jax._src.util import (prod, HashableFunction, unzip2, as_hashable_function,
|
||||
memoize, partition_list, merge_lists)
|
||||
from jax.api_util import flatten_fun_nokwargs, shaped_abstractify
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import xla
|
||||
from jax._src.interpreters import pxla
|
||||
|
@ -37,7 +37,7 @@ from jax.experimental.sparse.util import (
|
||||
SparseInfo)
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
import jax.numpy as jnp
|
||||
from jax.util import safe_zip, unzip2, split_list
|
||||
from jax._src import api_util
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -37,6 +37,7 @@ per-file-ignores =
|
||||
jax/flatten_util.py:F401
|
||||
jax/interpreters/ad.py:F401
|
||||
jax/interpreters/batching.py:F401
|
||||
jax/interpreters/mlir.py:F401
|
||||
jax/interpreters/pxla.py:F401
|
||||
jax/interpreters/xla.py:F401
|
||||
jax/linear_util.py:F401
|
||||
|
@ -52,7 +52,7 @@ from jax._src import api, dtypes, dispatch, lib, api_util
|
||||
from jax.core import Primitive
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import batching
|
||||
|
@ -27,7 +27,7 @@ from jax.interpreters import ad
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax._src import sharding
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import dispatch
|
||||
from jax._src import test_util as jtu
|
||||
|
@ -36,7 +36,7 @@ from jax import tree_util
|
||||
import jax.util
|
||||
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import mlir
|
||||
from jax._src.interpreters import mlir
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import pxla
|
||||
from jax._src import array
|
||||
|
Loading…
x
Reference in New Issue
Block a user