mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Delete pxla.make_sharded_device_array.
This function is unused and not exported from JAX. PiperOrigin-RevId: 549606907
This commit is contained in:
parent
1ceddfc98a
commit
fa5915b34d
@ -298,46 +298,6 @@ global_result_handlers: dict[type[core.AbstractValue], PxlaResultHandler] = {}
|
||||
|
||||
### lazy device-memory persistence and result handling
|
||||
|
||||
# TODO(yashkatariya, phawkins): Remove this function after March 15, 2023.
|
||||
def make_sharded_device_array(
|
||||
aval: ShapedArray,
|
||||
sharding_spec: Optional[ShardingSpec],
|
||||
# Any is for JAX extensions implementing their own buffer.
|
||||
device_buffers: list[Any],
|
||||
indices: Optional[tuple[Index, ...]] = None,
|
||||
):
|
||||
"""Returns a ShardedDeviceArray implementation based on arguments.
|
||||
|
||||
Returns either a C++ SDA or a Python DeviceArray when the buffers are not
|
||||
JAX buffers.
|
||||
|
||||
Args:
|
||||
aval: The `ShapedArray` for this array.
|
||||
sharding_spec: If `None`, assumes a pmap-style ShardedDeviceArrays over the
|
||||
first dimension.
|
||||
device_buffers: If a list of Jax `Buffer` objects, a C++ SDA will be
|
||||
returned (if the version is high enough). Otherwise, a Python object will
|
||||
be returned, for JAX extensions not implementing the C++ API.
|
||||
indices: For caching purposes, will be computed if `None`.
|
||||
"""
|
||||
if sharding_spec is None:
|
||||
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
|
||||
|
||||
mesh = mesh_lib.thread_resources.env.physical_mesh
|
||||
sharding: sharding_impls.XLACompatibleSharding
|
||||
if mesh.empty:
|
||||
sharding = sharding_impls.PmapSharding(
|
||||
np.asarray([d.device() for d in device_buffers]), sharding_spec)
|
||||
else:
|
||||
hlo_sharding = sharding_specs.sharding_spec_sharding_proto(sharding_spec)
|
||||
pspec = sharding_impls.parse_flatten_op_sharding(
|
||||
hlo_sharding, mesh)[0].get_partition_spec()
|
||||
sharding = sharding_impls.NamedSharding(mesh, pspec)
|
||||
|
||||
return jax.make_array_from_single_device_arrays(
|
||||
aval.shape, sharding, device_buffers) # type: ignore
|
||||
|
||||
|
||||
def _hashable_index(idx):
|
||||
return tree_map(lambda x: (x.start, x.stop) if type(x) == slice else x, idx)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user