mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Move jax.interpreters.pxla to jax._src.interpreters.pxla.
Make jax.interpreters.pxla a shim that at the moment re-exports everything in the implementation, with the goal of reducing it over time. PiperOrigin-RevId: 507584264
This commit is contained in:
parent
3d9ae6b467
commit
38a59a313b
@ -81,7 +81,7 @@ from jax.custom_transpose import custom_transpose
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax._src.interpreters import pxla
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
|
||||
|
3947
jax/_src/interpreters/pxla.py
Normal file
3947
jax/_src/interpreters/pxla.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -594,7 +594,7 @@ class Lowered(Stage):
|
||||
|
||||
def compile(self) -> Compiled:
|
||||
"""Compile, returning a corresponding ``Compiled`` instance."""
|
||||
from jax.interpreters import pxla
|
||||
from jax._src.interpreters import pxla
|
||||
|
||||
if (jax.config.jax_array and
|
||||
isinstance(self._lowering, pxla.MeshComputation)):
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user