Update flags to use the ABSL typed flag API.

Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.

For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.

Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.

This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.

PiperOrigin-RevId: 551604974
This commit is contained in:
Peter Hawkins 2023-07-27 12:15:16 -07:00 committed by jax authors
parent f35f226b44
commit 76cda0ae07
22 changed files with 336 additions and 289 deletions

View File

@ -27,8 +27,7 @@ from absl import app
from absl import flags
FLAGS = flags.FLAGS
flags.DEFINE_multi_string(
_SET_ENV = flags.DEFINE_multi_string(
"set_env", None,
"Specifies additional environment variables to be injected into the "
"environment (via --set_env=variable=value or --set_env=variable). "
@ -140,8 +139,8 @@ def jax_binary_op(state, **kwargs):
)
def main(argv):
if FLAGS.set_env:
for env_str in FLAGS.set_env:
if _SET_ENV.value:
for env_str in _SET_ENV.value:
# Stop matching at the first '=' since we want to capture
# --set_env='FOO=--foo_a=1 --foo_b=2' all as part of FOO.
env_list = env_str.split('=', 1)

View File

@ -83,22 +83,21 @@ import numpy.random as npr
from dp_accounting import dp_event
from dp_accounting import rdp
FLAGS = flags.FLAGS
flags.DEFINE_boolean(
_DPSGD = flags.DEFINE_boolean(
'dpsgd', True, 'If True, train with DP-SGD. If False, '
'train with vanilla SGD.')
flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
flags.DEFINE_float('noise_multiplier', 1.1,
_LEARNING_RATE = flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
_NOISE_MULTIPLIER = flags.DEFINE_float('noise_multiplier', 1.1,
'Ratio of the standard deviation to the clipping norm')
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
flags.DEFINE_integer('batch_size', 256, 'Batch size')
flags.DEFINE_integer('epochs', 60, 'Number of epochs')
flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG')
flags.DEFINE_integer(
_L2_NORM_CLIP = flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
_BATCH_SIZE = flags.DEFINE_integer('batch_size', 256, 'Batch size')
_EPOCHS = flags.DEFINE_integer('epochs', 60, 'Number of epochs')
_SEED = flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG')
_MICROBATCHES = flags.DEFINE_integer(
'microbatches', None, 'Number of microbatches '
'(must evenly divide batch_size)')
flags.DEFINE_string('model_dir', None, 'Model directory')
_MODEL_DIR = flags.DEFINE_string('model_dir', None, 'Model directory')
init_random_params, predict = stax.serial(
@ -163,39 +162,39 @@ def shape_as_image(images, labels, dummy_dim=False):
def compute_epsilon(steps, num_examples=60000, target_delta=1e-5):
if num_examples * target_delta > 1.:
warnings.warn('Your delta might be too high.')
q = FLAGS.batch_size / float(num_examples)
q = _BATCH_SIZE.value / float(num_examples)
orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64))
accountant = rdp.rdp_privacy_accountant.RdpAccountant(orders)
accountant.compose(
dp_event.PoissonSampledDpEvent(
q, dp_event.GaussianDpEvent(FLAGS.noise_multiplier)), steps)
q, dp_event.GaussianDpEvent(_NOISE_MULTIPLIER.value)), steps)
return accountant.get_epsilon(target_delta)
def main(_):
if FLAGS.microbatches:
if _MICROBATCHES.value:
raise NotImplementedError(
'Microbatches < batch size not currently supported'
)
train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, FLAGS.batch_size)
num_complete_batches, leftover = divmod(num_train, _BATCH_SIZE.value)
num_batches = num_complete_batches + bool(leftover)
key = random.PRNGKey(FLAGS.seed)
key = random.PRNGKey(_SEED.value)
def data_stream():
rng = npr.RandomState(FLAGS.seed)
rng = npr.RandomState(_SEED.value)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size]
batch_idx = perm[i * _BATCH_SIZE.value:(i + 1) * _BATCH_SIZE.value]
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)
opt_init, opt_update, get_params = optimizers.sgd(_LEARNING_RATE.value)
@jit
def update(_, i, opt_state, batch):
@ -208,19 +207,19 @@ def main(_):
rng = random.fold_in(rng, i) # get new key for new random numbers
return opt_update(
i,
private_grad(params, batch, rng, FLAGS.l2_norm_clip,
FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)
private_grad(params, batch, rng, _L2_NORM_CLIP.value,
_NOISE_MULTIPLIER.value, _BATCH_SIZE.value), opt_state)
_, init_params = init_random_params(key, (-1, 28, 28, 1))
opt_state = opt_init(init_params)
itercount = itertools.count()
steps_per_epoch = 60000 // FLAGS.batch_size
steps_per_epoch = 60000 // _BATCH_SIZE.value
print('\nStarting training...')
for epoch in range(1, FLAGS.epochs + 1):
for epoch in range(1, _EPOCHS.value + 1):
start_time = time.time()
for _ in range(num_batches):
if FLAGS.dpsgd:
if _DPSGD.value:
opt_state = \
private_update(
key, next(itercount), opt_state,
@ -239,7 +238,7 @@ def main(_):
test_loss, 100 * test_acc))
# determine privacy loss so far
if FLAGS.dpsgd:
if _DPSGD.value:
delta = 1e-5
num_examples = 60000
eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta)

View File

