From 0b31e8b8227e7542046bbae17fff2c37f188bde5 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 31 Mar 2023 13:45:37 -0700 Subject: [PATCH] Remove dead code from pxla.py PiperOrigin-RevId: 521003815 --- jax/_src/interpreters/pxla.py | 75 ----------------------------------- jax/interpreters/pxla.py | 5 --- 2 files changed, 80 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index be5e02af5..78b46f8e4 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -39,7 +39,6 @@ import itertools as it import logging import math import sys -import threading from typing import (Any, Callable, Dict, List, NamedTuple, Optional, FrozenSet, Sequence, Set, Tuple, Type, Union, Iterable, Mapping, cast, TYPE_CHECKING) @@ -676,16 +675,6 @@ else: raise RuntimeError("ShardedDeviceArray is a backward compatibility shim " "and cannot be instantiated.") -def _one_replica_buffer_indices(indices: Tuple[Index, ...]): - """Returns a set of buffer-indices containing one complete copy of the array.""" - one_replica_indices = [] - seen_index_hashes = set() - for i, index in enumerate(indices): - hashed_index = _hashable_index(index) - if hashed_index not in seen_index_hashes: - one_replica_indices.append(i) - seen_index_hashes.add(hashed_index) - return one_replica_indices def _hashable_index(idx): return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx) @@ -1635,17 +1624,6 @@ def get_num_partitions(*partitions): return num_partitions_set.pop() -def get_global_aval(local_aval, global_parts: PartitionsOrReplicated, - local_parts: PartitionsOrReplicated): - if global_parts is None: - return local_aval - assert local_parts is not None - global_shape = [dim * _safe_div(ngparts, nlparts) - for dim, ngparts, nlparts - in safe_zip(local_aval.shape, global_parts, local_parts)] - return local_aval.update(shape=global_shape) - - def get_local_aval(global_aval, global_parts: PartitionsOrReplicated, local_parts: PartitionsOrReplicated): if global_parts is None: @@ -1698,26 +1676,6 @@ class ResultsHandler: return [h(bufs) for h, bufs in safe_zip(self.handlers, out_bufs)] -def _get_sharding_specs( - shardings: Sequence[sharding_impls.XLACompatibleSharding], avals: Sequence[ShapedArray] -) -> Sequence[ShardingSpec]: - if all(isinstance(s, sharding_impls.PmapSharding) for s in shardings): - return [s.sharding_spec for s in shardings] # type: ignore - elif all(isinstance(s, sharding_impls.NamedSharding) for s in shardings): - out = [] - for aval, s in safe_zip(avals, shardings): - ns = cast(sharding_impls.NamedSharding, s) - out.append( - new_mesh_sharding_specs(ns.mesh.shape, ns.mesh.axis_names)( - aval.ndim, get_array_mapping(ns.spec) - ) - ) - return out - else: - raise ValueError('Getting sharding spec is only supported for ' - "PmapSharding and NamedSharding, " - f"but got {shardings}.") - def local_avals_to_results_handler( unmapped_local_out_avals: Sequence[ShapedArray], local_shardings: Sequence[sharding_impls.XLACompatibleSharding]) -> ResultsHandler: @@ -3493,39 +3451,6 @@ def maybe_extend_axis_env(*args, **kwargs): with core.extend_axis_env(*args, **kwargs): yield -class DynamicAxisEnvFrame: - __slots__ = ["name", "pmap_trace", "hard_size"] - def __init__(self, name, pmap_trace, hard_size): - self.name = name - self.pmap_trace = pmap_trace - self.hard_size = hard_size - -class DynamicAxisEnv(list): - def __contains__(self, axis_name): - return axis_name in (frame.name for frame in self) - - def __getitem__(self, axis_name): - if axis_name not in self: - raise NameError(f"unbound axis name: {axis_name}") - for frame in reversed(self): - if frame.name == axis_name: - return frame - - raise AssertionError - - @property - def sizes(self): - return tuple(frame.hard_size for frame in self) - - @property - def nreps(self): - return math.prod(frame.hard_size for frame in self) - -class _ThreadLocalState(threading.local): - def __init__(self): - self.dynamic_axis_env = DynamicAxisEnv() - -_thread_local_state = _ThreadLocalState() def device_put(x, devices: Sequence[xc.ArrayImpl], replicate: bool=False) -> List[xc.ArrayImpl]: diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index b13507b46..d1ffac7b3 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -18,8 +18,6 @@ from jax._src.interpreters.pxla import ( ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified, AvalDimSharding as AvalDimSharding, Chunked as Chunked, - DynamicAxisEnv as DynamicAxisEnv, - DynamicAxisEnvFrame as DynamicAxisEnvFrame, EmapInfo as EmapInfo, ExecuteReplicated as ExecuteReplicated, Index as Index, @@ -55,9 +53,7 @@ from jax._src.interpreters.pxla import ( _UNSPECIFIED as _UNSPECIFIED, _create_pmap_sharding_spec as _create_pmap_sharding_spec, _get_and_check_device_assignment as _get_and_check_device_assignment, - _get_sharding_specs as _get_sharding_specs, _is_unspecified as _is_unspecified, - _one_replica_buffer_indices as _one_replica_buffer_indices, _pmap_sharding_spec as _pmap_sharding_spec, are_op_shardings_equal as are_op_shardings_equal, array_mapping_to_axis_resources as array_mapping_to_axis_resources, @@ -67,7 +63,6 @@ from jax._src.interpreters.pxla import ( find_partitions as find_partitions, find_replicas as find_replicas, full_to_shard_p as full_to_shard_p, - get_global_aval as get_global_aval, get_local_aval as get_local_aval, get_num_partitions as get_num_partitions, global_aval_to_result_handler as global_aval_to_result_handler,