mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Made a shim to handle configuration without having absl parse command-line flags.
PiperOrigin-RevId: 223391288
This commit is contained in:
parent
1d2aaad6fe
commit
2df36f7510
@ -26,6 +26,7 @@ import itertools
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.config import config
|
||||
from jax import jit, grad
|
||||
from jax.experimental import minmax
|
||||
from jax.experimental import stax
|
||||
@ -80,6 +81,7 @@ if __name__ == "__main__":
|
||||
opt_state = opt_init(init_params)
|
||||
itercount = itertools.count()
|
||||
|
||||
print("\nStarting training...")
|
||||
for epoch in range(num_epochs):
|
||||
start_time = time.time()
|
||||
for _ in range(num_batches):
|
||||
|
@ -26,6 +26,7 @@ import time
|
||||
import numpy.random as npr
|
||||
|
||||
from jax.api import jit, grad
|
||||
from jax.config import config
|
||||
from jax.scipy.misc import logsumexp
|
||||
import jax.numpy as np
|
||||
import datasets
|
||||
|
@ -28,6 +28,7 @@ import time
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.config import config
|
||||
from jax import jit, grad, lax, random
|
||||
from jax.experimental import minmax
|
||||
from jax.experimental import stax
|
||||
|
@ -24,6 +24,7 @@ from __future__ import print_function
|
||||
import numpy.random as npr
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.config import config
|
||||
from jax import jit, grad
|
||||
from jax.experimental import minmax
|
||||
from jax.experimental import stax
|
||||
|
88
jax/config.py
Normal file
88
jax/config.py
Normal file
@ -0,0 +1,88 @@
|
||||
# Copyright 2018 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class Config(object):
|
||||
def __init__(self):
|
||||
self.values = {}
|
||||
self.meta = {}
|
||||
self.FLAGS = NameSpace(self.read)
|
||||
self.use_absl = False
|
||||
|
||||
def update(self, name, val):
|
||||
self.check_exists(name)
|
||||
if name not in self.values:
|
||||
raise Exception("Unrecognized config option: {}".format(name))
|
||||
self.values[name] = val
|
||||
|
||||
def read(self, name):
|
||||
if self.use_absl:
|
||||
return getattr(self.absl_flags.FLAGS, name)
|
||||
else:
|
||||
self.check_exists(name)
|
||||
return self.values[name]
|
||||
|
||||
def add_option(self, name, default, opt_type, meta_args, meta_kwargs):
|
||||
if name in self.values:
|
||||
raise Exception("Config option {} already defined".format(name))
|
||||
self.values[name] = default
|
||||
self.meta[name] = (opt_type, meta_args, meta_kwargs)
|
||||
|
||||
def check_exists(self, name):
|
||||
if name not in self.values:
|
||||
raise Exception("Unrecognized config option: {}".format(name))
|
||||
|
||||
def DEFINE_bool(self, name, default, *args, **kwargs):
|
||||
self.add_option(name, default, bool, args, kwargs)
|
||||
|
||||
def DEFINE_integer(self, name, default, *args, **kwargs):
|
||||
self.add_option(name, default, int, args, kwargs)
|
||||
|
||||
def DEFINE_string(self, name, default, *args, **kwargs):
|
||||
self.add_option(name, default, str, args, kwargs)
|
||||
|
||||
def DEFINE_enum(self, name, default, *args, **kwargs):
|
||||
self.add_option(name, default, 'enum', args, kwargs)
|
||||
|
||||
def config_with_absl(self):
|
||||
# Run this before calling `app.run(main)` etc
|
||||
import absl.flags as absl_FLAGS
|
||||
from absl import app, flags as absl_flags
|
||||
|
||||
self.use_absl = True
|
||||
self.absl_flags = absl_flags
|
||||
absl_defs = { bool: absl_flags.DEFINE_bool,
|
||||
int: absl_flags.DEFINE_integer,
|
||||
str: absl_flags.DEFINE_string,
|
||||
'enum': absl_flags.DEFINE_enum }
|
||||
|
||||
for name, val in self.values.items():
|
||||
flag_type, meta_args, meta_kwargs = self.meta[name]
|
||||
absl_defs[flag_type](name, val, *meta_args, **meta_kwargs)
|
||||
|
||||
def complete_absl_config(self, absl_flags):
|
||||
for name, _ in self.values.items():
|
||||
self.update(name, getattr(absl_flags.FLAGS, name))
|
||||
|
||||
|
||||
class NameSpace(object):
|
||||
def __init__(self, getter):
|
||||
self._getter = getter
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self._getter(name)
|
||||
|
||||
|
||||
config = Config()
|
||||
flags = config
|
@ -23,7 +23,7 @@ import operator as op
|
||||
import six
|
||||
from six.moves import xrange
|
||||
|
||||
from absl import flags
|
||||
from ..config import flags
|
||||
from .. import core
|
||||
from .. import ad_util
|
||||
from ..abstract_arrays import ConcreteArray, ShapedArray, make_shaped_array, array_types
|
||||
|
@ -26,7 +26,7 @@ from __future__ import print_function
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from absl import flags
|
||||
from ..config import flags
|
||||
import numpy as onp # 'onp' rather than 'np' to distinguish from autograd.numpy
|
||||
|
||||
from . import xla_data_pb2
|
||||
@ -206,21 +206,8 @@ _dtype_to_32bit_dtype = {
|
||||
}
|
||||
|
||||
|
||||
def canonicalize_dtype(dtype):
|
||||
"""Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64."""
|
||||
# This function is a thin wrapper around the memoized _canonicalize_dtype to
|
||||
# handle the case where FLAGS haven't been parsed yet, for example because
|
||||
# this function is called at module loading time. This situation can't obtain
|
||||
# during tracing and instead can arise when there are module-level constants
|
||||
# computed using lax or lax_numpy.
|
||||
if FLAGS.is_parsed():
|
||||
return _canonicalize_dtype(dtype)
|
||||
else:
|
||||
return dtype
|
||||
|
||||
|
||||
@memoize
|
||||
def _canonicalize_dtype(dtype):
|
||||
def canonicalize_dtype(dtype):
|
||||
"""Convert from a dtype to a canonical dtype based on FLAGS.jax_enable_x64."""
|
||||
dtype = onp.dtype(dtype)
|
||||
if FLAGS.jax_enable_x64:
|
||||
|
@ -19,7 +19,6 @@ from __future__ import print_function
|
||||
import functools
|
||||
import re
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
@ -27,6 +26,7 @@ import numpy as onp
|
||||
import numpy.random as npr
|
||||
|
||||
from . import api
|
||||
from .config import flags
|
||||
from .util import partial
|
||||
from .tree_util import tree_multimap, tree_all, tree_map, tree_reduce
|
||||
|
||||
|
@ -23,6 +23,7 @@ from absl.testing import absltest
|
||||
from jax import test_util as jtu
|
||||
|
||||
import jax.numpy as np
|
||||
from jax.config import config
|
||||
from jax import jit, grad, device_get, device_put
|
||||
from jax.core import Primitive
|
||||
from jax.interpreters.partial_eval import def_abstract_eval
|
||||
@ -239,4 +240,5 @@ class APITest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.config_with_absl()
|
||||
absltest.main()
|
||||
|
@ -26,6 +26,7 @@ from jax.abstract_arrays import ShapedArray
|
||||
from jax import lax
|
||||
from jax.api import jit, grad, jvp, vjp, trace_to_jaxpr
|
||||
from jax.api import vmap
|
||||
from jax.config import config
|
||||
from jax.core import unit
|
||||
from jax.interpreters import partial_eval as pe
|
||||
from jax.util import partial
|
||||
@ -195,4 +196,5 @@ class BatchingTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.config_with_absl()
|
||||
absltest.main()
|
||||
|
@ -28,6 +28,7 @@ from jax import core
|
||||
from jax import numpy as np
|
||||
from jax import test_util as jtu
|
||||
from jax.api import jvp, linearize, vjp, jit
|
||||
from jax.config import config
|
||||
from jax.lax import UnshapedArray, ShapedArray, ConcreteArray
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, tree_multimap, tree_reduce
|
||||
from jax.util import partial
|
||||
@ -331,4 +332,5 @@ class CoreTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.config_with_absl()
|
||||
absltest.main()
|
||||
|
@ -25,6 +25,7 @@ from absl.testing import parameterized
|
||||
|
||||
from jax import jit
|
||||
from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax.experimental import lapax
|
||||
|
||||
|
||||
@ -202,4 +203,5 @@ class LapaxTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.config_with_absl()
|
||||
absltest.main()
|
||||
|
@ -20,7 +20,6 @@ import collections
|
||||
from functools import partial
|
||||
import itertools
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
@ -30,8 +29,9 @@ from jax import api
|
||||
from jax import lax
|
||||
from jax import numpy as lnp
|
||||
from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
# We disable the whitespace continuation check in this file because otherwise it
|
||||
# makes the test name formatting unwieldy.
|
||||
@ -587,4 +587,5 @@ class IndexingTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.config_with_absl()
|
||||
absltest.main()
|
||||
|
@ -20,7 +20,6 @@ import collections
|
||||
import functools
|
||||
import itertools
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
@ -29,8 +28,9 @@ import numpy as onp
|
||||
from jax import api
|
||||
from jax import numpy as lnp
|
||||
from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4), (2, 3, 4)]
|
||||
|
||||
@ -542,4 +542,5 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.config_with_absl()
|
||||
absltest.main()
|
||||
|
@ -20,7 +20,6 @@ import collections
|
||||
import functools
|
||||
import itertools
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
@ -31,11 +30,12 @@ import scipy.stats as osp_stats
|
||||
|
||||
from jax import api
|
||||
from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
from jax.scipy import misc as lsp_misc
|
||||
from jax.scipy import special as lsp_special
|
||||
from jax.scipy import stats as lsp_stats
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
|
||||
|
||||
@ -154,4 +154,5 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.config_with_absl()
|
||||
absltest.main()
|
||||
|
@ -21,7 +21,6 @@ import functools
|
||||
from functools import partial
|
||||
import itertools
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
@ -33,10 +32,11 @@ from jax import core
|
||||
from jax import lax
|
||||
from jax import test_util as jtu
|
||||
from jax import lax_reference
|
||||
from jax.config import config
|
||||
from jax.interpreters import xla
|
||||
from jax.lib import xla_bridge
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
def num_float_bits(dtype):
|
||||
@ -1986,4 +1986,5 @@ class LaxAutodiffTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.config_with_absl()
|
||||
absltest.main()
|
||||
|
@ -20,7 +20,7 @@ from __future__ import print_function
|
||||
import functools
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
from jax.config import config
|
||||
import jax.numpy as np
|
||||
import jax.test_util as jtu
|
||||
from jax import jit, grad
|
||||
@ -170,4 +170,5 @@ class OptimizerTests(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.config_with_absl()
|
||||
absltest.main()
|
||||
|
@ -16,7 +16,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl import flags
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
@ -28,8 +27,9 @@ from jax import api
|
||||
from jax import lax
|
||||
from jax import random
|
||||
from jax import test_util as jtu
|
||||
from jax.config import config
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
FLAGS = config.FLAGS
|
||||
|
||||
|
||||
class LaxRandomTest(jtu.JaxTestCase):
|
||||
@ -150,4 +150,5 @@ class LaxRandomTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.config_with_absl()
|
||||
absltest.main()
|
||||
|
@ -24,6 +24,7 @@ import numpy as onp
|
||||
|
||||
from jax import test_util as jtu
|
||||
from jax import random
|
||||
from jax.config import config
|
||||
from jax.experimental import stax
|
||||
|
||||
|
||||
@ -129,4 +130,5 @@ class StaxTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
config.config_with_absl()
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user