Add a profiler test for gpu run

PiperOrigin-RevId: 732247572
This commit is contained in:
jax authors 2025-02-28 13:45:11 -08:00
parent 70024d2201
commit 48a55a6d71
2 changed files with 31 additions and 1 deletions

View File

@ -838,7 +838,13 @@ jax_multiplatform_test(
jax_multiplatform_test(
name = "profiler_test",
srcs = ["profiler_test.py"],
enable_backends = ["cpu"],
enable_backends = [
"cpu",
"gpu",
],
deps = [
"//jax:profiler",
],
)
jax_multiplatform_test(

View File

@ -29,6 +29,8 @@ import jax.numpy as jnp
import jax.profiler
import jax._src.test_util as jtu
from jax._src import profiler
from jax import jit
try:
import portpicker
@ -169,6 +171,28 @@ class ProfilerTest(unittest.TestCase):
if jtu.test_device_matches(["tpu"]):
self.assertIn(b"/device:TPU", proto)
@jtu.run_on_devices("gpu")
@jtu.thread_unsafe_test()
def testProgrammaticGpuCuptiTracing(self):
@jit
def xy_plus_z(x, y, z):
return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z
k = jax.random.key(0)
s = 1, 16, 16
jax.devices()
x = jnp.int8(jax.random.normal(k, shape=s))
y = jnp.bfloat16(jax.random.normal(k, shape=s))
z = jnp.float32(jax.random.normal(k, shape=s))
with tempfile.TemporaryDirectory() as tmpdir_string:
tmpdir = pathlib.Path(tmpdir_string)
with jax.profiler.trace(tmpdir):
print(xy_plus_z(x, y, z))
proto_path = tuple(tmpdir.rglob("*.xplane.pb"))
proto_bytes = proto_path[0].read_bytes()
if jtu.test_device_matches(["gpu"]):
self.assertIn(b"/device:GPU", proto_bytes)
def testProgrammaticProfilingContextManagerPathlib(self):
with tempfile.TemporaryDirectory() as tmpdir_string:
tmpdir = pathlib.Path(tmpdir_string)