From e15619ceabfe04a3b9e0d315011240bb3facaf55 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Fri, 11 Nov 2022 15:23:44 -0800 Subject: [PATCH] Convert string axis name into tuple of strings in Mesh constructor PiperOrigin-RevId: 487930412 --- CHANGELOG.md | 3 +++ jax/interpreters/pxla.py | 4 +++- tests/pjit_test.py | 4 ++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 885d2d2e7..2ed42b4d9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,9 @@ Remember to align the itemized text with the first line of an item within a list tridiagonal reductions are supported on CPU only. * Breaking Changes * Deleted the `jax_experimental_name_stack` config option. + * Convert a string `axis_names` arguments to the + {class}`jax.experimental.maps.Mesh` constructor into a singleton tuple + instead of unpacking the string into a sequence of character axis names. ## jaxlib 0.3.25 * Changes diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index d9fe25cc0..84755f11d 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -2312,9 +2312,11 @@ class Mesh(ContextDecorator): axis_names: Tuple[MeshAxisName, ...] def __init__(self, devices: Union[np.ndarray, Sequence[xc.Device]], - axis_names: Sequence[MeshAxisName]): + axis_names: Union[str, Sequence[MeshAxisName]]): if not isinstance(devices, np.ndarray): devices = np.array(devices) + if isinstance(axis_names, str): + axis_names = (axis_names,) assert devices.ndim == len(axis_names) # TODO: Make sure that devices are unique? At least with the quick and # dirty check that the array size is not larger than the number of diff --git a/tests/pjit_test.py b/tests/pjit_test.py index aecd92443..bb6d962cd 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3219,6 +3219,10 @@ class UtilTest(jtu.JaxTestCase): self.assertIsInstance(mesh.devices, np.ndarray) self.assertEqual(mesh.size, jax.device_count()) + def test_mesh_with_string_axis_names(self): + mesh = maps.Mesh(jax.devices(), 'dp') + self.assertTupleEqual(mesh.axis_names, ('dp',)) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())