diff --git a/tests/source_mapper_test.py b/tests/source_mapper_test.py index 4df62f744..ffb1d6252 100644 --- a/tests/source_mapper_test.py +++ b/tests/source_mapper_test.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys + from absl.testing import absltest from absl.testing import parameterized from jax import numpy as jnp @@ -20,6 +22,10 @@ from jax.experimental import source_mapper class SourceMapperTest(jtu.JaxTestCase): + def setUp(self): + if sys.platform == "win32": + self.skipTest("Only works on non-Windows platforms") + def test_jaxpr_pass(self): def jax_fn(x, y): return x + y