From fa730771467148cfba41aa67c8bde562012adbab Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 25 Jun 2024 09:02:32 -0700 Subject: [PATCH] jax.config: validate on set() --- jax/_src/config.py | 2 ++ tests/BUILD | 5 +++ tests/config_test.py | 73 +++++++++++++++++++++++++++++++++++++++++++ tests/logging_test.py | 4 +-- 4 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 tests/config_test.py diff --git a/jax/_src/config.py b/jax/_src/config.py index 8eadb2609..1d2d8fa44 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -276,6 +276,8 @@ class State(Generic[_T]): type(self).__name__)) def _set(self, value: _T) -> None: + if self._validator: + self._validator(value) self._value = value if self._update_global_hook: self._update_global_hook(value) diff --git a/tests/BUILD b/tests/BUILD index c420a3ba9..6093ffa35 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -75,6 +75,11 @@ jax_test( }, ) +jax_test( + name = "config_test", + srcs = ["config_test.py"], +) + jax_test( name = "core_test", srcs = ["core_test.py"], diff --git a/tests/config_test.py b/tests/config_test.py new file mode 100644 index 000000000..0f49d988a --- /dev/null +++ b/tests/config_test.py @@ -0,0 +1,73 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax +from jax._src import test_util as jtu +from jax._src import config + +jax.config.parse_flags_with_absl() + + +jax_test_enum_config = config.enum_state( + name='jax_test_enum_config', + enum_values=['default', 'xxx', 'yyy'], + default='default', + help='Configuration only used for tests.') + + +class ConfigTest(jtu.JaxTestCase): + def test_config_setting_via_update(self): + self.assertEqual(jax_test_enum_config.value, 'default') + + jax.config.update('jax_test_enum_config', 'xxx') + self.assertEqual(jax_test_enum_config.value, 'xxx') + + jax.config.update('jax_test_enum_config', 'yyy') + self.assertEqual(jax_test_enum_config.value, 'yyy') + + jax.config.update('jax_test_enum_config', 'default') + self.assertEqual(jax_test_enum_config.value, 'default') + + def test_config_setting_via_context(self): + self.assertEqual(jax_test_enum_config.value, 'default') + + with jax_test_enum_config('xxx'): + self.assertEqual(jax_test_enum_config.value, 'xxx') + + with jax_test_enum_config('yyy'): + self.assertEqual(jax_test_enum_config.value, 'yyy') + + self.assertEqual(jax_test_enum_config.value, 'xxx') + + self.assertEqual(jax_test_enum_config.value, 'default') + + def test_config_update_validation(self): + self.assertEqual(jax_test_enum_config.value, 'default') + with self.assertRaisesRegex(ValueError, 'new enum value must be in.*'): + jax.config.update('jax_test_enum_config', 'invalid') + # Error should raise before changing the value + self.assertEqual(jax_test_enum_config.value, 'default') + + def test_config_context_validation(self): + self.assertEqual(jax_test_enum_config.value, 'default') + with self.assertRaisesRegex(ValueError, 'new enum value must be in.*'): + with jax_test_enum_config('invalid'): + pass + self.assertEqual(jax_test_enum_config.value, 'default') + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/logging_test.py b/tests/logging_test.py index 29eb5ce35..5a495d47d 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -124,7 +124,7 @@ class LoggingTest(jtu.JaxTestCase): self.assertIn("Compiling ", log_output.getvalue()) # Turn off all debug logging. - with jax_debug_log_modules(None): + with jax_debug_log_modules(""): with capture_jax_logs() as log_output: jax.jit(lambda x: x + 1)(1) self.assertEmpty(log_output.getvalue()) @@ -137,7 +137,7 @@ class LoggingTest(jtu.JaxTestCase): self.assertNotIn("Compiling ", log_output.getvalue()) # Turn everything off again. - with jax_debug_log_modules(None): + with jax_debug_log_modules(""): with capture_jax_logs() as log_output: jax.jit(lambda x: x + 1)(1) self.assertEmpty(log_output.getvalue())