mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

This works on my machine as 'python benchmarks/pmap_benchmark.py'. It also follows the code in examples. This will need a copybara rule to change the import to 'from jax.benchmarks import benchmark'