Convert string axis name into tuple of strings in Mesh constructor

PiperOrigin-RevId: 487930412
This commit is contained in:
Sharad Vikram 2022-11-11 15:23:44 -08:00 committed by jax authors
parent 6897d37562
commit e15619ceab
3 changed files with 10 additions and 1 deletions

View File

@ -17,6 +17,9 @@ Remember to align the itemized text with the first line of an item within a list
tridiagonal reductions are supported on CPU only.
* Breaking Changes
* Deleted the `jax_experimental_name_stack` config option.
* Convert a string `axis_names` arguments to the
{class}`jax.experimental.maps.Mesh` constructor into a singleton tuple
instead of unpacking the string into a sequence of character axis names.
## jaxlib 0.3.25
* Changes

View File

@ -2312,9 +2312,11 @@ class Mesh(ContextDecorator):
axis_names: Tuple[MeshAxisName, ...]
def __init__(self, devices: Union[np.ndarray, Sequence[xc.Device]],
axis_names: Sequence[MeshAxisName]):
axis_names: Union[str, Sequence[MeshAxisName]]):
if not isinstance(devices, np.ndarray):
devices = np.array(devices)
if isinstance(axis_names, str):
axis_names = (axis_names,)
assert devices.ndim == len(axis_names)
# TODO: Make sure that devices are unique? At least with the quick and
# dirty check that the array size is not larger than the number of

View File

@ -3219,6 +3219,10 @@ class UtilTest(jtu.JaxTestCase):
self.assertIsInstance(mesh.devices, np.ndarray)
self.assertEqual(mesh.size, jax.device_count())
def test_mesh_with_string_axis_names(self):
mesh = maps.Mesh(jax.devices(), 'dp')
self.assertTupleEqual(mesh.axis_names, ('dp',))
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())