mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #6066 from apaszke:xmap-no-mesh-slicing
PiperOrigin-RevId: 378209333
This commit is contained in:
commit
648b5d3265
@ -635,12 +635,11 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
used_mesh_axes = used_resources & resource_env.physical_resource_axes
|
||||
if used_mesh_axes:
|
||||
assert spmd_in_axes is None and spmd_out_axes_thunk is None # No outer xmaps, so should be None
|
||||
submesh = resource_env.physical_mesh[sorted(used_mesh_axes, key=str)]
|
||||
mesh_in_axes, mesh_out_axes = plan.to_mesh_axes(in_axes, out_axes)
|
||||
return pxla.mesh_callable(f,
|
||||
name,
|
||||
backend,
|
||||
submesh,
|
||||
resource_env.physical_mesh,
|
||||
mesh_in_axes,
|
||||
mesh_out_axes,
|
||||
donated_invars,
|
||||
|
@ -1345,15 +1345,6 @@ class Mesh:
|
||||
assert is_local_device[subcube_indices].all()
|
||||
return Mesh(self.devices[subcube_indices], self.axis_names)
|
||||
|
||||
def __getitem__(self, new_axes):
|
||||
axis_pos = {name: i for i, name in enumerate(self.axis_names)}
|
||||
new_devices = self.devices.transpose(tuple(axis_pos[axis] for axis in new_axes) +
|
||||
tuple(axis_pos[axis] for axis in self.axis_names
|
||||
if axis not in new_axes))
|
||||
new_devices = new_devices[(slice(None),) * len(new_axes) +
|
||||
(0,) * (len(self.axis_names) - len(new_axes))]
|
||||
return Mesh(new_devices, new_axes)
|
||||
|
||||
@property
|
||||
def device_ids(self):
|
||||
return np.vectorize(lambda d: d.id, otypes=[int])(self.devices)
|
||||
|
Loading…
x
Reference in New Issue
Block a user