@ -16,7 +16,6 @@
"""
from absl import app
from absl import flags
from functools import partial
from jax import grad
from jax import jit
@ -27,8 +26,6 @@ import jax.random as random
import jax.scipy as scipy
import matplotlib.pyplot as plt
FLAGS = flags.FLAGS
def main(unused_argv):

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Hashable, Iterator
import contextlib
import functools
@ -20,7 +22,7 @@ import logging
import os
import sys
import threading
from typing import Any, Callable, NamedTuple, Optional
from typing import Any, Callable, Generic, NamedTuple, Optional, TypeVar
from jax._src import lib
from jax._src.lib import jax_jit
@ -29,6 +31,8 @@ from jax._src.lib import xla_client
logger = logging.getLogger(__name__)
_T = TypeVar('_T')
def bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean.
@ -64,6 +68,16 @@ UPGRADE_BOOL_HELP = (
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
class FlagHolder(Generic[_T]):
def __init__(self, flags: NameSpace, name: str):
self._flags = flags
self._name = name
@property
def value(self) -> _T:
return getattr(self._flags, self._name)
class Config:
_HAS_DYNAMIC_ATTRIBUTES = True
@ -112,26 +126,31 @@ class Config:
if name not in self.values:
raise AttributeError(f"Unrecognized config option: {name}")
def DEFINE_bool(self, name, default, *args, **kwargs):
def DEFINE_bool(self, name, default, *args, **kwargs) -> FlagHolder[bool]:
update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, bool, args, kwargs, update_hook=update_hook)
return FlagHolder(self.FLAGS, name)
def DEFINE_integer(self, name, default, *args, **kwargs):
def DEFINE_integer(self, name, default, *args, **kwargs) -> FlagHolder[int]:
update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, int, args, kwargs, update_hook=update_hook)
return FlagHolder(self.FLAGS, name)
def DEFINE_float(self, name, default, *args, **kwargs):
def DEFINE_float(self, name, default, *args, **kwargs) -> FlagHolder[float]:
update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, float, args, kwargs, update_hook=update_hook)
return FlagHolder(self.FLAGS, name)
def DEFINE_string(self, name, default, *args, **kwargs):
def DEFINE_string(self, name, default, *args, **kwargs) -> FlagHolder[str]:
update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, str, args, kwargs, update_hook=update_hook)
return FlagHolder(self.FLAGS, name)
def DEFINE_enum(self, name, default, *args, **kwargs):
def DEFINE_enum(self, name, default, *args, **kwargs) -> FlagHolder[str]:
update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, 'enum', args, kwargs,
update_hook=update_hook)
return FlagHolder(self.FLAGS, name)
def config_with_absl(self):
# Run this before calling `app.run(main)` etc
@ -551,6 +570,22 @@ config = Config()
flags = config
FLAGS = flags.FLAGS
def DEFINE_bool(name, default, *args, **kwargs):
return flags.DEFINE_bool(name, default, *args, **kwargs)
def DEFINE_integer(name, default, *args, **kwargs):
return flags.DEFINE_integer(name, default, *args, **kwargs)
def DEFINE_float(name, default, *args, **kwargs):
return flags.DEFINE_float(name, default, *args, **kwargs)
def DEFINE_string(name, default, *args, **kwargs):
return flags.DEFINE_string(name, default, *args, **kwargs)
def DEFINE_enum(name, default, *args, **kwargs):
return flags.DEFINE_enum(name, default, *args, **kwargs)
already_configured_with_absl = False
@ -617,51 +652,6 @@ def update_thread_local_jit_state(**kw):
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
flags.DEFINE_integer(
'jax_tracer_error_num_traceback_frames',
int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
help='Set the number of stack frames in JAX tracer error messages.'
)
flags.DEFINE_bool(
'jax_pprint_use_color',
bool_env('JAX_PPRINT_USE_COLOR', True),
help='Enable jaxpr pretty-printing with colorful syntax highlighting.'
)
flags.DEFINE_bool(
'jax_host_callback_inline',
bool_env('JAX_HOST_CALLBACK_INLINE', False),
help='Inline the host_callback, if not in a staged context.'
)
flags.DEFINE_integer(
'jax_host_callback_max_queue_byte_size',
int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)),
help=('The size in bytes of the buffer used to hold outfeeds from each '
'device. When this capacity is reached consuming outfeeds from the '
'device is paused, thus potentially pausing the device computation, '
'until the Python callback consume more outfeeds.'),
lower_bound=int(16 * 1e6)
)
flags.DEFINE_bool(
'jax_host_callback_outfeed',
bool_env('JAX_HOST_CALLBACK_OUTFEED', False),
help=(
'Use outfeed implementation for host_callback, even on CPU and GPU. '
'If false, use the CustomCall implementation. '
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
)
)
flags.DEFINE_bool(
'jax_host_callback_ad_transforms',
bool_env('JAX_HOST_CALLBACK_AD_TRANSFORMS', False),
help=(
'Enable support for jvp/vjp for the host_callback primitives. Default is '
'False, which means that host_callback operates only on primals. '
'The flag exists only temporarily, for backward compatibility.'
)
)
# TODO(b/214340779): remove flag when XLA:CPU is improved.
jax2tf_associative_scan_reductions = config.define_bool_state(
name='jax2tf_associative_scan_reductions',

View File

@ -38,7 +38,7 @@ import numpy as np
from jax._src import dtypes
from jax._src import config as jax_config
from jax._src import effects
from jax._src.config import FLAGS, config
from jax._src.config import config
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
@ -60,6 +60,13 @@ zip, unsafe_zip = safe_zip, zip
map, unsafe_map = safe_map, map
_TRACER_ERROR_NUM_TRACEBACK_FRAMES = jax_config.DEFINE_integer(
'jax_tracer_error_num_traceback_frames',
jax_config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
help='Set the number of stack frames in JAX tracer error messages.'
)
# -------------------- jaxprs --------------------
Effect = effects.Effect
@ -560,7 +567,7 @@ def raise_as_much_as_possible(tracer) -> Tracer:
def escaped_tracer_error(tracer, detail=None):
num_frames = FLAGS.jax_tracer_error_num_traceback_frames
num_frames = _TRACER_ERROR_NUM_TRACEBACK_FRAMES.value
msg = ('Encountered an unexpected tracer. A function transformed by JAX '
'had a side effect, allowing for a reference to an intermediate value '
f'with type {tracer.aval.str_short()} wrapped in a '

View File

@ -32,6 +32,7 @@ import warnings
import numpy as np
from jax._src import compilation_cache
from jax._src import config as jax_config
from jax._src import core
from jax._src import dtypes
from jax._src import linear_util as lu
@ -43,7 +44,7 @@ from jax._src import traceback_util
from jax._src import util
from jax._src import op_shardings
from jax._src import xla_bridge as xb
from jax._src.config import config, flags
from jax._src.config import config
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -64,9 +65,7 @@ JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
FLAGS = flags.FLAGS
flags.DEFINE_string(
_DUMP_IR_TO = jax_config.DEFINE_string(
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
help="Path to which the IR that is emitted by JAX as input to the "
"compiler should be dumped as text files. Optional. If omitted, JAX "
@ -472,7 +471,7 @@ def _make_string_safe_for_filename(s: str) -> str:
def _dump_ir_to_file(name: str, ir: str):
id = next(_ir_dump_counter)
name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir"
name = path.Path(FLAGS.jax_dump_ir_to) / name
name = path.Path(_DUMP_IR_TO.value) / name
name.write_text(ir)
@ -481,7 +480,7 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray,
sym_name = computation.operation.attributes['sym_name']
module_name = ir.StringAttr(sym_name).value
if FLAGS.jax_dump_ir_to:
if _DUMP_IR_TO.value:
_dump_ir_to_file(module_name, mlir.module_to_string(computation))
# Persistent compilation cache only implemented on TPU and GPU.

View File

@ -28,7 +28,7 @@ import warnings
import ml_dtypes
import numpy as np
from jax._src.config import flags, config
from jax._src.config import config
from jax._src.typing import DType, DTypeLike
from jax._src import traceback_util
@ -43,8 +43,6 @@ else:
raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; "
f"installed version is {ml_dtypes.__version__}.")
FLAGS = flags.FLAGS
class extended(np.generic):
"""Scalar class for extended dtypes.

View File

@ -32,13 +32,20 @@ from functools import partial
import sys
from typing import NamedTuple, Optional, Union
from jax._src.config import config
from jax._src import config
try:
import colorama # pytype: disable=import-error
except ImportError:
colorama = None
_PPRINT_USE_COLOR = config.DEFINE_bool(
'jax_pprint_use_color',
config.bool_env('JAX_PPRINT_USE_COLOR', True),
help='Enable jaxpr pretty-printing with colorful syntax highlighting.'
)
def _can_use_color() -> bool:
try:
# Check if we're in IPython or Colab
@ -63,7 +70,7 @@ class Doc(abc.ABC):
def format(self, width: int = 80, use_color: Optional[bool] = None,
annotation_prefix=" # ") -> str:
if use_color is None:
use_color = CAN_USE_COLOR and config.FLAGS.jax_pprint_use_color
use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value
return _format(self, width, use_color=use_color,
annotation_prefix=annotation_prefix)

