Remove axis_name from pmap_unmapped_aval_handlers

PiperOrigin-RevId: 718859837
This commit is contained in:
Yash Katariya 2025-01-23 07:34:03 -08:00 committed by jax authors
parent e8d40ff1a7
commit 33aa088a5c

View File

@ -896,9 +896,8 @@ def lower_parallel_callable(
shape_poly_state=lowering_result.shape_poly_state)
def _pmap_unmap_shaped_array(
size: int, axis_name: core.AxisName, axis: int | None, aval: ShapedArray
) -> ShapedArray:
def _pmap_unmap_shaped_array(size: int, axis: int | None, aval: ShapedArray
) -> ShapedArray:
if axis is None: return aval
elif type(axis) is int:
return ShapedArray(tuple_update(aval.shape, axis, size), aval.dtype,
@ -911,14 +910,14 @@ _pmap_aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
ShapedArray: (Any, _pmap_unmap_shaped_array),
}
def _pmap_unmapped_aval(size: core.AxisSize, axis_name, axis: int | None,
def _pmap_unmapped_aval(size: core.AxisSize, axis: int | None,
aval: core.AbstractValue) -> core.AbstractValue:
if not config.pmap_no_rank_reduction.value:
return core.unmapped_aval(size, axis, aval)
_, handler = _pmap_aval_mapping_handlers.get(type(aval), (None, None))
if handler is not None:
return handler(size, axis_name, axis, aval)
return handler(size, axis, aval)
else:
raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}")
@ -1086,7 +1085,7 @@ class UnloadedPmapExecutable:
local_unmapped_avals = [
_cast_to_shaped_array(
_pmap_unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval))
_pmap_unmapped_aval(pci.axis_size, out_axis, aval))
if out_axis is not None else aval
for aval, out_axis in safe_zip(shards.out_sharded_avals, pci.out_axes)]
out_specs = [