mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
profiler_test.py fixes and add coverage to Cloud TPU CI
* Add deps to test requirements, including in new `collect-profile-requirements.txt` (to avoid adding tensorflow to `test-requirements.txt`). * Use correct Python executable `ProfilerTest.test_remote_profiler` (`python` sometimes defaults to python2) * Run computations for longer in `ProfilerTest.test_remote_profiler`, othewise `collect_profile` sometimes misses it.
This commit is contained in:
parent
0ec03dbdce
commit
2ca151ef5b
1
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
1
.github/workflows/cloud-tpu-ci-nightly.yml
vendored
@ -26,6 +26,7 @@ jobs:
|
||||
- name: Install JAX test requirements
|
||||
run: |
|
||||
pip install -r build/test-requirements.txt
|
||||
pip install -r build/collect-profile-requirements.txt
|
||||
- name: Install JAX
|
||||
run: |
|
||||
pip uninstall -y jax jaxlib libtpu-nightly
|
||||
|
4
build/collect-profile-requirements.txt
Normal file
4
build/collect-profile-requirements.txt
Normal file
@ -0,0 +1,4 @@
|
||||
tensorflow
|
||||
tensorboard-plugin-profile
|
||||
# Needed for the profile plugin to work without error
|
||||
protobuf<=3.20
|
@ -3,6 +3,7 @@ cloudpickle
|
||||
colorama>=0.4.4
|
||||
numpy>=1.21
|
||||
pillow>=9.1.0
|
||||
portpicker
|
||||
pytest-xdist
|
||||
wheel
|
||||
rich
|
||||
|
@ -16,6 +16,7 @@ from functools import partial
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
@ -220,22 +221,22 @@ class ProfilerTest(unittest.TestCase):
|
||||
"tensorboard_profile_plugin")
|
||||
def test_remote_profiler(self):
|
||||
port = portpicker.pick_unused_port()
|
||||
jax.profiler.start_server(port)
|
||||
|
||||
logdir = absltest.get_default_test_tmpdir()
|
||||
# Remove any existing log files.
|
||||
shutil.rmtree(logdir, ignore_errors=True)
|
||||
def on_profile():
|
||||
os.system(
|
||||
f"python -m jax.collect_profile {port} 500 --log_dir {logdir} "
|
||||
"--no_perfetto_link")
|
||||
f"{sys.executable} -m jax.collect_profile {port} 500 "
|
||||
f"--log_dir {logdir} --no_perfetto_link")
|
||||
|
||||
thread_profiler = threading.Thread(
|
||||
target=on_profile, args=())
|
||||
thread_profiler.start()
|
||||
jax.profiler.start_server(port)
|
||||
start_time = time.time()
|
||||
y = jnp.zeros((5, 5))
|
||||
while time.time() - start_time < 3:
|
||||
while time.time() - start_time < 5:
|
||||
y = jnp.dot(y, y)
|
||||
jax.profiler.stop_server()
|
||||
thread_profiler.join()
|
||||
|
Loading…
x
Reference in New Issue
Block a user