View File

@ -16,9 +16,10 @@ from functools import partial
import operator
from jax._src import api
from jax._src import config as jax_config
from jax._src import dtypes as _dtypes
from jax._src import xla_bridge
from jax._src.config import config, flags
from jax._src.config import config
from jax._src.tree_util import tree_map, tree_reduce
import numpy as np
@ -30,7 +31,11 @@ import numpy as np
__all__ = ['check_grads', 'check_jvp', 'check_vjp']
FLAGS = flags.FLAGS
_TEST_DUT = jax_config.DEFINE_string(
'jax_test_dut', '',
help=
'Describes the device under test in case special consideration is required.'
)
EPS = 1e-4
@ -292,4 +297,4 @@ def check_grads(f, args, order,
def device_under_test():
return getattr(FLAGS, 'jax_test_dut', None) or xla_bridge.get_backend().platform
return _TEST_DUT.value or xla_bridge.get_backend().platform

View File

@ -41,11 +41,12 @@ from jax._src.interpreters import mlir
from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten
from jax._src import api
from jax._src import pjit as pjit_lib
from jax._src import config as jax_config
from jax._src import core
from jax._src import dispatch
from jax._src import dtypes as _dtypes
from jax._src.interpreters import pxla
from jax._src.config import (flags, bool_env, config,
from jax._src.config import (bool_env, config,
raise_persistent_cache_errors,
persistent_cache_min_compile_time_secs)
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
@ -60,19 +61,12 @@ from jax._src import xla_bridge
# jax.test_util. Functionality appearing here is for internal use only, and
# may be changed or removed at any time and without any deprecation cycle.
FLAGS = flags.FLAGS
flags.DEFINE_string(
'jax_test_dut', '',
help=
'Describes the device under test in case special consideration is required.'
)
flags.DEFINE_integer(
_NUM_GENERATED_CASES = jax_config.DEFINE_integer(
'jax_num_generated_cases',
int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
help='Number of generated cases to test')
flags.DEFINE_integer(
_MAX_CASES_SAMPLING_RETRIES = jax_config.DEFINE_integer(
'max_cases_sampling_retries',
int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')),
'Number of times a failed test sample should be retried. '
@ -80,24 +74,23 @@ flags.DEFINE_integer(
'sampling process is terminated.'
)
flags.DEFINE_bool(
_SKIP_SLOW_TESTS = jax_config.DEFINE_bool(
'jax_skip_slow_tests',
bool_env('JAX_SKIP_SLOW_TESTS', False),
help='Skip tests marked as slow (> 5 sec).'
)
flags.DEFINE_string(
_TEST_TARGETS = jax_config.DEFINE_string(
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
'Regular expression specifying which tests to run, called via re.search on '
'the test name. If empty or unspecified, run all tests.'
)
flags.DEFINE_string(
_EXCLUDE_TEST_TARGETS = jax_config.DEFINE_string(
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
'Regular expression specifying which tests NOT to run, called via re.search '
'on the test name. If empty or unspecified, run all tests.'
)
flags.DEFINE_bool(
TEST_WITH_PERSISTENT_COMPILATION_CACHE = jax_config.DEFINE_bool(
'jax_test_with_persistent_compilation_cache',
bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False),
help='If enabled, the persistent compilation cache will be enabled for all '
@ -731,7 +724,7 @@ def assert_dot_precision(expected_precision, fun, *args):
def cases_from_gens(*gens):
sizes = [1, 3, 10]
cases_per_size = int(FLAGS.jax_num_generated_cases / len(sizes)) + 1
cases_per_size = int(_NUM_GENERATED_CASES.value / len(sizes)) + 1
for size in sizes:
for i in range(cases_per_size):
yield (f'_{size}_{i}',) + tuple(gen(size) for gen in gens)
@ -744,8 +737,8 @@ def named_cases_from_sampler(gen):
if not isinstance(x, (list, tuple)):
x = list(x)
return [x[rng.randint(len(x))]]
while (len(seen) < FLAGS.jax_num_generated_cases and
retries < FLAGS.max_cases_sampling_retries):
while (len(seen) < _NUM_GENERATED_CASES.value and
retries < _MAX_CASES_SAMPLING_RETRIES.value):
retries += 1
cases = list(gen(choose_one))
if not cases:
@ -773,7 +766,7 @@ def sample_product_testcases(*args, **kw):
kw = [(k, list(v)) for k, v in kw.items()]
n = math.prod(len(a) for a in args) * math.prod(len(v) for _, v in kw)
testcases = []
for i in _choice(n, min(n, FLAGS.jax_num_generated_cases)):
for i in _choice(n, min(n, _NUM_GENERATED_CASES.value)):
testcase = {}
for a in args:
testcase.update(a[i % len(a)])
@ -804,12 +797,12 @@ def sample_product(*args, **kw):
class JaxTestLoader(absltest.TestLoader):
def getTestCaseNames(self, testCaseClass):
names = super().getTestCaseNames(testCaseClass)
if FLAGS.test_targets:
pattern = re.compile(FLAGS.test_targets)
if _TEST_TARGETS.value:
pattern = re.compile(_TEST_TARGETS.value)
names = [name for name in names
if pattern.search(f"{testCaseClass.__name__}.{name}")]
if FLAGS.exclude_test_targets:
pattern = re.compile(FLAGS.exclude_test_targets)
if _EXCLUDE_TEST_TARGETS.value:
pattern = re.compile(_EXCLUDE_TEST_TARGETS.value)
names = [name for name in names
if not pattern.search(f"{testCaseClass.__name__}.{name}")]
return names
@ -874,7 +867,7 @@ class JaxTestCase(parameterized.TestCase):
@classmethod
def setUpClass(cls):
if FLAGS.jax_test_with_persistent_compilation_cache:
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
cls._compilation_cache_exit_stack = ExitStack()
stack = cls._compilation_cache_exit_stack
stack.enter_context(raise_persistent_cache_errors(True))
@ -887,7 +880,7 @@ class JaxTestCase(parameterized.TestCase):
@classmethod
def tearDownClass(cls):
if FLAGS.jax_test_with_persistent_compilation_cache:
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
cls._compilation_cache_exit_stack.close()
def rng(self):

View File

@ -37,7 +37,8 @@ import numpy as np
from jax._src import lib
from jax._src import distributed
from jax._src.config import flags, bool_env, config, int_env
from jax._src import config as jax_config
from jax._src.config import bool_env, config, int_env
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src import traceback_util
@ -59,37 +60,35 @@ traceback_util.register_exclusion(__file__)
XlaBackend = xla_client.Client
FLAGS = flags.FLAGS
# TODO(phawkins): Remove jax_xla_backend.
flags.DEFINE_string(
_XLA_BACKEND = jax_config.DEFINE_string(
'jax_xla_backend', '',
'Deprecated, please use --jax_platforms instead.')
flags.DEFINE_string(
BACKEND_TARGET = jax_config.DEFINE_string(
'jax_backend_target',
os.getenv('JAX_BACKEND_TARGET', '').lower(),
'Either "local" or "rpc:address" to connect to a remote service target.')
# TODO(skye): warn when this is used once we test out --jax_platforms a bit
flags.DEFINE_string(
_PLATFORM_NAME = jax_config.DEFINE_string(
'jax_platform_name',
os.getenv('JAX_PLATFORM_NAME', '').lower(),
'Deprecated, please use --jax_platforms instead.')
flags.DEFINE_bool(
_DISABLE_MOST_OPTIMIZATIONS = jax_config.DEFINE_bool(
'jax_disable_most_optimizations',
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
'Try not to do much optimization work. This can be useful if the cost of '
'optimization is greater than that of running a less-optimized program.')
flags.DEFINE_integer(
_XLA_PROFILE_VERSION = jax_config.DEFINE_integer(
'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0),
'Optional profile version for XLA compilation. '
'This is meaningful only when XLA is configured to '
'support the remote compilation profile feature.')
flags.DEFINE_string(
CUDA_VISIBLE_DEVICES = jax_config.DEFINE_string(
'jax_cuda_visible_devices', 'all',
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
'comma-separate list of integer device IDs.')
flags.DEFINE_string(
_ROCM_VISIBLE_DEVICES = jax_config.DEFINE_string(
'jax_rocm_visible_devices', 'all',
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
'comma-separate list of integer device IDs.')
@ -171,13 +170,13 @@ def get_compile_options(
if lib.cuda_path is not None:
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path
if FLAGS.jax_disable_most_optimizations:
if _DISABLE_MOST_OPTIMIZATIONS.value:
debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False
compile_options.profile_version = FLAGS.jax_xla_profile_version
compile_options.profile_version = _XLA_PROFILE_VERSION.value
return compile_options
@ -264,9 +263,9 @@ register_backend_factory('cpu',
def make_gpu_client(
*, platform_name: str, visible_devices_flag: str
*, platform_name: str, visible_devices_flag: jax_config.FlagHolder[str]
) -> xla_client.Client:
visible_devices = getattr(FLAGS, visible_devices_flag, "all")
visible_devices = visible_devices_flag.value
allowed_devices = None
if visible_devices != "all":
allowed_devices = {int(x) for x in visible_devices.split(",")}
@ -292,15 +291,25 @@ def make_gpu_client(
if hasattr(xla_client, "make_gpu_client"):
register_backend_factory(
'cuda', partial(make_gpu_client, platform_name='cuda',
visible_devices_flag='jax_cuda_visible_devices'),
"cuda",
partial(
make_gpu_client,
platform_name="cuda",
visible_devices_flag=CUDA_VISIBLE_DEVICES,
),
priority=200,
fail_quietly=True)
fail_quietly=True,
)
register_backend_factory(
'rocm', partial(make_gpu_client, platform_name='rocm',
visible_devices_flag='jax_rocm_visible_devices'),
"rocm",
partial(
make_gpu_client,
platform_name="rocm",
visible_devices_flag=_ROCM_VISIBLE_DEVICES,
),
priority=200,
fail_quietly=True)
fail_quietly=True,
)
if hasattr(xla_client, "make_tpu_client"):
@ -623,7 +632,7 @@ def backends() -> dict[str, xla_client.Client]:
# support anything else there at the moment and warning would be pointless.
if (py_platform.system() != "Darwin" and
_default_backend.platform == "cpu" and
FLAGS.jax_platform_name != 'cpu'):
_PLATFORM_NAME.value != 'cpu'):
logger.warning('No GPU/TPU found, falling back to CPU. '
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
return _backends
@ -677,8 +686,7 @@ def _get_backend_uncached(
if platform is not None and not isinstance(platform, str):
return platform
platform = (platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name
or None)
platform = (platform or _XLA_BACKEND.value or _PLATFORM_NAME.value or None)
bs = backends()
if platform is not None:

View File

@ -507,7 +507,7 @@ import warnings
from jax._src import api
from jax._src import core
from jax import config
from jax._src import config
from jax import custom_derivatives
from jax._src import dtypes
from jax import lax
@ -531,17 +531,46 @@ from jax._src.lib.mlir.dialects import hlo
import numpy as np
FLAGS = config.FLAGS
_HOST_CALLBACK_INLINE = config.DEFINE_bool(
'jax_host_callback_inline',
config.bool_env('JAX_HOST_CALLBACK_INLINE', False),
help='Inline the host_callback, if not in a staged context.'
)
_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.DEFINE_integer(
'jax_host_callback_max_queue_byte_size',
config.int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)),
help=('The size in bytes of the buffer used to hold outfeeds from each '
'device. When this capacity is reached consuming outfeeds from the '
'device is paused, thus potentially pausing the device computation, '
'until the Python callback consume more outfeeds.'),
lower_bound=int(16 * 1e6)
)
_HOST_CALLBACK_OUTFEED = config.DEFINE_bool(
'jax_host_callback_outfeed',
config.bool_env('JAX_HOST_CALLBACK_OUTFEED', False),
help=(
'Use outfeed implementation for host_callback, even on CPU and GPU. '
'If false, use the CustomCall implementation. '
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
)
)
_HOST_CALLBACK_AD_TRANSFORMS = config.DEFINE_bool(
'jax_host_callback_ad_transforms',
config.bool_env('JAX_HOST_CALLBACK_AD_TRANSFORMS', False),
help=(
'Enable support for jvp/vjp for the host_callback primitives. Default is '
'False, which means that host_callback operates only on primals. '
'The flag exists only temporarily, for backward compatibility.'
)
)
logger = logging.getLogger(__name__)
def _inline_host_callback() -> bool:
return FLAGS.jax_host_callback_inline
def _use_outfeed(platform: str) -> bool:
return (platform in ("tpu", "gpu", "cuda", "rocm") or FLAGS.jax_host_callback_outfeed)
return (platform in ("tpu", "gpu", "cuda", "rocm") or
_HOST_CALLBACK_OUTFEED.value)
def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend):
@ -620,7 +649,7 @@ def id_tap(tap_func,
"pre-apply keyword arguments, either by using a closure or by passing "
"``functools.partial(tap_func, **kwargs)``.")
raise TypeError(msg)
if FLAGS.jax_host_callback_ad_transforms:
if _HOST_CALLBACK_AD_TRANSFORMS.value:
warnings.warn('The flag jax_host_callback_ad_transforms is for temporary '
'backwards compatibility mode. This flag, and the behavior '
'it enabled will be removed soon.',
@ -642,7 +671,7 @@ def id_tap(tap_func,
if result is not None:
# Return the results, but add a dependency on the call, to ensure it
# is kept in the graph.
if FLAGS.jax_host_callback_ad_transforms:
if _HOST_CALLBACK_AD_TRANSFORMS.value:
call_flat_results, _ = tree_util.tree_flatten(call_res)
if call_flat_results:
call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0])
@ -782,7 +811,7 @@ def _call(callback_func: Callable,
identity=False):
# Lazy initialization
_initialize_outfeed_receiver(
max_callback_queue_size_bytes=FLAGS.jax_host_callback_max_queue_byte_size)
max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value)
api.check_callable(callback_func)
flat_args, arg_treedef = tree_util.tree_flatten(arg)
for arg in flat_args:
@ -909,7 +938,7 @@ xla.register_translation(id_tap_dep_p,
id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a)
def _id_tap_dep_jvp_rule(primals, tangents):
if FLAGS.jax_host_callback_ad_transforms:
if _HOST_CALLBACK_AD_TRANSFORMS.value:
assert False
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
return (id_tap_dep_p.bind(primals[0], primals[1]),
@ -918,7 +947,7 @@ def _id_tap_dep_jvp_rule(primals, tangents):
ad.primitive_jvps[id_tap_dep_p] = _id_tap_dep_jvp_rule
def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap):
if FLAGS.jax_host_callback_ad_transforms:
if _HOST_CALLBACK_AD_TRANSFORMS.value:
assert False
if ad.is_undefined_primal(arg_res):
ct_res = _instantiate_zeros(cts, arg_res)
@ -934,7 +963,7 @@ ad.primitive_transposes[id_tap_dep_p] = _id_tap_dep_transpose_rule
def _id_tap_dep_batching_rule(batched_args, batch_dims):
if FLAGS.jax_host_callback_ad_transforms:
if _HOST_CALLBACK_AD_TRANSFORMS.value:
assert False
arg_res, arg_tap = batched_args
return id_tap_dep_p.bind(arg_res, arg_tap), batch_dims[0]
@ -1013,7 +1042,7 @@ outside_call_p.def_abstract_eval(_outside_call_abstract_eval)
def _outside_call_impl(*args, **params):
assert "has_token" not in params
if _inline_host_callback():
if _HOST_CALLBACK_INLINE.value:
device_index = params["device_index"]
device = xb.devices()[device_index]
results = _outside_call_run_callback(args, device, send_infeed=False, **params)
@ -1400,7 +1429,7 @@ def _outside_call_jvp_rule(primals, tangents, **params):
assert "has_token" not in params
if not params["identity"]:
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
if FLAGS.jax_host_callback_ad_transforms:
if _HOST_CALLBACK_AD_TRANSFORMS.value:
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
arg_treedef = params["arg_treedef"]
@ -1425,7 +1454,7 @@ ad.primitive_jvps[outside_call_p] = _outside_call_jvp_rule
def _outside_call_partial_eval_rule(trace, *args, **params):
# partial eval is used after jvp and before transpose.
if not FLAGS.jax_host_callback_ad_transforms:
if not _HOST_CALLBACK_AD_TRANSFORMS.value:
# TODO: just remote the partial eval rule
return trace.default_process_primitive(outside_call_p, args, params)
transforms = params.get("transforms", ())
@ -1492,7 +1521,7 @@ def _outside_call_transpose_rule(cts, *args, **params):
*cts_instantiated,
**_add_transform(params, "transpose"))
if not FLAGS.jax_host_callback_ad_transforms:
if not _HOST_CALLBACK_AD_TRANSFORMS.value:
assert False
assert len(args) % 2 == 0

View File

@ -65,7 +65,7 @@ def main(_):
tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)
keras_model.fit(train_ds, epochs=FLAGS.num_epochs, validation_data=test_ds)
if FLAGS.show_images:
if saved_model_main.SHOW_IMAGES.value:
mnist_lib.plot_images(
test_ds,
1,

View File

@ -38,8 +38,8 @@ import optax
import tensorflow as tf # type: ignore
import tensorflow_datasets as tfds # type: ignore
flags.DEFINE_boolean("mock_data", False, "Use fake data, for testing.")
FLAGS = flags.FLAGS
_MOCK_DATA = flags.DEFINE_boolean("mock_data", False,
"Use fake data, for testing.")
#### Model parameters
@ -64,7 +64,7 @@ def load_mnist(split: tfds.Split, batch_size: int):
an iterator with pairs (images, labels). The images have shape
(B, 28, 28, 1) and the labels have shape (B, 10), where B is the batch_size.
"""
if FLAGS.mock_data:
if _MOCK_DATA.value:
with tfds.testing.mock_data(num_examples=batch_size):
try:
ds = tfds.load("mnist", split=split)

View File

@ -37,46 +37,49 @@ import numpy as np
import tensorflow as tf # type: ignore
import tensorflow_datasets as tfds # type: ignore
flags.DEFINE_enum("model", "mnist_flax", ["mnist_flax", "mnist_pure_jax"],
"Which model to use.")
flags.DEFINE_boolean("model_classifier_layer", True,
_MODEL = flags.DEFINE_enum(
"model", "mnist_flax", ["mnist_flax", "mnist_pure_jax"],
"Which model to use.")
_MODEL_CLASSIFIER_LAYER = flags.DEFINE_boolean("model_classifier_layer", True,
("The model should include the classifier layer, or just "
"the last layer of logits. Set this to False when you "
"want to reuse the classifier-less model in a larger "
"model. See keras_reuse_main.py and README.md."))
flags.DEFINE_string("model_path", "/tmp/jax2tf/saved_models",
_MODEL_PATH = flags.DEFINE_string("model_path", "/tmp/jax2tf/saved_models",
"Path under which to save the SavedModel.")
flags.DEFINE_integer("model_version", 1,
_MODEL_VERSION = flags.DEFINE_integer("model_version", 1,
("The version number for the SavedModel. Needed for "
"serving, larger versions will take precedence"),
lower_bound=1)
flags.DEFINE_integer("serving_batch_size", 1,
_SERVING_BATCH_SIZE = flags.DEFINE_integer("serving_batch_size", 1,
"For what batch size to prepare the serving signature. "
"Use -1 for converting and saving with batch polymorphism.")
flags.register_validator(
"serving_batch_size",
lambda serving_batch_size: serving_batch_size > 0 or serving_batch_size == -1,
message="--serving_batch_size must be either -1 or a positive integer.")
lambda serving_batch_size: serving_batch_size > 0
or serving_batch_size == -1,
message="--serving_batch_size must be either -1 or a positive integer.",
)
flags.DEFINE_integer("num_epochs", 3, "For how many epochs to train.",
lower_bound=1)
flags.DEFINE_boolean(
_NUM_EPOCHS = flags.DEFINE_integer("num_epochs", 3,
"For how many epochs to train.",
lower_bound=1)
_GENERATE_MODEL = flags.DEFINE_boolean(
"generate_model", True,
"Train and save a new model. Otherwise, use an existing SavedModel.")
flags.DEFINE_boolean(
_COMPILE_MODEL = flags.DEFINE_boolean(
"compile_model", True,
"Enable TensorFlow jit_compiler for the SavedModel. This is "
"necessary if you want to use the model for TensorFlow serving.")
flags.DEFINE_boolean("show_model", True, "Show details of saved SavedModel.")
flags.DEFINE_boolean(
_SHOW_MODEL = flags.DEFINE_boolean("show_model", True,
"Show details of saved SavedModel.")
SHOW_IMAGES = flags.DEFINE_boolean(
"show_images", False,
"Plot some sample images with labels and inference results.")
flags.DEFINE_boolean(
_TEST_SAVEDMODEL = flags.DEFINE_boolean(
"test_savedmodel", True,
"Test TensorFlow inference using the SavedModel w.r.t. the JAX model.")
FLAGS = flags.FLAGS
def train_and_save():
logging.info("Loading the MNIST TensorFlow dataset")
@ -85,22 +88,22 @@ def train_and_save():
test_ds = mnist_lib.load_mnist(
tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)
if FLAGS.show_images:
if SHOW_IMAGES.value:
mnist_lib.plot_images(train_ds, 1, 5, "Training images", inference_fn=None)
the_model_class = pick_model_class()
model_dir = savedmodel_dir(with_version=True)
if FLAGS.generate_model:
if _GENERATE_MODEL.value:
model_descr = model_description()
logging.info("Generating model for %s", model_descr)
(predict_fn, predict_params) = the_model_class.train(
train_ds,
test_ds,
FLAGS.num_epochs,
with_classifier=FLAGS.model_classifier_layer)
num_epochs=_NUM_EPOCHS.value,
with_classifier=_MODEL_CLASSIFIER_LAYER.value)
if FLAGS.serving_batch_size == -1:
if _SERVING_BATCH_SIZE.value == -1:
# Batch-polymorphic SavedModel
input_signatures = [
tf.TensorSpec((None,) + mnist_lib.input_shape, tf.float32),
@ -109,7 +112,7 @@ def train_and_save():
else:
input_signatures = [
# The first one will be the serving signature
tf.TensorSpec((FLAGS.serving_batch_size,) + mnist_lib.input_shape,
tf.TensorSpec((_SERVING_BATCH_SIZE.value,) + mnist_lib.input_shape,
tf.float32),
tf.TensorSpec((mnist_lib.train_batch_size,) + mnist_lib.input_shape,
tf.float32),
@ -126,15 +129,15 @@ def train_and_save():
with_gradient=True,
input_signatures=input_signatures,
polymorphic_shapes=polymorphic_shapes,
compile_model=FLAGS.compile_model)
compile_model=_COMPILE_MODEL.value)
if FLAGS.test_savedmodel:
if _TEST_SAVEDMODEL.value:
tf_accelerator, tolerances = tf_accelerator_and_tolerances()
with tf.device(tf_accelerator):
logging.info("Testing savedmodel")
pure_restored_model = tf.saved_model.load(model_dir)
if FLAGS.show_images and FLAGS.model_classifier_layer:
if SHOW_IMAGES.value and _MODEL_CLASSIFIER_LAYER.value:
mnist_lib.plot_images(
test_ds,
1,
@ -149,7 +152,7 @@ def train_and_save():
pure_restored_model(tf.convert_to_tensor(test_input)),
predict_fn(predict_params, test_input), **tolerances)
if FLAGS.show_model:
if _SHOW_MODEL.value:
def print_model(model_dir: str):
cmd = f"saved_model_cli show --all --dir {model_dir}"
print(cmd)
@ -160,18 +163,18 @@ def train_and_save():
def pick_model_class():
"""Picks one of PureJaxMNIST or FlaxMNIST."""
if FLAGS.model == "mnist_pure_jax":
if _MODEL.value == "mnist_pure_jax":
return mnist_lib.PureJaxMNIST
elif FLAGS.model == "mnist_flax":
elif _MODEL.value == "mnist_flax":
return mnist_lib.FlaxMNIST
else:
raise ValueError(f"Unrecognized model: {FLAGS.model}")
raise ValueError(f"Unrecognized model: {_MODEL.value}")
def model_description() -> str:
"""A short description of the picked model."""
res = pick_model_class().name
if not FLAGS.model_classifier_layer:
if not _MODEL_CLASSIFIER_LAYER.value:
res += " (features_only)"
return res
@ -179,11 +182,11 @@ def model_description() -> str:
def savedmodel_dir(with_version: bool = True) -> str:
"""The directory where we save the SavedModel."""
model_dir = os.path.join(
FLAGS.model_path,
FLAGS.model + ('' if FLAGS.model_classifier_layer else '_features')
_MODEL_PATH.value,
_MODEL.value + ('' if _MODEL_CLASSIFIER_LAYER.value else '_features')
)
if with_version:
model_dir = os.path.join(model_dir, str(FLAGS.model_version))
model_dir = os.path.join(model_dir, str(_MODEL_VERSION.value))
return model_dir

View File

@ -32,31 +32,32 @@ from tensorflow_serving.apis import predict_pb2 # type: ignore[import]
from tensorflow_serving.apis import prediction_service_pb2_grpc
FLAGS = flags.FLAGS
flags.DEFINE_boolean(
_USE_GRPC = flags.DEFINE_boolean(
"use_grpc", True,
"Use the gRPC API (default), or the HTTP REST API.")
flags.DEFINE_string(
_MODEL_SPEC_NAME = flags.DEFINE_string(
"model_spec_name", "",
"The name you used to export your model to model server (e.g., mnist_flax).")
flags.DEFINE_string(
_PREDICTION_SERVICE_ADDR = flags.DEFINE_string(
"prediction_service_addr",
"localhost:8500",
"Stubby endpoint for the prediction service. If you serve your model "
"locally using TensorFlow model server, then you can use \"localhost:8500\""
"for the gRPC server and \"localhost:8501\" for the HTTP REST server.")
flags.DEFINE_integer("serving_batch_size", 1,
"Batch size for the serving request. Must match the "
"batch size at which the model was saved. Must divide "
"--count_images",
lower_bound=1)
flags.DEFINE_integer("count_images", 16,
"How many images to test.",
lower_bound=1)
_SERVING_BATCH_SIZE = flags.DEFINE_integer(
"serving_batch_size",
1,
"Batch size for the serving request. Must match the "
"batch size at which the model was saved. Must divide "
"--count_images",
lower_bound=1,
)
_COUNT_IMAGES = flags.DEFINE_integer(
"count_images", 16, "How many images to test.", lower_bound=1
)
def serving_call_mnist(images):
@ -69,12 +70,12 @@ def serving_call_mnist(images):
Returns:
A numpy.ndarray of shape [B, 10] with the one-hot inference response.
"""
if FLAGS.use_grpc:
channel = grpc.insecure_channel(FLAGS.prediction_service_addr)
if _USE_GRPC.value:
channel = grpc.insecure_channel(_PREDICTION_SERVICE_ADDR.value)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
request = predict_pb2.PredictRequest()
request.model_spec.name = FLAGS.model_spec_name
request.model_spec.name = _MODEL_SPEC_NAME.value
request.model_spec.signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
# You can see the name of the input ("inputs") in the SavedModel dump.
request.inputs["inputs"].CopyFrom(
@ -90,7 +91,7 @@ def serving_call_mnist(images):
images_json = json.dumps(images.tolist())
# You can see the name of the input ("inputs") in the SavedModel dump.
data = f'{{"inputs": {images_json}}}'
predict_url = f"http://{FLAGS.prediction_service_addr}/v1/models/{FLAGS.model_spec_name}:predict"
predict_url = f"http://{_PREDICTION_SERVICE_ADDR.value}/v1/models/{_MODEL_SPEC_NAME.value}:predict"
response = requests.post(predict_url, data=data)
if response.status_code != 200:
msg = (f"Received error response {response.status_code} from model "
@ -101,14 +102,14 @@ def serving_call_mnist(images):
def main(_):
if FLAGS.count_images % FLAGS.serving_batch_size != 0:
raise ValueError(f"The count_images ({FLAGS.count_images}) must be a "
if _COUNT_IMAGES.value % _SERVING_BATCH_SIZE.value != 0:
raise ValueError(f"The count_images ({_COUNT_IMAGES.value}) must be a "
"multiple of "
f"serving_batch_size ({FLAGS.serving_batch_size})")
f"serving_batch_size ({_SERVING_BATCH_SIZE.value})")
test_ds = mnist_lib.load_mnist(tfds.Split.TEST,
batch_size=FLAGS.serving_batch_size)
batch_size=_SERVING_BATCH_SIZE.value)
images_and_labels = tfds.as_numpy(test_ds.take(
FLAGS.count_images // FLAGS.serving_batch_size))
_COUNT_IMAGES.value // _SERVING_BATCH_SIZE.value))
accurate_count = 0
for batch_idx, (images, labels) in enumerate(images_and_labels):
@ -117,7 +118,7 @@ def main(_):
labels_digit = np.argmax(labels, axis=1)
accurate_count += np.sum(labels_digit == predictions_digit)
running_accuracy = (
100. * accurate_count / (1 + batch_idx) / FLAGS.serving_batch_size)
100. * accurate_count / (1 + batch_idx) / _SERVING_BATCH_SIZE.value)
logging.info(
" predicted digits = %s labels %s. Running accuracy %.3f%%",
predictions_digit, labels_digit, running_accuracy)

View File

@ -33,15 +33,17 @@ import tensorflowjs as tfjs
import input_pipeline # type: ignore[import]
flags.DEFINE_integer("num_epochs", 5,
("Number of epochs to train for."))
flags.DEFINE_integer("num_classes", 100, "Number of classification classes.")
_NUM_EPOCHS = flags.DEFINE_integer(
"num_epochs", 5, "Number of epochs to train for."
)
_NUM_CLASSES = flags.DEFINE_integer(
"num_classes", 100, "Number of classification classes."
)
flags.register_validator("num_classes",
lambda value: value >= 1 and value <= 100,
message="--num_classes must be in range [1, 100]")
FLAGS = flags.FLAGS
# The code below is an adaptation for Flax from the work published here:
# https://blog.tensorflow.org/2018/07/train-model-in-tfkeras-with-colab-and-run-in-browser-tensorflowjs.html
@ -65,7 +67,7 @@ class QuickDraw(nn.Module):
x = nn.Dense(features=128)(x)
x = nn.relu(x)
x = nn.Dense(features=FLAGS.num_classes)(x)
x = nn.Dense(features=_NUM_CLASSES.value)(x)
return x
@ -75,7 +77,7 @@ def apply_model(state, inputs, labels):
"""Computes gradients, loss and accuracy for a single batch."""
def loss_fn(params):
logits = state.apply_fn({'params': params}, inputs)
one_hot = jax.nn.one_hot(labels, FLAGS.num_classes)
one_hot = jax.nn.one_hot(labels, _NUM_CLASSES.value)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
@ -113,7 +115,7 @@ def create_train_state(rng):
def train(state, train_ds, test_ds):
for epoch in range(1, FLAGS.num_epochs+1):
for epoch in range(1, _NUM_EPOCHS.value+1):
start_time = time.time()
state, train_loss, train_accuracy = run_epoch(state, train_ds)
@ -136,12 +138,12 @@ def main(argv):
base_model_path = "/tmp/jax2tf/tf_js_quickdraw"
dataset_path = os.path.join(base_model_path, "data")
classes = input_pipeline.download_dataset(dataset_path, FLAGS.num_classes)
assert len(classes) == FLAGS.num_classes, "Incorrect number of classes"
classes = input_pipeline.download_dataset(dataset_path, _NUM_CLASSES.value)
assert len(classes) == _NUM_CLASSES.value, "Incorrect number of classes"
print(f"Classes are: {classes}")
print("Loading dataset into memory...")
train_ds, test_ds = input_pipeline.get_datasets(dataset_path, classes)
print(f"Starting training for {FLAGS.num_epochs} epochs...")
print(f"Starting training for {_NUM_EPOCHS.value} epochs...")
state = create_train_state(jax.random.PRNGKey(0))
state = train(state, train_ds, test_ds)

View File

@ -24,14 +24,19 @@ import numpy as np
import tensorflow as tf # type: ignore[import]
import tensorflow_datasets as tfds # type: ignore[import]
flags.DEFINE_string('tflite_file_path',
'/usr/local/google/home/qiuminxu/jax2tf/mnist.tflite',
'Path where to save the TensorFlow Lite file.')
flags.DEFINE_integer('serving_batch_size', 4,
('For what batch size to prepare the serving signature. '))
flags.DEFINE_integer('num_epochs', 10, 'For how many epochs to train.')
FLAGS = flags.FLAGS
_TFLITE_FILE_PATH = flags.DEFINE_string(
'tflite_file_path',
'/tmp/mnist.tflite',
'Path where to save the TensorFlow Lite file.',
)
_SERVING_BATCH_SIZE = flags.DEFINE_integer(
'serving_batch_size',
4,
'For what batch size to prepare the serving signature. ',
)
_NUM_EPOCHS = flags.DEFINE_integer(
'num_epochs', 10, 'For how many epochs to train.'
)
# A helper function to evaluate the TF Lite model using "test" dataset.
@ -71,10 +76,11 @@ def main(_):
train_ds = mnist_lib.load_mnist(
tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size)
test_ds = mnist_lib.load_mnist(
tfds.Split.TEST, batch_size=FLAGS.serving_batch_size)
tfds.Split.TEST, batch_size=_SERVING_BATCH_SIZE)
(flax_predict,
flax_params) = mnist_lib.FlaxMNIST.train(train_ds, test_ds, FLAGS.num_epochs)
(flax_predict, flax_params) = mnist_lib.FlaxMNIST.train(
train_ds, test_ds, _NUM_EPOCHS.value
)
def predict(image):
return flax_predict(flax_params, image)
@ -84,7 +90,7 @@ def main(_):
jax2tf.convert(predict, enable_xla=False),
input_signature=[
tf.TensorSpec(
shape=[FLAGS.serving_batch_size, 28, 28, 1],
shape=[_SERVING_BATCH_SIZE, 28, 28, 1],
dtype=tf.float32,
name='input')
],
@ -126,7 +132,7 @@ def main(_):
print('Quantized model accuracy = %.4f' % quantized_accuracy)
print('Accuracy drop = %.4f' % (float_accuracy - quantized_accuracy))
f = open(FLAGS.tflite_file_path, 'wb')
f = open(_TFLITE_FILE_PATH.value, 'wb')
f.write(tflite_quantized_model)
f.close()

View File

@ -58,34 +58,32 @@ from jax.experimental.jax2tf.shape_poly import InconclusiveDimensionOperation
from jax.experimental.jax2tf.tests.model_harness import ALL_HARNESSES
from jax.experimental.jax2tf.tests.converters import ALL_CONVERTERS
flags.DEFINE_list("converters", [x.name for x in ALL_CONVERTERS],
_CONVERTERS = flags.DEFINE_list("converters", [x.name for x in ALL_CONVERTERS],
"Which converters to test.")
flags.DEFINE_list("examples", [],
_EXAMPLES = flags.DEFINE_list("examples", [],
("List of examples to test, e.g.: 'flax/mnist,flax/seq2seq'. "
"If empty, will test all examples."))
flags.DEFINE_string("example_prefix", "",
_EXAMPLE_PREFIX = flags.DEFINE_string("example_prefix", "",
("Prefix for filtering tests. For instance 'flax/mnist' "
"will test all examples starting with 'flax/mnist' "
"(including all polymorphic tests)."))
flags.DEFINE_bool(
_WRITE_MARKDOWN = flags.DEFINE_bool(
"write_markdown", True,
"If true, write results as Markdown. Otherwise, only output to stdout.")
flags.DEFINE_bool(
_FAIL_ON_ERROR = flags.DEFINE_bool(
"fail_on_error", False,
("If true, exit with an error when a conversion fails. Useful for "
"debugging because it will show the entire stack trace."))
FLAGS = flags.FLAGS
def _write_markdown(results: dict[str, list[tuple[str, str,]]]) -> None:
"""Writes all results to Markdown file."""
table_lines = []
converters = FLAGS.converters
converters = _CONVERTERS.value
table_lines.append("| Example | " + " ".join([f"{c} |" for c in converters]))
table_lines.append("|" + (" --- |" * (len(converters) + 1)))
@ -173,7 +171,7 @@ def test_converters():
exit()
def _maybe_reraise(e):
if FLAGS.fail_on_error:
if _FAIL_ON_ERROR.value:
raise e
def _format(e):
@ -183,13 +181,13 @@ def test_converters():
return msg
converters = list(
filter(lambda x: x.name in FLAGS.converters, ALL_CONVERTERS))
filter(lambda x: x.name in _CONVERTERS.value, ALL_CONVERTERS))
_exit_if_empty(converters, "converters")
harnesses_to_test = {
name: fn for name, fn in ALL_HARNESSES.items()
if (not FLAGS.examples or name in FLAGS.examples) and
(not FLAGS.example_prefix or name.startswith(FLAGS.example_prefix))
if (not _EXAMPLES.value or name in _EXAMPLES.value) and
(not _EXAMPLE_PREFIX.value or name.startswith(_EXAMPLE_PREFIX.value))
}
_exit_if_empty(harnesses_to_test, "harness")
@ -243,7 +241,7 @@ def test_converters():
converter_results.append((converter.name, error_msg))
results[harness.name] = converter_results
if FLAGS.write_markdown:
if _WRITE_MARKDOWN:
_write_markdown(results)
else:
print("=== NOT writing results to Markdown.")

View File

@ -140,7 +140,6 @@ def jax_to_ir(name, deps, fn, input_shapes, constants = None, format = "HLO"):
from absl import app
import jax.tools.jax_to_ir as jax_to_ir
jax_to_ir.set_up_flags()
app.run(jax_to_ir.main)
EOF
""".format(runner = runner),

View File

@ -87,7 +87,30 @@ try:
except ImportError:
tf = None # type: ignore
FLAGS = flags.FLAGS
_FN = flags.DEFINE_string(
'fn', None, "Fully-qualified name of function that we're going to convert"
)
_INPUT_SHAPES = flags.DEFINE_string(
'input_shapes', None, 'Python dict indicating XLA shapes of params'
)
_CONSTANTS = flags.DEFINE_string(
'constants', '{}', 'Python dict giving constant values for some params'
)
_EVALED_CONSTANTS = flags.DEFINE_string(
'evaled_constants',
'{}',
'Python dict giving constant values for some params. '
'Values in this dict that are of type str are evaluated '
'using ast.literal_eval.',
)
_IR_FORMAT = flags.DEFINE_enum(
'ir_format', 'HLO', ('HLO', 'TF'), 'Output format.'
)
_IR_DEST = flags.DEFINE_string('ir_dest', None, 'File to write IR to')
_IR_HUMAN_DEST = flags.DEFINE_string(
'ir_human_dest', None, 'File to write human readable debug output'
)
def jax_to_ir(fn, input_shapes, *, constants=None, format):
@ -163,25 +186,25 @@ def main(argv):
if len(argv) != 1:
raise app.UsageError('No positional arguments are accepted.')
if not FLAGS.ir_dest and not FLAGS.ir_human_dest:
if not _IR_DEST.value and not _IR_HUMAN_DEST.value:
raise app.Error('At least one of --ir_dest and '
'--ir_human_dest is required.')
module_name, fn_name = FLAGS.fn.rsplit('.', 1)
module_name, fn_name = _FN.value.rsplit('.', 1)
module = importlib.import_module(module_name)
fn = getattr(module, fn_name)
input_shapes = [(name, parse_shape_str(shape_str))
for name, shape_str in literal_eval(FLAGS.input_shapes)]
for name, shape_str in literal_eval(_INPUT_SHAPES.value)]
# Parse --constants and --evaled_constants.
constants = {}
for k, v in literal_eval(FLAGS.constants).items():
for k, v in literal_eval(_CONSTANTS.value).items():
if isinstance(v, list):
v = jnp.asarray(v)
constants[k] = v
for k, v in literal_eval(FLAGS.evaled_constants).items():
for k, v in literal_eval(_EVALED_CONSTANTS.value).items():
if isinstance(v, str):
v = literal_eval(v)
if isinstance(v, list):
@ -192,14 +215,14 @@ def main(argv):
constants[k] = v
ir, debug_ir = jax_to_ir(fn, input_shapes, constants=constants,
format=FLAGS.ir_format)
format=_IR_FORMAT.value)
if FLAGS.ir_dest:
with open(FLAGS.ir_dest, 'wb') as f:
if _IR_DEST.value:
with open(_IR_DEST.value, 'wb') as f:
f.write(ir)
if FLAGS.ir_human_dest:
with open(FLAGS.ir_human_dest, 'w') as f:
if _IR_HUMAN_DEST.value:
with open(_IR_HUMAN_DEST.value, 'w') as f:
f.write(debug_ir)
@ -225,21 +248,6 @@ _SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$")
def set_up_flags():
flags.DEFINE_string(
'fn', None,
"Fully-qualified name of function that we're going to convert")
flags.DEFINE_string('input_shapes', None,
'Python dict indicating XLA shapes of params')
flags.DEFINE_string('constants', '{}',
'Python dict giving constant values for some params')
flags.DEFINE_string('evaled_constants', '{}',
'Python dict giving constant values for some params. '
'Values in this dict that are of type str are evaluated '
'using ast.literal_eval.')
flags.DEFINE_enum('ir_format', 'HLO', ('HLO', 'TF'), 'Output format.')
flags.DEFINE_string('ir_dest', None, 'File to write IR to')
flags.DEFINE_string('ir_human_dest', None,
'File to write human readable debug output')
flags.mark_flag_as_required('fn')
flags.mark_flag_as_required('input_shapes')

View File

@ -22,7 +22,6 @@ import jax
import jax.numpy as jnp
from jax._src import core
from jax._src import test_util as jtu
from jax._src.config import flags
from jax.experimental.pjit import pjit
from jax.experimental.serialize_executable import (
serialize, deserialize_and_load)
@ -73,7 +72,7 @@ class JaxAotTest(jtu.JaxTestCase):
except NotImplementedError:
raise unittest.SkipTest('PJRT Topology not supported')
if flags.FLAGS.jax_test_with_persistent_compilation_cache:
if jtu.TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
raise unittest.SkipTest('Compilation caching not yet supported.')
@jax.jit