profiler_test.py fixes and add coverage to Cloud TPU CI

* Add deps to test requirements, including in new
  `collect-profile-requirements.txt` (to avoid adding tensorflow to
  `test-requirements.txt`).
* Use correct Python executable `ProfilerTest.test_remote_profiler`
  (`python` sometimes defaults to python2)
* Run computations for longer in `ProfilerTest.test_remote_profiler`,
  othewise `collect_profile` sometimes misses it.
This commit is contained in:
Skye Wanderman-Milne 2023-06-09 17:58:42 +00:00
parent 0ec03dbdce
commit 2ca151ef5b
4 changed files with 11 additions and 4 deletions

View File

@ -26,6 +26,7 @@ jobs:
- name: Install JAX test requirements
run: |
pip install -r build/test-requirements.txt
pip install -r build/collect-profile-requirements.txt
- name: Install JAX
run: |
pip uninstall -y jax jaxlib libtpu-nightly

View File

@ -0,0 +1,4 @@
tensorflow
tensorboard-plugin-profile
# Needed for the profile plugin to work without error
protobuf<=3.20

View File

@ -3,6 +3,7 @@ cloudpickle
colorama>=0.4.4
numpy>=1.21
pillow>=9.1.0
portpicker
pytest-xdist
wheel
rich

View File

@ -16,6 +16,7 @@ from functools import partial
import glob
import os
import shutil
import sys
import tempfile
import threading
import time
@ -220,22 +221,22 @@ class ProfilerTest(unittest.TestCase):
"tensorboard_profile_plugin")
def test_remote_profiler(self):
port = portpicker.pick_unused_port()
jax.profiler.start_server(port)
logdir = absltest.get_default_test_tmpdir()
# Remove any existing log files.
shutil.rmtree(logdir, ignore_errors=True)
def on_profile():
os.system(
f"python -m jax.collect_profile {port} 500 --log_dir {logdir} "
"--no_perfetto_link")
f"{sys.executable} -m jax.collect_profile {port} 500 "
f"--log_dir {logdir} --no_perfetto_link")
thread_profiler = threading.Thread(
target=on_profile, args=())
thread_profiler.start()
jax.profiler.start_server(port)
start_time = time.time()
y = jnp.zeros((5, 5))
while time.time() - start_time < 3:
while time.time() - start_time < 5:
y = jnp.dot(y, y)
jax.profiler.stop_server()
thread_profiler.join()