mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Testing: avoid global fixture for doctests
This commit is contained in:
parent
438b56c483
commit
cee3af580b
2
.github/workflows/ci-build.yaml
vendored
2
.github/workflows/ci-build.yaml
vendored
@ -141,6 +141,8 @@ jobs:
|
||||
run: |
|
||||
pip install -r docs/requirements.txt
|
||||
- name: Test documentation
|
||||
env:
|
||||
XLA_FLAGS: "--xla_force_host_platform_device_count=8"
|
||||
run: |
|
||||
pytest -n 1 docs
|
||||
pytest -n 1 --doctest-modules --ignore=jax/experimental/jax2tf jax
|
||||
|
@ -15,7 +15,6 @@
|
||||
|
||||
import jax
|
||||
import numpy
|
||||
import os
|
||||
import pytest
|
||||
|
||||
|
||||
@ -25,10 +24,3 @@ def add_imports(doctest_namespace):
|
||||
doctest_namespace["lax"] = jax.lax
|
||||
doctest_namespace["jnp"] = jax.numpy
|
||||
doctest_namespace["np"] = numpy
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def spoof_devices(doctest_namespace):
|
||||
# Set up runtime to mimic an 8-core machine
|
||||
flags = os.environ.get('XLA_FLAGS', '')
|
||||
os.environ['XLA_FLAGS'] = flags + " --xla_force_host_platform_device_count=8"
|
||||
|
Loading…
x
Reference in New Issue
Block a user