Made a shim to handle configuration without having absl parse command-line flags.

PiperOrigin-RevId: 223391288
This commit is contained in:
Dougal Maclaurin 2018-11-29 12:30:34 -08:00 committed by Roy Frostig
parent 1d2aaad6fe
commit 2df36f7510
19 changed files with 124 additions and 28 deletions

View File

@ -26,6 +26,7 @@ import itertools
import numpy.random as npr
import jax.numpy as np
from jax.config import config
from jax import jit, grad
from jax.experimental import minmax
from jax.experimental import stax
@ -80,6 +81,7 @@ if __name__ == "__main__":
opt_state = opt_init(init_params)
itercount = itertools.count()
print("\nStarting training...")
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):

View File

@ -26,6 +26,7 @@ import time
import numpy.random as npr
from jax.api import jit, grad
from jax.config import config
from jax.scipy.misc import logsumexp
import jax.numpy as np
import datasets

View File

@ -28,6 +28,7 @@ import time
import matplotlib.pyplot as plt
import jax.numpy as np
from jax.config import config
from jax import jit, grad, lax, random
from jax.experimental import minmax
from jax.experimental import stax

View File

@ -24,6 +24,7 @@ from __future__ import print_function
import numpy.random as npr
import jax.numpy as np
from jax.config import config
from jax import jit, grad
from jax.experimental import minmax
from jax.experimental import stax

88
jax/config.py Normal file
View File

@ -0,0 +1,88 @@
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class Config(object):
def __init__(self):
self.values = {}
self.meta = {}
self.FLAGS = NameSpace(self.read)
self.use_absl = False
def update(self, name, val):
self.check_exists(name)
if name not in self.values:
raise Exception("Unrecognized config option: {}".format(name))
self.values[name] = val
def read(self, name):
if self.use_absl:
return getattr(self.absl_flags.FLAGS, name)
else:
self.check_exists(name)
return self.values[name]
def add_option(self, name, default, opt_type, meta_args, meta_kwargs):
if name in self.values:
raise Exception("Config option {} already defined".format(name))
self.values[name] = default
self.meta[name] = (opt_type, meta_args, meta_kwargs)
def check_exists(self, name):
if name not in self.values:
raise Exception("Unrecognized config option: {}".format(name))
def DEFINE_bool(self, name, default, *args, **kwargs):
self.add_option(name, default, bool, args, kwargs)
def DEFINE_integer(self, name, default, *args, **kwargs):
self.add_option(name, default, int, args, kwargs)
def DEFINE_string(self, name, default, *args, **kwargs):
self.add_option(name, default, str, args, kwargs)
def DEFINE_enum(self, name, default, *args, **kwargs):
self.add_option(name, default, 'enum', args, kwargs)
def config_with_absl(self):
# Run this before calling `app.run(main)` etc
import absl.flags as absl_FLAGS
from absl import app, flags as absl_flags
self.use_absl = True
self.absl_flags = absl_flags
absl_defs = { bool: absl_flags.DEFINE_bool,
int: absl_flags.DEFINE_integer,
str: absl_flags.DEFINE_string,
'enum': absl_flags.DEFINE_enum }
for name, val in self.values.items():
flag_type, meta_args, meta_kwargs = self.meta[name]
absl_defs[flag_type](name, val, *meta_args, **meta_kwargs)
def complete_absl_config(self, absl_flags):
for name, _ in self.values.items():
self.update(name, getattr(absl_flags.FLAGS, name))
class NameSpace(object):
def __init__(self, getter):
self._getter = getter
def __getattr__(self, name):
return self._getter(name)
config = Config()
flags = config

View File

@ -23,7 +23,7 @@ import operator as op
import six
from six.moves import xrange
from absl import flags
from ..config import flags
from .. import core
from .. import ad_util
from ..abstract_arrays import ConcreteArray, ShapedArray, make_shaped_array, array_types

View File

@ -26,7 +26,7 @@ from __future__ import print_function
import os
import warnings
from absl import flags
from ..config import flags
import numpy as onp # 'onp' rather than 'np' to distinguish from autograd.numpy
from . import xla_data_pb2
@ -206,21 +206,8 @@ _dtype_to_32bit_dtype = {
}
def canonicalize_dtype(dtype):
"""Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64."""
# This function is a thin wrapper around the memoized _canonicalize_dtype to
# handle the case where FLAGS haven't been parsed yet, for example because
# this function is called at module loading time. This situation can't obtain
# during tracing and instead can arise when there are module-level constants
# computed using lax or lax_numpy.
if FLAGS.is_parsed():
return _canonicalize_dtype(dtype)
else:
return dtype
@memoize
def _canonicalize_dtype(dtype):
def canonicalize_dtype(dtype):
"""Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64."""
dtype = onp.dtype(dtype)
if FLAGS.jax_enable_x64:

View File

@ -19,7 +19,6 @@ from __future__ import print_function
import functools
import re
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
@ -27,6 +26,7 @@ import numpy as onp
import numpy.random as npr
from . import api
from .config import flags
from .util import partial
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce

