Merge pull request #5038 from skye:sharded_jit_namespace

PiperOrigin-RevId: 346121390
This commit is contained in:
jax authors 2020-12-07 10:18:56 -08:00
commit 2a699d0b04
3 changed files with 13 additions and 2 deletions

View File

@ -50,6 +50,13 @@ pytype_library(
],
)
pytype_library(
name = "experimental",
srcs = ["experimental/__init__.py"],
srcs_version = "PY3",
deps = [":jax"],
)
pytype_library(
name = "stax",
srcs = ["experimental/stax.py"],

View File

@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa: F401
from ..interpreters.sharded_jit import (sharded_jit, PartitionSpec,
with_sharding_constraint)

View File

@ -28,9 +28,9 @@ from jax import jit, pmap, vjp
from jax import lax
from jax import test_util as jtu
from jax import tree_util
from jax.experimental import (sharded_jit, with_sharding_constraint,
PartitionSpec as P)
from jax.interpreters import pxla
from jax.interpreters.sharded_jit import sharded_jit, with_sharding_constraint
from jax.interpreters.sharded_jit import PartitionSpec as P
from jax.util import prod
import jax.numpy as jnp