mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
Convert string axis name into tuple of strings in Mesh constructor
PiperOrigin-RevId: 487930412
This commit is contained in:
parent
6897d37562
commit
e15619ceab
@ -17,6 +17,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
tridiagonal reductions are supported on CPU only.
|
||||
* Breaking Changes
|
||||
* Deleted the `jax_experimental_name_stack` config option.
|
||||
* Convert a string `axis_names` arguments to the
|
||||
{class}`jax.experimental.maps.Mesh` constructor into a singleton tuple
|
||||
instead of unpacking the string into a sequence of character axis names.
|
||||
|
||||
## jaxlib 0.3.25
|
||||
* Changes
|
||||
|
@ -2312,9 +2312,11 @@ class Mesh(ContextDecorator):
|
||||
axis_names: Tuple[MeshAxisName, ...]
|
||||
|
||||
def __init__(self, devices: Union[np.ndarray, Sequence[xc.Device]],
|
||||
axis_names: Sequence[MeshAxisName]):
|
||||
axis_names: Union[str, Sequence[MeshAxisName]]):
|
||||
if not isinstance(devices, np.ndarray):
|
||||
devices = np.array(devices)
|
||||
if isinstance(axis_names, str):
|
||||
axis_names = (axis_names,)
|
||||
assert devices.ndim == len(axis_names)
|
||||
# TODO: Make sure that devices are unique? At least with the quick and
|
||||
# dirty check that the array size is not larger than the number of
|
||||
|
@ -3219,6 +3219,10 @@ class UtilTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(mesh.devices, np.ndarray)
|
||||
self.assertEqual(mesh.size, jax.device_count())
|
||||
|
||||
def test_mesh_with_string_axis_names(self):
|
||||
mesh = maps.Mesh(jax.devices(), 'dp')
|
||||
self.assertTupleEqual(mesh.axis_names, ('dp',))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user