Remove unused exports from jax.interpreters.pxla.

This change removes names exported from jax.interpreters.pxla for which I couldn't find any JAX-external references.

PiperOrigin-RevId: 554483707
This commit is contained in:
Peter Hawkins 2023-08-07 08:26:27 -07:00 committed by jax authors
parent 5fcd9265b1
commit a80d952680

View File

@ -13,70 +13,22 @@
# limitations under the License.
from jax._src.interpreters.pxla import (
AvalDimSharding as AvalDimSharding,
EmapInfo as EmapInfo,
ExecuteReplicated as ExecuteReplicated,
Index as Index,
InputsHandler as InputsHandler,
MapTrace as MapTrace,
MapTracer as MapTracer,
MeshAxisName as MeshAxisName,
MeshComputation as MeshComputation,
MeshDimAssignment as MeshDimAssignment,
MeshExecutable as MeshExecutable,
ParallelCallableInfo as ParallelCallableInfo,
PmapComputation as PmapComputation,
PmapExecutable as PmapExecutable,
PxlaResultHandler as PxlaResultHandler,
ReplicaInfo as ReplicaInfo,
ResultsHandler as ResultsHandler,
SPMDBatchTrace as SPMDBatchTrace,
ShardInfo as ShardInfo,
TileManual as TileManual,
TileVectorize as TileVectorize,
TilingMethod as TilingMethod,
UnloadedMeshExecutable as UnloadedMeshExecutable,
UnloadedPmapExecutable as UnloadedPmapExecutable,
WeakRefList as WeakRefList,
_get_and_check_device_assignment as _get_and_check_device_assignment,
array_types as array_types,
custom_resource_typing_rules as custom_resource_typing_rules,
find_replicas as find_replicas,
full_to_shard_p as full_to_shard_p,
global_aval_to_result_handler as global_aval_to_result_handler,
global_avals_to_results_handler as global_avals_to_results_handler,
global_result_handlers as global_result_handlers,
local_aval_to_result_handler as local_aval_to_result_handler,
local_avals_to_results_handler as local_avals_to_results_handler,
local_result_handlers as local_result_handlers,
lower_mesh_computation as lower_mesh_computation,
lower_parallel_callable as lower_parallel_callable,
lower_sharding_computation as lower_sharding_computation,
maybe_extend_axis_env as maybe_extend_axis_env,
mesh_sharding_specs as mesh_sharding_specs,
multi_host_supported_collectives as multi_host_supported_collectives,
parallel_callable as parallel_callable,
resource_typecheck as resource_typecheck,
shard_arg_handlers as shard_arg_handlers,
shard_args as shard_args,
shard_arg as shard_arg,
shard_aval as shard_aval,
shard_aval_handlers as shard_aval_handlers,
shard_to_full_p as shard_to_full_p,
spmd_primitive_batchers as spmd_primitive_batchers,
stage_parallel_callable as stage_parallel_callable,
tile_aval_nd as tile_aval_nd,
untile_aval_nd as untile_aval_nd,
vtile_by_mesh as vtile_by_mesh,
vtile_manual as vtile_manual,
wrap_name as wrap_name,
xb as xb,
xla_pmap as xla_pmap,
xla_pmap_impl as xla_pmap_impl,
xla_pmap_impl_lazy as xla_pmap_impl_lazy,
shard_args as shard_args,
xla_pmap_p as xla_pmap_p,
)
from jax._src.mesh import (
MeshAxisName as MeshAxisName,
thread_resources as thread_resources,
)
@ -88,8 +40,6 @@ from jax._src.op_shardings import (
from jax._src.sharding_impls import (
ArrayMapping as ArrayMapping,
ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified,
AUTO as AUTO,
UNSPECIFIED as _UNSPECIFIED,
array_mapping_to_axis_resources as array_mapping_to_axis_resources,
is_unspecified as _is_unspecified,
@ -98,12 +48,9 @@ from jax._src.sharding_impls import (
from jax._src.sharding_specs import (
Chunked as Chunked,
NoSharding as NoSharding,
OpShardingType as OpShardingType,
Replicated as Replicated,
ShardedAxis as ShardedAxis,
ShardingSpec as ShardingSpec,
Unstacked as Unstacked,
new_mesh_sharding_specs as new_mesh_sharding_specs,
sharding_spec_sharding_proto as sharding_spec_sharding_proto,
spec_to_indices as spec_to_indices,
)