mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Remove dead code from pxla.py
PiperOrigin-RevId: 521003815
This commit is contained in:
parent
db025df030
commit
0b31e8b822
@ -39,7 +39,6 @@ import itertools as it
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import threading
|
||||
from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet,
|
||||
Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast,
|
||||
TYPE_CHECKING)
|
||||
@ -676,16 +675,6 @@ else:
|
||||
raise RuntimeError("ShardedDeviceArray is a backward compatibility shim "
|
||||
"and cannot be instantiated.")
|
||||
|
||||
def _one_replica_buffer_indices(indices: Tuple[Index, ...]):
|
||||
"""Returns a set of buffer-indices containing one complete copy of the array."""
|
||||
one_replica_indices = []
|
||||
seen_index_hashes = set()
|
||||
for i, index in enumerate(indices):
|
||||
hashed_index = _hashable_index(index)
|
||||
if hashed_index not in seen_index_hashes:
|
||||
one_replica_indices.append(i)
|
||||
seen_index_hashes.add(hashed_index)
|
||||
return one_replica_indices
|
||||
|
||||
def _hashable_index(idx):
|
||||
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx)
|
||||
@ -1635,17 +1624,6 @@ def get_num_partitions(*partitions):
|
||||
return num_partitions_set.pop()
|
||||
|
||||
|
||||
def get_global_aval(local_aval, global_parts: PartitionsOrReplicated,
|
||||
local_parts: PartitionsOrReplicated):
|
||||
if global_parts is None:
|
||||
return local_aval
|
||||
assert local_parts is not None
|
||||
global_shape = [dim * _safe_div(ngparts, nlparts)
|
||||
for dim, ngparts, nlparts
|
||||
in safe_zip(local_aval.shape, global_parts, local_parts)]
|
||||
return local_aval.update(shape=global_shape)
|
||||
|
||||
|
||||
def get_local_aval(global_aval, global_parts: PartitionsOrReplicated,
|
||||
local_parts: PartitionsOrReplicated):
|
||||
if global_parts is None:
|
||||
@ -1698,26 +1676,6 @@ class ResultsHandler:
|
||||
return [h(bufs) for h, bufs in safe_zip(self.handlers, out_bufs)]
|
||||
|
||||
|
||||
def _get_sharding_specs(
|
||||
shardings: Sequence[sharding_impls.XLACompatibleSharding], avals: Sequence[ShapedArray]
|
||||
) -> Sequence[ShardingSpec]:
|
||||
if all(isinstance(s, sharding_impls.PmapSharding) for s in shardings):
|
||||
return [s.sharding_spec for s in shardings] # type: ignore
|
||||
elif all(isinstance(s, sharding_impls.NamedSharding) for s in shardings):
|
||||
out = []
|
||||
for aval, s in safe_zip(avals, shardings):
|
||||
ns = cast(sharding_impls.NamedSharding, s)
|
||||
out.append(
|
||||
new_mesh_sharding_specs(ns.mesh.shape, ns.mesh.axis_names)(
|
||||
aval.ndim, get_array_mapping(ns.spec)
|
||||
)
|
||||
)
|
||||
return out
|
||||
else:
|
||||
raise ValueError('Getting sharding spec is only supported for '
|
||||
"PmapSharding and NamedSharding, "
|
||||
f"but got {shardings}.")
|
||||
|
||||
def local_avals_to_results_handler(
|
||||
unmapped_local_out_avals: Sequence[ShapedArray],
|
||||
local_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> ResultsHandler:
|
||||
@ -3493,39 +3451,6 @@ def maybe_extend_axis_env(*args, **kwargs):
|
||||
with core.extend_axis_env(*args, **kwargs):
|
||||
yield
|
||||
|
||||
class DynamicAxisEnvFrame:
|
||||
__slots__ = ["name", "pmap_trace", "hard_size"]
|
||||
def __init__(self, name, pmap_trace, hard_size):
|
||||
self.name = name
|
||||
self.pmap_trace = pmap_trace
|
||||
self.hard_size = hard_size
|
||||
|
||||
class DynamicAxisEnv(list):
|
||||
def __contains__(self, axis_name):
|
||||
return axis_name in (frame.name for frame in self)
|
||||
|
||||
def __getitem__(self, axis_name):
|
||||
if axis_name not in self:
|
||||
raise NameError(f"unbound axis name: {axis_name}")
|
||||
for frame in reversed(self):
|
||||
if frame.name == axis_name:
|
||||
return frame
|
||||
|
||||
raise AssertionError
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return tuple(frame.hard_size for frame in self)
|
||||
|
||||
@property
|
||||
def nreps(self):
|
||||
return math.prod(frame.hard_size for frame in self)
|
||||
|
||||
class _ThreadLocalState(threading.local):
|
||||
def __init__(self):
|
||||
self.dynamic_axis_env = DynamicAxisEnv()
|
||||
|
||||
_thread_local_state = _ThreadLocalState()
|
||||
|
||||
def device_put(x, devices: Sequence[xc.ArrayImpl],
|
||||
replicate: bool=False) -> List[xc.ArrayImpl]:
|
||||
|
@ -18,8 +18,6 @@ from jax._src.interpreters.pxla import (
|
||||
ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified,
|
||||
AvalDimSharding as AvalDimSharding,
|
||||
Chunked as Chunked,
|
||||
DynamicAxisEnv as DynamicAxisEnv,
|
||||
DynamicAxisEnvFrame as DynamicAxisEnvFrame,
|
||||
EmapInfo as EmapInfo,
|
||||
ExecuteReplicated as ExecuteReplicated,
|
||||
Index as Index,
|
||||
@ -55,9 +53,7 @@ from jax._src.interpreters.pxla import (
|
||||
_UNSPECIFIED as _UNSPECIFIED,
|
||||
_create_pmap_sharding_spec as _create_pmap_sharding_spec,
|
||||
_get_and_check_device_assignment as _get_and_check_device_assignment,
|
||||
_get_sharding_specs as _get_sharding_specs,
|
||||
_is_unspecified as _is_unspecified,
|
||||
_one_replica_buffer_indices as _one_replica_buffer_indices,
|
||||
_pmap_sharding_spec as _pmap_sharding_spec,
|
||||
are_op_shardings_equal as are_op_shardings_equal,
|
||||
array_mapping_to_axis_resources as array_mapping_to_axis_resources,
|
||||
@ -67,7 +63,6 @@ from jax._src.interpreters.pxla import (
|
||||
find_partitions as find_partitions,
|
||||
find_replicas as find_replicas,
|
||||
full_to_shard_p as full_to_shard_p,
|
||||
get_global_aval as get_global_aval,
|
||||
get_local_aval as get_local_aval,
|
||||
get_num_partitions as get_num_partitions,
|
||||
global_aval_to_result_handler as global_aval_to_result_handler,
|
||||
|
Loading…
x
Reference in New Issue
Block a user