Delete some dead code that pertained to sharded_jit.

sharded_jit is long gone.

PiperOrigin-RevId: 523711890
This commit is contained in:
Peter Hawkins 2023-04-12 08:49:07 -07:00 committed by jax authors
parent 3ca7d67e8d
commit 33acdc0e40
5 changed files with 49 additions and 220 deletions

View File

@ -31,7 +31,7 @@ import numpy as np
import jax
from jax.errors import JAXTypeError
from jax.tree_util import tree_flatten, tree_map
from jax.tree_util import tree_map
from jax._src import api_util
from jax._src import core
@ -712,7 +712,7 @@ class ParallelCallableInfo:
class ShardInfo(NamedTuple):
sharded_avals: Sequence[core.AbstractValue]
out_sharded_avals: Sequence[core.AbstractValue]
out_sharded_avals: Sequence[core.ShapedArray]
global_sharded_avals: Sequence[core.AbstractValue]
num_local_shards: int
num_global_shards: int
@ -761,25 +761,16 @@ def stage_parallel_callable(
check_multihost_collective_allowlist(jaxpr)
replicas = find_replicas(jaxpr, pci.axis_size, pci.global_axis_size)
parts = find_partitions(jaxpr)
num_local_shards = replicas.num_local_replicas * parts.local_num_partitions
num_global_shards = replicas.num_global_replicas * parts.num_partitions
num_local_shards = replicas.num_local_replicas
num_global_shards = replicas.num_global_replicas
shards = ShardInfo(
sharded_avals, out_sharded_avals, sharded_avals,
num_local_shards, num_global_shards)
return jaxpr, consts, replicas, parts, shards
return jaxpr, consts, replicas, shards
def _shardings_to_mlir_shardings(
shardings: Optional[Sequence[PartitionsOrReplicated]]
) -> Optional[Sequence[Optional[xc.OpSharding]]]:
if shardings is None:
return None
return [xla.sharding_to_proto(s) for s in shardings]
@profiler.annotate_function
def lower_parallel_callable(
fun: lu.WrappedFun,
@ -817,8 +808,7 @@ def lower_parallel_callable(
if xb.process_count(backend) > 1:
if devices:
# This allows each host in a multi-host pmap to run on a different number
# of devices, but precludes nested sharding (i.e. inner pmaps or
# sharded_jits).
# of devices, but precludes nested sharding (i.e. inner pmaps).
no_nested_sharding = True
else:
# This assumes all hosts run on the same number of devices. We make sure
@ -830,18 +820,12 @@ def lower_parallel_callable(
pci = ParallelCallableInfo(
name, backend, axis_name, axis_size, global_axis_size, devices,
in_axes, out_axes_thunk, avals)
jaxpr, consts, replicas, parts, shards = stage_parallel_callable(pci, fun)
jaxpr, consts, replicas, shards = stage_parallel_callable(pci, fun)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("sharded_avals: %s", shards.sharded_avals)
logger.debug("global_sharded_avals: %s", shards.global_sharded_avals)
logger.debug("num_replicas: %d num_local_replicas: %d",
replicas.num_global_replicas, replicas.num_local_replicas)
logger.debug("num_partitions: %d local_num_partitions: %d",
parts.num_partitions, parts.local_num_partitions)
logger.debug("arg_parts: %s", parts.arg_parts)
logger.debug("local_arg_parts: %s", parts.local_arg_parts)
logger.debug("out_parts: %s", parts.out_parts)
logger.debug("local_out_parts: %s", parts.local_out_parts)
logger.debug("devices: %s", devices)
logger.debug("local_devices: %s", pci.local_devices)
@ -858,24 +842,21 @@ def lower_parallel_callable(
f"On multi-host platforms, pmapped functions must run across all "
f"devices, i.e. num_replicas * num_partitions should equal the "
f"number of local devices. Got "
f"num_replicas={replicas.num_local_replicas}, "
f"num_partitions={parts.num_partitions}, and "
f"num_replicas={replicas.num_local_replicas}, and "
f"num_local_devices={xb.local_device_count(backend)}")
if no_nested_sharding and (
replicas.jaxpr_replicas > 1 or parts.num_partitions > 1):
if no_nested_sharding and replicas.jaxpr_replicas > 1:
raise ValueError(
f"On multi-host platforms, pmapped functions that both have `devices` "
f"specified and contain an inner_pmap or sharded_jit must specify an "
f"specified and contain an inner_pmap must specify an "
f"`axis_size` (or remove the `devices` argument). Got nested_replicas="
f"{replicas.jaxpr_replicas} and nested_partitions={parts.num_partitions}")
f"{replicas.jaxpr_replicas}")
log_priority = logging.WARNING if config.jax_log_compiles else logging.DEBUG
logger.log(log_priority,
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d"
" num_partitions=%d)", fun.__name__, id(fun),
shards.num_global_shards, avals, replicas.num_global_replicas,
parts.num_partitions)
"Compiling %s (%d) for %d devices with args %s. (num_replicas=%d)",
fun.__name__, id(fun),
shards.num_global_shards, avals, replicas.num_global_replicas)
axis_env = sharding_impls.AxisEnv(
replicas.num_global_replicas, (axis_name,), (global_axis_size,))
@ -903,14 +884,14 @@ def lower_parallel_callable(
name_stack,
donated_invars,
replicated_args=replicated_args,
arg_shardings=_shardings_to_mlir_shardings(parts.arg_parts),
result_shardings=_shardings_to_mlir_shardings(parts.out_parts),
arg_shardings=None,
result_shardings=None,
arg_names=jaxpr.debug_info and jaxpr.debug_info.arg_names,
result_names=jaxpr.debug_info and jaxpr.debug_info.result_paths)
module, keepalive, host_callbacks = (
lowering_result.module, lowering_result.keepalive,
lowering_result.host_callbacks)
return PmapComputation(module, pci=pci, replicas=replicas, parts=parts,
return PmapComputation(module, pci=pci, replicas=replicas,
shards=shards, tuple_args=tuple_args,
unordered_effects=unordered_effects,
ordered_effects=ordered_effects,
@ -951,6 +932,9 @@ class PmapComputation(stages.XlaLowering):
return executable
return self._executable
def _cast_to_shaped_array(aval: core.AbstractValue) -> ShapedArray:
assert isinstance(aval, ShapedArray), aval
return cast(ShapedArray, aval)
@dataclasses.dataclass
class UnloadedPmapExecutable:
@ -969,7 +953,6 @@ class UnloadedPmapExecutable:
def from_hlo(xla_computation,
pci: ParallelCallableInfo,
replicas: ReplicaInfo,
parts: PartitionInfo,
shards: ShardInfo,
tuple_args: bool,
unordered_effects: List[core.Effect],
@ -981,11 +964,10 @@ class UnloadedPmapExecutable:
if devices is None:
if shards.num_global_shards > xb.device_count(pci.backend):
msg = ("compiling computation that requires {} logical devices, but only {} XLA "
"devices are available (num_replicas={}, num_partitions={})")
"devices are available (num_replicas={})")
raise ValueError(msg.format(shards.num_global_shards,
xb.device_count(pci.backend),
replicas.num_global_replicas,
parts.num_partitions))
replicas.num_global_replicas))
# On a single host, we simply grab the first N devices from jax.devices().
# In the single host case, we want the default device order of pmap to
# match jax.devices().
@ -1023,16 +1005,14 @@ class UnloadedPmapExecutable:
# get_default_device_assignment() returns 2D assignment, caller may have
# provided 1D list of devices).
# Convert to 2D in case it's 1D and we have > 1 partitions.
num_partitions = 1
device_assignment: np.ndarray = np.array(devices).reshape(
(replicas.num_global_replicas, parts.num_partitions))
# TODO(b/162356737): Enabling SPMD partitioning causes issues with some
# non-partitioned workloads, so disable unless needed.
use_spmd_partitioning = parts.num_partitions > 1
(replicas.num_global_replicas, num_partitions))
compile_options = xb.get_compile_options(
num_replicas=replicas.num_global_replicas,
num_partitions=parts.num_partitions,
num_partitions=num_partitions,
device_assignment=device_assignment,
use_spmd_partitioning=use_spmd_partitioning,
use_spmd_partitioning=False,
env_options_overrides=compiler_options,
)
compile_options.parameter_is_tupled_arguments = tuple_args
@ -1042,34 +1022,24 @@ class UnloadedPmapExecutable:
d for d in device_assignment.flat if d.process_index == process_index
])
local_arg_parts_ = parts.local_arg_parts or [None] * len(pci.avals)
input_sharding_specs = [
sharding_specs.pmap_sharding_spec(
replicas.num_local_replicas, pci.axis_size,
parts.local_num_partitions, arg_parts,
cast(ShapedArray, aval).shape, in_axis)
for aval, arg_parts, in_axis in safe_zip(
shards.sharded_avals, local_arg_parts_, pci.in_axes)]
in_shardings = _get_pmap_sharding(local_device_assignment, input_sharding_specs)
nouts = len(shards.out_sharded_avals)
for aval, in_axis in safe_zip(shards.sharded_avals, pci.in_axes)]
in_shardings = _get_pmap_sharding(local_device_assignment,
input_sharding_specs)
out_parts = (None,) * nouts if parts.out_parts is None else parts.out_parts
local_out_parts = (None,) * nouts if parts.local_out_parts is None else parts.local_out_parts
local_out_avals = [
get_local_aval(aval, parts, lparts)
for aval, parts, lparts
in safe_zip(shards.out_sharded_avals, out_parts, local_out_parts)]
local_unmapped_avals = [
core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval)
_cast_to_shaped_array(
core.unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval))
if out_axis is not None else aval
for aval, out_axis in safe_zip(local_out_avals, pci.out_axes)]
for aval, out_axis in safe_zip(shards.out_sharded_avals, pci.out_axes)]
out_specs = [
sharding_specs.pmap_sharding_spec(
replicas.num_local_replicas, pci.axis_size,
parts.local_num_partitions, out_parts, aval.shape, out_axis)
for out_parts, aval, out_axis in safe_zip(
local_out_parts, local_out_avals, pci.out_axes)]
replicas.num_local_replicas, pci.axis_size, aval.shape, out_axis)
for aval, out_axis in safe_zip(
shards.out_sharded_avals, pci.out_axes)]
out_shardings = _get_pmap_sharding(local_device_assignment, out_specs)
if hasattr(pci.backend, "compile_replicated"):
@ -1193,124 +1163,6 @@ def check_multihost_collective_allowlist(jaxpr):
raise TypeError(msg.format(", ".join(map(str, bad_collectives))))
PartitionsOrReplicated = Optional[Tuple[int, ...]]
class PartitionInfo(NamedTuple):
arg_parts: Optional[Tuple[PartitionsOrReplicated, ...]]
out_parts: Optional[Tuple[PartitionsOrReplicated, ...]]
num_partitions: int
local_arg_parts: Optional[Tuple[PartitionsOrReplicated, ...]]
local_out_parts: Optional[Tuple[PartitionsOrReplicated, ...]]
local_num_partitions: Optional[int]
def _find_partitions(jaxpr):
"""Returns (in_partitions, out_partitions, num_partitions, local_in_parts,
local_out_parts, local_num_partitions).
"""
for eqn in jaxpr.eqns:
if eqn.primitive.name == "sharded_call":
if len(jaxpr.eqns) > 1:
raise NotImplementedError(
"pmap of sharded_jit + non-sharded operations not yet implemented.")
num_partitions = reconcile_num_partitions(eqn.params["call_jaxpr"],
eqn.params["nparts"])
return (eqn.params["in_parts"],
eqn.params["out_parts_thunk"](),
num_partitions,
eqn.params["local_in_parts"],
eqn.params["local_out_parts_thunk"](),
eqn.params["local_nparts"])
return None, None, 1, None, None, None
def find_partitions(jaxpr) -> PartitionInfo:
(arg_parts, out_parts, num_partitions, local_arg_parts, local_out_parts,
local_num_partitions) = _find_partitions(jaxpr)
if local_num_partitions is None:
local_num_partitions = num_partitions
if local_arg_parts is None:
local_arg_parts = arg_parts
if local_out_parts is None:
local_out_parts = out_parts
return PartitionInfo(arg_parts, out_parts, num_partitions,
local_arg_parts, local_out_parts, local_num_partitions)
def reconcile_num_partitions(jaxpr, outer_num_parts: Optional[int]):
"""Returns the total number of partitions to use.
Validates that any inner partitioning matches outer_num_parts if provided, and
returns the number of partitions to use based on outer_num_parts and any inner
partitioning.
"""
inner_num_parts = _inner_partitions(jaxpr, outer_num_parts)
if outer_num_parts is None and inner_num_parts is None:
# No partitions specified anywhere, everything is replicated.
return 1
if outer_num_parts is None:
return inner_num_parts
return outer_num_parts
def _inner_partitions(jaxpr, expected_num_parts: Optional[int]):
"""Returns the total number of partitions from PartitionSpecs inside `jaxpr`.
Also validates that this number matches `expected_num_parts` if provided.
"""
for eqn in jaxpr.eqns:
if eqn.primitive.name in ["sharding_constraint", "infeed"]:
parts = eqn.params["partitions"]
nparts = get_num_partitions(parts)
if expected_num_parts is None:
expected_num_parts = nparts
elif nparts is not None and nparts != expected_num_parts:
# TODO(skye): raise this error as we trace the jaxpr
raise ValueError(
f"with_sharding_constraint with partitions={parts} "
f"(total partitions: {nparts}) doesn't match expected number of "
f"partitions: {expected_num_parts}. If these partitions look "
f"right, check outer sharded_jit and/or other "
f"with_sharding_constraint calls.")
else:
for subjaxpr in core.jaxprs_in_params(eqn.params):
expected_num_parts = _inner_partitions(subjaxpr, expected_num_parts)
return expected_num_parts
def get_num_partitions(*partitions):
partition_specs = tree_flatten(partitions)[0]
if len(partition_specs) == 0:
# Everything is specified as replicated (all Nones).
return None
num_partitions_set = {np.prod(spec) for spec in partition_specs}
if len(num_partitions_set) > 1:
raise ValueError(
f"All partition specs must use the same number of total partitions, "
f"got {partitions}, with distinct number of partitions "
f"{num_partitions_set} (the total number of partitions is the product "
f"of a partition spec)")
assert len(num_partitions_set) == 1
return num_partitions_set.pop()
def get_local_aval(global_aval, global_parts: PartitionsOrReplicated,
local_parts: PartitionsOrReplicated):
if global_parts is None:
return global_aval
assert local_parts is not None
local_shape = [_safe_div(dim, _safe_div(ngparts, nlparts))
for dim, ngparts, nlparts
in safe_zip(global_aval.shape, global_parts, local_parts)]
return global_aval.update(shape=local_shape)
def _safe_div(x, y):
result, ragged = divmod(x, y)
assert not ragged, f"{x} % {y} != 0"
return result
class InputsHandler:
__slots__ = ("handler", "local_devices", "in_shardings", "input_indices")
@ -1413,7 +1265,7 @@ def replicate(val, axis_size, nrep, devices=None, backend=None, in_axis=0):
replicated_aval = aval
# TODO(skye): figure out how partitioning should work here
sharding_spec = sharding_specs.pmap_sharding_spec(
nrep, axis_size, 1, None, aval.shape, in_axis)
nrep, axis_size, aval.shape, in_axis)
buf = jax.device_put(val, devices[0])
sharding = sharding_impls.PmapSharding(
@ -3144,7 +2996,6 @@ def get_array_mapping(pspec: PartitionSpec) -> ArrayMappingOrAutoOrUnspecified:
_forbidden_primitives = {
'xla_pmap': 'pmap',
'sharded_call': 'sharded_jit',
}
def _sanitize_mesh_jaxpr(jaxpr):
if isinstance(jaxpr, core.ClosedJaxpr):
@ -3246,5 +3097,7 @@ def _create_pmap_sharding_spec(aval, sharded_dim=0, sharded_dim_size=None):
def _pmap_sharding_spec(nrep, axis_size, npart, parts,
sharded_aval, map_axis: Optional[int]) -> ShardingSpec:
return sharding_specs.pmap_sharding_spec(nrep, axis_size, npart, parts,
sharded_aval.shape, map_axis)
assert npart == 1, npart
assert parts is None, parts
return sharding_specs.pmap_sharding_spec(
nrep, axis_size, sharded_aval.shape, map_axis)

View File

@ -254,22 +254,6 @@ def spec_to_indices(shape: Sequence[int],
return tuple(spec.indices(shape).flat) # type: ignore
def partitioned_sharding_spec(num_partitions: int,
partitions: Optional[Sequence[int]],
shape: Sequence[int]) -> ShardingSpec:
if partitions is None:
maybe_replicate = () if num_partitions == 1 else (Replicated(num_partitions),)
return ShardingSpec(
sharding=[_UNSHARDED_INSTANCE] * len(shape),
mesh_mapping=maybe_replicate)
else:
assert len(partitions) == len(shape)
return ShardingSpec(
# Chunked expects a list of integers
sharding=map(Chunked, [[x] for x in partitions]),
mesh_mapping=map(ShardedAxis, range(len(partitions))))
def make_sharding_spec(axis_sizes, mesh_axis_pos, num_dimensions, aval_axes):
mesh_mapping = [Replicated(axis_size) for axis_size in axis_sizes.values()]
sharding = [_UNSHARDED_INSTANCE] * num_dimensions
@ -292,15 +276,12 @@ def new_mesh_sharding_specs(axis_sizes, axis_names):
mesh_axis_pos = {name: i for i, name in enumerate(axis_names)}
return functools.partial(make_sharding_spec, axis_sizes, mesh_axis_pos)
def pmap_sharding_spec(nrep, axis_size, npart, parts,
sharded_shape: Sequence[int],
def pmap_sharding_spec(nrep, axis_size, sharded_shape: Sequence[int],
map_axis: Optional[int]) -> ShardingSpec:
"""Sharding spec for arguments or results of a pmap.
Args:
nrep: number of local XLA replicas (product of local axis sizes)
axis_size: local axis size for outer pmap
npart: total number of XLA partitions (required by sharded_jit calls)
parts: the partitioning of the value or None
sharded_aval: the aval of the value inside the outer pmap, an instance of
a ShapedArray.
map_axis: the axis along which the value is mapped in the outer pmap
@ -309,8 +290,8 @@ def pmap_sharding_spec(nrep, axis_size, npart, parts,
"""
replication_factor, ragged = divmod(nrep, axis_size)
assert not ragged
# get the sharding spec from inner sharded_jits as if we weren't in a pmap
pspec = partitioned_sharding_spec(npart, parts, sharded_shape)
pspec = ShardingSpec(sharding=[_UNSHARDED_INSTANCE] * len(sharded_shape),
mesh_mapping=())
maybe_replicate = () if replication_factor == 1 else (Replicated(replication_factor),)
if map_axis is not None:
sharded_in_axis = sum(not isinstance(s, NoSharding) for s in pspec.sharding[:map_axis])
@ -342,5 +323,5 @@ def create_pmap_sharding_spec(shape: Tuple[int, ...], sharded_dim: int = 0,
assert sharded_dim_size is not None
sharded_shape = shape
return pmap_sharding_spec(sharded_dim_size, sharded_dim_size, 1, None,
sharded_shape, sharded_dim)
return pmap_sharding_spec(sharded_dim_size, sharded_dim_size, sharded_shape,
sharded_dim)

View File

@ -3063,8 +3063,10 @@ def _custom_lin(*args: TfVal, **_) -> Sequence[TfVal]:
tf_impl[ad.custom_lin_p] = _custom_lin
PartitionsOrReplicated = Optional[Tuple[int, ...]]
def split_to_logical_devices(tensor: TfVal,
partition_dimensions: pxla.PartitionsOrReplicated):
partition_dimensions: PartitionsOrReplicated):
"""Like TPUMPStrategy.experimental_split_to_logical_devices.
For jax2tf purposes we want to avoid needing to thread the `strategy` object

View File

@ -192,8 +192,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
for p in all_primitives:
if p.name == "axis_index":
continue
# TODO: remove once we delete sharded_jit.py
if p.name in ["sharded_call", "sharding_constraint"]:
if p.name == "sharding_constraint":
continue
# TODO: Remove once tensorflow is 2.10.0 everywhere.
if p.name == "optimization_barrier":

View File

@ -24,8 +24,6 @@ from jax._src.interpreters.pxla import (
MeshDimAssignment as MeshDimAssignment,
MeshExecutable as MeshExecutable,
ParallelCallableInfo as ParallelCallableInfo,
PartitionInfo as PartitionInfo,
PartitionsOrReplicated as PartitionsOrReplicated,
PmapComputation as PmapComputation,
PmapExecutable as PmapExecutable,
PxlaResultHandler as PxlaResultHandler,
@ -45,11 +43,8 @@ from jax._src.interpreters.pxla import (
array_types as array_types,
custom_resource_typing_rules as custom_resource_typing_rules,
device_put as _deprecated_device_put,
find_partitions as find_partitions,
find_replicas as find_replicas,
full_to_shard_p as full_to_shard_p,
get_local_aval as get_local_aval,
get_num_partitions as get_num_partitions,
global_aval_to_result_handler as global_aval_to_result_handler,
global_avals_to_results_handler as global_avals_to_results_handler,
global_result_handlers as global_result_handlers,
@ -63,7 +58,6 @@ from jax._src.interpreters.pxla import (
mesh_sharding_specs as mesh_sharding_specs,
multi_host_supported_collectives as multi_host_supported_collectives,
parallel_callable as parallel_callable,
reconcile_num_partitions as reconcile_num_partitions,
replicate as replicate,
resource_typecheck as resource_typecheck,
shard_arg_handlers as shard_arg_handlers,