mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove axis_name from pmap_unmapped_aval_handlers
PiperOrigin-RevId: 718859837
This commit is contained in:
parent
e8d40ff1a7
commit
33aa088a5c
@ -896,9 +896,8 @@ def lower_parallel_callable(
|
||||
shape_poly_state=lowering_result.shape_poly_state)
|
||||
|
||||
|
||||
def _pmap_unmap_shaped_array(
|
||||
size: int, axis_name: core.AxisName, axis: int | None, aval: ShapedArray
|
||||
) -> ShapedArray:
|
||||
def _pmap_unmap_shaped_array(size: int, axis: int | None, aval: ShapedArray
|
||||
) -> ShapedArray:
|
||||
if axis is None: return aval
|
||||
elif type(axis) is int:
|
||||
return ShapedArray(tuple_update(aval.shape, axis, size), aval.dtype,
|
||||
@ -911,14 +910,14 @@ _pmap_aval_mapping_handlers: dict[type, AvalMapHandlerPair] = {
|
||||
ShapedArray: (Any, _pmap_unmap_shaped_array),
|
||||
}
|
||||
|
||||
def _pmap_unmapped_aval(size: core.AxisSize, axis_name, axis: int | None,
|
||||
def _pmap_unmapped_aval(size: core.AxisSize, axis: int | None,
|
||||
aval: core.AbstractValue) -> core.AbstractValue:
|
||||
if not config.pmap_no_rank_reduction.value:
|
||||
return core.unmapped_aval(size, axis, aval)
|
||||
|
||||
_, handler = _pmap_aval_mapping_handlers.get(type(aval), (None, None))
|
||||
if handler is not None:
|
||||
return handler(size, axis_name, axis, aval)
|
||||
return handler(size, axis, aval)
|
||||
else:
|
||||
raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}")
|
||||
|
||||
@ -1086,7 +1085,7 @@ class UnloadedPmapExecutable:
|
||||
|
||||
local_unmapped_avals = [
|
||||
_cast_to_shaped_array(
|
||||
_pmap_unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval))
|
||||
_pmap_unmapped_aval(pci.axis_size, out_axis, aval))
|
||||
if out_axis is not None else aval
|
||||
for aval, out_axis in safe_zip(shards.out_sharded_avals, pci.out_axes)]
|
||||
out_specs = [
|
||||
|
Loading…
x
Reference in New Issue
Block a user