Merge pull request #6066 from apaszke:xmap-no-mesh-slicing

PiperOrigin-RevId: 378209333
This commit is contained in:
jax authors 2021-06-08 11:54:05 -07:00
commit 648b5d3265
2 changed files with 1 additions and 11 deletions

View File

@ -635,12 +635,11 @@ def make_xmap_callable(fun: lu.WrappedFun,
used_mesh_axes = used_resources & resource_env.physical_resource_axes
if used_mesh_axes:
assert spmd_in_axes is None and spmd_out_axes_thunk is None # No outer xmaps, so should be None
submesh = resource_env.physical_mesh[sorted(used_mesh_axes, key=str)]
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
return pxla.mesh_callable(f,
name,
backend,
submesh,
resource_env.physical_mesh,
mesh_in_axes,
mesh_out_axes,
donated_invars,

View File

@ -1345,15 +1345,6 @@ class Mesh:
assert is_local_device[subcube_indices].all()
return Mesh(self.devices[subcube_indices], self.axis_names)
def __getitem__(self, new_axes):
axis_pos = {name: i for i, name in enumerate(self.axis_names)}
new_devices = self.devices.transpose(tuple(axis_pos[axis] for axis in new_axes) +
tuple(axis_pos[axis] for axis in self.axis_names
if axis not in new_axes))
new_devices = new_devices[(slice(None),) * len(new_axes) +
(0,) * (len(self.axis_names) - len(new_axes))]
return Mesh(new_devices, new_axes)
@property
def device_ids(self):
return np.vectorize(lambda d: d.id, otypes=[int])(self.devices)