mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #5038 from skye:sharded_jit_namespace
PiperOrigin-RevId: 346121390
This commit is contained in:
commit
2a699d0b04
@ -50,6 +50,13 @@ pytype_library(
|
||||
],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "experimental",
|
||||
srcs = ["experimental/__init__.py"],
|
||||
srcs_version = "PY3",
|
||||
deps = [":jax"],
|
||||
)
|
||||
|
||||
pytype_library(
|
||||
name = "stax",
|
||||
srcs = ["experimental/stax.py"],
|
||||
|
@ -11,3 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# flake8: noqa: F401
|
||||
from ..interpreters.sharded_jit import (sharded_jit, PartitionSpec,
|
||||
with_sharding_constraint)
|
||||
|
@ -28,9 +28,9 @@ from jax import jit, pmap, vjp
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax import tree_util
|
||||
from jax.experimental import (sharded_jit, with_sharding_constraint,
|
||||
PartitionSpec as P)
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters.sharded_jit import sharded_jit, with_sharding_constraint
|
||||
from jax.interpreters.sharded_jit import PartitionSpec as P
|
||||
from jax.util import prod
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user