From 76cda0ae0732b8795e7aec22d575dd40b89fdf5f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 27 Jul 2023 12:15:16 -0700 Subject: [PATCH] Update flags to use the ABSL typed flag API. Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary. For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API. Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`. This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR. PiperOrigin-RevId: 551604974 --- benchmarks/math_benchmark.py | 7 +- examples/differentially_private_sgd.py | 47 +++++----- examples/gaussian_process_regression.py | 3 - jax/_src/config.py | 92 +++++++++---------- jax/_src/core.py | 11 ++- jax/_src/dispatch.py | 11 +-- jax/_src/dtypes.py | 4 +- jax/_src/pretty_printer.py | 11 ++- jax/_src/public_test_util.py | 11 ++- jax/_src/test_util.py | 43 ++++----- jax/_src/xla_bridge.py | 54 ++++++----- jax/experimental/host_callback.py | 63 +++++++++---- .../jax2tf/examples/keras_reuse_main.py | 2 +- jax/experimental/jax2tf/examples/mnist_lib.py | 6 +- .../jax2tf/examples/saved_model_main.py | 71 +++++++------- .../examples/serving/model_server_request.py | 47 +++++----- .../examples/tf_js/quickdraw/quickdraw.py | 22 +++-- .../jax2tf/examples/tflite/mnist/mnist.py | 32 ++++--- .../jax2tf/tests/models_test_main.py | 24 +++-- jax/tools/build_defs.bzl | 1 - jax/tools/jax_to_ir.py | 60 ++++++------ tests/aot_test.py | 3 +- 22 files changed, 336 insertions(+), 289 deletions(-) diff --git a/benchmarks/math_benchmark.py b/benchmarks/math_benchmark.py index 74674c159..52f449ae9 100644 --- a/benchmarks/math_benchmark.py +++ b/benchmarks/math_benchmark.py @@ -27,8 +27,7 @@ from absl import app from absl import flags -FLAGS = flags.FLAGS -flags.DEFINE_multi_string( +_SET_ENV = flags.DEFINE_multi_string( "set_env", None, "Specifies additional environment variables to be injected into the " "environment (via --set_env=variable=value or --set_env=variable). " @@ -140,8 +139,8 @@ def jax_binary_op(state, **kwargs): ) def main(argv): - if FLAGS.set_env: - for env_str in FLAGS.set_env: + if _SET_ENV.value: + for env_str in _SET_ENV.value: # Stop matching at the first '=' since we want to capture # --set_env='FOO=--foo_a=1 --foo_b=2' all as part of FOO. env_list = env_str.split('=', 1) diff --git a/examples/differentially_private_sgd.py b/examples/differentially_private_sgd.py index 45be4b2bc..4777554b1 100644 --- a/examples/differentially_private_sgd.py +++ b/examples/differentially_private_sgd.py @@ -83,22 +83,21 @@ import numpy.random as npr from dp_accounting import dp_event from dp_accounting import rdp -FLAGS = flags.FLAGS -flags.DEFINE_boolean( +_DPSGD = flags.DEFINE_boolean( 'dpsgd', True, 'If True, train with DP-SGD. If False, ' 'train with vanilla SGD.') -flags.DEFINE_float('learning_rate', .15, 'Learning rate for training') -flags.DEFINE_float('noise_multiplier', 1.1, +_LEARNING_RATE = flags.DEFINE_float('learning_rate', .15, 'Learning rate for training') +_NOISE_MULTIPLIER = flags.DEFINE_float('noise_multiplier', 1.1, 'Ratio of the standard deviation to the clipping norm') -flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm') -flags.DEFINE_integer('batch_size', 256, 'Batch size') -flags.DEFINE_integer('epochs', 60, 'Number of epochs') -flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG') -flags.DEFINE_integer( +_L2_NORM_CLIP = flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm') +_BATCH_SIZE = flags.DEFINE_integer('batch_size', 256, 'Batch size') +_EPOCHS = flags.DEFINE_integer('epochs', 60, 'Number of epochs') +_SEED = flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG') +_MICROBATCHES = flags.DEFINE_integer( 'microbatches', None, 'Number of microbatches ' '(must evenly divide batch_size)') -flags.DEFINE_string('model_dir', None, 'Model directory') +_MODEL_DIR = flags.DEFINE_string('model_dir', None, 'Model directory') init_random_params, predict = stax.serial( @@ -163,39 +162,39 @@ def shape_as_image(images, labels, dummy_dim=False): def compute_epsilon(steps, num_examples=60000, target_delta=1e-5): if num_examples * target_delta > 1.: warnings.warn('Your delta might be too high.') - q = FLAGS.batch_size / float(num_examples) + q = _BATCH_SIZE.value / float(num_examples) orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64)) accountant = rdp.rdp_privacy_accountant.RdpAccountant(orders) accountant.compose( dp_event.PoissonSampledDpEvent( - q, dp_event.GaussianDpEvent(FLAGS.noise_multiplier)), steps) + q, dp_event.GaussianDpEvent(_NOISE_MULTIPLIER.value)), steps) return accountant.get_epsilon(target_delta) def main(_): - if FLAGS.microbatches: + if _MICROBATCHES.value: raise NotImplementedError( 'Microbatches < batch size not currently supported' ) train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] - num_complete_batches, leftover = divmod(num_train, FLAGS.batch_size) + num_complete_batches, leftover = divmod(num_train, _BATCH_SIZE.value) num_batches = num_complete_batches + bool(leftover) - key = random.PRNGKey(FLAGS.seed) + key = random.PRNGKey(_SEED.value) def data_stream(): - rng = npr.RandomState(FLAGS.seed) + rng = npr.RandomState(_SEED.value) while True: perm = rng.permutation(num_train) for i in range(num_batches): - batch_idx = perm[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] + batch_idx = perm[i * _BATCH_SIZE.value:(i + 1) * _BATCH_SIZE.value] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() - opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) + opt_init, opt_update, get_params = optimizers.sgd(_LEARNING_RATE.value) @jit def update(_, i, opt_state, batch): @@ -208,19 +207,19 @@ def main(_): rng = random.fold_in(rng, i) # get new key for new random numbers return opt_update( i, - private_grad(params, batch, rng, FLAGS.l2_norm_clip, - FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) + private_grad(params, batch, rng, _L2_NORM_CLIP.value, + _NOISE_MULTIPLIER.value, _BATCH_SIZE.value), opt_state) _, init_params = init_random_params(key, (-1, 28, 28, 1)) opt_state = opt_init(init_params) itercount = itertools.count() - steps_per_epoch = 60000 // FLAGS.batch_size + steps_per_epoch = 60000 // _BATCH_SIZE.value print('\nStarting training...') - for epoch in range(1, FLAGS.epochs + 1): + for epoch in range(1, _EPOCHS.value + 1): start_time = time.time() for _ in range(num_batches): - if FLAGS.dpsgd: + if _DPSGD.value: opt_state = \ private_update( key, next(itercount), opt_state, @@ -239,7 +238,7 @@ def main(_): test_loss, 100 * test_acc)) # determine privacy loss so far - if FLAGS.dpsgd: + if _DPSGD.value: delta = 1e-5 num_examples = 60000 eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta) diff --git a/examples/gaussian_process_regression.py b/examples/gaussian_process_regression.py index e819f6331..35afcca6f 100644 --- a/examples/gaussian_process_regression.py +++ b/examples/gaussian_process_regression.py @@ -16,7 +16,6 @@ """ from absl import app -from absl import flags from functools import partial from jax import grad from jax import jit @@ -27,8 +26,6 @@ import jax.random as random import jax.scipy as scipy import matplotlib.pyplot as plt -FLAGS = flags.FLAGS - def main(unused_argv): diff --git a/jax/_src/config.py b/jax/_src/config.py index 99d3aca08..fa42727e5 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from collections.abc import Hashable, Iterator import contextlib import functools @@ -20,7 +22,7 @@ import logging import os import sys import threading -from typing import Any, Callable, NamedTuple, Optional +from typing import Any, Callable, Generic, NamedTuple, Optional, TypeVar from jax._src import lib from jax._src.lib import jax_jit @@ -29,6 +31,8 @@ from jax._src.lib import xla_client logger = logging.getLogger(__name__) +_T = TypeVar('_T') + def bool_env(varname: str, default: bool) -> bool: """Read an environment variable and interpret it as a boolean. @@ -64,6 +68,16 @@ UPGRADE_BOOL_HELP = ( UPGRADE_BOOL_EXTRA_DESC = " (transient)" +class FlagHolder(Generic[_T]): + def __init__(self, flags: NameSpace, name: str): + self._flags = flags + self._name = name + + @property + def value(self) -> _T: + return getattr(self._flags, self._name) + + class Config: _HAS_DYNAMIC_ATTRIBUTES = True @@ -112,26 +126,31 @@ class Config: if name not in self.values: raise AttributeError(f"Unrecognized config option: {name}") - def DEFINE_bool(self, name, default, *args, **kwargs): + def DEFINE_bool(self, name, default, *args, **kwargs) -> FlagHolder[bool]: update_hook = kwargs.pop("update_hook", None) self.add_option(name, default, bool, args, kwargs, update_hook=update_hook) + return FlagHolder(self.FLAGS, name) - def DEFINE_integer(self, name, default, *args, **kwargs): + def DEFINE_integer(self, name, default, *args, **kwargs) -> FlagHolder[int]: update_hook = kwargs.pop("update_hook", None) self.add_option(name, default, int, args, kwargs, update_hook=update_hook) + return FlagHolder(self.FLAGS, name) - def DEFINE_float(self, name, default, *args, **kwargs): + def DEFINE_float(self, name, default, *args, **kwargs) -> FlagHolder[float]: update_hook = kwargs.pop("update_hook", None) self.add_option(name, default, float, args, kwargs, update_hook=update_hook) + return FlagHolder(self.FLAGS, name) - def DEFINE_string(self, name, default, *args, **kwargs): + def DEFINE_string(self, name, default, *args, **kwargs) -> FlagHolder[str]: update_hook = kwargs.pop("update_hook", None) self.add_option(name, default, str, args, kwargs, update_hook=update_hook) + return FlagHolder(self.FLAGS, name) - def DEFINE_enum(self, name, default, *args, **kwargs): + def DEFINE_enum(self, name, default, *args, **kwargs) -> FlagHolder[str]: update_hook = kwargs.pop("update_hook", None) self.add_option(name, default, 'enum', args, kwargs, update_hook=update_hook) + return FlagHolder(self.FLAGS, name) def config_with_absl(self): # Run this before calling `app.run(main)` etc @@ -551,6 +570,22 @@ config = Config() flags = config FLAGS = flags.FLAGS +def DEFINE_bool(name, default, *args, **kwargs): + return flags.DEFINE_bool(name, default, *args, **kwargs) + +def DEFINE_integer(name, default, *args, **kwargs): + return flags.DEFINE_integer(name, default, *args, **kwargs) + +def DEFINE_float(name, default, *args, **kwargs): + return flags.DEFINE_float(name, default, *args, **kwargs) + +def DEFINE_string(name, default, *args, **kwargs): + return flags.DEFINE_string(name, default, *args, **kwargs) + +def DEFINE_enum(name, default, *args, **kwargs): + return flags.DEFINE_enum(name, default, *args, **kwargs) + + already_configured_with_absl = False @@ -617,51 +652,6 @@ def update_thread_local_jit_state(**kw): tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp) -flags.DEFINE_integer( - 'jax_tracer_error_num_traceback_frames', - int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5), - help='Set the number of stack frames in JAX tracer error messages.' -) - -flags.DEFINE_bool( - 'jax_pprint_use_color', - bool_env('JAX_PPRINT_USE_COLOR', True), - help='Enable jaxpr pretty-printing with colorful syntax highlighting.' -) - -flags.DEFINE_bool( - 'jax_host_callback_inline', - bool_env('JAX_HOST_CALLBACK_INLINE', False), - help='Inline the host_callback, if not in a staged context.' -) -flags.DEFINE_integer( - 'jax_host_callback_max_queue_byte_size', - int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)), - help=('The size in bytes of the buffer used to hold outfeeds from each ' - 'device. When this capacity is reached consuming outfeeds from the ' - 'device is paused, thus potentially pausing the device computation, ' - 'until the Python callback consume more outfeeds.'), - lower_bound=int(16 * 1e6) -) -flags.DEFINE_bool( - 'jax_host_callback_outfeed', - bool_env('JAX_HOST_CALLBACK_OUTFEED', False), - help=( - 'Use outfeed implementation for host_callback, even on CPU and GPU. ' - 'If false, use the CustomCall implementation. ' - 'Has no effect on TPU, since only the outfeed mechanism is implemented.' - ) -) -flags.DEFINE_bool( - 'jax_host_callback_ad_transforms', - bool_env('JAX_HOST_CALLBACK_AD_TRANSFORMS', False), - help=( - 'Enable support for jvp/vjp for the host_callback primitives. Default is ' - 'False, which means that host_callback operates only on primals. ' - 'The flag exists only temporarily, for backward compatibility.' - ) -) - # TODO(b/214340779): remove flag when XLA:CPU is improved. jax2tf_associative_scan_reductions = config.define_bool_state( name='jax2tf_associative_scan_reductions', diff --git a/jax/_src/core.py b/jax/_src/core.py index badb0d619..e75384851 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -38,7 +38,7 @@ import numpy as np from jax._src import dtypes from jax._src import config as jax_config from jax._src import effects -from jax._src.config import FLAGS, config +from jax._src.config import config from jax._src.errors import ( ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError, TracerIntegerConversionError, UnexpectedTracerError) @@ -60,6 +60,13 @@ zip, unsafe_zip = safe_zip, zip map, unsafe_map = safe_map, map +_TRACER_ERROR_NUM_TRACEBACK_FRAMES = jax_config.DEFINE_integer( + 'jax_tracer_error_num_traceback_frames', + jax_config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5), + help='Set the number of stack frames in JAX tracer error messages.' +) + + # -------------------- jaxprs -------------------- Effect = effects.Effect @@ -560,7 +567,7 @@ def raise_as_much_as_possible(tracer) -> Tracer: def escaped_tracer_error(tracer, detail=None): - num_frames = FLAGS.jax_tracer_error_num_traceback_frames + num_frames = _TRACER_ERROR_NUM_TRACEBACK_FRAMES.value msg = ('Encountered an unexpected tracer. A function transformed by JAX ' 'had a side effect, allowing for a reference to an intermediate value ' f'with type {tracer.aval.str_short()} wrapped in a ' diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 474cb052a..c6564c157 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -32,6 +32,7 @@ import warnings import numpy as np from jax._src import compilation_cache +from jax._src import config as jax_config from jax._src import core from jax._src import dtypes from jax._src import linear_util as lu @@ -43,7 +44,7 @@ from jax._src import traceback_util from jax._src import util from jax._src import op_shardings from jax._src import xla_bridge as xb -from jax._src.config import config, flags +from jax._src.config import config from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -64,9 +65,7 @@ JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration" JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration" BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration" -FLAGS = flags.FLAGS - -flags.DEFINE_string( +_DUMP_IR_TO = jax_config.DEFINE_string( 'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''), help="Path to which the IR that is emitted by JAX as input to the " "compiler should be dumped as text files. Optional. If omitted, JAX " @@ -472,7 +471,7 @@ def _make_string_safe_for_filename(s: str) -> str: def _dump_ir_to_file(name: str, ir: str): id = next(_ir_dump_counter) name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir" - name = path.Path(FLAGS.jax_dump_ir_to) / name + name = path.Path(_DUMP_IR_TO.value) / name name.write_text(ir) @@ -481,7 +480,7 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray, sym_name = computation.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value - if FLAGS.jax_dump_ir_to: + if _DUMP_IR_TO.value: _dump_ir_to_file(module_name, mlir.module_to_string(computation)) # Persistent compilation cache only implemented on TPU and GPU. diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index f69b7ab92..6be56e0ba 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -28,7 +28,7 @@ import warnings import ml_dtypes import numpy as np -from jax._src.config import flags, config +from jax._src.config import config from jax._src.typing import DType, DTypeLike from jax._src import traceback_util @@ -43,8 +43,6 @@ else: raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; " f"installed version is {ml_dtypes.__version__}.") -FLAGS = flags.FLAGS - class extended(np.generic): """Scalar class for extended dtypes. diff --git a/jax/_src/pretty_printer.py b/jax/_src/pretty_printer.py index 7d9b3065d..7a1d84053 100644 --- a/jax/_src/pretty_printer.py +++ b/jax/_src/pretty_printer.py @@ -32,13 +32,20 @@ from functools import partial import sys from typing import NamedTuple, Optional, Union -from jax._src.config import config +from jax._src import config try: import colorama # pytype: disable=import-error except ImportError: colorama = None + +_PPRINT_USE_COLOR = config.DEFINE_bool( + 'jax_pprint_use_color', + config.bool_env('JAX_PPRINT_USE_COLOR', True), + help='Enable jaxpr pretty-printing with colorful syntax highlighting.' +) + def _can_use_color() -> bool: try: # Check if we're in IPython or Colab @@ -63,7 +70,7 @@ class Doc(abc.ABC): def format(self, width: int = 80, use_color: Optional[bool] = None, annotation_prefix=" # ") -> str: if use_color is None: - use_color = CAN_USE_COLOR and config.FLAGS.jax_pprint_use_color + use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value return _format(self, width, use_color=use_color, annotation_prefix=annotation_prefix) diff --git a/jax/_src/public_test_util.py b/jax/_src/public_test_util.py index c0d6e3754..ad9367cd9 100644 --- a/jax/_src/public_test_util.py +++ b/jax/_src/public_test_util.py @@ -16,9 +16,10 @@ from functools import partial import operator from jax._src import api +from jax._src import config as jax_config from jax._src import dtypes as _dtypes from jax._src import xla_bridge -from jax._src.config import config, flags +from jax._src.config import config from jax._src.tree_util import tree_map, tree_reduce import numpy as np @@ -30,7 +31,11 @@ import numpy as np __all__ = ['check_grads', 'check_jvp', 'check_vjp'] -FLAGS = flags.FLAGS +_TEST_DUT = jax_config.DEFINE_string( + 'jax_test_dut', '', + help= + 'Describes the device under test in case special consideration is required.' +) EPS = 1e-4 @@ -292,4 +297,4 @@ def check_grads(f, args, order, def device_under_test(): - return getattr(FLAGS, 'jax_test_dut', None) or xla_bridge.get_backend().platform + return _TEST_DUT.value or xla_bridge.get_backend().platform diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 08369c08d..5986b021e 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -41,11 +41,12 @@ from jax._src.interpreters import mlir from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten from jax._src import api from jax._src import pjit as pjit_lib +from jax._src import config as jax_config from jax._src import core from jax._src import dispatch from jax._src import dtypes as _dtypes from jax._src.interpreters import pxla -from jax._src.config import (flags, bool_env, config, +from jax._src.config import (bool_env, config, raise_persistent_cache_errors, persistent_cache_min_compile_time_secs) from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact @@ -60,19 +61,12 @@ from jax._src import xla_bridge # jax.test_util. Functionality appearing here is for internal use only, and # may be changed or removed at any time and without any deprecation cycle. -FLAGS = flags.FLAGS -flags.DEFINE_string( - 'jax_test_dut', '', - help= - 'Describes the device under test in case special consideration is required.' -) - -flags.DEFINE_integer( +_NUM_GENERATED_CASES = jax_config.DEFINE_integer( 'jax_num_generated_cases', int(os.getenv('JAX_NUM_GENERATED_CASES', '10')), help='Number of generated cases to test') -flags.DEFINE_integer( +_MAX_CASES_SAMPLING_RETRIES = jax_config.DEFINE_integer( 'max_cases_sampling_retries', int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')), 'Number of times a failed test sample should be retried. ' @@ -80,24 +74,23 @@ flags.DEFINE_integer( 'sampling process is terminated.' ) -flags.DEFINE_bool( +_SKIP_SLOW_TESTS = jax_config.DEFINE_bool( 'jax_skip_slow_tests', bool_env('JAX_SKIP_SLOW_TESTS', False), help='Skip tests marked as slow (> 5 sec).' ) -flags.DEFINE_string( +_TEST_TARGETS = jax_config.DEFINE_string( 'test_targets', os.getenv('JAX_TEST_TARGETS', ''), 'Regular expression specifying which tests to run, called via re.search on ' 'the test name. If empty or unspecified, run all tests.' ) -flags.DEFINE_string( +_EXCLUDE_TEST_TARGETS = jax_config.DEFINE_string( 'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''), 'Regular expression specifying which tests NOT to run, called via re.search ' 'on the test name. If empty or unspecified, run all tests.' ) - -flags.DEFINE_bool( +TEST_WITH_PERSISTENT_COMPILATION_CACHE = jax_config.DEFINE_bool( 'jax_test_with_persistent_compilation_cache', bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False), help='If enabled, the persistent compilation cache will be enabled for all ' @@ -731,7 +724,7 @@ def assert_dot_precision(expected_precision, fun, *args): def cases_from_gens(*gens): sizes = [1, 3, 10] - cases_per_size = int(FLAGS.jax_num_generated_cases / len(sizes)) + 1 + cases_per_size = int(_NUM_GENERATED_CASES.value / len(sizes)) + 1 for size in sizes: for i in range(cases_per_size): yield (f'_{size}_{i}',) + tuple(gen(size) for gen in gens) @@ -744,8 +737,8 @@ def named_cases_from_sampler(gen): if not isinstance(x, (list, tuple)): x = list(x) return [x[rng.randint(len(x))]] - while (len(seen) < FLAGS.jax_num_generated_cases and - retries < FLAGS.max_cases_sampling_retries): + while (len(seen) < _NUM_GENERATED_CASES.value and + retries < _MAX_CASES_SAMPLING_RETRIES.value): retries += 1 cases = list(gen(choose_one)) if not cases: @@ -773,7 +766,7 @@ def sample_product_testcases(*args, **kw): kw = [(k, list(v)) for k, v in kw.items()] n = math.prod(len(a) for a in args) * math.prod(len(v) for _, v in kw) testcases = [] - for i in _choice(n, min(n, FLAGS.jax_num_generated_cases)): + for i in _choice(n, min(n, _NUM_GENERATED_CASES.value)): testcase = {} for a in args: testcase.update(a[i % len(a)]) @@ -804,12 +797,12 @@ def sample_product(*args, **kw): class JaxTestLoader(absltest.TestLoader): def getTestCaseNames(self, testCaseClass): names = super().getTestCaseNames(testCaseClass) - if FLAGS.test_targets: - pattern = re.compile(FLAGS.test_targets) + if _TEST_TARGETS.value: + pattern = re.compile(_TEST_TARGETS.value) names = [name for name in names if pattern.search(f"{testCaseClass.__name__}.{name}")] - if FLAGS.exclude_test_targets: - pattern = re.compile(FLAGS.exclude_test_targets) + if _EXCLUDE_TEST_TARGETS.value: + pattern = re.compile(_EXCLUDE_TEST_TARGETS.value) names = [name for name in names if not pattern.search(f"{testCaseClass.__name__}.{name}")] return names @@ -874,7 +867,7 @@ class JaxTestCase(parameterized.TestCase): @classmethod def setUpClass(cls): - if FLAGS.jax_test_with_persistent_compilation_cache: + if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: cls._compilation_cache_exit_stack = ExitStack() stack = cls._compilation_cache_exit_stack stack.enter_context(raise_persistent_cache_errors(True)) @@ -887,7 +880,7 @@ class JaxTestCase(parameterized.TestCase): @classmethod def tearDownClass(cls): - if FLAGS.jax_test_with_persistent_compilation_cache: + if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: cls._compilation_cache_exit_stack.close() def rng(self): diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 0ccca28cb..0cdb6c20b 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -37,7 +37,8 @@ import numpy as np from jax._src import lib from jax._src import distributed -from jax._src.config import flags, bool_env, config, int_env +from jax._src import config as jax_config +from jax._src.config import bool_env, config, int_env from jax._src.lib import xla_client from jax._src.lib import xla_extension_version from jax._src import traceback_util @@ -59,37 +60,35 @@ traceback_util.register_exclusion(__file__) XlaBackend = xla_client.Client -FLAGS = flags.FLAGS - # TODO(phawkins): Remove jax_xla_backend. -flags.DEFINE_string( +_XLA_BACKEND = jax_config.DEFINE_string( 'jax_xla_backend', '', 'Deprecated, please use --jax_platforms instead.') -flags.DEFINE_string( +BACKEND_TARGET = jax_config.DEFINE_string( 'jax_backend_target', os.getenv('JAX_BACKEND_TARGET', '').lower(), 'Either "local" or "rpc:address" to connect to a remote service target.') # TODO(skye): warn when this is used once we test out --jax_platforms a bit -flags.DEFINE_string( +_PLATFORM_NAME = jax_config.DEFINE_string( 'jax_platform_name', os.getenv('JAX_PLATFORM_NAME', '').lower(), 'Deprecated, please use --jax_platforms instead.') -flags.DEFINE_bool( +_DISABLE_MOST_OPTIMIZATIONS = jax_config.DEFINE_bool( 'jax_disable_most_optimizations', bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False), 'Try not to do much optimization work. This can be useful if the cost of ' 'optimization is greater than that of running a less-optimized program.') -flags.DEFINE_integer( +_XLA_PROFILE_VERSION = jax_config.DEFINE_integer( 'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0), 'Optional profile version for XLA compilation. ' 'This is meaningful only when XLA is configured to ' 'support the remote compilation profile feature.') -flags.DEFINE_string( +CUDA_VISIBLE_DEVICES = jax_config.DEFINE_string( 'jax_cuda_visible_devices', 'all', 'Restricts the set of CUDA devices that JAX will use. Either "all", or a ' 'comma-separate list of integer device IDs.') -flags.DEFINE_string( +_ROCM_VISIBLE_DEVICES = jax_config.DEFINE_string( 'jax_rocm_visible_devices', 'all', 'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' 'comma-separate list of integer device IDs.') @@ -171,13 +170,13 @@ def get_compile_options( if lib.cuda_path is not None: debug_options.xla_gpu_cuda_data_dir = lib.cuda_path - if FLAGS.jax_disable_most_optimizations: + if _DISABLE_MOST_OPTIMIZATIONS.value: debug_options.xla_backend_optimization_level = 0 debug_options.xla_llvm_disable_expensive_passes = True debug_options.xla_test_all_input_layouts = False - compile_options.profile_version = FLAGS.jax_xla_profile_version + compile_options.profile_version = _XLA_PROFILE_VERSION.value return compile_options @@ -264,9 +263,9 @@ register_backend_factory('cpu', def make_gpu_client( - *, platform_name: str, visible_devices_flag: str + *, platform_name: str, visible_devices_flag: jax_config.FlagHolder[str] ) -> xla_client.Client: - visible_devices = getattr(FLAGS, visible_devices_flag, "all") + visible_devices = visible_devices_flag.value allowed_devices = None if visible_devices != "all": allowed_devices = {int(x) for x in visible_devices.split(",")} @@ -292,15 +291,25 @@ def make_gpu_client( if hasattr(xla_client, "make_gpu_client"): register_backend_factory( - 'cuda', partial(make_gpu_client, platform_name='cuda', - visible_devices_flag='jax_cuda_visible_devices'), + "cuda", + partial( + make_gpu_client, + platform_name="cuda", + visible_devices_flag=CUDA_VISIBLE_DEVICES, + ), priority=200, - fail_quietly=True) + fail_quietly=True, + ) register_backend_factory( - 'rocm', partial(make_gpu_client, platform_name='rocm', - visible_devices_flag='jax_rocm_visible_devices'), + "rocm", + partial( + make_gpu_client, + platform_name="rocm", + visible_devices_flag=_ROCM_VISIBLE_DEVICES, + ), priority=200, - fail_quietly=True) + fail_quietly=True, + ) if hasattr(xla_client, "make_tpu_client"): @@ -623,7 +632,7 @@ def backends() -> dict[str, xla_client.Client]: # support anything else there at the moment and warning would be pointless. if (py_platform.system() != "Darwin" and _default_backend.platform == "cpu" and - FLAGS.jax_platform_name != 'cpu'): + _PLATFORM_NAME.value != 'cpu'): logger.warning('No GPU/TPU found, falling back to CPU. ' '(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)') return _backends @@ -677,8 +686,7 @@ def _get_backend_uncached( if platform is not None and not isinstance(platform, str): return platform - platform = (platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name - or None) + platform = (platform or _XLA_BACKEND.value or _PLATFORM_NAME.value or None) bs = backends() if platform is not None: diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index c89d9b5d0..aa3c3ae93 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -507,7 +507,7 @@ import warnings from jax._src import api from jax._src import core -from jax import config +from jax._src import config from jax import custom_derivatives from jax._src import dtypes from jax import lax @@ -531,17 +531,46 @@ from jax._src.lib.mlir.dialects import hlo import numpy as np -FLAGS = config.FLAGS +_HOST_CALLBACK_INLINE = config.DEFINE_bool( + 'jax_host_callback_inline', + config.bool_env('JAX_HOST_CALLBACK_INLINE', False), + help='Inline the host_callback, if not in a staged context.' +) +_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.DEFINE_integer( + 'jax_host_callback_max_queue_byte_size', + config.int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)), + help=('The size in bytes of the buffer used to hold outfeeds from each ' + 'device. When this capacity is reached consuming outfeeds from the ' + 'device is paused, thus potentially pausing the device computation, ' + 'until the Python callback consume more outfeeds.'), + lower_bound=int(16 * 1e6) +) +_HOST_CALLBACK_OUTFEED = config.DEFINE_bool( + 'jax_host_callback_outfeed', + config.bool_env('JAX_HOST_CALLBACK_OUTFEED', False), + help=( + 'Use outfeed implementation for host_callback, even on CPU and GPU. ' + 'If false, use the CustomCall implementation. ' + 'Has no effect on TPU, since only the outfeed mechanism is implemented.' + ) +) +_HOST_CALLBACK_AD_TRANSFORMS = config.DEFINE_bool( + 'jax_host_callback_ad_transforms', + config.bool_env('JAX_HOST_CALLBACK_AD_TRANSFORMS', False), + help=( + 'Enable support for jvp/vjp for the host_callback primitives. Default is ' + 'False, which means that host_callback operates only on primals. ' + 'The flag exists only temporarily, for backward compatibility.' + ) +) + logger = logging.getLogger(__name__) -def _inline_host_callback() -> bool: - return FLAGS.jax_host_callback_inline - - def _use_outfeed(platform: str) -> bool: - return (platform in ("tpu", "gpu", "cuda", "rocm") or FLAGS.jax_host_callback_outfeed) + return (platform in ("tpu", "gpu", "cuda", "rocm") or + _HOST_CALLBACK_OUTFEED.value) def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend): @@ -620,7 +649,7 @@ def id_tap(tap_func, "pre-apply keyword arguments, either by using a closure or by passing " "``functools.partial(tap_func, **kwargs)``.") raise TypeError(msg) - if FLAGS.jax_host_callback_ad_transforms: + if _HOST_CALLBACK_AD_TRANSFORMS.value: warnings.warn('The flag jax_host_callback_ad_transforms is for temporary ' 'backwards compatibility mode. This flag, and the behavior ' 'it enabled will be removed soon.', @@ -642,7 +671,7 @@ def id_tap(tap_func, if result is not None: # Return the results, but add a dependency on the call, to ensure it # is kept in the graph. - if FLAGS.jax_host_callback_ad_transforms: + if _HOST_CALLBACK_AD_TRANSFORMS.value: call_flat_results, _ = tree_util.tree_flatten(call_res) if call_flat_results: call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0]) @@ -782,7 +811,7 @@ def _call(callback_func: Callable, identity=False): # Lazy initialization _initialize_outfeed_receiver( - max_callback_queue_size_bytes=FLAGS.jax_host_callback_max_queue_byte_size) + max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value) api.check_callable(callback_func) flat_args, arg_treedef = tree_util.tree_flatten(arg) for arg in flat_args: @@ -909,7 +938,7 @@ xla.register_translation(id_tap_dep_p, id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a) def _id_tap_dep_jvp_rule(primals, tangents): - if FLAGS.jax_host_callback_ad_transforms: + if _HOST_CALLBACK_AD_TRANSFORMS.value: assert False tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals)) return (id_tap_dep_p.bind(primals[0], primals[1]), @@ -918,7 +947,7 @@ def _id_tap_dep_jvp_rule(primals, tangents): ad.primitive_jvps[id_tap_dep_p] = _id_tap_dep_jvp_rule def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap): - if FLAGS.jax_host_callback_ad_transforms: + if _HOST_CALLBACK_AD_TRANSFORMS.value: assert False if ad.is_undefined_primal(arg_res): ct_res = _instantiate_zeros(cts, arg_res) @@ -934,7 +963,7 @@ ad.primitive_transposes[id_tap_dep_p] = _id_tap_dep_transpose_rule def _id_tap_dep_batching_rule(batched_args, batch_dims): - if FLAGS.jax_host_callback_ad_transforms: + if _HOST_CALLBACK_AD_TRANSFORMS.value: assert False arg_res, arg_tap = batched_args return id_tap_dep_p.bind(arg_res, arg_tap), batch_dims[0] @@ -1013,7 +1042,7 @@ outside_call_p.def_abstract_eval(_outside_call_abstract_eval) def _outside_call_impl(*args, **params): assert "has_token" not in params - if _inline_host_callback(): + if _HOST_CALLBACK_INLINE.value: device_index = params["device_index"] device = xb.devices()[device_index] results = _outside_call_run_callback(args, device, send_infeed=False, **params) @@ -1400,7 +1429,7 @@ def _outside_call_jvp_rule(primals, tangents, **params): assert "has_token" not in params if not params["identity"]: raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.") - if FLAGS.jax_host_callback_ad_transforms: + if _HOST_CALLBACK_AD_TRANSFORMS.value: tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals)) arg_treedef = params["arg_treedef"] @@ -1425,7 +1454,7 @@ ad.primitive_jvps[outside_call_p] = _outside_call_jvp_rule def _outside_call_partial_eval_rule(trace, *args, **params): # partial eval is used after jvp and before transpose. - if not FLAGS.jax_host_callback_ad_transforms: + if not _HOST_CALLBACK_AD_TRANSFORMS.value: # TODO: just remote the partial eval rule return trace.default_process_primitive(outside_call_p, args, params) transforms = params.get("transforms", ()) @@ -1492,7 +1521,7 @@ def _outside_call_transpose_rule(cts, *args, **params): *cts_instantiated, **_add_transform(params, "transpose")) - if not FLAGS.jax_host_callback_ad_transforms: + if not _HOST_CALLBACK_AD_TRANSFORMS.value: assert False assert len(args) % 2 == 0 diff --git a/jax/experimental/jax2tf/examples/keras_reuse_main.py b/jax/experimental/jax2tf/examples/keras_reuse_main.py index 71f94e135..1f8fbea5b 100644 --- a/jax/experimental/jax2tf/examples/keras_reuse_main.py +++ b/jax/experimental/jax2tf/examples/keras_reuse_main.py @@ -65,7 +65,7 @@ def main(_): tfds.Split.TEST, batch_size=mnist_lib.test_batch_size) keras_model.fit(train_ds, epochs=FLAGS.num_epochs, validation_data=test_ds) - if FLAGS.show_images: + if saved_model_main.SHOW_IMAGES.value: mnist_lib.plot_images( test_ds, 1, diff --git a/jax/experimental/jax2tf/examples/mnist_lib.py b/jax/experimental/jax2tf/examples/mnist_lib.py index 658cd7536..c11c5c4e3 100644 --- a/jax/experimental/jax2tf/examples/mnist_lib.py +++ b/jax/experimental/jax2tf/examples/mnist_lib.py @@ -38,8 +38,8 @@ import optax import tensorflow as tf # type: ignore import tensorflow_datasets as tfds # type: ignore -flags.DEFINE_boolean("mock_data", False, "Use fake data, for testing.") -FLAGS = flags.FLAGS +_MOCK_DATA = flags.DEFINE_boolean("mock_data", False, + "Use fake data, for testing.") #### Model parameters @@ -64,7 +64,7 @@ def load_mnist(split: tfds.Split, batch_size: int): an iterator with pairs (images, labels). The images have shape (B, 28, 28, 1) and the labels have shape (B, 10), where B is the batch_size. """ - if FLAGS.mock_data: + if _MOCK_DATA.value: with tfds.testing.mock_data(num_examples=batch_size): try: ds = tfds.load("mnist", split=split) diff --git a/jax/experimental/jax2tf/examples/saved_model_main.py b/jax/experimental/jax2tf/examples/saved_model_main.py index 67a5cd383..0dfd93821 100644 --- a/jax/experimental/jax2tf/examples/saved_model_main.py +++ b/jax/experimental/jax2tf/examples/saved_model_main.py @@ -37,46 +37,49 @@ import numpy as np import tensorflow as tf # type: ignore import tensorflow_datasets as tfds # type: ignore -flags.DEFINE_enum("model", "mnist_flax", ["mnist_flax", "mnist_pure_jax"], - "Which model to use.") -flags.DEFINE_boolean("model_classifier_layer", True, +_MODEL = flags.DEFINE_enum( + "model", "mnist_flax", ["mnist_flax", "mnist_pure_jax"], + "Which model to use.") +_MODEL_CLASSIFIER_LAYER = flags.DEFINE_boolean("model_classifier_layer", True, ("The model should include the classifier layer, or just " "the last layer of logits. Set this to False when you " "want to reuse the classifier-less model in a larger " "model. See keras_reuse_main.py and README.md.")) -flags.DEFINE_string("model_path", "/tmp/jax2tf/saved_models", +_MODEL_PATH = flags.DEFINE_string("model_path", "/tmp/jax2tf/saved_models", "Path under which to save the SavedModel.") -flags.DEFINE_integer("model_version", 1, +_MODEL_VERSION = flags.DEFINE_integer("model_version", 1, ("The version number for the SavedModel. Needed for " "serving, larger versions will take precedence"), lower_bound=1) -flags.DEFINE_integer("serving_batch_size", 1, +_SERVING_BATCH_SIZE = flags.DEFINE_integer("serving_batch_size", 1, "For what batch size to prepare the serving signature. " "Use -1 for converting and saving with batch polymorphism.") flags.register_validator( "serving_batch_size", - lambda serving_batch_size: serving_batch_size > 0 or serving_batch_size == -1, - message="--serving_batch_size must be either -1 or a positive integer.") + lambda serving_batch_size: serving_batch_size > 0 + or serving_batch_size == -1, + message="--serving_batch_size must be either -1 or a positive integer.", +) -flags.DEFINE_integer("num_epochs", 3, "For how many epochs to train.", - lower_bound=1) -flags.DEFINE_boolean( +_NUM_EPOCHS = flags.DEFINE_integer("num_epochs", 3, + "For how many epochs to train.", + lower_bound=1) +_GENERATE_MODEL = flags.DEFINE_boolean( "generate_model", True, "Train and save a new model. Otherwise, use an existing SavedModel.") -flags.DEFINE_boolean( +_COMPILE_MODEL = flags.DEFINE_boolean( "compile_model", True, "Enable TensorFlow jit_compiler for the SavedModel. This is " "necessary if you want to use the model for TensorFlow serving.") -flags.DEFINE_boolean("show_model", True, "Show details of saved SavedModel.") -flags.DEFINE_boolean( +_SHOW_MODEL = flags.DEFINE_boolean("show_model", True, + "Show details of saved SavedModel.") +SHOW_IMAGES = flags.DEFINE_boolean( "show_images", False, "Plot some sample images with labels and inference results.") -flags.DEFINE_boolean( +_TEST_SAVEDMODEL = flags.DEFINE_boolean( "test_savedmodel", True, "Test TensorFlow inference using the SavedModel w.r.t. the JAX model.") -FLAGS = flags.FLAGS - def train_and_save(): logging.info("Loading the MNIST TensorFlow dataset") @@ -85,22 +88,22 @@ def train_and_save(): test_ds = mnist_lib.load_mnist( tfds.Split.TEST, batch_size=mnist_lib.test_batch_size) - if FLAGS.show_images: + if SHOW_IMAGES.value: mnist_lib.plot_images(train_ds, 1, 5, "Training images", inference_fn=None) the_model_class = pick_model_class() model_dir = savedmodel_dir(with_version=True) - if FLAGS.generate_model: + if _GENERATE_MODEL.value: model_descr = model_description() logging.info("Generating model for %s", model_descr) (predict_fn, predict_params) = the_model_class.train( train_ds, test_ds, - FLAGS.num_epochs, - with_classifier=FLAGS.model_classifier_layer) + num_epochs=_NUM_EPOCHS.value, + with_classifier=_MODEL_CLASSIFIER_LAYER.value) - if FLAGS.serving_batch_size == -1: + if _SERVING_BATCH_SIZE.value == -1: # Batch-polymorphic SavedModel input_signatures = [ tf.TensorSpec((None,) + mnist_lib.input_shape, tf.float32), @@ -109,7 +112,7 @@ def train_and_save(): else: input_signatures = [ # The first one will be the serving signature - tf.TensorSpec((FLAGS.serving_batch_size,) + mnist_lib.input_shape, + tf.TensorSpec((_SERVING_BATCH_SIZE.value,) + mnist_lib.input_shape, tf.float32), tf.TensorSpec((mnist_lib.train_batch_size,) + mnist_lib.input_shape, tf.float32), @@ -126,15 +129,15 @@ def train_and_save(): with_gradient=True, input_signatures=input_signatures, polymorphic_shapes=polymorphic_shapes, - compile_model=FLAGS.compile_model) + compile_model=_COMPILE_MODEL.value) - if FLAGS.test_savedmodel: + if _TEST_SAVEDMODEL.value: tf_accelerator, tolerances = tf_accelerator_and_tolerances() with tf.device(tf_accelerator): logging.info("Testing savedmodel") pure_restored_model = tf.saved_model.load(model_dir) - if FLAGS.show_images and FLAGS.model_classifier_layer: + if SHOW_IMAGES.value and _MODEL_CLASSIFIER_LAYER.value: mnist_lib.plot_images( test_ds, 1, @@ -149,7 +152,7 @@ def train_and_save(): pure_restored_model(tf.convert_to_tensor(test_input)), predict_fn(predict_params, test_input), **tolerances) - if FLAGS.show_model: + if _SHOW_MODEL.value: def print_model(model_dir: str): cmd = f"saved_model_cli show --all --dir {model_dir}" print(cmd) @@ -160,18 +163,18 @@ def train_and_save(): def pick_model_class(): """Picks one of PureJaxMNIST or FlaxMNIST.""" - if FLAGS.model == "mnist_pure_jax": + if _MODEL.value == "mnist_pure_jax": return mnist_lib.PureJaxMNIST - elif FLAGS.model == "mnist_flax": + elif _MODEL.value == "mnist_flax": return mnist_lib.FlaxMNIST else: - raise ValueError(f"Unrecognized model: {FLAGS.model}") + raise ValueError(f"Unrecognized model: {_MODEL.value}") def model_description() -> str: """A short description of the picked model.""" res = pick_model_class().name - if not FLAGS.model_classifier_layer: + if not _MODEL_CLASSIFIER_LAYER.value: res += " (features_only)" return res @@ -179,11 +182,11 @@ def model_description() -> str: def savedmodel_dir(with_version: bool = True) -> str: """The directory where we save the SavedModel.""" model_dir = os.path.join( - FLAGS.model_path, - FLAGS.model + ('' if FLAGS.model_classifier_layer else '_features') + _MODEL_PATH.value, + _MODEL.value + ('' if _MODEL_CLASSIFIER_LAYER.value else '_features') ) if with_version: - model_dir = os.path.join(model_dir, str(FLAGS.model_version)) + model_dir = os.path.join(model_dir, str(_MODEL_VERSION.value)) return model_dir diff --git a/jax/experimental/jax2tf/examples/serving/model_server_request.py b/jax/experimental/jax2tf/examples/serving/model_server_request.py index 8fd0b6fc6..4319c1026 100644 --- a/jax/experimental/jax2tf/examples/serving/model_server_request.py +++ b/jax/experimental/jax2tf/examples/serving/model_server_request.py @@ -32,31 +32,32 @@ from tensorflow_serving.apis import predict_pb2 # type: ignore[import] from tensorflow_serving.apis import prediction_service_pb2_grpc -FLAGS = flags.FLAGS - -flags.DEFINE_boolean( +_USE_GRPC = flags.DEFINE_boolean( "use_grpc", True, "Use the gRPC API (default), or the HTTP REST API.") -flags.DEFINE_string( +_MODEL_SPEC_NAME = flags.DEFINE_string( "model_spec_name", "", "The name you used to export your model to model server (e.g., mnist_flax).") -flags.DEFINE_string( +_PREDICTION_SERVICE_ADDR = flags.DEFINE_string( "prediction_service_addr", "localhost:8500", "Stubby endpoint for the prediction service. If you serve your model " "locally using TensorFlow model server, then you can use \"localhost:8500\"" "for the gRPC server and \"localhost:8501\" for the HTTP REST server.") -flags.DEFINE_integer("serving_batch_size", 1, - "Batch size for the serving request. Must match the " - "batch size at which the model was saved. Must divide " - "--count_images", - lower_bound=1) -flags.DEFINE_integer("count_images", 16, - "How many images to test.", - lower_bound=1) +_SERVING_BATCH_SIZE = flags.DEFINE_integer( + "serving_batch_size", + 1, + "Batch size for the serving request. Must match the " + "batch size at which the model was saved. Must divide " + "--count_images", + lower_bound=1, +) +_COUNT_IMAGES = flags.DEFINE_integer( + "count_images", 16, "How many images to test.", lower_bound=1 +) def serving_call_mnist(images): @@ -69,12 +70,12 @@ def serving_call_mnist(images): Returns: A numpy.ndarray of shape [B, 10] with the one-hot inference response. """ - if FLAGS.use_grpc: - channel = grpc.insecure_channel(FLAGS.prediction_service_addr) + if _USE_GRPC.value: + channel = grpc.insecure_channel(_PREDICTION_SERVICE_ADDR.value) stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) request = predict_pb2.PredictRequest() - request.model_spec.name = FLAGS.model_spec_name + request.model_spec.name = _MODEL_SPEC_NAME.value request.model_spec.signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY # You can see the name of the input ("inputs") in the SavedModel dump. request.inputs["inputs"].CopyFrom( @@ -90,7 +91,7 @@ def serving_call_mnist(images): images_json = json.dumps(images.tolist()) # You can see the name of the input ("inputs") in the SavedModel dump. data = f'{{"inputs": {images_json}}}' - predict_url = f"http://{FLAGS.prediction_service_addr}/v1/models/{FLAGS.model_spec_name}:predict" + predict_url = f"http://{_PREDICTION_SERVICE_ADDR.value}/v1/models/{_MODEL_SPEC_NAME.value}:predict" response = requests.post(predict_url, data=data) if response.status_code != 200: msg = (f"Received error response {response.status_code} from model " @@ -101,14 +102,14 @@ def serving_call_mnist(images): def main(_): - if FLAGS.count_images % FLAGS.serving_batch_size != 0: - raise ValueError(f"The count_images ({FLAGS.count_images}) must be a " + if _COUNT_IMAGES.value % _SERVING_BATCH_SIZE.value != 0: + raise ValueError(f"The count_images ({_COUNT_IMAGES.value}) must be a " "multiple of " - f"serving_batch_size ({FLAGS.serving_batch_size})") + f"serving_batch_size ({_SERVING_BATCH_SIZE.value})") test_ds = mnist_lib.load_mnist(tfds.Split.TEST, - batch_size=FLAGS.serving_batch_size) + batch_size=_SERVING_BATCH_SIZE.value) images_and_labels = tfds.as_numpy(test_ds.take( - FLAGS.count_images // FLAGS.serving_batch_size)) + _COUNT_IMAGES.value // _SERVING_BATCH_SIZE.value)) accurate_count = 0 for batch_idx, (images, labels) in enumerate(images_and_labels): @@ -117,7 +118,7 @@ def main(_): labels_digit = np.argmax(labels, axis=1) accurate_count += np.sum(labels_digit == predictions_digit) running_accuracy = ( - 100. * accurate_count / (1 + batch_idx) / FLAGS.serving_batch_size) + 100. * accurate_count / (1 + batch_idx) / _SERVING_BATCH_SIZE.value) logging.info( " predicted digits = %s labels %s. Running accuracy %.3f%%", predictions_digit, labels_digit, running_accuracy) diff --git a/jax/experimental/jax2tf/examples/tf_js/quickdraw/quickdraw.py b/jax/experimental/jax2tf/examples/tf_js/quickdraw/quickdraw.py index 94f21724d..08a207e12 100644 --- a/jax/experimental/jax2tf/examples/tf_js/quickdraw/quickdraw.py +++ b/jax/experimental/jax2tf/examples/tf_js/quickdraw/quickdraw.py @@ -33,15 +33,17 @@ import tensorflowjs as tfjs import input_pipeline # type: ignore[import] -flags.DEFINE_integer("num_epochs", 5, - ("Number of epochs to train for.")) -flags.DEFINE_integer("num_classes", 100, "Number of classification classes.") +_NUM_EPOCHS = flags.DEFINE_integer( + "num_epochs", 5, "Number of epochs to train for." +) +_NUM_CLASSES = flags.DEFINE_integer( + "num_classes", 100, "Number of classification classes." +) flags.register_validator("num_classes", lambda value: value >= 1 and value <= 100, message="--num_classes must be in range [1, 100]") -FLAGS = flags.FLAGS # The code below is an adaptation for Flax from the work published here: # https://blog.tensorflow.org/2018/07/train-model-in-tfkeras-with-colab-and-run-in-browser-tensorflowjs.html @@ -65,7 +67,7 @@ class QuickDraw(nn.Module): x = nn.Dense(features=128)(x) x = nn.relu(x) - x = nn.Dense(features=FLAGS.num_classes)(x) + x = nn.Dense(features=_NUM_CLASSES.value)(x) return x @@ -75,7 +77,7 @@ def apply_model(state, inputs, labels): """Computes gradients, loss and accuracy for a single batch.""" def loss_fn(params): logits = state.apply_fn({'params': params}, inputs) - one_hot = jax.nn.one_hot(labels, FLAGS.num_classes) + one_hot = jax.nn.one_hot(labels, _NUM_CLASSES.value) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits grad_fn = jax.value_and_grad(loss_fn, has_aux=True) @@ -113,7 +115,7 @@ def create_train_state(rng): def train(state, train_ds, test_ds): - for epoch in range(1, FLAGS.num_epochs+1): + for epoch in range(1, _NUM_EPOCHS.value+1): start_time = time.time() state, train_loss, train_accuracy = run_epoch(state, train_ds) @@ -136,12 +138,12 @@ def main(argv): base_model_path = "/tmp/jax2tf/tf_js_quickdraw" dataset_path = os.path.join(base_model_path, "data") - classes = input_pipeline.download_dataset(dataset_path, FLAGS.num_classes) - assert len(classes) == FLAGS.num_classes, "Incorrect number of classes" + classes = input_pipeline.download_dataset(dataset_path, _NUM_CLASSES.value) + assert len(classes) == _NUM_CLASSES.value, "Incorrect number of classes" print(f"Classes are: {classes}") print("Loading dataset into memory...") train_ds, test_ds = input_pipeline.get_datasets(dataset_path, classes) - print(f"Starting training for {FLAGS.num_epochs} epochs...") + print(f"Starting training for {_NUM_EPOCHS.value} epochs...") state = create_train_state(jax.random.PRNGKey(0)) state = train(state, train_ds, test_ds) diff --git a/jax/experimental/jax2tf/examples/tflite/mnist/mnist.py b/jax/experimental/jax2tf/examples/tflite/mnist/mnist.py index 5452ea13d..6a2494d40 100644 --- a/jax/experimental/jax2tf/examples/tflite/mnist/mnist.py +++ b/jax/experimental/jax2tf/examples/tflite/mnist/mnist.py @@ -24,14 +24,19 @@ import numpy as np import tensorflow as tf # type: ignore[import] import tensorflow_datasets as tfds # type: ignore[import] -flags.DEFINE_string('tflite_file_path', - '/usr/local/google/home/qiuminxu/jax2tf/mnist.tflite', - 'Path where to save the TensorFlow Lite file.') -flags.DEFINE_integer('serving_batch_size', 4, - ('For what batch size to prepare the serving signature. ')) -flags.DEFINE_integer('num_epochs', 10, 'For how many epochs to train.') - -FLAGS = flags.FLAGS +_TFLITE_FILE_PATH = flags.DEFINE_string( + 'tflite_file_path', + '/tmp/mnist.tflite', + 'Path where to save the TensorFlow Lite file.', +) +_SERVING_BATCH_SIZE = flags.DEFINE_integer( + 'serving_batch_size', + 4, + 'For what batch size to prepare the serving signature. ', +) +_NUM_EPOCHS = flags.DEFINE_integer( + 'num_epochs', 10, 'For how many epochs to train.' +) # A helper function to evaluate the TF Lite model using "test" dataset. @@ -71,10 +76,11 @@ def main(_): train_ds = mnist_lib.load_mnist( tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size) test_ds = mnist_lib.load_mnist( - tfds.Split.TEST, batch_size=FLAGS.serving_batch_size) + tfds.Split.TEST, batch_size=_SERVING_BATCH_SIZE) - (flax_predict, - flax_params) = mnist_lib.FlaxMNIST.train(train_ds, test_ds, FLAGS.num_epochs) + (flax_predict, flax_params) = mnist_lib.FlaxMNIST.train( + train_ds, test_ds, _NUM_EPOCHS.value + ) def predict(image): return flax_predict(flax_params, image) @@ -84,7 +90,7 @@ def main(_): jax2tf.convert(predict, enable_xla=False), input_signature=[ tf.TensorSpec( - shape=[FLAGS.serving_batch_size, 28, 28, 1], + shape=[_SERVING_BATCH_SIZE, 28, 28, 1], dtype=tf.float32, name='input') ], @@ -126,7 +132,7 @@ def main(_): print('Quantized model accuracy = %.4f' % quantized_accuracy) print('Accuracy drop = %.4f' % (float_accuracy - quantized_accuracy)) - f = open(FLAGS.tflite_file_path, 'wb') + f = open(_TFLITE_FILE_PATH.value, 'wb') f.write(tflite_quantized_model) f.close() diff --git a/jax/experimental/jax2tf/tests/models_test_main.py b/jax/experimental/jax2tf/tests/models_test_main.py index b5ab4dfa4..23ee0b315 100644 --- a/jax/experimental/jax2tf/tests/models_test_main.py +++ b/jax/experimental/jax2tf/tests/models_test_main.py @@ -58,34 +58,32 @@ from jax.experimental.jax2tf.shape_poly import InconclusiveDimensionOperation from jax.experimental.jax2tf.tests.model_harness import ALL_HARNESSES from jax.experimental.jax2tf.tests.converters import ALL_CONVERTERS -flags.DEFINE_list("converters", [x.name for x in ALL_CONVERTERS], +_CONVERTERS = flags.DEFINE_list("converters", [x.name for x in ALL_CONVERTERS], "Which converters to test.") -flags.DEFINE_list("examples", [], +_EXAMPLES = flags.DEFINE_list("examples", [], ("List of examples to test, e.g.: 'flax/mnist,flax/seq2seq'. " "If empty, will test all examples.")) -flags.DEFINE_string("example_prefix", "", +_EXAMPLE_PREFIX = flags.DEFINE_string("example_prefix", "", ("Prefix for filtering tests. For instance 'flax/mnist' " "will test all examples starting with 'flax/mnist' " "(including all polymorphic tests).")) -flags.DEFINE_bool( +_WRITE_MARKDOWN = flags.DEFINE_bool( "write_markdown", True, "If true, write results as Markdown. Otherwise, only output to stdout.") -flags.DEFINE_bool( +_FAIL_ON_ERROR = flags.DEFINE_bool( "fail_on_error", False, ("If true, exit with an error when a conversion fails. Useful for " "debugging because it will show the entire stack trace.")) -FLAGS = flags.FLAGS - def _write_markdown(results: dict[str, list[tuple[str, str,]]]) -> None: """Writes all results to Markdown file.""" table_lines = [] - converters = FLAGS.converters + converters = _CONVERTERS.value table_lines.append("| Example | " + " ".join([f"{c} |" for c in converters])) table_lines.append("|" + (" --- |" * (len(converters) + 1))) @@ -173,7 +171,7 @@ def test_converters(): exit() def _maybe_reraise(e): - if FLAGS.fail_on_error: + if _FAIL_ON_ERROR.value: raise e def _format(e): @@ -183,13 +181,13 @@ def test_converters(): return msg converters = list( - filter(lambda x: x.name in FLAGS.converters, ALL_CONVERTERS)) + filter(lambda x: x.name in _CONVERTERS.value, ALL_CONVERTERS)) _exit_if_empty(converters, "converters") harnesses_to_test = { name: fn for name, fn in ALL_HARNESSES.items() - if (not FLAGS.examples or name in FLAGS.examples) and - (not FLAGS.example_prefix or name.startswith(FLAGS.example_prefix)) + if (not _EXAMPLES.value or name in _EXAMPLES.value) and + (not _EXAMPLE_PREFIX.value or name.startswith(_EXAMPLE_PREFIX.value)) } _exit_if_empty(harnesses_to_test, "harness") @@ -243,7 +241,7 @@ def test_converters(): converter_results.append((converter.name, error_msg)) results[harness.name] = converter_results - if FLAGS.write_markdown: + if _WRITE_MARKDOWN: _write_markdown(results) else: print("=== NOT writing results to Markdown.") diff --git a/jax/tools/build_defs.bzl b/jax/tools/build_defs.bzl index 928e3c567..1540afe42 100644 --- a/jax/tools/build_defs.bzl +++ b/jax/tools/build_defs.bzl @@ -140,7 +140,6 @@ def jax_to_ir(name, deps, fn, input_shapes, constants = None, format = "HLO"): from absl import app import jax.tools.jax_to_ir as jax_to_ir -jax_to_ir.set_up_flags() app.run(jax_to_ir.main) EOF """.format(runner = runner), diff --git a/jax/tools/jax_to_ir.py b/jax/tools/jax_to_ir.py index 435ce57b3..a0f86bd62 100644 --- a/jax/tools/jax_to_ir.py +++ b/jax/tools/jax_to_ir.py @@ -87,7 +87,30 @@ try: except ImportError: tf = None # type: ignore -FLAGS = flags.FLAGS + +_FN = flags.DEFINE_string( + 'fn', None, "Fully-qualified name of function that we're going to convert" +) +_INPUT_SHAPES = flags.DEFINE_string( + 'input_shapes', None, 'Python dict indicating XLA shapes of params' +) +_CONSTANTS = flags.DEFINE_string( + 'constants', '{}', 'Python dict giving constant values for some params' +) +_EVALED_CONSTANTS = flags.DEFINE_string( + 'evaled_constants', + '{}', + 'Python dict giving constant values for some params. ' + 'Values in this dict that are of type str are evaluated ' + 'using ast.literal_eval.', +) +_IR_FORMAT = flags.DEFINE_enum( + 'ir_format', 'HLO', ('HLO', 'TF'), 'Output format.' +) +_IR_DEST = flags.DEFINE_string('ir_dest', None, 'File to write IR to') +_IR_HUMAN_DEST = flags.DEFINE_string( + 'ir_human_dest', None, 'File to write human readable debug output' +) def jax_to_ir(fn, input_shapes, *, constants=None, format): @@ -163,25 +186,25 @@ def main(argv): if len(argv) != 1: raise app.UsageError('No positional arguments are accepted.') - if not FLAGS.ir_dest and not FLAGS.ir_human_dest: + if not _IR_DEST.value and not _IR_HUMAN_DEST.value: raise app.Error('At least one of --ir_dest and ' '--ir_human_dest is required.') - module_name, fn_name = FLAGS.fn.rsplit('.', 1) + module_name, fn_name = _FN.value.rsplit('.', 1) module = importlib.import_module(module_name) fn = getattr(module, fn_name) input_shapes = [(name, parse_shape_str(shape_str)) - for name, shape_str in literal_eval(FLAGS.input_shapes)] + for name, shape_str in literal_eval(_INPUT_SHAPES.value)] # Parse --constants and --evaled_constants. constants = {} - for k, v in literal_eval(FLAGS.constants).items(): + for k, v in literal_eval(_CONSTANTS.value).items(): if isinstance(v, list): v = jnp.asarray(v) constants[k] = v - for k, v in literal_eval(FLAGS.evaled_constants).items(): + for k, v in literal_eval(_EVALED_CONSTANTS.value).items(): if isinstance(v, str): v = literal_eval(v) if isinstance(v, list): @@ -192,14 +215,14 @@ def main(argv): constants[k] = v ir, debug_ir = jax_to_ir(fn, input_shapes, constants=constants, - format=FLAGS.ir_format) + format=_IR_FORMAT.value) - if FLAGS.ir_dest: - with open(FLAGS.ir_dest, 'wb') as f: + if _IR_DEST.value: + with open(_IR_DEST.value, 'wb') as f: f.write(ir) - if FLAGS.ir_human_dest: - with open(FLAGS.ir_human_dest, 'w') as f: + if _IR_HUMAN_DEST.value: + with open(_IR_HUMAN_DEST.value, 'w') as f: f.write(debug_ir) @@ -225,21 +248,6 @@ _SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$") def set_up_flags(): - flags.DEFINE_string( - 'fn', None, - "Fully-qualified name of function that we're going to convert") - flags.DEFINE_string('input_shapes', None, - 'Python dict indicating XLA shapes of params') - flags.DEFINE_string('constants', '{}', - 'Python dict giving constant values for some params') - flags.DEFINE_string('evaled_constants', '{}', - 'Python dict giving constant values for some params. ' - 'Values in this dict that are of type str are evaluated ' - 'using ast.literal_eval.') - flags.DEFINE_enum('ir_format', 'HLO', ('HLO', 'TF'), 'Output format.') - flags.DEFINE_string('ir_dest', None, 'File to write IR to') - flags.DEFINE_string('ir_human_dest', None, - 'File to write human readable debug output') flags.mark_flag_as_required('fn') flags.mark_flag_as_required('input_shapes') diff --git a/tests/aot_test.py b/tests/aot_test.py index 24ae9476d..318f9c685 100644 --- a/tests/aot_test.py +++ b/tests/aot_test.py @@ -22,7 +22,6 @@ import jax import jax.numpy as jnp from jax._src import core from jax._src import test_util as jtu -from jax._src.config import flags from jax.experimental.pjit import pjit from jax.experimental.serialize_executable import ( serialize, deserialize_and_load) @@ -73,7 +72,7 @@ class JaxAotTest(jtu.JaxTestCase): except NotImplementedError: raise unittest.SkipTest('PJRT Topology not supported') - if flags.FLAGS.jax_test_with_persistent_compilation_cache: + if jtu.TEST_WITH_PERSISTENT_COMPILATION_CACHE.value: raise unittest.SkipTest('Compilation caching not yet supported.') @jax.jit