mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Move array.py
and sharding.py
from experimental/
to _src/
.
PiperOrigin-RevId: 477201711
This commit is contained in:
parent
0e116888ea
commit
9e4114f0f1
@ -26,8 +26,8 @@ from jax._src.api_util import shaped_abstractify # technically not an api fn
|
||||
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax.interpreters import pxla
|
||||
from jax.experimental import array
|
||||
from jax.experimental import sharding
|
||||
from jax._src import array
|
||||
from jax._src import sharding
|
||||
from jax.experimental import pjit as pjit_lib
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
@ -22,7 +22,7 @@ from jax._src.util import prod
|
||||
from jax.interpreters.pxla import PartitionSpec as P
|
||||
from jax.interpreters import pxla
|
||||
from jax.experimental import global_device_array as gda
|
||||
from jax.experimental.sharding import MeshPspecSharding
|
||||
from jax._src.sharding import MeshPspecSharding
|
||||
import numpy as np
|
||||
|
||||
mesh_shapes_axes = [
|
||||
|
@ -99,8 +99,6 @@ py_library_providing_imports_info(
|
||||
"experimental/maps.py",
|
||||
"experimental/pjit.py",
|
||||
"experimental/global_device_array.py",
|
||||
"experimental/array.py",
|
||||
"experimental/sharding.py",
|
||||
"experimental/multihost_utils.py",
|
||||
# until checkify is moved out of experimental
|
||||
"experimental/checkify.py",
|
||||
|
@ -114,6 +114,12 @@ from jax._src.api import (
|
||||
xla, # TODO(phawkins): update users to avoid this.
|
||||
xla_computation as xla_computation,
|
||||
)
|
||||
|
||||
from jax._src.array import (
|
||||
make_array_from_single_device_arrays as make_array_from_single_device_arrays,
|
||||
make_array_from_callback as make_array_from_callback,
|
||||
)
|
||||
|
||||
from jax.version import __version__ as __version__
|
||||
from jax.version import __version_info__ as __version_info__
|
||||
|
||||
@ -143,6 +149,7 @@ from jax import numpy as numpy
|
||||
from jax import ops as ops
|
||||
from jax import profiler as profiler
|
||||
from jax import random as random
|
||||
from jax import sharding as sharding
|
||||
from jax import stages as stages
|
||||
from jax import tree_util as tree_util
|
||||
from jax import util as util
|
||||
|
@ -667,7 +667,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
|
||||
# all the other arguments stored as attributes.
|
||||
|
||||
def arg_spec(x):
|
||||
from jax.experimental.sharding import PmapSharding
|
||||
from jax._src.sharding import PmapSharding
|
||||
# like xla.arg_spec but duck-types on x.shape and x.dtype
|
||||
aval = None if jax.config.jax_dynamic_shapes else shaped_abstractify(x)
|
||||
if jax.config.jax_array:
|
||||
@ -2869,7 +2869,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
|
||||
buffers = [buf for x, d in zip(xs, devices)
|
||||
for buf in dispatch.device_put(x, d)]
|
||||
if config.jax_array:
|
||||
from jax.experimental import array, sharding
|
||||
from jax._src import array, sharding
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
|
||||
return array.ArrayImpl(
|
||||
stacked_aval,
|
||||
@ -2924,7 +2924,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
|
||||
buf, = dispatch.device_put(x, devices[0])
|
||||
rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
|
||||
if config.jax_array:
|
||||
from jax.experimental import array, sharding
|
||||
from jax._src import array, sharding
|
||||
sharding_spec = pxla._create_pmap_sharding_spec(aval)
|
||||
return array.ArrayImpl(
|
||||
aval, sharding.PmapSharding(np.array(devices), sharding_spec),
|
||||
|
@ -18,7 +18,6 @@ import operator as op
|
||||
import numpy as np
|
||||
from typing import Sequence, Tuple, Callable, Union, Optional, cast, List
|
||||
|
||||
import jax
|
||||
from jax import core
|
||||
from jax._src import abstract_arrays
|
||||
from jax._src import ad_util
|
||||
@ -33,7 +32,7 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.api import device_put
|
||||
from jax._src.typing import ArrayLike
|
||||
from jax.interpreters import pxla, xla, mlir
|
||||
from jax.experimental.sharding import (
|
||||
from jax._src.sharding import (
|
||||
Sharding, SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
|
||||
device_replica_id_map)
|
||||
|
@ -31,7 +31,7 @@ from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.experimental.sharding import OpShardingSharding
|
||||
from jax._src.sharding import OpShardingSharding
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
|
||||
from jax._src import source_info_util, traceback_util
|
||||
from jax._src.lax import control_flow as cf
|
||||
|
@ -25,8 +25,7 @@ from jax import tree_util
|
||||
from jax import lax
|
||||
from jax import linear_util as lu
|
||||
from jax.config import config
|
||||
from jax.experimental.sharding import Sharding
|
||||
from jax.experimental.sharding import OpShardingSharding
|
||||
from jax._src.sharding import Sharding, OpShardingSharding
|
||||
from jax.experimental import pjit
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import batching
|
||||
|
@ -90,7 +90,7 @@ _on_exit = False
|
||||
ArgSpec = Tuple[core.AbstractValue, Optional[Device]]
|
||||
|
||||
def arg_spec(x: Any) -> ArgSpec:
|
||||
from jax.experimental.sharding import PmapSharding
|
||||
from jax._src.sharding import PmapSharding
|
||||
|
||||
aval = xla.abstractify(x)
|
||||
try:
|
||||
@ -285,7 +285,7 @@ def not_none_device_or_backend_on_jit(backend, device, num_ins):
|
||||
# TODO(yashkatariya): Remove this entire function when backend and device are
|
||||
# removed as arguments on jit.
|
||||
|
||||
from jax.experimental import sharding
|
||||
from jax._src import sharding
|
||||
|
||||
if device is not None and backend is not None:
|
||||
raise ValueError("can't specify both a device and a backend for jit, "
|
||||
@ -311,8 +311,9 @@ def sharded_lowering(fun, device, backend, name, donated_invars, always_lower,
|
||||
keep_unused, *arg_specs):
|
||||
# TODO(yashkatariya): Remove the local imports from here when the functions
|
||||
# in pxla.py move to dispatch.py or a utils file.
|
||||
from jax._src import sharding
|
||||
from jax.interpreters import pxla
|
||||
from jax.experimental import pjit, sharding
|
||||
from jax.experimental import pjit
|
||||
|
||||
in_avals, in_shardings = util.unzip2(arg_specs)
|
||||
|
||||
@ -369,7 +370,7 @@ xla_callable = lu.cache(_xla_callable_uncached)
|
||||
|
||||
|
||||
def is_single_device_sharding(sharding) -> bool:
|
||||
from jax.experimental.sharding import PmapSharding
|
||||
from jax._src.sharding import PmapSharding
|
||||
# Special case PmapSharding here because PmapSharding maps away an axis
|
||||
# and needs to be handled separately.
|
||||
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
|
||||
@ -748,8 +749,8 @@ class SimpleResultHandler:
|
||||
|
||||
def maybe_create_array_from_da(buf, aval, device):
|
||||
if config.jax_array:
|
||||
from jax.experimental.array import ArrayImpl
|
||||
from jax.experimental.sharding import SingleDeviceSharding
|
||||
from jax._src.array import ArrayImpl
|
||||
from jax._src.sharding import SingleDeviceSharding
|
||||
return ArrayImpl(aval, SingleDeviceSharding(buf.device()), [buf],
|
||||
committed=(device is not None), _skip_checks=True)
|
||||
else:
|
||||
@ -1222,7 +1223,7 @@ def _copy_device_array_to_device(
|
||||
|
||||
def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Array:
|
||||
"""Copies `Array`s with SingleDeviceSharding to a different device."""
|
||||
from jax.experimental import array, sharding
|
||||
from jax._src import array, sharding
|
||||
|
||||
if device is None:
|
||||
# no copying to be done because there's no target specified
|
||||
@ -1250,7 +1251,7 @@ def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Arra
|
||||
|
||||
|
||||
def _device_put_impl(x, device: Optional[Device] = None):
|
||||
from jax.experimental import array, sharding
|
||||
from jax._src import array, sharding
|
||||
|
||||
if device_array.type_is_device_array(x):
|
||||
return _copy_device_array_to_device(x, device)
|
||||
|
@ -39,7 +39,7 @@ def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False)
|
||||
undefined behavior if the DLPack consumer writes to a buffer that JAX
|
||||
owns.
|
||||
"""
|
||||
from jax.experimental import array
|
||||
from jax._src import array
|
||||
if not isinstance(x, (device_array.DeviceArray, array.ArrayImpl)):
|
||||
raise TypeError("Argument to to_dlpack must be a DeviceArray or Array, got {}"
|
||||
.format(type(x)))
|
||||
|
@ -85,7 +85,7 @@ zip, unsafe_zip = safe_zip, zip
|
||||
# TODO(jakevdp): replace this with an isinstance() check when JEP 12049 is complete.
|
||||
def _is_array_or_tracer(operand: Any) -> bool:
|
||||
if config.jax_array:
|
||||
from jax.experimental import array # pylint: disable=g-import-not-at-top
|
||||
from jax._src import array # pylint: disable=g-import-not-at-top
|
||||
return isinstance(operand, (core.Tracer, array.ArrayImpl))
|
||||
else:
|
||||
return isinstance(operand, (core.Tracer, device_array.DeviceArray))
|
||||
@ -1323,8 +1323,8 @@ def full_like(x: ArrayLike, fill_value: ArrayLike, dtype: Optional[DTypeLike] =
|
||||
An ndarray with the same shape as `x` with its entries set equal to
|
||||
`fill_value`, similar to the output of np.full.
|
||||
"""
|
||||
from jax.experimental import array
|
||||
from jax.experimental.sharding import PmapSharding
|
||||
from jax._src import array
|
||||
from jax._src.sharding import PmapSharding
|
||||
|
||||
fill_shape = np.shape(x) if shape is None else canonicalize_shape(shape)
|
||||
weak_type = dtype is None and dtypes.is_weakly_typed(x)
|
||||
|
@ -78,7 +78,7 @@ from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.ops import scatter
|
||||
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
|
||||
canonicalize_axis as _canonicalize_axis)
|
||||
from jax.experimental.array import ArrayImpl
|
||||
from jax._src.array import ArrayImpl
|
||||
|
||||
newaxis = None
|
||||
|
||||
|
@ -31,7 +31,7 @@ from jax.interpreters import batching
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax.experimental.sharding import (
|
||||
from jax._src.sharding import (
|
||||
MeshPspecSharding, PmapSharding, OpShardingSharding)
|
||||
|
||||
from jax._src import dispatch
|
||||
|
@ -26,8 +26,8 @@ import jax
|
||||
from jax._src import distributed
|
||||
from jax._src.config import config
|
||||
from jax.experimental import global_device_array as gda
|
||||
from jax.experimental import array
|
||||
from jax.experimental import sharding
|
||||
from jax._src import array
|
||||
from jax._src import sharding
|
||||
from jax.experimental.maps import Mesh
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
@ -21,8 +21,8 @@ from jax._src import test_util as jtu
|
||||
from jax._src import util
|
||||
from jax._src import config as jax_config
|
||||
from jax.config import config
|
||||
from jax.experimental import array
|
||||
from jax.experimental.sharding import MeshPspecSharding
|
||||
from jax._src import array
|
||||
from jax._src.sharding import MeshPspecSharding
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.gda_serialization import serialization
|
||||
|
@ -30,7 +30,7 @@ from jax import random, tree_util
|
||||
from jax import numpy as jnp
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental import sharding
|
||||
from jax._src import sharding
|
||||
from jax.interpreters import ad
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval
|
||||
|
@ -39,9 +39,9 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src.config import config
|
||||
from jax.errors import JAXTypeError
|
||||
from jax.experimental.array import ArrayImpl
|
||||
from jax._src.array import ArrayImpl
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.sharding import MeshPspecSharding
|
||||
from jax._src.sharding import MeshPspecSharding
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters import pxla
|
||||
|
@ -23,7 +23,7 @@ import threading
|
||||
|
||||
from jax.experimental import maps
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
|
||||
from jax.experimental.sharding import (
|
||||
from jax._src.sharding import (
|
||||
MeshPspecSharding, Sharding, XLACompatibleSharding, OpShardingSharding,
|
||||
XLADeviceAssignment)
|
||||
from jax import core
|
||||
|
@ -85,7 +85,7 @@ from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
|
||||
split_dict, unzip2)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from jax.experimental.sharding import MeshPspecSharding, XLACompatibleSharding
|
||||
from jax._src.sharding import MeshPspecSharding, XLACompatibleSharding
|
||||
|
||||
# Built in Python lists don't support weak refs but subclasses of lists do.
|
||||
class WeakRefList(list):
|
||||
@ -873,7 +873,7 @@ def _hashable_index(idx):
|
||||
# The fast path is handled directly in shard_args().
|
||||
# TODO(skye): is there a simpler way to rewrite this using sharding_spec?
|
||||
def _shard_sharded_device_array_slow_path(x, devices, indices, mode):
|
||||
from jax.experimental.array import ArrayImpl
|
||||
from jax._src.array import ArrayImpl
|
||||
|
||||
candidates = defaultdict(list)
|
||||
if isinstance(x, ArrayImpl):
|
||||
@ -974,6 +974,7 @@ def _emap_impl(fun: lu.WrappedFun, *args,
|
||||
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
|
||||
donated_invars: Sequence[bool],
|
||||
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
|
||||
from jax._src import array
|
||||
# TODO(sharadmv,mattjj): implement these cases
|
||||
if any(d for d in donated_invars):
|
||||
raise NotImplementedError("Buffer donation not supported in eager pmap.")
|
||||
@ -1002,7 +1003,7 @@ def _emap_impl(fun: lu.WrappedFun, *args,
|
||||
for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals):
|
||||
with jax.disable_jit(False):
|
||||
donate_argnums_ = donate_argnums
|
||||
if isinstance(outval, (ShardedDeviceArray, jax.experimental.array.ArrayImpl)):
|
||||
if isinstance(outval, (ShardedDeviceArray, array.ArrayImpl)):
|
||||
# We don't want to donate if it's already sharded.
|
||||
donate_argnums_ = ()
|
||||
out = jax.pmap(
|
||||
@ -1636,7 +1637,7 @@ class PmapExecutable(stages.XlaExecutable):
|
||||
|
||||
|
||||
def _get_pmap_sharding(devices, specs):
|
||||
from jax.experimental.sharding import PmapSharding
|
||||
from jax._src.sharding import PmapSharding
|
||||
|
||||
return [PmapSharding(devices, spec) for spec in specs]
|
||||
|
||||
@ -1826,7 +1827,7 @@ class ResultsHandler:
|
||||
def _get_sharding_specs(
|
||||
shardings: Sequence[XLACompatibleSharding], avals: Sequence[ShapedArray]
|
||||
) -> Sequence[ShardingSpec]:
|
||||
from jax.experimental import sharding
|
||||
from jax._src import sharding
|
||||
|
||||
if all(isinstance(s, sharding.PmapSharding) for s in shardings):
|
||||
return [s.sharding_spec for s in shardings] # type: ignore
|
||||
@ -1855,7 +1856,7 @@ def global_avals_to_results_handler(
|
||||
shardings: Sequence[XLACompatibleSharding],
|
||||
committed: bool,
|
||||
are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler:
|
||||
from jax.experimental.sharding import MeshPspecSharding
|
||||
from jax._src.sharding import MeshPspecSharding
|
||||
|
||||
if config.jax_parallel_functions_output_gda or config.jax_array:
|
||||
handlers = [
|
||||
@ -3078,7 +3079,7 @@ def _get_input_metadata(
|
||||
in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool]
|
||||
) -> Tuple[Sequence[XLACompatibleSharding], Sequence[Tuple[Optional[Index], ...]],
|
||||
Sequence[ShapedArray]]:
|
||||
from jax.experimental.sharding import MeshPspecSharding
|
||||
from jax._src.sharding import MeshPspecSharding
|
||||
|
||||
shardings, input_indices, input_avals = [], [], []
|
||||
for gaval, i, is_global in safe_zip(global_in_avals, in_shardings, in_is_global):
|
||||
@ -3113,7 +3114,7 @@ def _get_input_metadata(
|
||||
def _get_op_sharding_shardings_from_executable(
|
||||
xla_executable, device_assignment, num_in_avals, num_out_avals):
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental.sharding import OpShardingSharding, SingleDeviceSharding
|
||||
from jax._src.sharding import OpShardingSharding, SingleDeviceSharding
|
||||
|
||||
# When the device assignment only has 1 device, SPMD partitioner will not run.
|
||||
# Hence the op shardings will not be set on the `hlo_module`. In that case,
|
||||
@ -3133,7 +3134,7 @@ def _get_op_sharding_shardings_from_executable(
|
||||
# without mesh.
|
||||
def _get_mesh_pspec_shardings_from_executable(xla_executable, mesh):
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental.sharding import MeshPspecSharding
|
||||
from jax._src.sharding import MeshPspecSharding
|
||||
|
||||
in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh)
|
||||
return ([MeshPspecSharding(mesh, i) for i in in_pspec],
|
||||
@ -3343,8 +3344,8 @@ def _out_shardings_for_trivial(
|
||||
# * if the output is a constant Array, get its .sharding attribute;
|
||||
# * otherwise, the output is a literal or numpy.ndarray constant, so give it
|
||||
# a replicated sharding
|
||||
from jax.experimental import array
|
||||
from jax.experimental import sharding
|
||||
from jax._src import array
|
||||
from jax._src import sharding
|
||||
rep = sharding.OpShardingSharding(
|
||||
device_assignment, sharding._get_replicated_op_sharding())
|
||||
shardings: Dict[core.Var, sharding.XLACompatibleSharding] = {}
|
||||
@ -3368,13 +3369,13 @@ def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args
|
||||
|
||||
@lru_cache()
|
||||
def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None):
|
||||
from jax.experimental.sharding import MeshPspecSharding
|
||||
from jax._src.sharding import MeshPspecSharding
|
||||
return MeshPspecSharding(mesh, pspec, parsed_pspec)
|
||||
|
||||
|
||||
def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
|
||||
from jax.experimental.global_device_array import GlobalDeviceArray
|
||||
from jax.experimental.array import ArrayImpl
|
||||
from jax._src.array import ArrayImpl
|
||||
|
||||
@lru_cache(maxsize=4096)
|
||||
def _cached_check(arg_sharding, in_xla_sharding, arg_type, ndim, committed):
|
||||
|
22
jax/sharding.py
Normal file
22
jax/sharding.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.sharding import (
|
||||
Sharding as Sharding,
|
||||
XLACompatibleSharding as XLACompatibleSharding,
|
||||
MeshPspecSharding as MeshPspecSharding,
|
||||
SingleDeviceSharding as SingleDeviceSharding,
|
||||
PmapSharding as PmapSharding,
|
||||
OpShardingSharding as OpShardingSharding,
|
||||
)
|
@ -54,7 +54,7 @@ from jax.interpreters import xla
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.interpreters.pxla import PartitionSpec as P
|
||||
from jax.experimental import array, sharding
|
||||
from jax._src import array, sharding
|
||||
from jax.experimental import pjit
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import custom_derivatives
|
||||
|
@ -26,8 +26,8 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.util import prod, safe_zip
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental import sharding
|
||||
from jax.experimental import array
|
||||
from jax._src import sharding
|
||||
from jax._src import array
|
||||
from jax.experimental import maps
|
||||
|
||||
from jax.config import config
|
||||
|
@ -27,8 +27,8 @@ from jax.config import config
|
||||
from jax.experimental import checkify
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental import maps
|
||||
from jax.experimental.sharding import MeshPspecSharding
|
||||
from jax.experimental import array
|
||||
from jax._src.sharding import MeshPspecSharding
|
||||
from jax._src import array
|
||||
from jax._src.checkify import CheckEffect
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
@ -23,7 +23,7 @@ from jax import lax
|
||||
from jax.config import config
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental import sharding
|
||||
from jax._src import sharding
|
||||
from jax.interpreters import pxla
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import debugging
|
||||
|
@ -27,7 +27,7 @@ from jax.config import config
|
||||
from jax.interpreters import ad
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import pjit
|
||||
from jax.experimental import sharding
|
||||
from jax._src import sharding
|
||||
from jax.interpreters import mlir
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import lib as jaxlib
|
||||
|
@ -49,7 +49,7 @@ from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact
|
||||
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
|
||||
from jax._src.util import prod, safe_zip
|
||||
from jax.experimental import array
|
||||
from jax._src import array
|
||||
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
@ -39,7 +39,7 @@ from jax.interpreters import xla
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import batching
|
||||
from jax.interpreters import pxla
|
||||
from jax.experimental import array
|
||||
from jax._src import array
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
|
@ -40,8 +40,8 @@ from jax.experimental import maps
|
||||
from jax.experimental import PartitionSpec as P
|
||||
from jax.experimental.maps import xmap
|
||||
from jax.experimental import global_device_array
|
||||
from jax.experimental import array
|
||||
from jax.experimental.sharding import MeshPspecSharding, Sharding, OpShardingSharding
|
||||
from jax._src import array
|
||||
from jax._src.sharding import MeshPspecSharding, Sharding, OpShardingSharding
|
||||
import jax.experimental.pjit as pjit_lib
|
||||
from jax.experimental.pjit import (pjit, pjit_p, with_sharding_constraint,
|
||||
FROM_GDA, AUTO)
|
||||
|
@ -46,8 +46,8 @@ from jax._src.lib import xla_bridge
|
||||
from jax._src.util import prod, safe_map, safe_zip
|
||||
from jax.interpreters import pxla
|
||||
from jax.interpreters import xla
|
||||
from jax.experimental import array
|
||||
from jax.experimental.sharding import PmapSharding
|
||||
from jax._src import array
|
||||
from jax._src.sharding import PmapSharding
|
||||
from jax.ad_checkpoint import checkpoint as new_checkpoint
|
||||
|
||||
from jax.config import config
|
||||
|
@ -27,7 +27,7 @@ from jax._src import typing
|
||||
from jax import lax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax.experimental.array import ArrayImpl
|
||||
from jax._src.array import ArrayImpl
|
||||
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
|
@ -35,8 +35,8 @@ from jax import core
|
||||
from jax.core import NamedShape
|
||||
from jax.experimental import maps
|
||||
from jax.experimental import global_device_array
|
||||
from jax.experimental import array
|
||||
from jax.experimental.sharding import MeshPspecSharding
|
||||
from jax._src import array
|
||||
from jax._src.sharding import MeshPspecSharding
|
||||
from jax.experimental.pjit import pjit, with_sharding_constraint
|
||||
from jax.experimental.pjit import PartitionSpec as P
|
||||
from jax.experimental.maps import xmap, serial_loop, SerialLoop
|
||||
|
Loading…
x
Reference in New Issue
Block a user