1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

tweak readme pmap imports ()

This commit is contained in:
brett koonce 2020-02-20 18:44:21 -06:00 committed by GitHub
parent 218a1711d2
commit 8372a70079
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -272,7 +272,8 @@ replicated and executed in parallel accross devices.
Here's an example on an 8-GPU machine:
```python
from jax import random
from jax import random, pmap
import jax.numpy as np
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)