aot_test: Stop forcing XLA to assume a certain number of devices.

Test cases are still frequently skipped due to lack of CompileOptions
support, but the skip/run behavior does not seem to meaningfully change
compared to a clean checkout. This was verified by inserting an exception
in place of unittest.SkipTest.

PiperOrigin-RevId: 529437419
This commit is contained in:
pizzud 2023-05-04 09:52:50 -07:00 committed by jax authors
parent 68614b4dcc
commit 40d730be49
2 changed files with 2 additions and 25 deletions

View File

@ -210,7 +210,7 @@ jax_test(
tags = ["multiaccelerator"],
deps = [
"//jax:experimental",
],
] + py_deps("numpy"),
)
jax_test(

View File

@ -11,10 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for GlobalDeviceArray."""
"""Tests for AOT compilation."""
import contextlib
import os
import unittest
from absl.testing import absltest
import numpy as np
@ -23,7 +22,6 @@ import jax
import jax.numpy as jnp
from jax._src import core
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.config import flags
from jax.experimental.pjit import pjit
from jax.experimental.serialize_executable import (
@ -41,27 +39,6 @@ with contextlib.suppress(ImportError):
pytestmark = pytest.mark.multiaccelerator
# Run all tests with 8 CPU devices.
def setUpModule():
global prev_xla_flags
prev_xla_flags = os.getenv("XLA_FLAGS")
flags_str = prev_xla_flags or ""
# Don't override user-specified device count, or other XLA flags.
if "xla_force_host_platform_device_count" not in flags_str:
os.environ["XLA_FLAGS"] = (flags_str +
" --xla_force_host_platform_device_count=8")
# Clear any cached backends so new CPU backend will pick up the env var.
xb.get_backend.cache_clear()
# Reset to previous configuration in case other test modules will be run.
def tearDownModule():
if prev_xla_flags is None:
del os.environ["XLA_FLAGS"]
else:
os.environ["XLA_FLAGS"] = prev_xla_flags
xb.get_backend.cache_clear()
class JaxAotTest(jtu.JaxTestCase):
def check_for_compile_options(self):