Remove dead code from pxla.py

PiperOrigin-RevId: 521003815
This commit is contained in:
Yash Katariya 2023-03-31 13:45:37 -07:00 committed by jax authors
parent db025df030
commit 0b31e8b822
2 changed files with 0 additions and 80 deletions

View File

@ -39,7 +39,6 @@ import itertools as it
import logging
import math
import sys
import threading
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast,
TYPE_CHECKING)
@ -676,16 +675,6 @@ else:
raise RuntimeError("ShardedDeviceArray is a backward compatibility shim "
"and cannot be instantiated.")
def _one_replica_buffer_indices(indices: Tuple[Index, ...]):
"""Returns a set of buffer-indices containing one complete copy of the array."""
one_replica_indices = []
seen_index_hashes = set()
for i, index in enumerate(indices):
hashed_index = _hashable_index(index)
if hashed_index not in seen_index_hashes:
one_replica_indices.append(i)
seen_index_hashes.add(hashed_index)
return one_replica_indices
def _hashable_index(idx):
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx)
@ -1635,17 +1624,6 @@ def get_num_partitions(*partitions):
return num_partitions_set.pop()
def get_global_aval(local_aval, global_parts: PartitionsOrReplicated,
local_parts: PartitionsOrReplicated):
if global_parts is None:
return local_aval
assert local_parts is not None
global_shape = [dim * _safe_div(ngparts, nlparts)
for dim, ngparts, nlparts
in safe_zip(local_aval.shape, global_parts, local_parts)]
return local_aval.update(shape=global_shape)
def get_local_aval(global_aval, global_parts: PartitionsOrReplicated,
local_parts: PartitionsOrReplicated):
if global_parts is None:
@ -1698,26 +1676,6 @@ class ResultsHandler:
return [h(bufs) for h, bufs in safe_zip(self.handlers, out_bufs)]
def _get_sharding_specs(
shardings: Sequence[sharding_impls.XLACompatibleSharding], avals: Sequence[ShapedArray]
) -> Sequence[ShardingSpec]:
if all(isinstance(s, sharding_impls.PmapSharding) for s in shardings):
return [s.sharding_spec for s in shardings] # type: ignore
elif all(isinstance(s, sharding_impls.NamedSharding) for s in shardings):
out = []
for aval, s in safe_zip(avals, shardings):
ns = cast(sharding_impls.NamedSharding, s)
out.append(
new_mesh_sharding_specs(ns.mesh.shape, ns.mesh.axis_names)(
aval.ndim, get_array_mapping(ns.spec)
)
)
return out
else:
raise ValueError('Getting sharding spec is only supported for '
"PmapSharding and NamedSharding, "
f"but got {shardings}.")
def local_avals_to_results_handler(
unmapped_local_out_avals: Sequence[ShapedArray],
local_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> ResultsHandler:
@ -3493,39 +3451,6 @@ def maybe_extend_axis_env(*args, **kwargs):
with core.extend_axis_env(*args, **kwargs):
yield
class DynamicAxisEnvFrame:
__slots__ = ["name", "pmap_trace", "hard_size"]
def __init__(self, name, pmap_trace, hard_size):
self.name = name
self.pmap_trace = pmap_trace
self.hard_size = hard_size
class DynamicAxisEnv(list):
def __contains__(self, axis_name):
return axis_name in (frame.name for frame in self)
def __getitem__(self, axis_name):
if axis_name not in self:
raise NameError(f"unbound axis name: {axis_name}")
for frame in reversed(self):
if frame.name == axis_name:
return frame
raise AssertionError
@property
def sizes(self):
return tuple(frame.hard_size for frame in self)
@property
def nreps(self):
return math.prod(frame.hard_size for frame in self)
class _ThreadLocalState(threading.local):
def __init__(self):
self.dynamic_axis_env = DynamicAxisEnv()
_thread_local_state = _ThreadLocalState()
def device_put(x, devices: Sequence[xc.ArrayImpl],
replicate: bool=False) -> List[xc.ArrayImpl]:

View File

@ -18,8 +18,6 @@ from jax._src.interpreters.pxla import (
ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified,
AvalDimSharding as AvalDimSharding,
Chunked as Chunked,
DynamicAxisEnv as DynamicAxisEnv,
DynamicAxisEnvFrame as DynamicAxisEnvFrame,
EmapInfo as EmapInfo,
ExecuteReplicated as ExecuteReplicated,
Index as Index,
@ -55,9 +53,7 @@ from jax._src.interpreters.pxla import (
_UNSPECIFIED as _UNSPECIFIED,
_create_pmap_sharding_spec as _create_pmap_sharding_spec,
_get_and_check_device_assignment as _get_and_check_device_assignment,
_get_sharding_specs as _get_sharding_specs,
_is_unspecified as _is_unspecified,
_one_replica_buffer_indices as _one_replica_buffer_indices,
_pmap_sharding_spec as _pmap_sharding_spec,
are_op_shardings_equal as are_op_shardings_equal,
array_mapping_to_axis_resources as array_mapping_to_axis_resources,
@ -67,7 +63,6 @@ from jax._src.interpreters.pxla import (
find_partitions as find_partitions,
find_replicas as find_replicas,
full_to_shard_p as full_to_shard_p,
get_global_aval as get_global_aval,
get_local_aval as get_local_aval,
get_num_partitions as get_num_partitions,
global_aval_to_result_handler as global_aval_to_result_handler,