mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Disable pgle_test.py for GPU plugin.
PiperOrigin-RevId: 573304221
This commit is contained in:
parent
16061e6302
commit
4cd4f3f3b3
@ -23,6 +23,7 @@ from absl.testing import absltest
|
||||
import jax
|
||||
from jax import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge
|
||||
from jax.sharding import NamedSharding
|
||||
from jax.experimental import profiler as exp_profiler
|
||||
import jax.numpy as jnp
|
||||
@ -62,7 +63,11 @@ class PgleTest(jtu.JaxTestCase):
|
||||
logging.info('rundir: %s', rundir)
|
||||
fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir)
|
||||
|
||||
if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda():
|
||||
if (jtu.test_device_matches(['gpu'])
|
||||
and jtu.is_device_cuda()
|
||||
# TODO(b/305270770): disabled for GPU plugin. Remove after it is fixed.
|
||||
and not xla_bridge.get_backend().platform_version.startswith(
|
||||
'PJRT C API\ncuda')):
|
||||
self.assertIn(b'custom', fdo_profile)
|
||||
|
||||
logging.info('fdo_profile: %s', fdo_profile)
|
||||
|
Loading…
x
Reference in New Issue
Block a user