Delete pxla.make_sharded_device_array.

This function is unused and not exported from JAX.

PiperOrigin-RevId: 549606907
This commit is contained in:
Peter Hawkins 2023-07-20 05:55:23 -07:00 committed by jax authors
parent 1ceddfc98a
commit fa5915b34d

View File

@ -298,46 +298,6 @@ global_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
### lazy device-memory persistence and result handling
# TODO(yashkatariya, phawkins): Remove this function after March 15, 2023.
def make_sharded_device_array(
aval: ShapedArray,
sharding_spec: Optional[ShardingSpec],
# Any is for JAX extensions implementing their own buffer.
device_buffers: list[Any],
indices: Optional[tuple[Index, ...]] = None,
):
"""Returns a ShardedDeviceArray implementation based on arguments.
Returns either a C++ SDA or a Python DeviceArray when the buffers are not
JAX buffers.
Args:
aval: The `ShapedArray` for this array.
sharding_spec: If `None`, assumes a pmap-style ShardedDeviceArrays over the
first dimension.
device_buffers: If a list of Jax `Buffer` objects, a C++ SDA will be
returned (if the version is high enough). Otherwise, a Python object will
be returned, for JAX extensions not implementing the C++ API.
indices: For caching purposes, will be computed if `None`.
"""
if sharding_spec is None:
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
mesh = mesh_lib.thread_resources.env.physical_mesh
sharding: sharding_impls.XLACompatibleSharding
if mesh.empty:
sharding = sharding_impls.PmapSharding(
np.asarray([d.device() for d in device_buffers]), sharding_spec)
else:
hlo_sharding = sharding_specs.sharding_spec_sharding_proto(sharding_spec)
pspec = sharding_impls.parse_flatten_op_sharding(
hlo_sharding, mesh)[0].get_partition_spec()
sharding = sharding_impls.NamedSharding(mesh, pspec)
return jax.make_array_from_single_device_arrays(
aval.shape, sharding, device_buffers) # type: ignore
def _hashable_index(idx):
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx)