diff --git a/tests/BUILD b/tests/BUILD index 016eb195c..8648928af 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -838,7 +838,13 @@ jax_multiplatform_test( jax_multiplatform_test( name = "profiler_test", srcs = ["profiler_test.py"], - enable_backends = ["cpu"], + enable_backends = [ + "cpu", + "gpu", + ], + deps = [ + "//jax:profiler", + ], ) jax_multiplatform_test( diff --git a/tests/profiler_test.py b/tests/profiler_test.py index b686d30ad..215e363e4 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -29,6 +29,8 @@ import jax.numpy as jnp import jax.profiler import jax._src.test_util as jtu from jax._src import profiler +from jax import jit + try: import portpicker @@ -169,6 +171,28 @@ class ProfilerTest(unittest.TestCase): if jtu.test_device_matches(["tpu"]): self.assertIn(b"/device:TPU", proto) + @jtu.run_on_devices("gpu") + @jtu.thread_unsafe_test() + def testProgrammaticGpuCuptiTracing(self): + @jit + def xy_plus_z(x, y, z): + return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z + k = jax.random.key(0) + s = 1, 16, 16 + jax.devices() + x = jnp.int8(jax.random.normal(k, shape=s)) + y = jnp.bfloat16(jax.random.normal(k, shape=s)) + z = jnp.float32(jax.random.normal(k, shape=s)) + with tempfile.TemporaryDirectory() as tmpdir_string: + tmpdir = pathlib.Path(tmpdir_string) + with jax.profiler.trace(tmpdir): + print(xy_plus_z(x, y, z)) + + proto_path = tuple(tmpdir.rglob("*.xplane.pb")) + proto_bytes = proto_path[0].read_bytes() + if jtu.test_device_matches(["gpu"]): + self.assertIn(b"/device:GPU", proto_bytes) + def testProgrammaticProfilingContextManagerPathlib(self): with tempfile.TemporaryDirectory() as tmpdir_string: tmpdir = pathlib.Path(tmpdir_string)