diff --git a/examples/mnist_classifier.py b/examples/mnist_classifier.py index 0eac70541..6336ac675 100644 --- a/examples/mnist_classifier.py +++ b/examples/mnist_classifier.py @@ -26,6 +26,7 @@ import itertools import numpy.random as npr import jax.numpy as np +from jax.config import config from jax import jit, grad from jax.experimental import minmax from jax.experimental import stax @@ -80,6 +81,7 @@ if __name__ == "__main__": opt_state = opt_init(init_params) itercount = itertools.count() + print("\nStarting training...") for epoch in range(num_epochs): start_time = time.time() for _ in range(num_batches): diff --git a/examples/mnist_classifier_fromscratch.py b/examples/mnist_classifier_fromscratch.py index 1f675a68e..b0dfbb5cd 100644 --- a/examples/mnist_classifier_fromscratch.py +++ b/examples/mnist_classifier_fromscratch.py @@ -26,6 +26,7 @@ import time import numpy.random as npr from jax.api import jit, grad +from jax.config import config from jax.scipy.misc import logsumexp import jax.numpy as np import datasets diff --git a/examples/mnist_vae.py b/examples/mnist_vae.py index da7b4c717..aa2fd1725 100644 --- a/examples/mnist_vae.py +++ b/examples/mnist_vae.py @@ -28,6 +28,7 @@ import time import matplotlib.pyplot as plt import jax.numpy as np +from jax.config import config from jax import jit, grad, lax, random from jax.experimental import minmax from jax.experimental import stax diff --git a/examples/resnet50.py b/examples/resnet50.py index f384c2f30..2f92272e0 100644 --- a/examples/resnet50.py +++ b/examples/resnet50.py @@ -24,6 +24,7 @@ from __future__ import print_function import numpy.random as npr import jax.numpy as np +from jax.config import config from jax import jit, grad from jax.experimental import minmax from jax.experimental import stax diff --git a/jax/config.py b/jax/config.py new file mode 100644 index 000000000..e92dfb0eb --- /dev/null +++ b/jax/config.py @@ -0,0 +1,88 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Config(object): + def __init__(self): + self.values = {} + self.meta = {} + self.FLAGS = NameSpace(self.read) + self.use_absl = False + + def update(self, name, val): + self.check_exists(name) + if name not in self.values: + raise Exception("Unrecognized config option: {}".format(name)) + self.values[name] = val + + def read(self, name): + if self.use_absl: + return getattr(self.absl_flags.FLAGS, name) + else: + self.check_exists(name) + return self.values[name] + + def add_option(self, name, default, opt_type, meta_args, meta_kwargs): + if name in self.values: + raise Exception("Config option {} already defined".format(name)) + self.values[name] = default + self.meta[name] = (opt_type, meta_args, meta_kwargs) + + def check_exists(self, name): + if name not in self.values: + raise Exception("Unrecognized config option: {}".format(name)) + + def DEFINE_bool(self, name, default, *args, **kwargs): + self.add_option(name, default, bool, args, kwargs) + + def DEFINE_integer(self, name, default, *args, **kwargs): + self.add_option(name, default, int, args, kwargs) + + def DEFINE_string(self, name, default, *args, **kwargs): + self.add_option(name, default, str, args, kwargs) + + def DEFINE_enum(self, name, default, *args, **kwargs): + self.add_option(name, default, 'enum', args, kwargs) + + def config_with_absl(self): + # Run this before calling `app.run(main)` etc + import absl.flags as absl_FLAGS + from absl import app, flags as absl_flags + + self.use_absl = True + self.absl_flags = absl_flags + absl_defs = { bool: absl_flags.DEFINE_bool, + int: absl_flags.DEFINE_integer, + str: absl_flags.DEFINE_string, + 'enum': absl_flags.DEFINE_enum } + + for name, val in self.values.items(): + flag_type, meta_args, meta_kwargs = self.meta[name] + absl_defs[flag_type](name, val, *meta_args, **meta_kwargs) + + def complete_absl_config(self, absl_flags): + for name, _ in self.values.items(): + self.update(name, getattr(absl_flags.FLAGS, name)) + + +class NameSpace(object): + def __init__(self, getter): + self._getter = getter + + def __getattr__(self, name): + return self._getter(name) + + +config = Config() +flags = config diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 2e559060f..8f99a7cc0 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -23,7 +23,7 @@ import operator as op import six from six.moves import xrange -from absl import flags +from ..config import flags from .. import core from .. import ad_util from ..abstract_arrays import ConcreteArray, ShapedArray, make_shaped_array, array_types diff --git a/jax/lib/xla_bridge.py b/jax/lib/xla_bridge.py index 442a1f935..053c6bbe9 100644 --- a/jax/lib/xla_bridge.py +++ b/jax/lib/xla_bridge.py @@ -26,7 +26,7 @@ from __future__ import print_function import os import warnings -from absl import flags +from ..config import flags import numpy as onp # 'onp' rather than 'np' to distinguish from autograd.numpy from . import xla_data_pb2 @@ -206,21 +206,8 @@ _dtype_to_32bit_dtype = { } -def canonicalize_dtype(dtype): - """Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64.""" - # This function is a thin wrapper around the memoized _canonicalize_dtype to - # handle the case where FLAGS haven't been parsed yet, for example because - # this function is called at module loading time. This situation can't obtain - # during tracing and instead can arise when there are module-level constants - # computed using lax or lax_numpy. - if FLAGS.is_parsed(): - return _canonicalize_dtype(dtype) - else: - return dtype - - @memoize -def _canonicalize_dtype(dtype): +def canonicalize_dtype(dtype): """Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64.""" dtype = onp.dtype(dtype) if FLAGS.jax_enable_x64: diff --git a/jax/test_util.py b/jax/test_util.py index 4c01b612a..2c05a0b65 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -19,7 +19,6 @@ from __future__ import print_function import functools import re -from absl import flags from absl.testing import absltest from absl.testing import parameterized @@ -27,6 +26,7 @@ import numpy as onp import numpy.random as npr from . import api +from .config import flags from .util import partial from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce diff --git a/tests/api_test.py b/tests/api_test.py index 188fff8e9..39eb326b7 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -23,6 +23,7 @@ from absl.testing import absltest from jax import test_util as jtu import jax.numpy as np +from jax.config import config from jax import jit, grad, device_get, device_put from jax.core import Primitive from jax.interpreters.partial_eval import def_abstract_eval @@ -239,4 +240,5 @@ class APITest(jtu.JaxTestCase): if __name__ == '__main__': + config.config_with_absl() absltest.main() diff --git a/tests/batching_test.py b/tests/batching_test.py index d42012f51..fd372bc47 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -26,6 +26,7 @@ from jax.abstract_arrays import ShapedArray from jax import lax from jax.api import jit, grad, jvp, vjp, trace_to_jaxpr from jax.api import vmap +from jax.config import config from jax.core import unit from jax.interpreters import partial_eval as pe from jax.util import partial @@ -195,4 +196,5 @@ class BatchingTest(jtu.JaxTestCase): if __name__ == '__main__': + config.config_with_absl() absltest.main() diff --git a/tests/core_test.py b/tests/core_test.py index 52383a6ac..6091d92ef 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -28,6 +28,7 @@ from jax import core from jax import numpy as np from jax import test_util as jtu from jax.api import jvp, linearize, vjp, jit +from jax.config import config from jax.lax import UnshapedArray, ShapedArray, ConcreteArray from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce from jax.util import partial @@ -331,4 +332,5 @@ class CoreTest(jtu.JaxTestCase): if __name__ == '__main__': + config.config_with_absl() absltest.main() diff --git a/tests/lapax_test.py b/tests/lapax_test.py index 07ef0ad20..444156171 100644 --- a/tests/lapax_test.py +++ b/tests/lapax_test.py @@ -25,6 +25,7 @@ from absl.testing import parameterized from jax import jit from jax import test_util as jtu +from jax.config import config from jax.experimental import lapax @@ -202,4 +203,5 @@ class LapaxTest(jtu.JaxTestCase): if __name__ == "__main__": + config.config_with_absl() absltest.main() diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index 6b4db39dd..024c3e991 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -20,7 +20,6 @@ import collections from functools import partial import itertools -from absl import flags from absl.testing import absltest from absl.testing import parameterized @@ -30,8 +29,9 @@ from jax import api from jax import lax from jax import numpy as lnp from jax import test_util as jtu +from jax.config import config -FLAGS = flags.FLAGS +FLAGS = config.FLAGS # We disable the whitespace continuation check in this file because otherwise it # makes the test name formatting unwieldy. @@ -587,4 +587,5 @@ class IndexingTest(jtu.JaxTestCase): if __name__ == "__main__": + config.config_with_absl() absltest.main() diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 70280c0c6..3881dc06b 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -20,7 +20,6 @@ import collections import functools import itertools -from absl import flags from absl.testing import absltest from absl.testing import parameterized @@ -29,8 +28,9 @@ import numpy as onp from jax import api from jax import numpy as lnp from jax import test_util as jtu +from jax.config import config -FLAGS = flags.FLAGS +FLAGS = config.FLAGS all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)] @@ -542,4 +542,5 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): if __name__ == "__main__": + config.config_with_absl() absltest.main() diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index bd587d98c..53c1e688f 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -20,7 +20,6 @@ import collections import functools import itertools -from absl import flags from absl.testing import absltest from absl.testing import parameterized @@ -31,11 +30,12 @@ import scipy.stats as osp_stats from jax import api from jax import test_util as jtu +from jax.config import config from jax.scipy import misc as lsp_misc from jax.scipy import special as lsp_special from jax.scipy import stats as lsp_stats -FLAGS = flags.FLAGS +FLAGS = config.FLAGS all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)] @@ -154,4 +154,5 @@ class LaxBackedScipyTests(jtu.JaxTestCase): if __name__ == "__main__": + config.config_with_absl() absltest.main() diff --git a/tests/lax_test.py b/tests/lax_test.py index 800709fef..f1c3d8580 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -21,7 +21,6 @@ import functools from functools import partial import itertools -from absl import flags from absl.testing import absltest from absl.testing import parameterized @@ -33,10 +32,11 @@ from jax import core from jax import lax from jax import test_util as jtu from jax import lax_reference +from jax.config import config from jax.interpreters import xla from jax.lib import xla_bridge -FLAGS = flags.FLAGS +FLAGS = config.FLAGS def num_float_bits(dtype): @@ -1986,4 +1986,5 @@ class LaxAutodiffTest(jtu.JaxTestCase): if __name__ == '__main__': + config.config_with_absl() absltest.main() diff --git a/tests/minmax_test.py b/tests/minmax_test.py index 5eee658ce..a548d1362 100644 --- a/tests/minmax_test.py +++ b/tests/minmax_test.py @@ -20,7 +20,7 @@ from __future__ import print_function import functools from absl.testing import absltest - +from jax.config import config import jax.numpy as np import jax.test_util as jtu from jax import jit, grad @@ -170,4 +170,5 @@ class OptimizerTests(jtu.JaxTestCase): if __name__ == '__main__': + config.config_with_absl() absltest.main() diff --git a/tests/random_test.py b/tests/random_test.py index c5ac91106..e90582b92 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -16,7 +16,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from absl import flags from absl.testing import absltest from absl.testing import parameterized @@ -28,8 +27,9 @@ from jax import api from jax import lax from jax import random from jax import test_util as jtu +from jax.config import config -FLAGS = flags.FLAGS +FLAGS = config.FLAGS class LaxRandomTest(jtu.JaxTestCase): @@ -150,4 +150,5 @@ class LaxRandomTest(jtu.JaxTestCase): if __name__ == "__main__": + config.config_with_absl() absltest.main() diff --git a/tests/stax_test.py b/tests/stax_test.py index 1689301f9..3b48dadfb 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -24,6 +24,7 @@ import numpy as onp from jax import test_util as jtu from jax import random +from jax.config import config from jax.experimental import stax @@ -129,4 +130,5 @@ class StaxTest(jtu.JaxTestCase): if __name__ == "__main__": + config.config_with_absl() absltest.main()