Update JAX official doc: point out that the device numbers are not in numerical order because of the underlying torus hardware topology.

PiperOrigin-RevId: 625401373
This commit is contained in:
Yue Sheng 2024-04-16 11:36:18 -07:00 committed by jax authors
parent e83b0ce3f2
commit 1a650cdc00
2 changed files with 4 additions and 0 deletions

View File

@ -412,6 +412,8 @@
"id": "uRLpOcmNj_Vt"
},
"source": [
"The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n",
"\n",
"By writing `PositionalSharding(ndarray_of_devices)`, we fix the device order and the initial shape. Then we can reshape it:"
]
},

View File

@ -192,6 +192,8 @@ sharding
+++ {"id": "uRLpOcmNj_Vt"}
The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.
By writing `PositionalSharding(ndarray_of_devices)`, we fix the device order and the initial shape. Then we can reshape it:
```{code-cell}