Update flags to use the ABSL typed flag API.

Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary.

For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API.

Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`.

This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR.

PiperOrigin-RevId: 551604974
This commit is contained in:
Peter Hawkins 2023-07-27 12:15:16 -07:00 committed by jax authors
parent f35f226b44
commit 76cda0ae07
22 changed files with 336 additions and 289 deletions

View File

@ -27,8 +27,7 @@ from absl import app
from absl import flags from absl import flags
FLAGS = flags.FLAGS _SET_ENV = flags.DEFINE_multi_string(
flags.DEFINE_multi_string(
"set_env", None, "set_env", None,
"Specifies additional environment variables to be injected into the " "Specifies additional environment variables to be injected into the "
"environment (via --set_env=variable=value or --set_env=variable). " "environment (via --set_env=variable=value or --set_env=variable). "
@ -140,8 +139,8 @@ def jax_binary_op(state, **kwargs):
) )
def main(argv): def main(argv):
if FLAGS.set_env: if _SET_ENV.value:
for env_str in FLAGS.set_env: for env_str in _SET_ENV.value:
# Stop matching at the first '=' since we want to capture # Stop matching at the first '=' since we want to capture
# --set_env='FOO=--foo_a=1 --foo_b=2' all as part of FOO. # --set_env='FOO=--foo_a=1 --foo_b=2' all as part of FOO.
env_list = env_str.split('=', 1) env_list = env_str.split('=', 1)

View File

@ -83,22 +83,21 @@ import numpy.random as npr
from dp_accounting import dp_event from dp_accounting import dp_event
from dp_accounting import rdp from dp_accounting import rdp
FLAGS = flags.FLAGS
flags.DEFINE_boolean( _DPSGD = flags.DEFINE_boolean(
'dpsgd', True, 'If True, train with DP-SGD. If False, ' 'dpsgd', True, 'If True, train with DP-SGD. If False, '
'train with vanilla SGD.') 'train with vanilla SGD.')
flags.DEFINE_float('learning_rate', .15, 'Learning rate for training') _LEARNING_RATE = flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
flags.DEFINE_float('noise_multiplier', 1.1, _NOISE_MULTIPLIER = flags.DEFINE_float('noise_multiplier', 1.1,
'Ratio of the standard deviation to the clipping norm') 'Ratio of the standard deviation to the clipping norm')
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm') _L2_NORM_CLIP = flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
flags.DEFINE_integer('batch_size', 256, 'Batch size') _BATCH_SIZE = flags.DEFINE_integer('batch_size', 256, 'Batch size')
flags.DEFINE_integer('epochs', 60, 'Number of epochs') _EPOCHS = flags.DEFINE_integer('epochs', 60, 'Number of epochs')
flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG') _SEED = flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG')
flags.DEFINE_integer( _MICROBATCHES = flags.DEFINE_integer(
'microbatches', None, 'Number of microbatches ' 'microbatches', None, 'Number of microbatches '
'(must evenly divide batch_size)') '(must evenly divide batch_size)')
flags.DEFINE_string('model_dir', None, 'Model directory') _MODEL_DIR = flags.DEFINE_string('model_dir', None, 'Model directory')
init_random_params, predict = stax.serial( init_random_params, predict = stax.serial(
@ -163,39 +162,39 @@ def shape_as_image(images, labels, dummy_dim=False):
def compute_epsilon(steps, num_examples=60000, target_delta=1e-5): def compute_epsilon(steps, num_examples=60000, target_delta=1e-5):
if num_examples * target_delta > 1.: if num_examples * target_delta > 1.:
warnings.warn('Your delta might be too high.') warnings.warn('Your delta might be too high.')
q = FLAGS.batch_size / float(num_examples) q = _BATCH_SIZE.value / float(num_examples)
orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64)) orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64))
accountant = rdp.rdp_privacy_accountant.RdpAccountant(orders) accountant = rdp.rdp_privacy_accountant.RdpAccountant(orders)
accountant.compose( accountant.compose(
dp_event.PoissonSampledDpEvent( dp_event.PoissonSampledDpEvent(
q, dp_event.GaussianDpEvent(FLAGS.noise_multiplier)), steps) q, dp_event.GaussianDpEvent(_NOISE_MULTIPLIER.value)), steps)
return accountant.get_epsilon(target_delta) return accountant.get_epsilon(target_delta)
def main(_): def main(_):
if FLAGS.microbatches: if _MICROBATCHES.value:
raise NotImplementedError( raise NotImplementedError(
'Microbatches < batch size not currently supported' 'Microbatches < batch size not currently supported'
) )
train_images, train_labels, test_images, test_labels = datasets.mnist() train_images, train_labels, test_images, test_labels = datasets.mnist()
num_train = train_images.shape[0] num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, FLAGS.batch_size) num_complete_batches, leftover = divmod(num_train, _BATCH_SIZE.value)
num_batches = num_complete_batches + bool(leftover) num_batches = num_complete_batches + bool(leftover)
key = random.PRNGKey(FLAGS.seed) key = random.PRNGKey(_SEED.value)
def data_stream(): def data_stream():
rng = npr.RandomState(FLAGS.seed) rng = npr.RandomState(_SEED.value)
while True: while True:
perm = rng.permutation(num_train) perm = rng.permutation(num_train)
for i in range(num_batches): for i in range(num_batches):
batch_idx = perm[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size] batch_idx = perm[i * _BATCH_SIZE.value:(i + 1) * _BATCH_SIZE.value]
yield train_images[batch_idx], train_labels[batch_idx] yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream() batches = data_stream()
opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate) opt_init, opt_update, get_params = optimizers.sgd(_LEARNING_RATE.value)
@jit @jit
def update(_, i, opt_state, batch): def update(_, i, opt_state, batch):
@ -208,19 +207,19 @@ def main(_):
rng = random.fold_in(rng, i) # get new key for new random numbers rng = random.fold_in(rng, i) # get new key for new random numbers
return opt_update( return opt_update(
i, i,
private_grad(params, batch, rng, FLAGS.l2_norm_clip, private_grad(params, batch, rng, _L2_NORM_CLIP.value,
FLAGS.noise_multiplier, FLAGS.batch_size), opt_state) _NOISE_MULTIPLIER.value, _BATCH_SIZE.value), opt_state)
_, init_params = init_random_params(key, (-1, 28, 28, 1)) _, init_params = init_random_params(key, (-1, 28, 28, 1))
opt_state = opt_init(init_params) opt_state = opt_init(init_params)
itercount = itertools.count() itercount = itertools.count()
steps_per_epoch = 60000 // FLAGS.batch_size steps_per_epoch = 60000 // _BATCH_SIZE.value
print('\nStarting training...') print('\nStarting training...')
for epoch in range(1, FLAGS.epochs + 1): for epoch in range(1, _EPOCHS.value + 1):
start_time = time.time() start_time = time.time()
for _ in range(num_batches): for _ in range(num_batches):
if FLAGS.dpsgd: if _DPSGD.value:
opt_state = \ opt_state = \
private_update( private_update(
key, next(itercount), opt_state, key, next(itercount), opt_state,
@ -239,7 +238,7 @@ def main(_):
test_loss, 100 * test_acc)) test_loss, 100 * test_acc))
# determine privacy loss so far # determine privacy loss so far
if FLAGS.dpsgd: if _DPSGD.value:
delta = 1e-5 delta = 1e-5
num_examples = 60000 num_examples = 60000
eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta) eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta)

View File

@ -16,7 +16,6 @@
""" """
from absl import app from absl import app
from absl import flags
from functools import partial from functools import partial
from jax import grad from jax import grad
from jax import jit from jax import jit
@ -27,8 +26,6 @@ import jax.random as random
import jax.scipy as scipy import jax.scipy as scipy
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
FLAGS = flags.FLAGS
def main(unused_argv): def main(unused_argv):

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import annotations
from collections.abc import Hashable, Iterator from collections.abc import Hashable, Iterator
import contextlib import contextlib
import functools import functools
@ -20,7 +22,7 @@ import logging
import os import os
import sys import sys
import threading import threading
from typing import Any, Callable, NamedTuple, Optional from typing import Any, Callable, Generic, NamedTuple, Optional, TypeVar
from jax._src import lib from jax._src import lib
from jax._src.lib import jax_jit from jax._src.lib import jax_jit
@ -29,6 +31,8 @@ from jax._src.lib import xla_client
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_T = TypeVar('_T')
def bool_env(varname: str, default: bool) -> bool: def bool_env(varname: str, default: bool) -> bool:
"""Read an environment variable and interpret it as a boolean. """Read an environment variable and interpret it as a boolean.
@ -64,6 +68,16 @@ UPGRADE_BOOL_HELP = (
UPGRADE_BOOL_EXTRA_DESC = " (transient)" UPGRADE_BOOL_EXTRA_DESC = " (transient)"
class FlagHolder(Generic[_T]):
def __init__(self, flags: NameSpace, name: str):
self._flags = flags
self._name = name
@property
def value(self) -> _T:
return getattr(self._flags, self._name)
class Config: class Config:
_HAS_DYNAMIC_ATTRIBUTES = True _HAS_DYNAMIC_ATTRIBUTES = True
@ -112,26 +126,31 @@ class Config:
if name not in self.values: if name not in self.values:
raise AttributeError(f"Unrecognized config option: {name}") raise AttributeError(f"Unrecognized config option: {name}")
def DEFINE_bool(self, name, default, *args, **kwargs): def DEFINE_bool(self, name, default, *args, **kwargs) -> FlagHolder[bool]:
update_hook = kwargs.pop("update_hook", None) update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, bool, args, kwargs, update_hook=update_hook) self.add_option(name, default, bool, args, kwargs, update_hook=update_hook)
return FlagHolder(self.FLAGS, name)
def DEFINE_integer(self, name, default, *args, **kwargs): def DEFINE_integer(self, name, default, *args, **kwargs) -> FlagHolder[int]:
update_hook = kwargs.pop("update_hook", None) update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, int, args, kwargs, update_hook=update_hook) self.add_option(name, default, int, args, kwargs, update_hook=update_hook)
return FlagHolder(self.FLAGS, name)
def DEFINE_float(self, name, default, *args, **kwargs): def DEFINE_float(self, name, default, *args, **kwargs) -> FlagHolder[float]:
update_hook = kwargs.pop("update_hook", None) update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, float, args, kwargs, update_hook=update_hook) self.add_option(name, default, float, args, kwargs, update_hook=update_hook)
return FlagHolder(self.FLAGS, name)
def DEFINE_string(self, name, default, *args, **kwargs): def DEFINE_string(self, name, default, *args, **kwargs) -> FlagHolder[str]:
update_hook = kwargs.pop("update_hook", None) update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, str, args, kwargs, update_hook=update_hook) self.add_option(name, default, str, args, kwargs, update_hook=update_hook)
return FlagHolder(self.FLAGS, name)
def DEFINE_enum(self, name, default, *args, **kwargs): def DEFINE_enum(self, name, default, *args, **kwargs) -> FlagHolder[str]:
update_hook = kwargs.pop("update_hook", None) update_hook = kwargs.pop("update_hook", None)
self.add_option(name, default, 'enum', args, kwargs, self.add_option(name, default, 'enum', args, kwargs,
update_hook=update_hook) update_hook=update_hook)
return FlagHolder(self.FLAGS, name)
def config_with_absl(self): def config_with_absl(self):
# Run this before calling `app.run(main)` etc # Run this before calling `app.run(main)` etc
@ -551,6 +570,22 @@ config = Config()
flags = config flags = config
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def DEFINE_bool(name, default, *args, **kwargs):
return flags.DEFINE_bool(name, default, *args, **kwargs)
def DEFINE_integer(name, default, *args, **kwargs):
return flags.DEFINE_integer(name, default, *args, **kwargs)
def DEFINE_float(name, default, *args, **kwargs):
return flags.DEFINE_float(name, default, *args, **kwargs)
def DEFINE_string(name, default, *args, **kwargs):
return flags.DEFINE_string(name, default, *args, **kwargs)
def DEFINE_enum(name, default, *args, **kwargs):
return flags.DEFINE_enum(name, default, *args, **kwargs)
already_configured_with_absl = False already_configured_with_absl = False
@ -617,51 +652,6 @@ def update_thread_local_jit_state(**kw):
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp) tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
flags.DEFINE_integer(
'jax_tracer_error_num_traceback_frames',
int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
help='Set the number of stack frames in JAX tracer error messages.'
)
flags.DEFINE_bool(
'jax_pprint_use_color',
bool_env('JAX_PPRINT_USE_COLOR', True),
help='Enable jaxpr pretty-printing with colorful syntax highlighting.'
)
flags.DEFINE_bool(
'jax_host_callback_inline',
bool_env('JAX_HOST_CALLBACK_INLINE', False),
help='Inline the host_callback, if not in a staged context.'
)
flags.DEFINE_integer(
'jax_host_callback_max_queue_byte_size',
int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)),
help=('The size in bytes of the buffer used to hold outfeeds from each '
'device. When this capacity is reached consuming outfeeds from the '
'device is paused, thus potentially pausing the device computation, '
'until the Python callback consume more outfeeds.'),
lower_bound=int(16 * 1e6)
)
flags.DEFINE_bool(
'jax_host_callback_outfeed',
bool_env('JAX_HOST_CALLBACK_OUTFEED', False),
help=(
'Use outfeed implementation for host_callback, even on CPU and GPU. '
'If false, use the CustomCall implementation. '
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
)
)
flags.DEFINE_bool(
'jax_host_callback_ad_transforms',
bool_env('JAX_HOST_CALLBACK_AD_TRANSFORMS', False),
help=(
'Enable support for jvp/vjp for the host_callback primitives. Default is '
'False, which means that host_callback operates only on primals. '
'The flag exists only temporarily, for backward compatibility.'
)
)
# TODO(b/214340779): remove flag when XLA:CPU is improved. # TODO(b/214340779): remove flag when XLA:CPU is improved.
jax2tf_associative_scan_reductions = config.define_bool_state( jax2tf_associative_scan_reductions = config.define_bool_state(
name='jax2tf_associative_scan_reductions', name='jax2tf_associative_scan_reductions',

View File

@ -38,7 +38,7 @@ import numpy as np
from jax._src import dtypes from jax._src import dtypes
from jax._src import config as jax_config from jax._src import config as jax_config
from jax._src import effects from jax._src import effects
from jax._src.config import FLAGS, config from jax._src.config import config
from jax._src.errors import ( from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError, ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError) TracerIntegerConversionError, UnexpectedTracerError)
@ -60,6 +60,13 @@ zip, unsafe_zip = safe_zip, zip
map, unsafe_map = safe_map, map map, unsafe_map = safe_map, map
_TRACER_ERROR_NUM_TRACEBACK_FRAMES = jax_config.DEFINE_integer(
'jax_tracer_error_num_traceback_frames',
jax_config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
help='Set the number of stack frames in JAX tracer error messages.'
)
# -------------------- jaxprs -------------------- # -------------------- jaxprs --------------------
Effect = effects.Effect Effect = effects.Effect
@ -560,7 +567,7 @@ def raise_as_much_as_possible(tracer) -> Tracer:
def escaped_tracer_error(tracer, detail=None): def escaped_tracer_error(tracer, detail=None):
num_frames = FLAGS.jax_tracer_error_num_traceback_frames num_frames = _TRACER_ERROR_NUM_TRACEBACK_FRAMES.value
msg = ('Encountered an unexpected tracer. A function transformed by JAX ' msg = ('Encountered an unexpected tracer. A function transformed by JAX '
'had a side effect, allowing for a reference to an intermediate value ' 'had a side effect, allowing for a reference to an intermediate value '
f'with type {tracer.aval.str_short()} wrapped in a ' f'with type {tracer.aval.str_short()} wrapped in a '

View File

@ -32,6 +32,7 @@ import warnings
import numpy as np import numpy as np
from jax._src import compilation_cache from jax._src import compilation_cache
from jax._src import config as jax_config
from jax._src import core from jax._src import core
from jax._src import dtypes from jax._src import dtypes
from jax._src import linear_util as lu from jax._src import linear_util as lu
@ -43,7 +44,7 @@ from jax._src import traceback_util
from jax._src import util from jax._src import util
from jax._src import op_shardings from jax._src import op_shardings
from jax._src import xla_bridge as xb from jax._src import xla_bridge as xb
from jax._src.config import config, flags from jax._src.config import config
from jax._src.interpreters import ad from jax._src.interpreters import ad
from jax._src.interpreters import batching from jax._src.interpreters import batching
from jax._src.interpreters import mlir from jax._src.interpreters import mlir
@ -64,9 +65,7 @@ JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration" JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration" BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
FLAGS = flags.FLAGS _DUMP_IR_TO = jax_config.DEFINE_string(
flags.DEFINE_string(
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''), 'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
help="Path to which the IR that is emitted by JAX as input to the " help="Path to which the IR that is emitted by JAX as input to the "
"compiler should be dumped as text files. Optional. If omitted, JAX " "compiler should be dumped as text files. Optional. If omitted, JAX "
@ -472,7 +471,7 @@ def _make_string_safe_for_filename(s: str) -> str:
def _dump_ir_to_file(name: str, ir: str): def _dump_ir_to_file(name: str, ir: str):
id = next(_ir_dump_counter) id = next(_ir_dump_counter)
name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir" name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir"
name = path.Path(FLAGS.jax_dump_ir_to) / name name = path.Path(_DUMP_IR_TO.value) / name
name.write_text(ir) name.write_text(ir)
@ -481,7 +480,7 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray,
sym_name = computation.operation.attributes['sym_name'] sym_name = computation.operation.attributes['sym_name']
module_name = ir.StringAttr(sym_name).value module_name = ir.StringAttr(sym_name).value
if FLAGS.jax_dump_ir_to: if _DUMP_IR_TO.value:
_dump_ir_to_file(module_name, mlir.module_to_string(computation)) _dump_ir_to_file(module_name, mlir.module_to_string(computation))
# Persistent compilation cache only implemented on TPU and GPU. # Persistent compilation cache only implemented on TPU and GPU.

View File

@ -28,7 +28,7 @@ import warnings
import ml_dtypes import ml_dtypes
import numpy as np import numpy as np
from jax._src.config import flags, config from jax._src.config import config
from jax._src.typing import DType, DTypeLike from jax._src.typing import DType, DTypeLike
from jax._src import traceback_util from jax._src import traceback_util
@ -43,8 +43,6 @@ else:
raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; " raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; "
f"installed version is {ml_dtypes.__version__}.") f"installed version is {ml_dtypes.__version__}.")
FLAGS = flags.FLAGS
class extended(np.generic): class extended(np.generic):
"""Scalar class for extended dtypes. """Scalar class for extended dtypes.

View File

@ -32,13 +32,20 @@ from functools import partial
import sys import sys
from typing import NamedTuple, Optional, Union from typing import NamedTuple, Optional, Union
from jax._src.config import config from jax._src import config
try: try:
import colorama # pytype: disable=import-error import colorama # pytype: disable=import-error
except ImportError: except ImportError:
colorama = None colorama = None
_PPRINT_USE_COLOR = config.DEFINE_bool(
'jax_pprint_use_color',
config.bool_env('JAX_PPRINT_USE_COLOR', True),
help='Enable jaxpr pretty-printing with colorful syntax highlighting.'
)
def _can_use_color() -> bool: def _can_use_color() -> bool:
try: try:
# Check if we're in IPython or Colab # Check if we're in IPython or Colab
@ -63,7 +70,7 @@ class Doc(abc.ABC):
def format(self, width: int = 80, use_color: Optional[bool] = None, def format(self, width: int = 80, use_color: Optional[bool] = None,
annotation_prefix=" # ") -> str: annotation_prefix=" # ") -> str:
if use_color is None: if use_color is None:
use_color = CAN_USE_COLOR and config.FLAGS.jax_pprint_use_color use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value
return _format(self, width, use_color=use_color, return _format(self, width, use_color=use_color,
annotation_prefix=annotation_prefix) annotation_prefix=annotation_prefix)

View File

@ -16,9 +16,10 @@ from functools import partial
import operator import operator
from jax._src import api from jax._src import api
from jax._src import config as jax_config
from jax._src import dtypes as _dtypes from jax._src import dtypes as _dtypes
from jax._src import xla_bridge from jax._src import xla_bridge
from jax._src.config import config, flags from jax._src.config import config
from jax._src.tree_util import tree_map, tree_reduce from jax._src.tree_util import tree_map, tree_reduce
import numpy as np import numpy as np
@ -30,7 +31,11 @@ import numpy as np
__all__ = ['check_grads', 'check_jvp', 'check_vjp'] __all__ = ['check_grads', 'check_jvp', 'check_vjp']
FLAGS = flags.FLAGS _TEST_DUT = jax_config.DEFINE_string(
'jax_test_dut', '',
help=
'Describes the device under test in case special consideration is required.'
)
EPS = 1e-4 EPS = 1e-4
@ -292,4 +297,4 @@ def check_grads(f, args, order,
def device_under_test(): def device_under_test():
return getattr(FLAGS, 'jax_test_dut', None) or xla_bridge.get_backend().platform return _TEST_DUT.value or xla_bridge.get_backend().platform

View File

@ -41,11 +41,12 @@ from jax._src.interpreters import mlir
from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten
from jax._src import api from jax._src import api
from jax._src import pjit as pjit_lib from jax._src import pjit as pjit_lib
from jax._src import config as jax_config
from jax._src import core from jax._src import core
from jax._src import dispatch from jax._src import dispatch
from jax._src import dtypes as _dtypes from jax._src import dtypes as _dtypes
from jax._src.interpreters import pxla from jax._src.interpreters import pxla
from jax._src.config import (flags, bool_env, config, from jax._src.config import (bool_env, config,
raise_persistent_cache_errors, raise_persistent_cache_errors,
persistent_cache_min_compile_time_secs) persistent_cache_min_compile_time_secs)
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
@ -60,19 +61,12 @@ from jax._src import xla_bridge
# jax.test_util. Functionality appearing here is for internal use only, and # jax.test_util. Functionality appearing here is for internal use only, and
# may be changed or removed at any time and without any deprecation cycle. # may be changed or removed at any time and without any deprecation cycle.
FLAGS = flags.FLAGS _NUM_GENERATED_CASES = jax_config.DEFINE_integer(
flags.DEFINE_string(
'jax_test_dut', '',
help=
'Describes the device under test in case special consideration is required.'
)
flags.DEFINE_integer(
'jax_num_generated_cases', 'jax_num_generated_cases',
int(os.getenv('JAX_NUM_GENERATED_CASES', '10')), int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
help='Number of generated cases to test') help='Number of generated cases to test')
flags.DEFINE_integer( _MAX_CASES_SAMPLING_RETRIES = jax_config.DEFINE_integer(
'max_cases_sampling_retries', 'max_cases_sampling_retries',
int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')), int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')),
'Number of times a failed test sample should be retried. ' 'Number of times a failed test sample should be retried. '
@ -80,24 +74,23 @@ flags.DEFINE_integer(
'sampling process is terminated.' 'sampling process is terminated.'
) )
flags.DEFINE_bool( _SKIP_SLOW_TESTS = jax_config.DEFINE_bool(
'jax_skip_slow_tests', 'jax_skip_slow_tests',
bool_env('JAX_SKIP_SLOW_TESTS', False), bool_env('JAX_SKIP_SLOW_TESTS', False),
help='Skip tests marked as slow (> 5 sec).' help='Skip tests marked as slow (> 5 sec).'
) )
flags.DEFINE_string( _TEST_TARGETS = jax_config.DEFINE_string(
'test_targets', os.getenv('JAX_TEST_TARGETS', ''), 'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
'Regular expression specifying which tests to run, called via re.search on ' 'Regular expression specifying which tests to run, called via re.search on '
'the test name. If empty or unspecified, run all tests.' 'the test name. If empty or unspecified, run all tests.'
) )
flags.DEFINE_string( _EXCLUDE_TEST_TARGETS = jax_config.DEFINE_string(
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''), 'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
'Regular expression specifying which tests NOT to run, called via re.search ' 'Regular expression specifying which tests NOT to run, called via re.search '
'on the test name. If empty or unspecified, run all tests.' 'on the test name. If empty or unspecified, run all tests.'
) )
TEST_WITH_PERSISTENT_COMPILATION_CACHE = jax_config.DEFINE_bool(
flags.DEFINE_bool(
'jax_test_with_persistent_compilation_cache', 'jax_test_with_persistent_compilation_cache',
bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False), bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False),
help='If enabled, the persistent compilation cache will be enabled for all ' help='If enabled, the persistent compilation cache will be enabled for all '
@ -731,7 +724,7 @@ def assert_dot_precision(expected_precision, fun, *args):
def cases_from_gens(*gens): def cases_from_gens(*gens):
sizes = [1, 3, 10] sizes = [1, 3, 10]
cases_per_size = int(FLAGS.jax_num_generated_cases / len(sizes)) + 1 cases_per_size = int(_NUM_GENERATED_CASES.value / len(sizes)) + 1
for size in sizes: for size in sizes:
for i in range(cases_per_size): for i in range(cases_per_size):
yield (f'_{size}_{i}',) + tuple(gen(size) for gen in gens) yield (f'_{size}_{i}',) + tuple(gen(size) for gen in gens)
@ -744,8 +737,8 @@ def named_cases_from_sampler(gen):
if not isinstance(x, (list, tuple)): if not isinstance(x, (list, tuple)):
x = list(x) x = list(x)
return [x[rng.randint(len(x))]] return [x[rng.randint(len(x))]]
while (len(seen) < FLAGS.jax_num_generated_cases and while (len(seen) < _NUM_GENERATED_CASES.value and
retries < FLAGS.max_cases_sampling_retries): retries < _MAX_CASES_SAMPLING_RETRIES.value):
retries += 1 retries += 1
cases = list(gen(choose_one)) cases = list(gen(choose_one))
if not cases: if not cases:
@ -773,7 +766,7 @@ def sample_product_testcases(*args, **kw):
kw = [(k, list(v)) for k, v in kw.items()] kw = [(k, list(v)) for k, v in kw.items()]
n = math.prod(len(a) for a in args) * math.prod(len(v) for _, v in kw) n = math.prod(len(a) for a in args) * math.prod(len(v) for _, v in kw)
testcases = [] testcases = []
for i in _choice(n, min(n, FLAGS.jax_num_generated_cases)): for i in _choice(n, min(n, _NUM_GENERATED_CASES.value)):
testcase = {} testcase = {}
for a in args: for a in args:
testcase.update(a[i % len(a)]) testcase.update(a[i % len(a)])
@ -804,12 +797,12 @@ def sample_product(*args, **kw):
class JaxTestLoader(absltest.TestLoader): class JaxTestLoader(absltest.TestLoader):
def getTestCaseNames(self, testCaseClass): def getTestCaseNames(self, testCaseClass):
names = super().getTestCaseNames(testCaseClass) names = super().getTestCaseNames(testCaseClass)
if FLAGS.test_targets: if _TEST_TARGETS.value:
pattern = re.compile(FLAGS.test_targets) pattern = re.compile(_TEST_TARGETS.value)
names = [name for name in names names = [name for name in names
if pattern.search(f"{testCaseClass.__name__}.{name}")] if pattern.search(f"{testCaseClass.__name__}.{name}")]
if FLAGS.exclude_test_targets: if _EXCLUDE_TEST_TARGETS.value:
pattern = re.compile(FLAGS.exclude_test_targets) pattern = re.compile(_EXCLUDE_TEST_TARGETS.value)
names = [name for name in names names = [name for name in names
if not pattern.search(f"{testCaseClass.__name__}.{name}")] if not pattern.search(f"{testCaseClass.__name__}.{name}")]
return names return names
@ -874,7 +867,7 @@ class JaxTestCase(parameterized.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
if FLAGS.jax_test_with_persistent_compilation_cache: if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
cls._compilation_cache_exit_stack = ExitStack() cls._compilation_cache_exit_stack = ExitStack()
stack = cls._compilation_cache_exit_stack stack = cls._compilation_cache_exit_stack
stack.enter_context(raise_persistent_cache_errors(True)) stack.enter_context(raise_persistent_cache_errors(True))
@ -887,7 +880,7 @@ class JaxTestCase(parameterized.TestCase):
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls):
if FLAGS.jax_test_with_persistent_compilation_cache: if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
cls._compilation_cache_exit_stack.close() cls._compilation_cache_exit_stack.close()
def rng(self): def rng(self):

View File

@ -37,7 +37,8 @@ import numpy as np
from jax._src import lib from jax._src import lib
from jax._src import distributed from jax._src import distributed
from jax._src.config import flags, bool_env, config, int_env from jax._src import config as jax_config
from jax._src.config import bool_env, config, int_env
from jax._src.lib import xla_client from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version from jax._src.lib import xla_extension_version
from jax._src import traceback_util from jax._src import traceback_util
@ -59,37 +60,35 @@ traceback_util.register_exclusion(__file__)
XlaBackend = xla_client.Client XlaBackend = xla_client.Client
FLAGS = flags.FLAGS
# TODO(phawkins): Remove jax_xla_backend. # TODO(phawkins): Remove jax_xla_backend.
flags.DEFINE_string( _XLA_BACKEND = jax_config.DEFINE_string(
'jax_xla_backend', '', 'jax_xla_backend', '',
'Deprecated, please use --jax_platforms instead.') 'Deprecated, please use --jax_platforms instead.')
flags.DEFINE_string( BACKEND_TARGET = jax_config.DEFINE_string(
'jax_backend_target', 'jax_backend_target',
os.getenv('JAX_BACKEND_TARGET', '').lower(), os.getenv('JAX_BACKEND_TARGET', '').lower(),
'Either "local" or "rpc:address" to connect to a remote service target.') 'Either "local" or "rpc:address" to connect to a remote service target.')
# TODO(skye): warn when this is used once we test out --jax_platforms a bit # TODO(skye): warn when this is used once we test out --jax_platforms a bit
flags.DEFINE_string( _PLATFORM_NAME = jax_config.DEFINE_string(
'jax_platform_name', 'jax_platform_name',
os.getenv('JAX_PLATFORM_NAME', '').lower(), os.getenv('JAX_PLATFORM_NAME', '').lower(),
'Deprecated, please use --jax_platforms instead.') 'Deprecated, please use --jax_platforms instead.')
flags.DEFINE_bool( _DISABLE_MOST_OPTIMIZATIONS = jax_config.DEFINE_bool(
'jax_disable_most_optimizations', 'jax_disable_most_optimizations',
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False), bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
'Try not to do much optimization work. This can be useful if the cost of ' 'Try not to do much optimization work. This can be useful if the cost of '
'optimization is greater than that of running a less-optimized program.') 'optimization is greater than that of running a less-optimized program.')
flags.DEFINE_integer( _XLA_PROFILE_VERSION = jax_config.DEFINE_integer(
'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0), 'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0),
'Optional profile version for XLA compilation. ' 'Optional profile version for XLA compilation. '
'This is meaningful only when XLA is configured to ' 'This is meaningful only when XLA is configured to '
'support the remote compilation profile feature.') 'support the remote compilation profile feature.')
flags.DEFINE_string( CUDA_VISIBLE_DEVICES = jax_config.DEFINE_string(
'jax_cuda_visible_devices', 'all', 'jax_cuda_visible_devices', 'all',
'Restricts the set of CUDA devices that JAX will use. Either "all", or a ' 'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
'comma-separate list of integer device IDs.') 'comma-separate list of integer device IDs.')
flags.DEFINE_string( _ROCM_VISIBLE_DEVICES = jax_config.DEFINE_string(
'jax_rocm_visible_devices', 'all', 'jax_rocm_visible_devices', 'all',
'Restricts the set of ROCM devices that JAX will use. Either "all", or a ' 'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
'comma-separate list of integer device IDs.') 'comma-separate list of integer device IDs.')
@ -171,13 +170,13 @@ def get_compile_options(
if lib.cuda_path is not None: if lib.cuda_path is not None:
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path debug_options.xla_gpu_cuda_data_dir = lib.cuda_path
if FLAGS.jax_disable_most_optimizations: if _DISABLE_MOST_OPTIMIZATIONS.value:
debug_options.xla_backend_optimization_level = 0 debug_options.xla_backend_optimization_level = 0
debug_options.xla_llvm_disable_expensive_passes = True debug_options.xla_llvm_disable_expensive_passes = True
debug_options.xla_test_all_input_layouts = False debug_options.xla_test_all_input_layouts = False
compile_options.profile_version = FLAGS.jax_xla_profile_version compile_options.profile_version = _XLA_PROFILE_VERSION.value
return compile_options return compile_options
@ -264,9 +263,9 @@ register_backend_factory('cpu',
def make_gpu_client( def make_gpu_client(
*, platform_name: str, visible_devices_flag: str *, platform_name: str, visible_devices_flag: jax_config.FlagHolder[str]
) -> xla_client.Client: ) -> xla_client.Client:
visible_devices = getattr(FLAGS, visible_devices_flag, "all") visible_devices = visible_devices_flag.value
allowed_devices = None allowed_devices = None
if visible_devices != "all": if visible_devices != "all":
allowed_devices = {int(x) for x in visible_devices.split(",")} allowed_devices = {int(x) for x in visible_devices.split(",")}
@ -292,15 +291,25 @@ def make_gpu_client(
if hasattr(xla_client, "make_gpu_client"): if hasattr(xla_client, "make_gpu_client"):
register_backend_factory( register_backend_factory(
'cuda', partial(make_gpu_client, platform_name='cuda', "cuda",
visible_devices_flag='jax_cuda_visible_devices'), partial(
make_gpu_client,
platform_name="cuda",
visible_devices_flag=CUDA_VISIBLE_DEVICES,
),
priority=200, priority=200,
fail_quietly=True) fail_quietly=True,
)
register_backend_factory( register_backend_factory(
'rocm', partial(make_gpu_client, platform_name='rocm', "rocm",
visible_devices_flag='jax_rocm_visible_devices'), partial(
make_gpu_client,
platform_name="rocm",
visible_devices_flag=_ROCM_VISIBLE_DEVICES,
),
priority=200, priority=200,
fail_quietly=True) fail_quietly=True,
)
if hasattr(xla_client, "make_tpu_client"): if hasattr(xla_client, "make_tpu_client"):
@ -623,7 +632,7 @@ def backends() -> dict[str, xla_client.Client]:
# support anything else there at the moment and warning would be pointless. # support anything else there at the moment and warning would be pointless.
if (py_platform.system() != "Darwin" and if (py_platform.system() != "Darwin" and
_default_backend.platform == "cpu" and _default_backend.platform == "cpu" and
FLAGS.jax_platform_name != 'cpu'): _PLATFORM_NAME.value != 'cpu'):
logger.warning('No GPU/TPU found, falling back to CPU. ' logger.warning('No GPU/TPU found, falling back to CPU. '
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)') '(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
return _backends return _backends
@ -677,8 +686,7 @@ def _get_backend_uncached(
if platform is not None and not isinstance(platform, str): if platform is not None and not isinstance(platform, str):
return platform return platform
platform = (platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name platform = (platform or _XLA_BACKEND.value or _PLATFORM_NAME.value or None)
or None)
bs = backends() bs = backends()
if platform is not None: if platform is not None:

View File

@ -507,7 +507,7 @@ import warnings
from jax._src import api from jax._src import api
from jax._src import core from jax._src import core
from jax import config from jax._src import config
from jax import custom_derivatives from jax import custom_derivatives
from jax._src import dtypes from jax._src import dtypes
from jax import lax from jax import lax
@ -531,17 +531,46 @@ from jax._src.lib.mlir.dialects import hlo
import numpy as np import numpy as np
FLAGS = config.FLAGS _HOST_CALLBACK_INLINE = config.DEFINE_bool(
'jax_host_callback_inline',
config.bool_env('JAX_HOST_CALLBACK_INLINE', False),
help='Inline the host_callback, if not in a staged context.'
)
_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.DEFINE_integer(
'jax_host_callback_max_queue_byte_size',
config.int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)),
help=('The size in bytes of the buffer used to hold outfeeds from each '
'device. When this capacity is reached consuming outfeeds from the '
'device is paused, thus potentially pausing the device computation, '
'until the Python callback consume more outfeeds.'),
lower_bound=int(16 * 1e6)
)
_HOST_CALLBACK_OUTFEED = config.DEFINE_bool(
'jax_host_callback_outfeed',
config.bool_env('JAX_HOST_CALLBACK_OUTFEED', False),
help=(
'Use outfeed implementation for host_callback, even on CPU and GPU. '
'If false, use the CustomCall implementation. '
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
)
)
_HOST_CALLBACK_AD_TRANSFORMS = config.DEFINE_bool(
'jax_host_callback_ad_transforms',
config.bool_env('JAX_HOST_CALLBACK_AD_TRANSFORMS', False),
help=(
'Enable support for jvp/vjp for the host_callback primitives. Default is '
'False, which means that host_callback operates only on primals. '
'The flag exists only temporarily, for backward compatibility.'
)
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _inline_host_callback() -> bool:
return FLAGS.jax_host_callback_inline
def _use_outfeed(platform: str) -> bool: def _use_outfeed(platform: str) -> bool:
return (platform in ("tpu", "gpu", "cuda", "rocm") or FLAGS.jax_host_callback_outfeed) return (platform in ("tpu", "gpu", "cuda", "rocm") or
_HOST_CALLBACK_OUTFEED.value)
def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend): def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend):
@ -620,7 +649,7 @@ def id_tap(tap_func,
"pre-apply keyword arguments, either by using a closure or by passing " "pre-apply keyword arguments, either by using a closure or by passing "
"``functools.partial(tap_func, **kwargs)``.") "``functools.partial(tap_func, **kwargs)``.")
raise TypeError(msg) raise TypeError(msg)
if FLAGS.jax_host_callback_ad_transforms: if _HOST_CALLBACK_AD_TRANSFORMS.value:
warnings.warn('The flag jax_host_callback_ad_transforms is for temporary ' warnings.warn('The flag jax_host_callback_ad_transforms is for temporary '
'backwards compatibility mode. This flag, and the behavior ' 'backwards compatibility mode. This flag, and the behavior '
'it enabled will be removed soon.', 'it enabled will be removed soon.',
@ -642,7 +671,7 @@ def id_tap(tap_func,
if result is not None: if result is not None:
# Return the results, but add a dependency on the call, to ensure it # Return the results, but add a dependency on the call, to ensure it
# is kept in the graph. # is kept in the graph.
if FLAGS.jax_host_callback_ad_transforms: if _HOST_CALLBACK_AD_TRANSFORMS.value:
call_flat_results, _ = tree_util.tree_flatten(call_res) call_flat_results, _ = tree_util.tree_flatten(call_res)
if call_flat_results: if call_flat_results:
call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0]) call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0])
@ -782,7 +811,7 @@ def _call(callback_func: Callable,
identity=False): identity=False):
# Lazy initialization # Lazy initialization
_initialize_outfeed_receiver( _initialize_outfeed_receiver(
max_callback_queue_size_bytes=FLAGS.jax_host_callback_max_queue_byte_size) max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value)
api.check_callable(callback_func) api.check_callable(callback_func)
flat_args, arg_treedef = tree_util.tree_flatten(arg) flat_args, arg_treedef = tree_util.tree_flatten(arg)
for arg in flat_args: for arg in flat_args:
@ -909,7 +938,7 @@ xla.register_translation(id_tap_dep_p,
id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a) id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a)
def _id_tap_dep_jvp_rule(primals, tangents): def _id_tap_dep_jvp_rule(primals, tangents):
if FLAGS.jax_host_callback_ad_transforms: if _HOST_CALLBACK_AD_TRANSFORMS.value:
assert False assert False
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals)) tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
return (id_tap_dep_p.bind(primals[0], primals[1]), return (id_tap_dep_p.bind(primals[0], primals[1]),
@ -918,7 +947,7 @@ def _id_tap_dep_jvp_rule(primals, tangents):
ad.primitive_jvps[id_tap_dep_p] = _id_tap_dep_jvp_rule ad.primitive_jvps[id_tap_dep_p] = _id_tap_dep_jvp_rule
def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap): def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap):
if FLAGS.jax_host_callback_ad_transforms: if _HOST_CALLBACK_AD_TRANSFORMS.value:
assert False assert False
if ad.is_undefined_primal(arg_res): if ad.is_undefined_primal(arg_res):
ct_res = _instantiate_zeros(cts, arg_res) ct_res = _instantiate_zeros(cts, arg_res)
@ -934,7 +963,7 @@ ad.primitive_transposes[id_tap_dep_p] = _id_tap_dep_transpose_rule
def _id_tap_dep_batching_rule(batched_args, batch_dims): def _id_tap_dep_batching_rule(batched_args, batch_dims):
if FLAGS.jax_host_callback_ad_transforms: if _HOST_CALLBACK_AD_TRANSFORMS.value:
assert False assert False
arg_res, arg_tap = batched_args arg_res, arg_tap = batched_args
return id_tap_dep_p.bind(arg_res, arg_tap), batch_dims[0] return id_tap_dep_p.bind(arg_res, arg_tap), batch_dims[0]
@ -1013,7 +1042,7 @@ outside_call_p.def_abstract_eval(_outside_call_abstract_eval)
def _outside_call_impl(*args, **params): def _outside_call_impl(*args, **params):
assert "has_token" not in params assert "has_token" not in params
if _inline_host_callback(): if _HOST_CALLBACK_INLINE.value:
device_index = params["device_index"] device_index = params["device_index"]
device = xb.devices()[device_index] device = xb.devices()[device_index]
results = _outside_call_run_callback(args, device, send_infeed=False, **params) results = _outside_call_run_callback(args, device, send_infeed=False, **params)
@ -1400,7 +1429,7 @@ def _outside_call_jvp_rule(primals, tangents, **params):
assert "has_token" not in params assert "has_token" not in params
if not params["identity"]: if not params["identity"]:
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.") raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
if FLAGS.jax_host_callback_ad_transforms: if _HOST_CALLBACK_AD_TRANSFORMS.value:
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals)) tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
arg_treedef = params["arg_treedef"] arg_treedef = params["arg_treedef"]
@ -1425,7 +1454,7 @@ ad.primitive_jvps[outside_call_p] = _outside_call_jvp_rule
def _outside_call_partial_eval_rule(trace, *args, **params): def _outside_call_partial_eval_rule(trace, *args, **params):
# partial eval is used after jvp and before transpose. # partial eval is used after jvp and before transpose.
if not FLAGS.jax_host_callback_ad_transforms: if not _HOST_CALLBACK_AD_TRANSFORMS.value:
# TODO: just remote the partial eval rule # TODO: just remote the partial eval rule
return trace.default_process_primitive(outside_call_p, args, params) return trace.default_process_primitive(outside_call_p, args, params)
transforms = params.get("transforms", ()) transforms = params.get("transforms", ())
@ -1492,7 +1521,7 @@ def _outside_call_transpose_rule(cts, *args, **params):
*cts_instantiated, *cts_instantiated,
**_add_transform(params, "transpose")) **_add_transform(params, "transpose"))
if not FLAGS.jax_host_callback_ad_transforms: if not _HOST_CALLBACK_AD_TRANSFORMS.value:
assert False assert False
assert len(args) % 2 == 0 assert len(args) % 2 == 0

View File

@ -65,7 +65,7 @@ def main(_):
tfds.Split.TEST, batch_size=mnist_lib.test_batch_size) tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)
keras_model.fit(train_ds, epochs=FLAGS.num_epochs, validation_data=test_ds) keras_model.fit(train_ds, epochs=FLAGS.num_epochs, validation_data=test_ds)
if FLAGS.show_images: if saved_model_main.SHOW_IMAGES.value:
mnist_lib.plot_images( mnist_lib.plot_images(
test_ds, test_ds,
1, 1,

View File

@ -38,8 +38,8 @@ import optax
import tensorflow as tf # type: ignore import tensorflow as tf # type: ignore
import tensorflow_datasets as tfds # type: ignore import tensorflow_datasets as tfds # type: ignore
flags.DEFINE_boolean("mock_data", False, "Use fake data, for testing.") _MOCK_DATA = flags.DEFINE_boolean("mock_data", False,
FLAGS = flags.FLAGS "Use fake data, for testing.")
#### Model parameters #### Model parameters
@ -64,7 +64,7 @@ def load_mnist(split: tfds.Split, batch_size: int):
an iterator with pairs (images, labels). The images have shape an iterator with pairs (images, labels). The images have shape
(B, 28, 28, 1) and the labels have shape (B, 10), where B is the batch_size. (B, 28, 28, 1) and the labels have shape (B, 10), where B is the batch_size.
""" """
if FLAGS.mock_data: if _MOCK_DATA.value:
with tfds.testing.mock_data(num_examples=batch_size): with tfds.testing.mock_data(num_examples=batch_size):
try: try:
ds = tfds.load("mnist", split=split) ds = tfds.load("mnist", split=split)

View File

@ -37,46 +37,49 @@ import numpy as np
import tensorflow as tf # type: ignore import tensorflow as tf # type: ignore
import tensorflow_datasets as tfds # type: ignore import tensorflow_datasets as tfds # type: ignore
flags.DEFINE_enum("model", "mnist_flax", ["mnist_flax", "mnist_pure_jax"], _MODEL = flags.DEFINE_enum(
"Which model to use.") "model", "mnist_flax", ["mnist_flax", "mnist_pure_jax"],
flags.DEFINE_boolean("model_classifier_layer", True, "Which model to use.")
_MODEL_CLASSIFIER_LAYER = flags.DEFINE_boolean("model_classifier_layer", True,
("The model should include the classifier layer, or just " ("The model should include the classifier layer, or just "
"the last layer of logits. Set this to False when you " "the last layer of logits. Set this to False when you "
"want to reuse the classifier-less model in a larger " "want to reuse the classifier-less model in a larger "
"model. See keras_reuse_main.py and README.md.")) "model. See keras_reuse_main.py and README.md."))
flags.DEFINE_string("model_path", "/tmp/jax2tf/saved_models", _MODEL_PATH = flags.DEFINE_string("model_path", "/tmp/jax2tf/saved_models",
"Path under which to save the SavedModel.") "Path under which to save the SavedModel.")
flags.DEFINE_integer("model_version", 1, _MODEL_VERSION = flags.DEFINE_integer("model_version", 1,
("The version number for the SavedModel. Needed for " ("The version number for the SavedModel. Needed for "
"serving, larger versions will take precedence"), "serving, larger versions will take precedence"),
lower_bound=1) lower_bound=1)
flags.DEFINE_integer("serving_batch_size", 1, _SERVING_BATCH_SIZE = flags.DEFINE_integer("serving_batch_size", 1,
"For what batch size to prepare the serving signature. " "For what batch size to prepare the serving signature. "
"Use -1 for converting and saving with batch polymorphism.") "Use -1 for converting and saving with batch polymorphism.")
flags.register_validator( flags.register_validator(
"serving_batch_size", "serving_batch_size",
lambda serving_batch_size: serving_batch_size > 0 or serving_batch_size == -1, lambda serving_batch_size: serving_batch_size > 0
message="--serving_batch_size must be either -1 or a positive integer.") or serving_batch_size == -1,
message="--serving_batch_size must be either -1 or a positive integer.",
)
flags.DEFINE_integer("num_epochs", 3, "For how many epochs to train.", _NUM_EPOCHS = flags.DEFINE_integer("num_epochs", 3,
lower_bound=1) "For how many epochs to train.",
flags.DEFINE_boolean( lower_bound=1)
_GENERATE_MODEL = flags.DEFINE_boolean(
"generate_model", True, "generate_model", True,
"Train and save a new model. Otherwise, use an existing SavedModel.") "Train and save a new model. Otherwise, use an existing SavedModel.")
flags.DEFINE_boolean( _COMPILE_MODEL = flags.DEFINE_boolean(
"compile_model", True, "compile_model", True,
"Enable TensorFlow jit_compiler for the SavedModel. This is " "Enable TensorFlow jit_compiler for the SavedModel. This is "
"necessary if you want to use the model for TensorFlow serving.") "necessary if you want to use the model for TensorFlow serving.")
flags.DEFINE_boolean("show_model", True, "Show details of saved SavedModel.") _SHOW_MODEL = flags.DEFINE_boolean("show_model", True,
flags.DEFINE_boolean( "Show details of saved SavedModel.")
SHOW_IMAGES = flags.DEFINE_boolean(
"show_images", False, "show_images", False,
"Plot some sample images with labels and inference results.") "Plot some sample images with labels and inference results.")
flags.DEFINE_boolean( _TEST_SAVEDMODEL = flags.DEFINE_boolean(
"test_savedmodel", True, "test_savedmodel", True,
"Test TensorFlow inference using the SavedModel w.r.t. the JAX model.") "Test TensorFlow inference using the SavedModel w.r.t. the JAX model.")
FLAGS = flags.FLAGS
def train_and_save(): def train_and_save():
logging.info("Loading the MNIST TensorFlow dataset") logging.info("Loading the MNIST TensorFlow dataset")
@ -85,22 +88,22 @@ def train_and_save():
test_ds = mnist_lib.load_mnist( test_ds = mnist_lib.load_mnist(
tfds.Split.TEST, batch_size=mnist_lib.test_batch_size) tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)
if FLAGS.show_images: if SHOW_IMAGES.value:
mnist_lib.plot_images(train_ds, 1, 5, "Training images", inference_fn=None) mnist_lib.plot_images(train_ds, 1, 5, "Training images", inference_fn=None)
the_model_class = pick_model_class() the_model_class = pick_model_class()
model_dir = savedmodel_dir(with_version=True) model_dir = savedmodel_dir(with_version=True)
if FLAGS.generate_model: if _GENERATE_MODEL.value:
model_descr = model_description() model_descr = model_description()
logging.info("Generating model for %s", model_descr) logging.info("Generating model for %s", model_descr)
(predict_fn, predict_params) = the_model_class.train( (predict_fn, predict_params) = the_model_class.train(
train_ds, train_ds,
test_ds, test_ds,
FLAGS.num_epochs, num_epochs=_NUM_EPOCHS.value,
with_classifier=FLAGS.model_classifier_layer) with_classifier=_MODEL_CLASSIFIER_LAYER.value)
if FLAGS.serving_batch_size == -1: if _SERVING_BATCH_SIZE.value == -1:
# Batch-polymorphic SavedModel # Batch-polymorphic SavedModel
input_signatures = [ input_signatures = [
tf.TensorSpec((None,) + mnist_lib.input_shape, tf.float32), tf.TensorSpec((None,) + mnist_lib.input_shape, tf.float32),
@ -109,7 +112,7 @@ def train_and_save():
else: else:
input_signatures = [ input_signatures = [
# The first one will be the serving signature # The first one will be the serving signature
tf.TensorSpec((FLAGS.serving_batch_size,) + mnist_lib.input_shape, tf.TensorSpec((_SERVING_BATCH_SIZE.value,) + mnist_lib.input_shape,
tf.float32), tf.float32),
tf.TensorSpec((mnist_lib.train_batch_size,) + mnist_lib.input_shape, tf.TensorSpec((mnist_lib.train_batch_size,) + mnist_lib.input_shape,
tf.float32), tf.float32),
@ -126,15 +129,15 @@ def train_and_save():
with_gradient=True, with_gradient=True,
input_signatures=input_signatures, input_signatures=input_signatures,
polymorphic_shapes=polymorphic_shapes, polymorphic_shapes=polymorphic_shapes,
compile_model=FLAGS.compile_model) compile_model=_COMPILE_MODEL.value)
if FLAGS.test_savedmodel: if _TEST_SAVEDMODEL.value:
tf_accelerator, tolerances = tf_accelerator_and_tolerances() tf_accelerator, tolerances = tf_accelerator_and_tolerances()
with tf.device(tf_accelerator): with tf.device(tf_accelerator):
logging.info("Testing savedmodel") logging.info("Testing savedmodel")
pure_restored_model = tf.saved_model.load(model_dir) pure_restored_model = tf.saved_model.load(model_dir)
if FLAGS.show_images and FLAGS.model_classifier_layer: if SHOW_IMAGES.value and _MODEL_CLASSIFIER_LAYER.value:
mnist_lib.plot_images( mnist_lib.plot_images(
test_ds, test_ds,
1, 1,
@ -149,7 +152,7 @@ def train_and_save():
pure_restored_model(tf.convert_to_tensor(test_input)), pure_restored_model(tf.convert_to_tensor(test_input)),
predict_fn(predict_params, test_input), **tolerances) predict_fn(predict_params, test_input), **tolerances)
if FLAGS.show_model: if _SHOW_MODEL.value:
def print_model(model_dir: str): def print_model(model_dir: str):
cmd = f"saved_model_cli show --all --dir {model_dir}" cmd = f"saved_model_cli show --all --dir {model_dir}"
print(cmd) print(cmd)
@ -160,18 +163,18 @@ def train_and_save():
def pick_model_class(): def pick_model_class():
"""Picks one of PureJaxMNIST or FlaxMNIST.""" """Picks one of PureJaxMNIST or FlaxMNIST."""
if FLAGS.model == "mnist_pure_jax": if _MODEL.value == "mnist_pure_jax":
return mnist_lib.PureJaxMNIST return mnist_lib.PureJaxMNIST
elif FLAGS.model == "mnist_flax": elif _MODEL.value == "mnist_flax":
return mnist_lib.FlaxMNIST return mnist_lib.FlaxMNIST
else: else:
raise ValueError(f"Unrecognized model: {FLAGS.model}") raise ValueError(f"Unrecognized model: {_MODEL.value}")
def model_description() -> str: def model_description() -> str:
"""A short description of the picked model.""" """A short description of the picked model."""
res = pick_model_class().name res = pick_model_class().name
if not FLAGS.model_classifier_layer: if not _MODEL_CLASSIFIER_LAYER.value:
res += " (features_only)" res += " (features_only)"
return res return res
@ -179,11 +182,11 @@ def model_description() -> str:
def savedmodel_dir(with_version: bool = True) -> str: def savedmodel_dir(with_version: bool = True) -> str:
"""The directory where we save the SavedModel.""" """The directory where we save the SavedModel."""
model_dir = os.path.join( model_dir = os.path.join(
FLAGS.model_path, _MODEL_PATH.value,
FLAGS.model + ('' if FLAGS.model_classifier_layer else '_features') _MODEL.value + ('' if _MODEL_CLASSIFIER_LAYER.value else '_features')
) )
if with_version: if with_version:
model_dir = os.path.join(model_dir, str(FLAGS.model_version)) model_dir = os.path.join(model_dir, str(_MODEL_VERSION.value))
return model_dir return model_dir

View File

@ -32,31 +32,32 @@ from tensorflow_serving.apis import predict_pb2 # type: ignore[import]
from tensorflow_serving.apis import prediction_service_pb2_grpc from tensorflow_serving.apis import prediction_service_pb2_grpc
FLAGS = flags.FLAGS _USE_GRPC = flags.DEFINE_boolean(
flags.DEFINE_boolean(
"use_grpc", True, "use_grpc", True,
"Use the gRPC API (default), or the HTTP REST API.") "Use the gRPC API (default), or the HTTP REST API.")
flags.DEFINE_string( _MODEL_SPEC_NAME = flags.DEFINE_string(
"model_spec_name", "", "model_spec_name", "",
"The name you used to export your model to model server (e.g., mnist_flax).") "The name you used to export your model to model server (e.g., mnist_flax).")
flags.DEFINE_string( _PREDICTION_SERVICE_ADDR = flags.DEFINE_string(
"prediction_service_addr", "prediction_service_addr",
"localhost:8500", "localhost:8500",
"Stubby endpoint for the prediction service. If you serve your model " "Stubby endpoint for the prediction service. If you serve your model "
"locally using TensorFlow model server, then you can use \"localhost:8500\"" "locally using TensorFlow model server, then you can use \"localhost:8500\""
"for the gRPC server and \"localhost:8501\" for the HTTP REST server.") "for the gRPC server and \"localhost:8501\" for the HTTP REST server.")
flags.DEFINE_integer("serving_batch_size", 1, _SERVING_BATCH_SIZE = flags.DEFINE_integer(
"Batch size for the serving request. Must match the " "serving_batch_size",
"batch size at which the model was saved. Must divide " 1,
"--count_images", "Batch size for the serving request. Must match the "
lower_bound=1) "batch size at which the model was saved. Must divide "
flags.DEFINE_integer("count_images", 16, "--count_images",
"How many images to test.", lower_bound=1,
lower_bound=1) )
_COUNT_IMAGES = flags.DEFINE_integer(
"count_images", 16, "How many images to test.", lower_bound=1
)
def serving_call_mnist(images): def serving_call_mnist(images):
@ -69,12 +70,12 @@ def serving_call_mnist(images):
Returns: Returns:
A numpy.ndarray of shape [B, 10] with the one-hot inference response. A numpy.ndarray of shape [B, 10] with the one-hot inference response.
""" """
if FLAGS.use_grpc: if _USE_GRPC.value:
channel = grpc.insecure_channel(FLAGS.prediction_service_addr) channel = grpc.insecure_channel(_PREDICTION_SERVICE_ADDR.value)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
request = predict_pb2.PredictRequest() request = predict_pb2.PredictRequest()
request.model_spec.name = FLAGS.model_spec_name request.model_spec.name = _MODEL_SPEC_NAME.value
request.model_spec.signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY request.model_spec.signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
# You can see the name of the input ("inputs") in the SavedModel dump. # You can see the name of the input ("inputs") in the SavedModel dump.
request.inputs["inputs"].CopyFrom( request.inputs["inputs"].CopyFrom(
@ -90,7 +91,7 @@ def serving_call_mnist(images):
images_json = json.dumps(images.tolist()) images_json = json.dumps(images.tolist())
# You can see the name of the input ("inputs") in the SavedModel dump. # You can see the name of the input ("inputs") in the SavedModel dump.
data = f'{{"inputs": {images_json}}}' data = f'{{"inputs": {images_json}}}'
predict_url = f"http://{FLAGS.prediction_service_addr}/v1/models/{FLAGS.model_spec_name}:predict" predict_url = f"http://{_PREDICTION_SERVICE_ADDR.value}/v1/models/{_MODEL_SPEC_NAME.value}:predict"
response = requests.post(predict_url, data=data) response = requests.post(predict_url, data=data)
if response.status_code != 200: if response.status_code != 200:
msg = (f"Received error response {response.status_code} from model " msg = (f"Received error response {response.status_code} from model "
@ -101,14 +102,14 @@ def serving_call_mnist(images):
def main(_): def main(_):
if FLAGS.count_images % FLAGS.serving_batch_size != 0: if _COUNT_IMAGES.value % _SERVING_BATCH_SIZE.value != 0:
raise ValueError(f"The count_images ({FLAGS.count_images}) must be a " raise ValueError(f"The count_images ({_COUNT_IMAGES.value}) must be a "
"multiple of " "multiple of "
f"serving_batch_size ({FLAGS.serving_batch_size})") f"serving_batch_size ({_SERVING_BATCH_SIZE.value})")
test_ds = mnist_lib.load_mnist(tfds.Split.TEST, test_ds = mnist_lib.load_mnist(tfds.Split.TEST,
batch_size=FLAGS.serving_batch_size) batch_size=_SERVING_BATCH_SIZE.value)
images_and_labels = tfds.as_numpy(test_ds.take( images_and_labels = tfds.as_numpy(test_ds.take(
FLAGS.count_images // FLAGS.serving_batch_size)) _COUNT_IMAGES.value // _SERVING_BATCH_SIZE.value))
accurate_count = 0 accurate_count = 0
for batch_idx, (images, labels) in enumerate(images_and_labels): for batch_idx, (images, labels) in enumerate(images_and_labels):
@ -117,7 +118,7 @@ def main(_):
labels_digit = np.argmax(labels, axis=1) labels_digit = np.argmax(labels, axis=1)
accurate_count += np.sum(labels_digit == predictions_digit) accurate_count += np.sum(labels_digit == predictions_digit)
running_accuracy = ( running_accuracy = (
100. * accurate_count / (1 + batch_idx) / FLAGS.serving_batch_size) 100. * accurate_count / (1 + batch_idx) / _SERVING_BATCH_SIZE.value)
logging.info( logging.info(
" predicted digits = %s labels %s. Running accuracy %.3f%%", " predicted digits = %s labels %s. Running accuracy %.3f%%",
predictions_digit, labels_digit, running_accuracy) predictions_digit, labels_digit, running_accuracy)

View File

@ -33,15 +33,17 @@ import tensorflowjs as tfjs
import input_pipeline # type: ignore[import] import input_pipeline # type: ignore[import]
flags.DEFINE_integer("num_epochs", 5, _NUM_EPOCHS = flags.DEFINE_integer(
("Number of epochs to train for.")) "num_epochs", 5, "Number of epochs to train for."
flags.DEFINE_integer("num_classes", 100, "Number of classification classes.") )
_NUM_CLASSES = flags.DEFINE_integer(
"num_classes", 100, "Number of classification classes."
)
flags.register_validator("num_classes", flags.register_validator("num_classes",
lambda value: value >= 1 and value <= 100, lambda value: value >= 1 and value <= 100,
message="--num_classes must be in range [1, 100]") message="--num_classes must be in range [1, 100]")
FLAGS = flags.FLAGS
# The code below is an adaptation for Flax from the work published here: # The code below is an adaptation for Flax from the work published here:
# https://blog.tensorflow.org/2018/07/train-model-in-tfkeras-with-colab-and-run-in-browser-tensorflowjs.html # https://blog.tensorflow.org/2018/07/train-model-in-tfkeras-with-colab-and-run-in-browser-tensorflowjs.html
@ -65,7 +67,7 @@ class QuickDraw(nn.Module):
x = nn.Dense(features=128)(x) x = nn.Dense(features=128)(x)
x = nn.relu(x) x = nn.relu(x)
x = nn.Dense(features=FLAGS.num_classes)(x) x = nn.Dense(features=_NUM_CLASSES.value)(x)
return x return x
@ -75,7 +77,7 @@ def apply_model(state, inputs, labels):
"""Computes gradients, loss and accuracy for a single batch.""" """Computes gradients, loss and accuracy for a single batch."""
def loss_fn(params): def loss_fn(params):
logits = state.apply_fn({'params': params}, inputs) logits = state.apply_fn({'params': params}, inputs)
one_hot = jax.nn.one_hot(labels, FLAGS.num_classes) one_hot = jax.nn.one_hot(labels, _NUM_CLASSES.value)
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
return loss, logits return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True) grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
@ -113,7 +115,7 @@ def create_train_state(rng):
def train(state, train_ds, test_ds): def train(state, train_ds, test_ds):
for epoch in range(1, FLAGS.num_epochs+1): for epoch in range(1, _NUM_EPOCHS.value+1):
start_time = time.time() start_time = time.time()
state, train_loss, train_accuracy = run_epoch(state, train_ds) state, train_loss, train_accuracy = run_epoch(state, train_ds)
@ -136,12 +138,12 @@ def main(argv):
base_model_path = "/tmp/jax2tf/tf_js_quickdraw" base_model_path = "/tmp/jax2tf/tf_js_quickdraw"
dataset_path = os.path.join(base_model_path, "data") dataset_path = os.path.join(base_model_path, "data")
classes = input_pipeline.download_dataset(dataset_path, FLAGS.num_classes) classes = input_pipeline.download_dataset(dataset_path, _NUM_CLASSES.value)
assert len(classes) == FLAGS.num_classes, "Incorrect number of classes" assert len(classes) == _NUM_CLASSES.value, "Incorrect number of classes"
print(f"Classes are: {classes}") print(f"Classes are: {classes}")
print("Loading dataset into memory...") print("Loading dataset into memory...")
train_ds, test_ds = input_pipeline.get_datasets(dataset_path, classes) train_ds, test_ds = input_pipeline.get_datasets(dataset_path, classes)
print(f"Starting training for {FLAGS.num_epochs} epochs...") print(f"Starting training for {_NUM_EPOCHS.value} epochs...")
state = create_train_state(jax.random.PRNGKey(0)) state = create_train_state(jax.random.PRNGKey(0))
state = train(state, train_ds, test_ds) state = train(state, train_ds, test_ds)

View File

@ -24,14 +24,19 @@ import numpy as np
import tensorflow as tf # type: ignore[import] import tensorflow as tf # type: ignore[import]
import tensorflow_datasets as tfds # type: ignore[import] import tensorflow_datasets as tfds # type: ignore[import]
flags.DEFINE_string('tflite_file_path', _TFLITE_FILE_PATH = flags.DEFINE_string(
'/usr/local/google/home/qiuminxu/jax2tf/mnist.tflite', 'tflite_file_path',
'Path where to save the TensorFlow Lite file.') '/tmp/mnist.tflite',
flags.DEFINE_integer('serving_batch_size', 4, 'Path where to save the TensorFlow Lite file.',
('For what batch size to prepare the serving signature. ')) )
flags.DEFINE_integer('num_epochs', 10, 'For how many epochs to train.') _SERVING_BATCH_SIZE = flags.DEFINE_integer(
'serving_batch_size',
FLAGS = flags.FLAGS 4,
'For what batch size to prepare the serving signature. ',
)
_NUM_EPOCHS = flags.DEFINE_integer(
'num_epochs', 10, 'For how many epochs to train.'
)
# A helper function to evaluate the TF Lite model using "test" dataset. # A helper function to evaluate the TF Lite model using "test" dataset.
@ -71,10 +76,11 @@ def main(_):
train_ds = mnist_lib.load_mnist( train_ds = mnist_lib.load_mnist(
tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size) tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size)
test_ds = mnist_lib.load_mnist( test_ds = mnist_lib.load_mnist(
tfds.Split.TEST, batch_size=FLAGS.serving_batch_size) tfds.Split.TEST, batch_size=_SERVING_BATCH_SIZE)
(flax_predict, (flax_predict, flax_params) = mnist_lib.FlaxMNIST.train(
flax_params) = mnist_lib.FlaxMNIST.train(train_ds, test_ds, FLAGS.num_epochs) train_ds, test_ds, _NUM_EPOCHS.value
)
def predict(image): def predict(image):
return flax_predict(flax_params, image) return flax_predict(flax_params, image)
@ -84,7 +90,7 @@ def main(_):
jax2tf.convert(predict, enable_xla=False), jax2tf.convert(predict, enable_xla=False),
input_signature=[ input_signature=[
tf.TensorSpec( tf.TensorSpec(
shape=[FLAGS.serving_batch_size, 28, 28, 1], shape=[_SERVING_BATCH_SIZE, 28, 28, 1],
dtype=tf.float32, dtype=tf.float32,
name='input') name='input')
], ],
@ -126,7 +132,7 @@ def main(_):
print('Quantized model accuracy = %.4f' % quantized_accuracy) print('Quantized model accuracy = %.4f' % quantized_accuracy)
print('Accuracy drop = %.4f' % (float_accuracy - quantized_accuracy)) print('Accuracy drop = %.4f' % (float_accuracy - quantized_accuracy))
f = open(FLAGS.tflite_file_path, 'wb') f = open(_TFLITE_FILE_PATH.value, 'wb')
f.write(tflite_quantized_model) f.write(tflite_quantized_model)
f.close() f.close()

View File

@ -58,34 +58,32 @@ from jax.experimental.jax2tf.shape_poly import InconclusiveDimensionOperation
from jax.experimental.jax2tf.tests.model_harness import ALL_HARNESSES from jax.experimental.jax2tf.tests.model_harness import ALL_HARNESSES
from jax.experimental.jax2tf.tests.converters import ALL_CONVERTERS from jax.experimental.jax2tf.tests.converters import ALL_CONVERTERS
flags.DEFINE_list("converters", [x.name for x in ALL_CONVERTERS], _CONVERTERS = flags.DEFINE_list("converters", [x.name for x in ALL_CONVERTERS],
"Which converters to test.") "Which converters to test.")
flags.DEFINE_list("examples", [], _EXAMPLES = flags.DEFINE_list("examples", [],
("List of examples to test, e.g.: 'flax/mnist,flax/seq2seq'. " ("List of examples to test, e.g.: 'flax/mnist,flax/seq2seq'. "
"If empty, will test all examples.")) "If empty, will test all examples."))
flags.DEFINE_string("example_prefix", "", _EXAMPLE_PREFIX = flags.DEFINE_string("example_prefix", "",
("Prefix for filtering tests. For instance 'flax/mnist' " ("Prefix for filtering tests. For instance 'flax/mnist' "
"will test all examples starting with 'flax/mnist' " "will test all examples starting with 'flax/mnist' "
"(including all polymorphic tests).")) "(including all polymorphic tests)."))
flags.DEFINE_bool( _WRITE_MARKDOWN = flags.DEFINE_bool(
"write_markdown", True, "write_markdown", True,
"If true, write results as Markdown. Otherwise, only output to stdout.") "If true, write results as Markdown. Otherwise, only output to stdout.")
flags.DEFINE_bool( _FAIL_ON_ERROR = flags.DEFINE_bool(
"fail_on_error", False, "fail_on_error", False,
("If true, exit with an error when a conversion fails. Useful for " ("If true, exit with an error when a conversion fails. Useful for "
"debugging because it will show the entire stack trace.")) "debugging because it will show the entire stack trace."))
FLAGS = flags.FLAGS
def _write_markdown(results: dict[str, list[tuple[str, str,]]]) -> None: def _write_markdown(results: dict[str, list[tuple[str, str,]]]) -> None:
"""Writes all results to Markdown file.""" """Writes all results to Markdown file."""
table_lines = [] table_lines = []
converters = FLAGS.converters converters = _CONVERTERS.value
table_lines.append("| Example | " + " ".join([f"{c} |" for c in converters])) table_lines.append("| Example | " + " ".join([f"{c} |" for c in converters]))
table_lines.append("|" + (" --- |" * (len(converters) + 1))) table_lines.append("|" + (" --- |" * (len(converters) + 1)))
@ -173,7 +171,7 @@ def test_converters():
exit() exit()
def _maybe_reraise(e): def _maybe_reraise(e):
if FLAGS.fail_on_error: if _FAIL_ON_ERROR.value:
raise e raise e
def _format(e): def _format(e):
@ -183,13 +181,13 @@ def test_converters():
return msg return msg
converters = list( converters = list(
filter(lambda x: x.name in FLAGS.converters, ALL_CONVERTERS)) filter(lambda x: x.name in _CONVERTERS.value, ALL_CONVERTERS))
_exit_if_empty(converters, "converters") _exit_if_empty(converters, "converters")
harnesses_to_test = { harnesses_to_test = {
name: fn for name, fn in ALL_HARNESSES.items() name: fn for name, fn in ALL_HARNESSES.items()
if (not FLAGS.examples or name in FLAGS.examples) and if (not _EXAMPLES.value or name in _EXAMPLES.value) and
(not FLAGS.example_prefix or name.startswith(FLAGS.example_prefix)) (not _EXAMPLE_PREFIX.value or name.startswith(_EXAMPLE_PREFIX.value))
} }
_exit_if_empty(harnesses_to_test, "harness") _exit_if_empty(harnesses_to_test, "harness")
@ -243,7 +241,7 @@ def test_converters():
converter_results.append((converter.name, error_msg)) converter_results.append((converter.name, error_msg))
results[harness.name] = converter_results results[harness.name] = converter_results
if FLAGS.write_markdown: if _WRITE_MARKDOWN:
_write_markdown(results) _write_markdown(results)
else: else:
print("=== NOT writing results to Markdown.") print("=== NOT writing results to Markdown.")

View File

@ -140,7 +140,6 @@ def jax_to_ir(name, deps, fn, input_shapes, constants = None, format = "HLO"):
from absl import app from absl import app
import jax.tools.jax_to_ir as jax_to_ir import jax.tools.jax_to_ir as jax_to_ir
jax_to_ir.set_up_flags()
app.run(jax_to_ir.main) app.run(jax_to_ir.main)
EOF EOF
""".format(runner = runner), """.format(runner = runner),

View File

@ -87,7 +87,30 @@ try:
except ImportError: except ImportError:
tf = None # type: ignore tf = None # type: ignore
FLAGS = flags.FLAGS
_FN = flags.DEFINE_string(
'fn', None, "Fully-qualified name of function that we're going to convert"
)
_INPUT_SHAPES = flags.DEFINE_string(
'input_shapes', None, 'Python dict indicating XLA shapes of params'
)
_CONSTANTS = flags.DEFINE_string(
'constants', '{}', 'Python dict giving constant values for some params'
)
_EVALED_CONSTANTS = flags.DEFINE_string(
'evaled_constants',
'{}',
'Python dict giving constant values for some params. '
'Values in this dict that are of type str are evaluated '
'using ast.literal_eval.',
)
_IR_FORMAT = flags.DEFINE_enum(
'ir_format', 'HLO', ('HLO', 'TF'), 'Output format.'
)
_IR_DEST = flags.DEFINE_string('ir_dest', None, 'File to write IR to')
_IR_HUMAN_DEST = flags.DEFINE_string(
'ir_human_dest', None, 'File to write human readable debug output'
)
def jax_to_ir(fn, input_shapes, *, constants=None, format): def jax_to_ir(fn, input_shapes, *, constants=None, format):
@ -163,25 +186,25 @@ def main(argv):
if len(argv) != 1: if len(argv) != 1:
raise app.UsageError('No positional arguments are accepted.') raise app.UsageError('No positional arguments are accepted.')
if not FLAGS.ir_dest and not FLAGS.ir_human_dest: if not _IR_DEST.value and not _IR_HUMAN_DEST.value:
raise app.Error('At least one of --ir_dest and ' raise app.Error('At least one of --ir_dest and '
'--ir_human_dest is required.') '--ir_human_dest is required.')
module_name, fn_name = FLAGS.fn.rsplit('.', 1) module_name, fn_name = _FN.value.rsplit('.', 1)
module = importlib.import_module(module_name) module = importlib.import_module(module_name)
fn = getattr(module, fn_name) fn = getattr(module, fn_name)
input_shapes = [(name, parse_shape_str(shape_str)) input_shapes = [(name, parse_shape_str(shape_str))
for name, shape_str in literal_eval(FLAGS.input_shapes)] for name, shape_str in literal_eval(_INPUT_SHAPES.value)]
# Parse --constants and --evaled_constants. # Parse --constants and --evaled_constants.
constants = {} constants = {}
for k, v in literal_eval(FLAGS.constants).items(): for k, v in literal_eval(_CONSTANTS.value).items():
if isinstance(v, list): if isinstance(v, list):
v = jnp.asarray(v) v = jnp.asarray(v)
constants[k] = v constants[k] = v
for k, v in literal_eval(FLAGS.evaled_constants).items(): for k, v in literal_eval(_EVALED_CONSTANTS.value).items():
if isinstance(v, str): if isinstance(v, str):
v = literal_eval(v) v = literal_eval(v)
if isinstance(v, list): if isinstance(v, list):
@ -192,14 +215,14 @@ def main(argv):
constants[k] = v constants[k] = v
ir, debug_ir = jax_to_ir(fn, input_shapes, constants=constants, ir, debug_ir = jax_to_ir(fn, input_shapes, constants=constants,
format=FLAGS.ir_format) format=_IR_FORMAT.value)
if FLAGS.ir_dest: if _IR_DEST.value:
with open(FLAGS.ir_dest, 'wb') as f: with open(_IR_DEST.value, 'wb') as f:
f.write(ir) f.write(ir)
if FLAGS.ir_human_dest: if _IR_HUMAN_DEST.value:
with open(FLAGS.ir_human_dest, 'w') as f: with open(_IR_HUMAN_DEST.value, 'w') as f:
f.write(debug_ir) f.write(debug_ir)
@ -225,21 +248,6 @@ _SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$")
def set_up_flags(): def set_up_flags():
flags.DEFINE_string(
'fn', None,
"Fully-qualified name of function that we're going to convert")
flags.DEFINE_string('input_shapes', None,
'Python dict indicating XLA shapes of params')
flags.DEFINE_string('constants', '{}',
'Python dict giving constant values for some params')
flags.DEFINE_string('evaled_constants', '{}',
'Python dict giving constant values for some params. '
'Values in this dict that are of type str are evaluated '
'using ast.literal_eval.')
flags.DEFINE_enum('ir_format', 'HLO', ('HLO', 'TF'), 'Output format.')
flags.DEFINE_string('ir_dest', None, 'File to write IR to')
flags.DEFINE_string('ir_human_dest', None,
'File to write human readable debug output')
flags.mark_flag_as_required('fn') flags.mark_flag_as_required('fn')
flags.mark_flag_as_required('input_shapes') flags.mark_flag_as_required('input_shapes')

View File

@ -22,7 +22,6 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax._src import core from jax._src import core
from jax._src import test_util as jtu from jax._src import test_util as jtu
from jax._src.config import flags
from jax.experimental.pjit import pjit from jax.experimental.pjit import pjit
from jax.experimental.serialize_executable import ( from jax.experimental.serialize_executable import (
serialize, deserialize_and_load) serialize, deserialize_and_load)
@ -73,7 +72,7 @@ class JaxAotTest(jtu.JaxTestCase):
except NotImplementedError: except NotImplementedError:
raise unittest.SkipTest('PJRT Topology not supported') raise unittest.SkipTest('PJRT Topology not supported')
if flags.FLAGS.jax_test_with_persistent_compilation_cache: if jtu.TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
raise unittest.SkipTest('Compilation caching not yet supported.') raise unittest.SkipTest('Compilation caching not yet supported.')
@jax.jit @jax.jit