diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 49d95cb6f..c5fb921f6 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -896,9 +896,8 @@ def lower_parallel_callable( shape_poly_state=lowering_result.shape_poly_state) -def _pmap_unmap_shaped_array( - size: int, axis_name: core.AxisName, axis: int | None, aval: ShapedArray - ) -> ShapedArray: +def _pmap_unmap_shaped_array(size: int, axis: int | None, aval: ShapedArray + ) -> ShapedArray: if axis is None: return aval elif type(axis) is int: return ShapedArray(tuple_update(aval.shape, axis, size), aval.dtype, @@ -911,14 +910,14 @@ _pmap_aval_mapping_handlers: dict[type, AvalMapHandlerPair] = { ShapedArray: (Any, _pmap_unmap_shaped_array), } -def _pmap_unmapped_aval(size: core.AxisSize, axis_name, axis: int | None, +def _pmap_unmapped_aval(size: core.AxisSize, axis: int | None, aval: core.AbstractValue) -> core.AbstractValue: if not config.pmap_no_rank_reduction.value: return core.unmapped_aval(size, axis, aval) _, handler = _pmap_aval_mapping_handlers.get(type(aval), (None, None)) if handler is not None: - return handler(size, axis_name, axis, aval) + return handler(size, axis, aval) else: raise TypeError(f"no unmapping handler for {aval} of type {type(aval)}") @@ -1086,7 +1085,7 @@ class UnloadedPmapExecutable: local_unmapped_avals = [ _cast_to_shaped_array( - _pmap_unmapped_aval(pci.axis_size, pci.axis_name, out_axis, aval)) + _pmap_unmapped_aval(pci.axis_size, out_axis, aval)) if out_axis is not None else aval for aval, out_axis in safe_zip(shards.out_sharded_avals, pci.out_axes)] out_specs = [