mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Remove unused exports from jax.interpreters.pxla.
This change removes names exported from jax.interpreters.pxla for which I couldn't find any JAX-external references. PiperOrigin-RevId: 554483707
This commit is contained in:
parent
5fcd9265b1
commit
a80d952680
@ -13,70 +13,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
from jax._src.interpreters.pxla import (
|
||||
AvalDimSharding as AvalDimSharding,
|
||||
EmapInfo as EmapInfo,
|
||||
ExecuteReplicated as ExecuteReplicated,
|
||||
Index as Index,
|
||||
InputsHandler as InputsHandler,
|
||||
MapTrace as MapTrace,
|
||||
MapTracer as MapTracer,
|
||||
MeshAxisName as MeshAxisName,
|
||||
MeshComputation as MeshComputation,
|
||||
MeshDimAssignment as MeshDimAssignment,
|
||||
MeshExecutable as MeshExecutable,
|
||||
ParallelCallableInfo as ParallelCallableInfo,
|
||||
PmapComputation as PmapComputation,
|
||||
PmapExecutable as PmapExecutable,
|
||||
PxlaResultHandler as PxlaResultHandler,
|
||||
ReplicaInfo as ReplicaInfo,
|
||||
ResultsHandler as ResultsHandler,
|
||||
SPMDBatchTrace as SPMDBatchTrace,
|
||||
ShardInfo as ShardInfo,
|
||||
TileManual as TileManual,
|
||||
TileVectorize as TileVectorize,
|
||||
TilingMethod as TilingMethod,
|
||||
UnloadedMeshExecutable as UnloadedMeshExecutable,
|
||||
UnloadedPmapExecutable as UnloadedPmapExecutable,
|
||||
WeakRefList as WeakRefList,
|
||||
_get_and_check_device_assignment as _get_and_check_device_assignment,
|
||||
array_types as array_types,
|
||||
custom_resource_typing_rules as custom_resource_typing_rules,
|
||||
find_replicas as find_replicas,
|
||||
full_to_shard_p as full_to_shard_p,
|
||||
global_aval_to_result_handler as global_aval_to_result_handler,
|
||||
global_avals_to_results_handler as global_avals_to_results_handler,
|
||||
global_result_handlers as global_result_handlers,
|
||||
local_aval_to_result_handler as local_aval_to_result_handler,
|
||||
local_avals_to_results_handler as local_avals_to_results_handler,
|
||||
local_result_handlers as local_result_handlers,
|
||||
lower_mesh_computation as lower_mesh_computation,
|
||||
lower_parallel_callable as lower_parallel_callable,
|
||||
lower_sharding_computation as lower_sharding_computation,
|
||||
maybe_extend_axis_env as maybe_extend_axis_env,
|
||||
mesh_sharding_specs as mesh_sharding_specs,
|
||||
multi_host_supported_collectives as multi_host_supported_collectives,
|
||||
parallel_callable as parallel_callable,
|
||||
resource_typecheck as resource_typecheck,
|
||||
shard_arg_handlers as shard_arg_handlers,
|
||||
shard_args as shard_args,
|
||||
shard_arg as shard_arg,
|
||||
shard_aval as shard_aval,
|
||||
shard_aval_handlers as shard_aval_handlers,
|
||||
shard_to_full_p as shard_to_full_p,
|
||||
spmd_primitive_batchers as spmd_primitive_batchers,
|
||||
stage_parallel_callable as stage_parallel_callable,
|
||||
tile_aval_nd as tile_aval_nd,
|
||||
untile_aval_nd as untile_aval_nd,
|
||||
vtile_by_mesh as vtile_by_mesh,
|
||||
vtile_manual as vtile_manual,
|
||||
wrap_name as wrap_name,
|
||||
xb as xb,
|
||||
xla_pmap as xla_pmap,
|
||||
xla_pmap_impl as xla_pmap_impl,
|
||||
xla_pmap_impl_lazy as xla_pmap_impl_lazy,
|
||||
shard_args as shard_args,
|
||||
xla_pmap_p as xla_pmap_p,
|
||||
)
|
||||
from jax._src.mesh import (
|
||||
MeshAxisName as MeshAxisName,
|
||||
thread_resources as thread_resources,
|
||||
)
|
||||
|
||||
@ -88,8 +40,6 @@ from jax._src.op_shardings import (
|
||||
|
||||
from jax._src.sharding_impls import (
|
||||
ArrayMapping as ArrayMapping,
|
||||
ArrayMappingOrAutoOrUnspecified as ArrayMappingOrAutoOrUnspecified,
|
||||
AUTO as AUTO,
|
||||
UNSPECIFIED as _UNSPECIFIED,
|
||||
array_mapping_to_axis_resources as array_mapping_to_axis_resources,
|
||||
is_unspecified as _is_unspecified,
|
||||
@ -98,12 +48,9 @@ from jax._src.sharding_impls import (
|
||||
from jax._src.sharding_specs import (
|
||||
Chunked as Chunked,
|
||||
NoSharding as NoSharding,
|
||||
OpShardingType as OpShardingType,
|
||||
Replicated as Replicated,
|
||||
ShardedAxis as ShardedAxis,
|
||||
ShardingSpec as ShardingSpec,
|
||||
Unstacked as Unstacked,
|
||||
new_mesh_sharding_specs as new_mesh_sharding_specs,
|
||||
sharding_spec_sharding_proto as sharding_spec_sharding_proto,
|
||||
spec_to_indices as spec_to_indices,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user