`profiler_test.py:ProfilerTest.test_remote_profiler` fails with the
protobuf upgrade. However, I was seeing mysterious hangs without this,
and in general I think we should be testing with up-to-date deps given
that we don't pin. I'm gonna continue working on getting the Cloud TPU
CI green.
* Add deps to test requirements, including in new
`collect-profile-requirements.txt` (to avoid adding tensorflow to
`test-requirements.txt`).
* Use correct Python executable `ProfilerTest.test_remote_profiler`
(`python` sometimes defaults to python2)
* Run computations for longer in `ProfilerTest.test_remote_profiler`,
othewise `collect_profile` sometimes misses it.
Instead, we skip tests that the PJRT C API doesn't support. We had
this tag for feature development so it was easy to broadly disable,
but now we don't expect to need to do that.
Also shortens the job names so the full name is visible from the
github UI (this was driving me crazy), and marks a new test that can't
be run on the PJRT C API yet.
Example run: https://github.com/google/jax/actions/runs/4019968334
We're seeing failures on v3-8 that don't appear on the current v4-8
testing. v3-8 also exposes 8 devices (vs. v4-8 exposes 4), and some
tests needs 8 devices to run.
I just added a v3-8 runner VM.
Also adds a missing pip install command (I only caught this with a
fresh runner since it only needs to be installed once).
This prevents spamming the test output with 100s of failures when something fundamental is broken.
Also updates some `python3` commands to use `python` for consistency.
This change also marks multiaccelerator test files in a way pytest can
understand (if pytest is installed).
By running single-device tests on a single TPU chip, running the test
suite goes from 1hr 45m to 35m (both timings are running slow tests).
I tried using bazel at first, which already supported parallel
execution across TPU cores, but somehow it still takes 2h 20m! I'm not
sure why it's so slow. It appears that bazel creates many new test
processes over time, vs. pytest reuses the number of processes
initially specified, and starting and stopping the TPU runtime takes a
few seconds so that may be adding up. It also appears that
single-process bazel is slower than single-process pytest, which I
haven't looked into yet.
This also includes some utilites for setting up the self-hosted
runner. Googlers, see go/jax-self-hosted-runners for more setup info.
The workflow is pretty basic currently. We can and should add more
functionality later, such as email notifications. I kept it simple
here for easier reviewing.
Testing:
- Sample workflow run in my fork: https://github.com/skye/jax/actions/runs/3333614180
- Sample PR attempt: (will add soon but I did verify validate_job.sh blocks pull_request workflows)