mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
aot_test: Stop forcing XLA to assume a certain number of devices.
Test cases are still frequently skipped due to lack of CompileOptions support, but the skip/run behavior does not seem to meaningfully change compared to a clean checkout. This was verified by inserting an exception in place of unittest.SkipTest. PiperOrigin-RevId: 529437419
This commit is contained in:
parent
68614b4dcc
commit
40d730be49
@ -210,7 +210,7 @@ jax_test(
|
||||
tags = ["multiaccelerator"],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
],
|
||||
] + py_deps("numpy"),
|
||||
)
|
||||
|
||||
jax_test(
|
||||
|
@ -11,10 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for GlobalDeviceArray."""
|
||||
"""Tests for AOT compilation."""
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import unittest
|
||||
from absl.testing import absltest
|
||||
import numpy as np
|
||||
@ -23,7 +22,6 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.config import flags
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental.serialize_executable import (
|
||||
@ -41,27 +39,6 @@ with contextlib.suppress(ImportError):
|
||||
pytestmark = pytest.mark.multiaccelerator
|
||||
|
||||
|
||||
# Run all tests with 8 CPU devices.
|
||||
def setUpModule():
|
||||
global prev_xla_flags
|
||||
prev_xla_flags = os.getenv("XLA_FLAGS")
|
||||
flags_str = prev_xla_flags or ""
|
||||
# Don't override user-specified device count, or other XLA flags.
|
||||
if "xla_force_host_platform_device_count" not in flags_str:
|
||||
os.environ["XLA_FLAGS"] = (flags_str +
|
||||
" --xla_force_host_platform_device_count=8")
|
||||
# Clear any cached backends so new CPU backend will pick up the env var.
|
||||
xb.get_backend.cache_clear()
|
||||
|
||||
# Reset to previous configuration in case other test modules will be run.
|
||||
def tearDownModule():
|
||||
if prev_xla_flags is None:
|
||||
del os.environ["XLA_FLAGS"]
|
||||
else:
|
||||
os.environ["XLA_FLAGS"] = prev_xla_flags
|
||||
xb.get_backend.cache_clear()
|
||||
|
||||
|
||||
class JaxAotTest(jtu.JaxTestCase):
|
||||
|
||||
def check_for_compile_options(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user