View File

@ -23,6 +23,7 @@ from absl.testing import absltest
from jax import test_util as jtu
import jax.numpy as np
from jax.config import config
from jax import jit, grad, device_get, device_put
from jax.core import Primitive
from jax.interpreters.partial_eval import def_abstract_eval
@ -239,4 +240,5 @@ class APITest(jtu.JaxTestCase):
if __name__ == '__main__':
config.config_with_absl()
absltest.main()

View File

@ -26,6 +26,7 @@ from jax.abstract_arrays import ShapedArray
from jax import lax
from jax.api import jit, grad, jvp, vjp, trace_to_jaxpr
from jax.api import vmap
from jax.config import config
from jax.core import unit
from jax.interpreters import partial_eval as pe
from jax.util import partial
@ -195,4 +196,5 @@ class BatchingTest(jtu.JaxTestCase):
if __name__ == '__main__':
config.config_with_absl()
absltest.main()

View File

@ -28,6 +28,7 @@ from jax import core
from jax import numpy as np
from jax import test_util as jtu
from jax.api import jvp, linearize, vjp, jit
from jax.config import config
from jax.lax import UnshapedArray, ShapedArray, ConcreteArray
from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce
from jax.util import partial
@ -331,4 +332,5 @@ class CoreTest(jtu.JaxTestCase):
if __name__ == '__main__':
config.config_with_absl()
absltest.main()

View File

@ -25,6 +25,7 @@ from absl.testing import parameterized
from jax import jit
from jax import test_util as jtu
from jax.config import config
from jax.experimental import lapax
@ -202,4 +203,5 @@ class LapaxTest(jtu.JaxTestCase):
if __name__ == "__main__":
config.config_with_absl()
absltest.main()

View File

@ -20,7 +20,6 @@ import collections
from functools import partial
import itertools
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
@ -30,8 +29,9 @@ from jax import api
from jax import lax
from jax import numpy as lnp
from jax import test_util as jtu
from jax.config import config
FLAGS = flags.FLAGS
FLAGS = config.FLAGS
# We disable the whitespace continuation check in this file because otherwise it
# makes the test name formatting unwieldy.
@ -587,4 +587,5 @@ class IndexingTest(jtu.JaxTestCase):
if __name__ == "__main__":
config.config_with_absl()
absltest.main()

View File

@ -20,7 +20,6 @@ import collections
import functools
import itertools
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
@ -29,8 +28,9 @@ import numpy as onp
from jax import api
from jax import numpy as lnp
from jax import test_util as jtu
from jax.config import config
FLAGS = flags.FLAGS
FLAGS = config.FLAGS
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
@ -542,4 +542,5 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if __name__ == "__main__":
config.config_with_absl()
absltest.main()

View File

@ -20,7 +20,6 @@ import collections
import functools
import itertools
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
@ -31,11 +30,12 @@ import scipy.stats as osp_stats
from jax import api
from jax import test_util as jtu
from jax.config import config
from jax.scipy import misc as lsp_misc
from jax.scipy import special as lsp_special
from jax.scipy import stats as lsp_stats
FLAGS = flags.FLAGS
FLAGS = config.FLAGS
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
@ -154,4 +154,5 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
if __name__ == "__main__":
config.config_with_absl()
absltest.main()

View File

@ -21,7 +21,6 @@ import functools
from functools import partial
import itertools
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
@ -33,10 +32,11 @@ from jax import core
from jax import lax
from jax import test_util as jtu
from jax import lax_reference
from jax.config import config
from jax.interpreters import xla
from jax.lib import xla_bridge
FLAGS = flags.FLAGS
FLAGS = config.FLAGS
def num_float_bits(dtype):
@ -1986,4 +1986,5 @@ class LaxAutodiffTest(jtu.JaxTestCase):
if __name__ == '__main__':
config.config_with_absl()
absltest.main()

View File

@ -20,7 +20,7 @@ from __future__ import print_function
import functools
from absl.testing import absltest
from jax.config import config
import jax.numpy as np
import jax.test_util as jtu
from jax import jit, grad
@ -170,4 +170,5 @@ class OptimizerTests(jtu.JaxTestCase):
if __name__ == '__main__':
config.config_with_absl()
absltest.main()

View File

@ -16,7 +16,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
@ -28,8 +27,9 @@ from jax import api
from jax import lax
from jax import random
from jax import test_util as jtu
from jax.config import config
FLAGS = flags.FLAGS
FLAGS = config.FLAGS
class LaxRandomTest(jtu.JaxTestCase):
@ -150,4 +150,5 @@ class LaxRandomTest(jtu.JaxTestCase):
if __name__ == "__main__":
config.config_with_absl()
absltest.main()

View File

@ -24,6 +24,7 @@ import numpy as onp
from jax import test_util as jtu
from jax import random
from jax.config import config
from jax.experimental import stax
@ -129,4 +130,5 @@ class StaxTest(jtu.JaxTestCase):
if __name__ == "__main__":
config.config_with_absl()
absltest.main()