mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add a profiler test for gpu run
PiperOrigin-RevId: 732247572
This commit is contained in:
parent
70024d2201
commit
48a55a6d71
@ -838,7 +838,13 @@ jax_multiplatform_test(
|
||||
jax_multiplatform_test(
|
||||
name = "profiler_test",
|
||||
srcs = ["profiler_test.py"],
|
||||
enable_backends = ["cpu"],
|
||||
enable_backends = [
|
||||
"cpu",
|
||||
"gpu",
|
||||
],
|
||||
deps = [
|
||||
"//jax:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
jax_multiplatform_test(
|
||||
|
@ -29,6 +29,8 @@ import jax.numpy as jnp
|
||||
import jax.profiler
|
||||
import jax._src.test_util as jtu
|
||||
from jax._src import profiler
|
||||
from jax import jit
|
||||
|
||||
|
||||
try:
|
||||
import portpicker
|
||||
@ -169,6 +171,28 @@ class ProfilerTest(unittest.TestCase):
|
||||
if jtu.test_device_matches(["tpu"]):
|
||||
self.assertIn(b"/device:TPU", proto)
|
||||
|
||||
@jtu.run_on_devices("gpu")
|
||||
@jtu.thread_unsafe_test()
|
||||
def testProgrammaticGpuCuptiTracing(self):
|
||||
@jit
|
||||
def xy_plus_z(x, y, z):
|
||||
return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z
|
||||
k = jax.random.key(0)
|
||||
s = 1, 16, 16
|
||||
jax.devices()
|
||||
x = jnp.int8(jax.random.normal(k, shape=s))
|
||||
y = jnp.bfloat16(jax.random.normal(k, shape=s))
|
||||
z = jnp.float32(jax.random.normal(k, shape=s))
|
||||
with tempfile.TemporaryDirectory() as tmpdir_string:
|
||||
tmpdir = pathlib.Path(tmpdir_string)
|
||||
with jax.profiler.trace(tmpdir):
|
||||
print(xy_plus_z(x, y, z))
|
||||
|
||||
proto_path = tuple(tmpdir.rglob("*.xplane.pb"))
|
||||
proto_bytes = proto_path[0].read_bytes()
|
||||
if jtu.test_device_matches(["gpu"]):
|
||||
self.assertIn(b"/device:GPU", proto_bytes)
|
||||
|
||||
def testProgrammaticProfilingContextManagerPathlib(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir_string:
|
||||
tmpdir = pathlib.Path(tmpdir_string)
|
||||
|
Loading…
x
Reference in New Issue
Block a user