Move array.py and sharding.py from experimental/ to _src/.

PiperOrigin-RevId: 477201711
This commit is contained in:
Yash Katariya 2022-09-27 10:06:10 -07:00 committed by jax authors
parent 0e116888ea
commit 9e4114f0f1
32 changed files with 91 additions and 64 deletions

View File

@ -26,8 +26,8 @@ from jax._src.api_util import shaped_abstractify # technically not an api fn
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
from jax._src.lib import xla_client as xc
from jax.interpreters import pxla
from jax.experimental import array
from jax.experimental import sharding
from jax._src import array
from jax._src import sharding
from jax.experimental import pjit as pjit_lib
import jax.numpy as jnp
import numpy as np

View File

@ -22,7 +22,7 @@ from jax._src.util import prod
from jax.interpreters.pxla import PartitionSpec as P
from jax.interpreters import pxla
from jax.experimental import global_device_array as gda
from jax.experimental.sharding import MeshPspecSharding
from jax._src.sharding import MeshPspecSharding
import numpy as np
mesh_shapes_axes = [

View File

@ -99,8 +99,6 @@ py_library_providing_imports_info(
"experimental/maps.py",
"experimental/pjit.py",
"experimental/global_device_array.py",
"experimental/array.py",
"experimental/sharding.py",
"experimental/multihost_utils.py",
# until checkify is moved out of experimental
"experimental/checkify.py",

View File

@ -114,6 +114,12 @@ from jax._src.api import (
xla, # TODO(phawkins): update users to avoid this.
xla_computation as xla_computation,
)
from jax._src.array import (
make_array_from_single_device_arrays as make_array_from_single_device_arrays,
make_array_from_callback as make_array_from_callback,
)
from jax.version import __version__ as __version__
from jax.version import __version_info__ as __version_info__
@ -143,6 +149,7 @@ from jax import numpy as numpy
from jax import ops as ops
from jax import profiler as profiler
from jax import random as random
from jax import sharding as sharding
from jax import stages as stages
from jax import tree_util as tree_util
from jax import util as util

View File

@ -667,7 +667,7 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
# all the other arguments stored as attributes.
def arg_spec(x):
from jax.experimental.sharding import PmapSharding
from jax._src.sharding import PmapSharding
# like xla.arg_spec but duck-types on x.shape and x.dtype
aval = None if jax.config.jax_dynamic_shapes else shaped_abstractify(x)
if jax.config.jax_array:
@ -2869,7 +2869,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
buffers = [buf for x, d in zip(xs, devices)
for buf in dispatch.device_put(x, d)]
if config.jax_array:
from jax.experimental import array, sharding
from jax._src import array, sharding
sharding_spec = pxla._create_pmap_sharding_spec(stacked_aval)
return array.ArrayImpl(
stacked_aval,
@ -2924,7 +2924,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
buf, = dispatch.device_put(x, devices[0])
rest_bufs = [buf.copy_to_device(d) for d in devices[1:]]
if config.jax_array:
from jax.experimental import array, sharding
from jax._src import array, sharding
sharding_spec = pxla._create_pmap_sharding_spec(aval)
return array.ArrayImpl(
aval, sharding.PmapSharding(np.array(devices), sharding_spec),

View File

@ -18,7 +18,6 @@ import operator as op
import numpy as np
from typing import Sequence, Tuple, Callable, Union, Optional, cast, List
import jax
from jax import core
from jax._src import abstract_arrays
from jax._src import ad_util
@ -33,7 +32,7 @@ from jax._src.lib import xla_client as xc
from jax._src.api import device_put
from jax._src.typing import ArrayLike
from jax.interpreters import pxla, xla, mlir
from jax.experimental.sharding import (
from jax._src.sharding import (
Sharding, SingleDeviceSharding, XLACompatibleSharding, PmapSharding,
device_replica_id_map)

View File

@ -31,7 +31,7 @@ from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.experimental.sharding import OpShardingSharding
from jax._src.sharding import OpShardingSharding
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
from jax._src import source_info_util, traceback_util
from jax._src.lax import control_flow as cf

View File

@ -25,8 +25,7 @@ from jax import tree_util
from jax import lax
from jax import linear_util as lu
from jax.config import config
from jax.experimental.sharding import Sharding
from jax.experimental.sharding import OpShardingSharding
from jax._src.sharding import Sharding, OpShardingSharding
from jax.experimental import pjit
from jax.interpreters import ad
from jax.interpreters import batching

View File

@ -90,7 +90,7 @@ _on_exit = False
ArgSpec = Tuple[core.AbstractValue, Optional[Device]]
def arg_spec(x: Any) -> ArgSpec:
from jax.experimental.sharding import PmapSharding
from jax._src.sharding import PmapSharding
aval = xla.abstractify(x)
try:
@ -285,7 +285,7 @@ def not_none_device_or_backend_on_jit(backend, device, num_ins):
# TODO(yashkatariya): Remove this entire function when backend and device are
# removed as arguments on jit.
from jax.experimental import sharding
from jax._src import sharding
if device is not None and backend is not None:
raise ValueError("can't specify both a device and a backend for jit, "
@ -311,8 +311,9 @@ def sharded_lowering(fun, device, backend, name, donated_invars, always_lower,
keep_unused, *arg_specs):
# TODO(yashkatariya): Remove the local imports from here when the functions
# in pxla.py move to dispatch.py or a utils file.
from jax._src import sharding
from jax.interpreters import pxla
from jax.experimental import pjit, sharding
from jax.experimental import pjit
in_avals, in_shardings = util.unzip2(arg_specs)
@ -369,7 +370,7 @@ xla_callable = lu.cache(_xla_callable_uncached)
def is_single_device_sharding(sharding) -> bool:
from jax.experimental.sharding import PmapSharding
from jax._src.sharding import PmapSharding
# Special case PmapSharding here because PmapSharding maps away an axis
# and needs to be handled separately.
return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding)
@ -748,8 +749,8 @@ class SimpleResultHandler:
def maybe_create_array_from_da(buf, aval, device):
if config.jax_array:
from jax.experimental.array import ArrayImpl
from jax.experimental.sharding import SingleDeviceSharding
from jax._src.array import ArrayImpl
from jax._src.sharding import SingleDeviceSharding
return ArrayImpl(aval, SingleDeviceSharding(buf.device()), [buf],
committed=(device is not None), _skip_checks=True)
else:
@ -1222,7 +1223,7 @@ def _copy_device_array_to_device(
def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Array:
"""Copies `Array`s with SingleDeviceSharding to a different device."""
from jax.experimental import array, sharding
from jax._src import array, sharding
if device is None:
# no copying to be done because there's no target specified
@ -1250,7 +1251,7 @@ def _copy_array_to_device(x: jax.Array, device: Optional[xc.Device]) -> jax.Arra
def _device_put_impl(x, device: Optional[Device] = None):
from jax.experimental import array, sharding
from jax._src import array, sharding
if device_array.type_is_device_array(x):
return _copy_device_array_to_device(x, device)

View File

@ -39,7 +39,7 @@ def to_dlpack(x: device_array.DeviceArrayProtocol, take_ownership: bool = False)
undefined behavior if the DLPack consumer writes to a buffer that JAX
owns.
"""
from jax.experimental import array
from jax._src import array
if not isinstance(x, (device_array.DeviceArray, array.ArrayImpl)):
raise TypeError("Argument to to_dlpack must be a DeviceArray or Array, got {}"
.format(type(x)))

View File

@ -85,7 +85,7 @@ zip, unsafe_zip = safe_zip, zip
# TODO(jakevdp): replace this with an isinstance() check when JEP 12049 is complete.
def _is_array_or_tracer(operand: Any) -> bool:
if config.jax_array:
from jax.experimental import array # pylint: disable=g-import-not-at-top
from jax._src import array # pylint: disable=g-import-not-at-top
return isinstance(operand, (core.Tracer, array.ArrayImpl))
else:
return isinstance(operand, (core.Tracer, device_array.DeviceArray))
@ -1323,8 +1323,8 @@ def full_like(x: ArrayLike, fill_value: ArrayLike, dtype: Optional[DTypeLike] =
An ndarray with the same shape as `x` with its entries set equal to
`fill_value`, similar to the output of np.full.
"""
from jax.experimental import array
from jax.experimental.sharding import PmapSharding
from jax._src import array
from jax._src.sharding import PmapSharding
fill_shape = np.shape(x) if shape is None else canonicalize_shape(shape)
weak_type = dtype is None and dtypes.is_weakly_typed(x)

View File

@ -78,7 +78,7 @@ from jax._src.numpy.vectorize import vectorize
from jax._src.ops import scatter
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
canonicalize_axis as _canonicalize_axis)
from jax.experimental.array import ArrayImpl
from jax._src.array import ArrayImpl
newaxis = None

View File

@ -31,7 +31,7 @@ from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import pxla
from jax.interpreters import xla
from jax.experimental.sharding import (
from jax._src.sharding import (
MeshPspecSharding, PmapSharding, OpShardingSharding)
from jax._src import dispatch

View File

@ -26,8 +26,8 @@ import jax
from jax._src import distributed
from jax._src.config import config
from jax.experimental import global_device_array as gda
from jax.experimental import array
from jax.experimental import sharding
from jax._src import array
from jax._src import sharding
from jax.experimental.maps import Mesh
import jax.numpy as jnp
import numpy as np

View File

@ -21,8 +21,8 @@ from jax._src import test_util as jtu
from jax._src import util
from jax._src import config as jax_config
from jax.config import config
from jax.experimental import array
from jax.experimental.sharding import MeshPspecSharding
from jax._src import array
from jax._src.sharding import MeshPspecSharding
from jax.experimental import PartitionSpec as P
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.gda_serialization import serialization

View File

@ -30,7 +30,7 @@ from jax import random, tree_util
from jax import numpy as jnp
from jax.experimental import maps
from jax.experimental import pjit
from jax.experimental import sharding
from jax._src import sharding
from jax.interpreters import ad
from jax.interpreters import mlir
from jax.interpreters import partial_eval

View File

@ -39,9 +39,9 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.config import config
from jax.errors import JAXTypeError
from jax.experimental.array import ArrayImpl
from jax._src.array import ArrayImpl
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.sharding import MeshPspecSharding
from jax._src.sharding import MeshPspecSharding
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla

View File

@ -23,7 +23,7 @@ import threading
from jax.experimental import maps
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
from jax.experimental.sharding import (
from jax._src.sharding import (
MeshPspecSharding, Sharding, XLACompatibleSharding, OpShardingSharding,
XLADeviceAssignment)
from jax import core

View File

@ -85,7 +85,7 @@ from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
split_dict, unzip2)
if TYPE_CHECKING:
from jax.experimental.sharding import MeshPspecSharding, XLACompatibleSharding
from jax._src.sharding import MeshPspecSharding, XLACompatibleSharding
# Built in Python lists don't support weak refs but subclasses of lists do.
class WeakRefList(list):
@ -873,7 +873,7 @@ def _hashable_index(idx):
# The fast path is handled directly in shard_args().
# TODO(skye): is there a simpler way to rewrite this using sharding_spec?
def _shard_sharded_device_array_slow_path(x, devices, indices, mode):
from jax.experimental.array import ArrayImpl
from jax._src.array import ArrayImpl
candidates = defaultdict(list)
if isinstance(x, ArrayImpl):
@ -974,6 +974,7 @@ def _emap_impl(fun: lu.WrappedFun, *args,
out_axes_thunk: Callable[[], Sequence[Optional[int]]],
donated_invars: Sequence[bool],
global_arg_shapes: Sequence[Optional[Tuple[int, ...]]]):
from jax._src import array
# TODO(sharadmv,mattjj): implement these cases
if any(d for d in donated_invars):
raise NotImplementedError("Buffer donation not supported in eager pmap.")
@ -1002,7 +1003,7 @@ def _emap_impl(fun: lu.WrappedFun, *args,
for out_axis_src, out_axis, outval in zip(out_axes_src, out_axes, outvals):
with jax.disable_jit(False):
donate_argnums_ = donate_argnums
if isinstance(outval, (ShardedDeviceArray, jax.experimental.array.ArrayImpl)):
if isinstance(outval, (ShardedDeviceArray, array.ArrayImpl)):
# We don't want to donate if it's already sharded.
donate_argnums_ = ()
out = jax.pmap(
@ -1636,7 +1637,7 @@ class PmapExecutable(stages.XlaExecutable):
def _get_pmap_sharding(devices, specs):
from jax.experimental.sharding import PmapSharding
from jax._src.sharding import PmapSharding
return [PmapSharding(devices, spec) for spec in specs]
@ -1826,7 +1827,7 @@ class ResultsHandler:
def _get_sharding_specs(
shardings: Sequence[XLACompatibleSharding], avals: Sequence[ShapedArray]
) -> Sequence[ShardingSpec]:
from jax.experimental import sharding
from jax._src import sharding
if all(isinstance(s, sharding.PmapSharding) for s in shardings):
return [s.sharding_spec for s in shardings] # type: ignore
@ -1855,7 +1856,7 @@ def global_avals_to_results_handler(
shardings: Sequence[XLACompatibleSharding],
committed: bool,
are_out_shardings_from_xla: Sequence[bool]) -> ResultsHandler:
from jax.experimental.sharding import MeshPspecSharding
from jax._src.sharding import MeshPspecSharding
if config.jax_parallel_functions_output_gda or config.jax_array:
handlers = [
@ -3078,7 +3079,7 @@ def _get_input_metadata(
in_shardings: Sequence[XLACompatibleSharding], in_is_global: Sequence[bool]
) -> Tuple[Sequence[XLACompatibleSharding], Sequence[Tuple[Optional[Index], ...]],
Sequence[ShapedArray]]:
from jax.experimental.sharding import MeshPspecSharding
from jax._src.sharding import MeshPspecSharding
shardings, input_indices, input_avals = [], [], []
for gaval, i, is_global in safe_zip(global_in_avals, in_shardings, in_is_global):
@ -3113,7 +3114,7 @@ def _get_input_metadata(
def _get_op_sharding_shardings_from_executable(
xla_executable, device_assignment, num_in_avals, num_out_avals):
from jax.experimental import pjit
from jax.experimental.sharding import OpShardingSharding, SingleDeviceSharding
from jax._src.sharding import OpShardingSharding, SingleDeviceSharding
# When the device assignment only has 1 device, SPMD partitioner will not run.
# Hence the op shardings will not be set on the `hlo_module`. In that case,
@ -3133,7 +3134,7 @@ def _get_op_sharding_shardings_from_executable(
# without mesh.
def _get_mesh_pspec_shardings_from_executable(xla_executable, mesh):
from jax.experimental import pjit
from jax.experimental.sharding import MeshPspecSharding
from jax._src.sharding import MeshPspecSharding
in_pspec, out_pspec = pjit._get_pspec_from_executable(xla_executable, mesh)
return ([MeshPspecSharding(mesh, i) for i in in_pspec],
@ -3343,8 +3344,8 @@ def _out_shardings_for_trivial(
# * if the output is a constant Array, get its .sharding attribute;
# * otherwise, the output is a literal or numpy.ndarray constant, so give it
# a replicated sharding
from jax.experimental import array
from jax.experimental import sharding
from jax._src import array
from jax._src import sharding
rep = sharding.OpShardingSharding(
device_assignment, sharding._get_replicated_op_sharding())
shardings: Dict[core.Var, sharding.XLACompatibleSharding] = {}
@ -3368,13 +3369,13 @@ def _execute_trivial(jaxpr, consts, in_handler, out_handler, kept_var_idx, *args
@lru_cache()
def _create_mesh_pspec_sharding(mesh, pspec, parsed_pspec=None):
from jax.experimental.sharding import MeshPspecSharding
from jax._src.sharding import MeshPspecSharding
return MeshPspecSharding(mesh, pspec, parsed_pspec)
def _check_gda_or_array_xla_sharding_match(args, in_xla_shardings):
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.array import ArrayImpl
from jax._src.array import ArrayImpl
@lru_cache(maxsize=4096)
def _cached_check(arg_sharding, in_xla_sharding, arg_type, ndim, committed):

22
jax/sharding.py Normal file
View File

@ -0,0 +1,22 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.sharding import (
Sharding as Sharding,
XLACompatibleSharding as XLACompatibleSharding,
MeshPspecSharding as MeshPspecSharding,
SingleDeviceSharding as SingleDeviceSharding,
PmapSharding as PmapSharding,
OpShardingSharding as OpShardingSharding,
)

View File

@ -54,7 +54,7 @@ from jax.interpreters import xla
from jax.interpreters import pxla
from jax.interpreters import partial_eval as pe
from jax.interpreters.pxla import PartitionSpec as P
from jax.experimental import array, sharding
from jax._src import array, sharding
from jax.experimental import pjit
from jax._src import config as jax_config
from jax._src import custom_derivatives

View File

@ -26,8 +26,8 @@ from jax._src.lib import xla_client as xc
from jax._src.lib import xla_bridge as xb
from jax._src.util import prod, safe_zip
from jax.experimental import PartitionSpec as P
from jax.experimental import sharding
from jax.experimental import array
from jax._src import sharding
from jax._src import array
from jax.experimental import maps
from jax.config import config

View File

@ -27,8 +27,8 @@ from jax.config import config
from jax.experimental import checkify
from jax.experimental import pjit
from jax.experimental import maps
from jax.experimental.sharding import MeshPspecSharding
from jax.experimental import array
from jax._src.sharding import MeshPspecSharding
from jax._src import array
from jax._src.checkify import CheckEffect
import jax.numpy as jnp

View File

@ -23,7 +23,7 @@ from jax import lax
from jax.config import config
from jax.experimental import maps
from jax.experimental import pjit
from jax.experimental import sharding
from jax._src import sharding
from jax.interpreters import pxla
from jax._src import ad_checkpoint
from jax._src import debugging

View File

@ -27,7 +27,7 @@ from jax.config import config
from jax.interpreters import ad
from jax.experimental import maps
from jax.experimental import pjit
from jax.experimental import sharding
from jax._src import sharding
from jax.interpreters import mlir
from jax._src import ad_checkpoint
from jax._src import lib as jaxlib

View File

@ -49,7 +49,7 @@ from jax._src.lax import lax as lax_internal
from jax._src.numpy.lax_numpy import _promote_dtypes, _promote_dtypes_inexact
from jax._src.numpy.util import _parse_numpydoc, ParsedDoc, _wraps
from jax._src.util import prod, safe_zip
from jax.experimental import array
from jax._src import array
from jax.config import config
config.parse_flags_with_absl()

View File

@ -39,7 +39,7 @@ from jax.interpreters import xla
from jax.interpreters import mlir
from jax.interpreters import batching
from jax.interpreters import pxla
from jax.experimental import array
from jax._src import array
from jax._src.lib.mlir.dialects import mhlo
from jax._src import dispatch
from jax._src import dtypes

View File

@ -40,8 +40,8 @@ from jax.experimental import maps
from jax.experimental import PartitionSpec as P
from jax.experimental.maps import xmap
from jax.experimental import global_device_array
from jax.experimental import array
from jax.experimental.sharding import MeshPspecSharding, Sharding, OpShardingSharding
from jax._src import array
from jax._src.sharding import MeshPspecSharding, Sharding, OpShardingSharding
import jax.experimental.pjit as pjit_lib
from jax.experimental.pjit import (pjit, pjit_p, with_sharding_constraint,
FROM_GDA, AUTO)

View File

@ -46,8 +46,8 @@ from jax._src.lib import xla_bridge
from jax._src.util import prod, safe_map, safe_zip
from jax.interpreters import pxla
from jax.interpreters import xla
from jax.experimental import array
from jax.experimental.sharding import PmapSharding
from jax._src import array
from jax._src.sharding import PmapSharding
from jax.ad_checkpoint import checkpoint as new_checkpoint
from jax.config import config

View File

@ -27,7 +27,7 @@ from jax._src import typing
from jax import lax
import jax.numpy as jnp
from jax.experimental.array import ArrayImpl
from jax._src.array import ArrayImpl
from absl.testing import absltest
import numpy as np

View File

@ -35,8 +35,8 @@ from jax import core
from jax.core import NamedShape
from jax.experimental import maps
from jax.experimental import global_device_array
from jax.experimental import array
from jax.experimental.sharding import MeshPspecSharding
from jax._src import array
from jax._src.sharding import MeshPspecSharding
from jax.experimental.pjit import pjit, with_sharding_constraint
from jax.experimental.pjit import PartitionSpec as P
from jax.experimental.maps import xmap, serial_loop, SerialLoop