mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Update JAX official doc: point out that the device numbers are not in numerical order because of the underlying torus hardware topology.
PiperOrigin-RevId: 625401373
This commit is contained in:
parent
e83b0ce3f2
commit
1a650cdc00
@ -412,6 +412,8 @@
|
||||
"id": "uRLpOcmNj_Vt"
|
||||
},
|
||||
"source": [
|
||||
"The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n",
|
||||
"\n",
|
||||
"By writing `PositionalSharding(ndarray_of_devices)`, we fix the device order and the initial shape. Then we can reshape it:"
|
||||
]
|
||||
},
|
||||
|
@ -192,6 +192,8 @@ sharding
|
||||
|
||||
+++ {"id": "uRLpOcmNj_Vt"}
|
||||
|
||||
The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.
|
||||
|
||||
By writing `PositionalSharding(ndarray_of_devices)`, we fix the device order and the initial shape. Then we can reshape it:
|
||||
|
||||
```{code-cell}
|
||||
|
Loading…
x
Reference in New Issue
Block a user