mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
tweak readme pmap imports (#2276)
This commit is contained in:
parent
218a1711d2
commit
8372a70079
@ -272,7 +272,8 @@ replicated and executed in parallel accross devices.
|
||||
Here's an example on an 8-GPU machine:
|
||||
|
||||
```python
|
||||
from jax import random
|
||||
from jax import random, pmap
|
||||
import jax.numpy as np
|
||||
|
||||
# Create 8 random 5000 x 6000 matrices, one per GPU
|
||||
keys = random.split(random.PRNGKey(0), 8)
|
||||
|
Loading…
x
Reference in New Issue
Block a user