Disable pgle_test.py for GPU plugin.

PiperOrigin-RevId: 573304221
This commit is contained in:
Jieying Luo 2023-10-13 13:22:16 -07:00 committed by jax authors
parent 16061e6302
commit 4cd4f3f3b3

View File

@ -23,6 +23,7 @@ from absl.testing import absltest
import jax
from jax import config
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax.sharding import NamedSharding
from jax.experimental import profiler as exp_profiler
import jax.numpy as jnp
@ -62,7 +63,11 @@ class PgleTest(jtu.JaxTestCase):
logging.info('rundir: %s', rundir)
fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir)
if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda():
if (jtu.test_device_matches(['gpu'])
and jtu.is_device_cuda()
# TODO(b/305270770): disabled for GPU plugin. Remove after it is fixed.
and not xla_bridge.get_backend().platform_version.startswith(
'PJRT C API\ncuda')):
self.assertIn(b'custom', fdo_profile)
logging.info('fdo_profile: %s', fdo_profile)