Fix import in benchmarks

This works on my machine as 'python benchmarks/pmap_benchmark.py'. It also
follows the code in examples.

This will need a copybara rule to change the import to 'from jax.benchmarks import benchmark'
This commit is contained in:
George Necula 2020-03-31 10:36:47 +02:00
parent a4ceae1c00
commit fd52fbf411

View File

@ -25,7 +25,7 @@ from jax import numpy as np
from jax import pmap
from jax.config import config
import benchmark
from benchmarks import benchmark
import numpy as onp