Junwhan Ahn 5046cedbfc Make pxla.shard_arg batch calls to xc.copy_array_to_devices_with_sharding
This CL changes `shard_arg_handlers` to be batched, in that it now receives a list of objects and a list of shardings and returns a list of array. This makes it possible to batch backend calls whenever it's beneficial to do so.

Based on the above, the batched shard arg for arrays leverages the newly added `xla::ifrt::Client::CopyArrays()` (https://github.com/tensorflow/tensorflow/pull/69096) to make bulk copy cheaper in some backend implementations. Since `Client::CopyArrays()` requires batched arrays to have the same set of source/destination devices, `PyArray::BatchedCopyToDeviceWithSharding()` internally groups arrays by their source/destination devices and memory kinds. The grouping is pushed all the way to C++ for performance in case we have lots of arrays.

PiperOrigin-RevId: 643097852
2024-06-13 13:10:10 -07:00

55 lines
1.8 KiB
Python

# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from jax._src.interpreters.pxla import (
Index as Index,
MapTracer as MapTracer,
MeshAxisName as MeshAxisName,
MeshComputation as MeshComputation,
MeshExecutable as MeshExecutable,
PmapExecutable as PmapExecutable,
global_aval_to_result_handler as global_aval_to_result_handler,
global_avals_to_results_handler as global_avals_to_results_handler,
global_result_handlers as global_result_handlers,
parallel_callable as parallel_callable,
shard_args as shard_args,
xla_pmap_p as xla_pmap_p,
)
from jax._src.mesh import (
thread_resources as thread_resources,
)
from jax._src.op_shardings import (
are_op_shardings_equal as are_op_shardings_equal,
is_op_sharding_replicated as is_op_sharding_replicated,
op_sharding_to_indices as op_sharding_to_indices,
)
from jax._src.sharding_impls import (
ArrayMapping as ArrayMapping,
UNSPECIFIED as _UNSPECIFIED,
array_mapping_to_axis_resources as array_mapping_to_axis_resources,
is_unspecified as _is_unspecified,
)
from jax._src.sharding_specs import (
Chunked as Chunked,
NoSharding as NoSharding,
Replicated as Replicated,
ShardedAxis as ShardedAxis,
ShardingSpec as ShardingSpec,
Unstacked as Unstacked,
spec_to_indices as spec_to_indices,
)