Move jax.interpreters.pxla to jax._src.interpreters.pxla.

Make jax.interpreters.pxla a shim that at the moment re-exports everything in the implementation, with the goal of reducing it over time.

PiperOrigin-RevId: 507584264
This commit is contained in:
Peter Hawkins 2023-02-06 14:28:36 -08:00 committed by jax authors
parent 3d9ae6b467
commit 38a59a313b
5 changed files with 4203 additions and 3935 deletions

View File

@ -81,7 +81,7 @@ from jax.custom_transpose import custom_transpose
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters import pxla
from jax._src.interpreters import pxla
from jax.interpreters import ad
from jax.interpreters import batching

File diff suppressed because it is too large Load Diff

View File

@ -594,7 +594,7 @@ class Lowered(Stage):
def compile(self) -> Compiled:
"""Compile, returning a corresponding ``Compiled`` instance."""
from jax.interpreters import pxla
from jax._src.interpreters import pxla
if (jax.config.jax_array and
isinstance(self._lowering, pxla.MeshComputation)):

File diff suppressed because it is too large Load Diff

View File

@ -36,6 +36,7 @@ per-file-ignores =
jax/errors.py:F401
jax/flatten_util.py:F401
jax/interpreters/ad.py:F401
jax/interpreters/pxla.py:F401
jax/linear_util.py:F401
jax/prng.py:F401
jax/profiler.py:F401