This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).
By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).
I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
This also includes some utilites for setting up the self-hosted
runner. Googlers, see go/jax-self-hosted-runners for more setup info.
The workflow is pretty basic currently. We can and should add more
functionality later, such as email notifications. I kept it simple
here for easier reviewing.
Testing:
- Sample workflow run in my fork: https://github.com/skye/jax/actions/runs/3333614180
- Sample PR attempt: (will add soon but I did verify validate_job.sh blocks pull_request workflows)