Update the non-contiguous error message to not say GDA anymore

PiperOrigin-RevId: 501396344
This commit is contained in:
Yash Katariya 2023-01-11 15:34:35 -08:00 committed by jax authors
parent acd8dadc74
commit 68c43e6c99
2 changed files with 2 additions and 2 deletions

View File

@ -229,7 +229,7 @@ def create_device_mesh(
jax.devices().
contiguous_submeshes: if True, this function will attempt to create a mesh
where each process's local devices form a contiguous submesh. This is
required when passing non-GlobalDeviceArrays to `pjit` (see the
required when passing host local inputs to `pjit` (see the
"Multi-process platforms" note of the [pjit
documentation](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html)
for more information on this constraint). A ValueError will be raised if

View File

@ -2446,7 +2446,7 @@ class Mesh(ContextDecorator):
# subcube that hull will contain non-local devices.
if not is_local_device[subcube_indices].all():
raise ValueError(
"When passing non-GlobalDeviceArray inputs to pjit or xmap, devices "
"When passing host local inputs to pjit or xmap, devices "
"connected to a single host must form a contiguous subcube of the "
"global device mesh")
return Mesh(self.devices[subcube_indices], self.axis_names)