mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Update flags to use the ABSL typed flag API.
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary. For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API. Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`. This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR. PiperOrigin-RevId: 551604974
This commit is contained in:
parent
f35f226b44
commit
76cda0ae07
@ -27,8 +27,7 @@ from absl import app
|
||||
from absl import flags
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_multi_string(
|
||||
_SET_ENV = flags.DEFINE_multi_string(
|
||||
"set_env", None,
|
||||
"Specifies additional environment variables to be injected into the "
|
||||
"environment (via --set_env=variable=value or --set_env=variable). "
|
||||
@ -140,8 +139,8 @@ def jax_binary_op(state, **kwargs):
|
||||
)
|
||||
|
||||
def main(argv):
|
||||
if FLAGS.set_env:
|
||||
for env_str in FLAGS.set_env:
|
||||
if _SET_ENV.value:
|
||||
for env_str in _SET_ENV.value:
|
||||
# Stop matching at the first '=' since we want to capture
|
||||
# --set_env='FOO=--foo_a=1 --foo_b=2' all as part of FOO.
|
||||
env_list = env_str.split('=', 1)
|
||||
|
@ -83,22 +83,21 @@ import numpy.random as npr
|
||||
from dp_accounting import dp_event
|
||||
from dp_accounting import rdp
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
_DPSGD = flags.DEFINE_boolean(
|
||||
'dpsgd', True, 'If True, train with DP-SGD. If False, '
|
||||
'train with vanilla SGD.')
|
||||
flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
|
||||
flags.DEFINE_float('noise_multiplier', 1.1,
|
||||
_LEARNING_RATE = flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
|
||||
_NOISE_MULTIPLIER = flags.DEFINE_float('noise_multiplier', 1.1,
|
||||
'Ratio of the standard deviation to the clipping norm')
|
||||
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
||||
flags.DEFINE_integer('batch_size', 256, 'Batch size')
|
||||
flags.DEFINE_integer('epochs', 60, 'Number of epochs')
|
||||
flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG')
|
||||
flags.DEFINE_integer(
|
||||
_L2_NORM_CLIP = flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
||||
_BATCH_SIZE = flags.DEFINE_integer('batch_size', 256, 'Batch size')
|
||||
_EPOCHS = flags.DEFINE_integer('epochs', 60, 'Number of epochs')
|
||||
_SEED = flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG')
|
||||
_MICROBATCHES = flags.DEFINE_integer(
|
||||
'microbatches', None, 'Number of microbatches '
|
||||
'(must evenly divide batch_size)')
|
||||
flags.DEFINE_string('model_dir', None, 'Model directory')
|
||||
_MODEL_DIR = flags.DEFINE_string('model_dir', None, 'Model directory')
|
||||
|
||||
|
||||
init_random_params, predict = stax.serial(
|
||||
@ -163,39 +162,39 @@ def shape_as_image(images, labels, dummy_dim=False):
|
||||
def compute_epsilon(steps, num_examples=60000, target_delta=1e-5):
|
||||
if num_examples * target_delta > 1.:
|
||||
warnings.warn('Your delta might be too high.')
|
||||
q = FLAGS.batch_size / float(num_examples)
|
||||
q = _BATCH_SIZE.value / float(num_examples)
|
||||
orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64))
|
||||
accountant = rdp.rdp_privacy_accountant.RdpAccountant(orders)
|
||||
accountant.compose(
|
||||
dp_event.PoissonSampledDpEvent(
|
||||
q, dp_event.GaussianDpEvent(FLAGS.noise_multiplier)), steps)
|
||||
q, dp_event.GaussianDpEvent(_NOISE_MULTIPLIER.value)), steps)
|
||||
return accountant.get_epsilon(target_delta)
|
||||
|
||||
|
||||
def main(_):
|
||||
|
||||
if FLAGS.microbatches:
|
||||
if _MICROBATCHES.value:
|
||||
raise NotImplementedError(
|
||||
'Microbatches < batch size not currently supported'
|
||||
)
|
||||
|
||||
train_images, train_labels, test_images, test_labels = datasets.mnist()
|
||||
num_train = train_images.shape[0]
|
||||
num_complete_batches, leftover = divmod(num_train, FLAGS.batch_size)
|
||||
num_complete_batches, leftover = divmod(num_train, _BATCH_SIZE.value)
|
||||
num_batches = num_complete_batches + bool(leftover)
|
||||
key = random.PRNGKey(FLAGS.seed)
|
||||
key = random.PRNGKey(_SEED.value)
|
||||
|
||||
def data_stream():
|
||||
rng = npr.RandomState(FLAGS.seed)
|
||||
rng = npr.RandomState(_SEED.value)
|
||||
while True:
|
||||
perm = rng.permutation(num_train)
|
||||
for i in range(num_batches):
|
||||
batch_idx = perm[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size]
|
||||
batch_idx = perm[i * _BATCH_SIZE.value:(i + 1) * _BATCH_SIZE.value]
|
||||
yield train_images[batch_idx], train_labels[batch_idx]
|
||||
|
||||
batches = data_stream()
|
||||
|
||||
opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)
|
||||
opt_init, opt_update, get_params = optimizers.sgd(_LEARNING_RATE.value)
|
||||
|
||||
@jit
|
||||
def update(_, i, opt_state, batch):
|
||||
@ -208,19 +207,19 @@ def main(_):
|
||||
rng = random.fold_in(rng, i) # get new key for new random numbers
|
||||
return opt_update(
|
||||
i,
|
||||
private_grad(params, batch, rng, FLAGS.l2_norm_clip,
|
||||
FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)
|
||||
private_grad(params, batch, rng, _L2_NORM_CLIP.value,
|
||||
_NOISE_MULTIPLIER.value, _BATCH_SIZE.value), opt_state)
|
||||
|
||||
_, init_params = init_random_params(key, (-1, 28, 28, 1))
|
||||
opt_state = opt_init(init_params)
|
||||
itercount = itertools.count()
|
||||
|
||||
steps_per_epoch = 60000 // FLAGS.batch_size
|
||||
steps_per_epoch = 60000 // _BATCH_SIZE.value
|
||||
print('\nStarting training...')
|
||||
for epoch in range(1, FLAGS.epochs + 1):
|
||||
for epoch in range(1, _EPOCHS.value + 1):
|
||||
start_time = time.time()
|
||||
for _ in range(num_batches):
|
||||
if FLAGS.dpsgd:
|
||||
if _DPSGD.value:
|
||||
opt_state = \
|
||||
private_update(
|
||||
key, next(itercount), opt_state,
|
||||
@ -239,7 +238,7 @@ def main(_):
|
||||
test_loss, 100 * test_acc))
|
||||
|
||||
# determine privacy loss so far
|
||||
if FLAGS.dpsgd:
|
||||
if _DPSGD.value:
|
||||
delta = 1e-5
|
||||
num_examples = 60000
|
||||
eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta)
|
||||
|
@ -16,7 +16,6 @@
|
||||
"""
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from functools import partial
|
||||
from jax import grad
|
||||
from jax import jit
|
||||
@ -27,8 +26,6 @@ import jax.random as random
|
||||
import jax.scipy as scipy
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
|
||||
|
@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Hashable, Iterator
|
||||
import contextlib
|
||||
import functools
|
||||
@ -20,7 +22,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, Callable, NamedTuple, Optional
|
||||
from typing import Any, Callable, Generic, NamedTuple, Optional, TypeVar
|
||||
|
||||
from jax._src import lib
|
||||
from jax._src.lib import jax_jit
|
||||
@ -29,6 +31,8 @@ from jax._src.lib import xla_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_T = TypeVar('_T')
|
||||
|
||||
|
||||
def bool_env(varname: str, default: bool) -> bool:
|
||||
"""Read an environment variable and interpret it as a boolean.
|
||||
@ -64,6 +68,16 @@ UPGRADE_BOOL_HELP = (
|
||||
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
|
||||
|
||||
|
||||
class FlagHolder(Generic[_T]):
|
||||
def __init__(self, flags: NameSpace, name: str):
|
||||
self._flags = flags
|
||||
self._name = name
|
||||
|
||||
@property
|
||||
def value(self) -> _T:
|
||||
return getattr(self._flags, self._name)
|
||||
|
||||
|
||||
class Config:
|
||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||
|
||||
@ -112,26 +126,31 @@ class Config:
|
||||
if name not in self.values:
|
||||
raise AttributeError(f"Unrecognized config option: {name}")
|
||||
|
||||
def DEFINE_bool(self, name, default, *args, **kwargs):
|
||||
def DEFINE_bool(self, name, default, *args, **kwargs) -> FlagHolder[bool]:
|
||||
update_hook = kwargs.pop("update_hook", None)
|
||||
self.add_option(name, default, bool, args, kwargs, update_hook=update_hook)
|
||||
return FlagHolder(self.FLAGS, name)
|
||||
|
||||
def DEFINE_integer(self, name, default, *args, **kwargs):
|
||||
def DEFINE_integer(self, name, default, *args, **kwargs) -> FlagHolder[int]:
|
||||
update_hook = kwargs.pop("update_hook", None)
|
||||
self.add_option(name, default, int, args, kwargs, update_hook=update_hook)
|
||||
return FlagHolder(self.FLAGS, name)
|
||||
|
||||
def DEFINE_float(self, name, default, *args, **kwargs):
|
||||
def DEFINE_float(self, name, default, *args, **kwargs) -> FlagHolder[float]:
|
||||
update_hook = kwargs.pop("update_hook", None)
|
||||
self.add_option(name, default, float, args, kwargs, update_hook=update_hook)
|
||||
return FlagHolder(self.FLAGS, name)
|
||||
|
||||
def DEFINE_string(self, name, default, *args, **kwargs):
|
||||
def DEFINE_string(self, name, default, *args, **kwargs) -> FlagHolder[str]:
|
||||
update_hook = kwargs.pop("update_hook", None)
|
||||
self.add_option(name, default, str, args, kwargs, update_hook=update_hook)
|
||||
return FlagHolder(self.FLAGS, name)
|
||||
|
||||
def DEFINE_enum(self, name, default, *args, **kwargs):
|
||||
def DEFINE_enum(self, name, default, *args, **kwargs) -> FlagHolder[str]:
|
||||
update_hook = kwargs.pop("update_hook", None)
|
||||
self.add_option(name, default, 'enum', args, kwargs,
|
||||
update_hook=update_hook)
|
||||
return FlagHolder(self.FLAGS, name)
|
||||
|
||||
def config_with_absl(self):
|
||||
# Run this before calling `app.run(main)` etc
|
||||
@ -551,6 +570,22 @@ config = Config()
|
||||
flags = config
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
def DEFINE_bool(name, default, *args, **kwargs):
|
||||
return flags.DEFINE_bool(name, default, *args, **kwargs)
|
||||
|
||||
def DEFINE_integer(name, default, *args, **kwargs):
|
||||
return flags.DEFINE_integer(name, default, *args, **kwargs)
|
||||
|
||||
def DEFINE_float(name, default, *args, **kwargs):
|
||||
return flags.DEFINE_float(name, default, *args, **kwargs)
|
||||
|
||||
def DEFINE_string(name, default, *args, **kwargs):
|
||||
return flags.DEFINE_string(name, default, *args, **kwargs)
|
||||
|
||||
def DEFINE_enum(name, default, *args, **kwargs):
|
||||
return flags.DEFINE_enum(name, default, *args, **kwargs)
|
||||
|
||||
|
||||
already_configured_with_absl = False
|
||||
|
||||
|
||||
@ -617,51 +652,6 @@ def update_thread_local_jit_state(**kw):
|
||||
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
|
||||
|
||||
|
||||
flags.DEFINE_integer(
|
||||
'jax_tracer_error_num_traceback_frames',
|
||||
int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
|
||||
help='Set the number of stack frames in JAX tracer error messages.'
|
||||
)
|
||||
|
||||
flags.DEFINE_bool(
|
||||
'jax_pprint_use_color',
|
||||
bool_env('JAX_PPRINT_USE_COLOR', True),
|
||||
help='Enable jaxpr pretty-printing with colorful syntax highlighting.'
|
||||
)
|
||||
|
||||
flags.DEFINE_bool(
|
||||
'jax_host_callback_inline',
|
||||
bool_env('JAX_HOST_CALLBACK_INLINE', False),
|
||||
help='Inline the host_callback, if not in a staged context.'
|
||||
)
|
||||
flags.DEFINE_integer(
|
||||
'jax_host_callback_max_queue_byte_size',
|
||||
int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)),
|
||||
help=('The size in bytes of the buffer used to hold outfeeds from each '
|
||||
'device. When this capacity is reached consuming outfeeds from the '
|
||||
'device is paused, thus potentially pausing the device computation, '
|
||||
'until the Python callback consume more outfeeds.'),
|
||||
lower_bound=int(16 * 1e6)
|
||||
)
|
||||
flags.DEFINE_bool(
|
||||
'jax_host_callback_outfeed',
|
||||
bool_env('JAX_HOST_CALLBACK_OUTFEED', False),
|
||||
help=(
|
||||
'Use outfeed implementation for host_callback, even on CPU and GPU. '
|
||||
'If false, use the CustomCall implementation. '
|
||||
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
|
||||
)
|
||||
)
|
||||
flags.DEFINE_bool(
|
||||
'jax_host_callback_ad_transforms',
|
||||
bool_env('JAX_HOST_CALLBACK_AD_TRANSFORMS', False),
|
||||
help=(
|
||||
'Enable support for jvp/vjp for the host_callback primitives. Default is '
|
||||
'False, which means that host_callback operates only on primals. '
|
||||
'The flag exists only temporarily, for backward compatibility.'
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(b/214340779): remove flag when XLA:CPU is improved.
|
||||
jax2tf_associative_scan_reductions = config.define_bool_state(
|
||||
name='jax2tf_associative_scan_reductions',
|
||||
|
@ -38,7 +38,7 @@ import numpy as np
|
||||
from jax._src import dtypes
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import effects
|
||||
from jax._src.config import FLAGS, config
|
||||
from jax._src.config import config
|
||||
from jax._src.errors import (
|
||||
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
|
||||
TracerIntegerConversionError, UnexpectedTracerError)
|
||||
@ -60,6 +60,13 @@ zip, unsafe_zip = safe_zip, zip
|
||||
map, unsafe_map = safe_map, map
|
||||
|
||||
|
||||
_TRACER_ERROR_NUM_TRACEBACK_FRAMES = jax_config.DEFINE_integer(
|
||||
'jax_tracer_error_num_traceback_frames',
|
||||
jax_config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
|
||||
help='Set the number of stack frames in JAX tracer error messages.'
|
||||
)
|
||||
|
||||
|
||||
# -------------------- jaxprs --------------------
|
||||
|
||||
Effect = effects.Effect
|
||||
@ -560,7 +567,7 @@ def raise_as_much_as_possible(tracer) -> Tracer:
|
||||
|
||||
|
||||
def escaped_tracer_error(tracer, detail=None):
|
||||
num_frames = FLAGS.jax_tracer_error_num_traceback_frames
|
||||
num_frames = _TRACER_ERROR_NUM_TRACEBACK_FRAMES.value
|
||||
msg = ('Encountered an unexpected tracer. A function transformed by JAX '
|
||||
'had a side effect, allowing for a reference to an intermediate value '
|
||||
f'with type {tracer.aval.str_short()} wrapped in a '
|
||||
|
@ -32,6 +32,7 @@ import warnings
|
||||
import numpy as np
|
||||
|
||||
from jax._src import compilation_cache
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src import linear_util as lu
|
||||
@ -43,7 +44,7 @@ from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src import op_shardings
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.config import config, flags
|
||||
from jax._src.config import config
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -64,9 +65,7 @@ JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
|
||||
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
|
||||
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string(
|
||||
_DUMP_IR_TO = jax_config.DEFINE_string(
|
||||
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
|
||||
help="Path to which the IR that is emitted by JAX as input to the "
|
||||
"compiler should be dumped as text files. Optional. If omitted, JAX "
|
||||
@ -472,7 +471,7 @@ def _make_string_safe_for_filename(s: str) -> str:
|
||||
def _dump_ir_to_file(name: str, ir: str):
|
||||
id = next(_ir_dump_counter)
|
||||
name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir"
|
||||
name = path.Path(FLAGS.jax_dump_ir_to) / name
|
||||
name = path.Path(_DUMP_IR_TO.value) / name
|
||||
name.write_text(ir)
|
||||
|
||||
|
||||
@ -481,7 +480,7 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray,
|
||||
sym_name = computation.operation.attributes['sym_name']
|
||||
module_name = ir.StringAttr(sym_name).value
|
||||
|
||||
if FLAGS.jax_dump_ir_to:
|
||||
if _DUMP_IR_TO.value:
|
||||
_dump_ir_to_file(module_name, mlir.module_to_string(computation))
|
||||
|
||||
# Persistent compilation cache only implemented on TPU and GPU.
|
||||
|
@ -28,7 +28,7 @@ import warnings
|
||||
import ml_dtypes
|
||||
import numpy as np
|
||||
|
||||
from jax._src.config import flags, config
|
||||
from jax._src.config import config
|
||||
from jax._src.typing import DType, DTypeLike
|
||||
|
||||
from jax._src import traceback_util
|
||||
@ -43,8 +43,6 @@ else:
|
||||
raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; "
|
||||
f"installed version is {ml_dtypes.__version__}.")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class extended(np.generic):
|
||||
"""Scalar class for extended dtypes.
|
||||
|
@ -32,13 +32,20 @@ from functools import partial
|
||||
import sys
|
||||
from typing import NamedTuple, Optional, Union
|
||||
|
||||
from jax._src.config import config
|
||||
from jax._src import config
|
||||
|
||||
try:
|
||||
import colorama # pytype: disable=import-error
|
||||
except ImportError:
|
||||
colorama = None
|
||||
|
||||
|
||||
_PPRINT_USE_COLOR = config.DEFINE_bool(
|
||||
'jax_pprint_use_color',
|
||||
config.bool_env('JAX_PPRINT_USE_COLOR', True),
|
||||
help='Enable jaxpr pretty-printing with colorful syntax highlighting.'
|
||||
)
|
||||
|
||||
def _can_use_color() -> bool:
|
||||
try:
|
||||
# Check if we're in IPython or Colab
|
||||
@ -63,7 +70,7 @@ class Doc(abc.ABC):
|
||||
def format(self, width: int = 80, use_color: Optional[bool] = None,
|
||||
annotation_prefix=" # ") -> str:
|
||||
if use_color is None:
|
||||
use_color = CAN_USE_COLOR and config.FLAGS.jax_pprint_use_color
|
||||
use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value
|
||||
return _format(self, width, use_color=use_color,
|
||||
annotation_prefix=annotation_prefix)
|
||||
|
||||
|
@ -16,9 +16,10 @@ from functools import partial
|
||||
import operator
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.config import config, flags
|
||||
from jax._src.config import config
|
||||
from jax._src.tree_util import tree_map, tree_reduce
|
||||
|
||||
import numpy as np
|
||||
@ -30,7 +31,11 @@ import numpy as np
|
||||
__all__ = ['check_grads', 'check_jvp', 'check_vjp']
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
_TEST_DUT = jax_config.DEFINE_string(
|
||||
'jax_test_dut', '',
|
||||
help=
|
||||
'Describes the device under test in case special consideration is required.'
|
||||
)
|
||||
|
||||
EPS = 1e-4
|
||||
|
||||
@ -292,4 +297,4 @@ def check_grads(f, args, order,
|
||||
|
||||
|
||||
def device_under_test():
|
||||
return getattr(FLAGS, 'jax_test_dut', None) or xla_bridge.get_backend().platform
|
||||
return _TEST_DUT.value or xla_bridge.get_backend().platform
|
||||
|
@ -41,11 +41,12 @@ from jax._src.interpreters import mlir
|
||||
from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten
|
||||
from jax._src import api
|
||||
from jax._src import pjit as pjit_lib
|
||||
from jax._src import config as jax_config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes as _dtypes
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.config import (flags, bool_env, config,
|
||||
from jax._src.config import (bool_env, config,
|
||||
raise_persistent_cache_errors,
|
||||
persistent_cache_min_compile_time_secs)
|
||||
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
|
||||
@ -60,19 +61,12 @@ from jax._src import xla_bridge
|
||||
# jax.test_util. Functionality appearing here is for internal use only, and
|
||||
# may be changed or removed at any time and without any deprecation cycle.
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string(
|
||||
'jax_test_dut', '',
|
||||
help=
|
||||
'Describes the device under test in case special consideration is required.'
|
||||
)
|
||||
|
||||
flags.DEFINE_integer(
|
||||
_NUM_GENERATED_CASES = jax_config.DEFINE_integer(
|
||||
'jax_num_generated_cases',
|
||||
int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
|
||||
help='Number of generated cases to test')
|
||||
|
||||
flags.DEFINE_integer(
|
||||
_MAX_CASES_SAMPLING_RETRIES = jax_config.DEFINE_integer(
|
||||
'max_cases_sampling_retries',
|
||||
int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')),
|
||||
'Number of times a failed test sample should be retried. '
|
||||
@ -80,24 +74,23 @@ flags.DEFINE_integer(
|
||||
'sampling process is terminated.'
|
||||
)
|
||||
|
||||
flags.DEFINE_bool(
|
||||
_SKIP_SLOW_TESTS = jax_config.DEFINE_bool(
|
||||
'jax_skip_slow_tests',
|
||||
bool_env('JAX_SKIP_SLOW_TESTS', False),
|
||||
help='Skip tests marked as slow (> 5 sec).'
|
||||
)
|
||||
|
||||
flags.DEFINE_string(
|
||||
_TEST_TARGETS = jax_config.DEFINE_string(
|
||||
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
|
||||
'Regular expression specifying which tests to run, called via re.search on '
|
||||
'the test name. If empty or unspecified, run all tests.'
|
||||
)
|
||||
flags.DEFINE_string(
|
||||
_EXCLUDE_TEST_TARGETS = jax_config.DEFINE_string(
|
||||
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
|
||||
'Regular expression specifying which tests NOT to run, called via re.search '
|
||||
'on the test name. If empty or unspecified, run all tests.'
|
||||
)
|
||||
|
||||
flags.DEFINE_bool(
|
||||
TEST_WITH_PERSISTENT_COMPILATION_CACHE = jax_config.DEFINE_bool(
|
||||
'jax_test_with_persistent_compilation_cache',
|
||||
bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False),
|
||||
help='If enabled, the persistent compilation cache will be enabled for all '
|
||||
@ -731,7 +724,7 @@ def assert_dot_precision(expected_precision, fun, *args):
|
||||
|
||||
def cases_from_gens(*gens):
|
||||
sizes = [1, 3, 10]
|
||||
cases_per_size = int(FLAGS.jax_num_generated_cases / len(sizes)) + 1
|
||||
cases_per_size = int(_NUM_GENERATED_CASES.value / len(sizes)) + 1
|
||||
for size in sizes:
|
||||
for i in range(cases_per_size):
|
||||
yield (f'_{size}_{i}',) + tuple(gen(size) for gen in gens)
|
||||
@ -744,8 +737,8 @@ def named_cases_from_sampler(gen):
|
||||
if not isinstance(x, (list, tuple)):
|
||||
x = list(x)
|
||||
return [x[rng.randint(len(x))]]
|
||||
while (len(seen) < FLAGS.jax_num_generated_cases and
|
||||
retries < FLAGS.max_cases_sampling_retries):
|
||||
while (len(seen) < _NUM_GENERATED_CASES.value and
|
||||
retries < _MAX_CASES_SAMPLING_RETRIES.value):
|
||||
retries += 1
|
||||
cases = list(gen(choose_one))
|
||||
if not cases:
|
||||
@ -773,7 +766,7 @@ def sample_product_testcases(*args, **kw):
|
||||
kw = [(k, list(v)) for k, v in kw.items()]
|
||||
n = math.prod(len(a) for a in args) * math.prod(len(v) for _, v in kw)
|
||||
testcases = []
|
||||
for i in _choice(n, min(n, FLAGS.jax_num_generated_cases)):
|
||||
for i in _choice(n, min(n, _NUM_GENERATED_CASES.value)):
|
||||
testcase = {}
|
||||
for a in args:
|
||||
testcase.update(a[i % len(a)])
|
||||
@ -804,12 +797,12 @@ def sample_product(*args, **kw):
|
||||
class JaxTestLoader(absltest.TestLoader):
|
||||
def getTestCaseNames(self, testCaseClass):
|
||||
names = super().getTestCaseNames(testCaseClass)
|
||||
if FLAGS.test_targets:
|
||||
pattern = re.compile(FLAGS.test_targets)
|
||||
if _TEST_TARGETS.value:
|
||||
pattern = re.compile(_TEST_TARGETS.value)
|
||||
names = [name for name in names
|
||||
if pattern.search(f"{testCaseClass.__name__}.{name}")]
|
||||
if FLAGS.exclude_test_targets:
|
||||
pattern = re.compile(FLAGS.exclude_test_targets)
|
||||
if _EXCLUDE_TEST_TARGETS.value:
|
||||
pattern = re.compile(_EXCLUDE_TEST_TARGETS.value)
|
||||
names = [name for name in names
|
||||
if not pattern.search(f"{testCaseClass.__name__}.{name}")]
|
||||
return names
|
||||
@ -874,7 +867,7 @@ class JaxTestCase(parameterized.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if FLAGS.jax_test_with_persistent_compilation_cache:
|
||||
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
|
||||
cls._compilation_cache_exit_stack = ExitStack()
|
||||
stack = cls._compilation_cache_exit_stack
|
||||
stack.enter_context(raise_persistent_cache_errors(True))
|
||||
@ -887,7 +880,7 @@ class JaxTestCase(parameterized.TestCase):
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
if FLAGS.jax_test_with_persistent_compilation_cache:
|
||||
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
|
||||
cls._compilation_cache_exit_stack.close()
|
||||
|
||||
def rng(self):
|
||||
|
@ -37,7 +37,8 @@ import numpy as np
|
||||
|
||||
from jax._src import lib
|
||||
from jax._src import distributed
|
||||
from jax._src.config import flags, bool_env, config, int_env
|
||||
from jax._src import config as jax_config
|
||||
from jax._src.config import bool_env, config, int_env
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import traceback_util
|
||||
@ -59,37 +60,35 @@ traceback_util.register_exclusion(__file__)
|
||||
|
||||
XlaBackend = xla_client.Client
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
# TODO(phawkins): Remove jax_xla_backend.
|
||||
flags.DEFINE_string(
|
||||
_XLA_BACKEND = jax_config.DEFINE_string(
|
||||
'jax_xla_backend', '',
|
||||
'Deprecated, please use --jax_platforms instead.')
|
||||
flags.DEFINE_string(
|
||||
BACKEND_TARGET = jax_config.DEFINE_string(
|
||||
'jax_backend_target',
|
||||
os.getenv('JAX_BACKEND_TARGET', '').lower(),
|
||||
'Either "local" or "rpc:address" to connect to a remote service target.')
|
||||
# TODO(skye): warn when this is used once we test out --jax_platforms a bit
|
||||
flags.DEFINE_string(
|
||||
_PLATFORM_NAME = jax_config.DEFINE_string(
|
||||
'jax_platform_name',
|
||||
os.getenv('JAX_PLATFORM_NAME', '').lower(),
|
||||
'Deprecated, please use --jax_platforms instead.')
|
||||
flags.DEFINE_bool(
|
||||
_DISABLE_MOST_OPTIMIZATIONS = jax_config.DEFINE_bool(
|
||||
'jax_disable_most_optimizations',
|
||||
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
|
||||
'Try not to do much optimization work. This can be useful if the cost of '
|
||||
'optimization is greater than that of running a less-optimized program.')
|
||||
flags.DEFINE_integer(
|
||||
_XLA_PROFILE_VERSION = jax_config.DEFINE_integer(
|
||||
'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0),
|
||||
'Optional profile version for XLA compilation. '
|
||||
'This is meaningful only when XLA is configured to '
|
||||
'support the remote compilation profile feature.')
|
||||
flags.DEFINE_string(
|
||||
CUDA_VISIBLE_DEVICES = jax_config.DEFINE_string(
|
||||
'jax_cuda_visible_devices', 'all',
|
||||
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
|
||||
'comma-separate list of integer device IDs.')
|
||||
flags.DEFINE_string(
|
||||
_ROCM_VISIBLE_DEVICES = jax_config.DEFINE_string(
|
||||
'jax_rocm_visible_devices', 'all',
|
||||
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
|
||||
'comma-separate list of integer device IDs.')
|
||||
@ -171,13 +170,13 @@ def get_compile_options(
|
||||
if lib.cuda_path is not None:
|
||||
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path
|
||||
|
||||
if FLAGS.jax_disable_most_optimizations:
|
||||
if _DISABLE_MOST_OPTIMIZATIONS.value:
|
||||
|
||||
debug_options.xla_backend_optimization_level = 0
|
||||
debug_options.xla_llvm_disable_expensive_passes = True
|
||||
debug_options.xla_test_all_input_layouts = False
|
||||
|
||||
compile_options.profile_version = FLAGS.jax_xla_profile_version
|
||||
compile_options.profile_version = _XLA_PROFILE_VERSION.value
|
||||
return compile_options
|
||||
|
||||
|
||||
@ -264,9 +263,9 @@ register_backend_factory('cpu',
|
||||
|
||||
|
||||
def make_gpu_client(
|
||||
*, platform_name: str, visible_devices_flag: str
|
||||
*, platform_name: str, visible_devices_flag: jax_config.FlagHolder[str]
|
||||
) -> xla_client.Client:
|
||||
visible_devices = getattr(FLAGS, visible_devices_flag, "all")
|
||||
visible_devices = visible_devices_flag.value
|
||||
allowed_devices = None
|
||||
if visible_devices != "all":
|
||||
allowed_devices = {int(x) for x in visible_devices.split(",")}
|
||||
@ -292,15 +291,25 @@ def make_gpu_client(
|
||||
|
||||
if hasattr(xla_client, "make_gpu_client"):
|
||||
register_backend_factory(
|
||||
'cuda', partial(make_gpu_client, platform_name='cuda',
|
||||
visible_devices_flag='jax_cuda_visible_devices'),
|
||||
"cuda",
|
||||
partial(
|
||||
make_gpu_client,
|
||||
platform_name="cuda",
|
||||
visible_devices_flag=CUDA_VISIBLE_DEVICES,
|
||||
),
|
||||
priority=200,
|
||||
fail_quietly=True)
|
||||
fail_quietly=True,
|
||||
)
|
||||
register_backend_factory(
|
||||
'rocm', partial(make_gpu_client, platform_name='rocm',
|
||||
visible_devices_flag='jax_rocm_visible_devices'),
|
||||
"rocm",
|
||||
partial(
|
||||
make_gpu_client,
|
||||
platform_name="rocm",
|
||||
visible_devices_flag=_ROCM_VISIBLE_DEVICES,
|
||||
),
|
||||
priority=200,
|
||||
fail_quietly=True)
|
||||
fail_quietly=True,
|
||||
)
|
||||
|
||||
|
||||
if hasattr(xla_client, "make_tpu_client"):
|
||||
@ -623,7 +632,7 @@ def backends() -> dict[str, xla_client.Client]:
|
||||
# support anything else there at the moment and warning would be pointless.
|
||||
if (py_platform.system() != "Darwin" and
|
||||
_default_backend.platform == "cpu" and
|
||||
FLAGS.jax_platform_name != 'cpu'):
|
||||
_PLATFORM_NAME.value != 'cpu'):
|
||||
logger.warning('No GPU/TPU found, falling back to CPU. '
|
||||
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
|
||||
return _backends
|
||||
@ -677,8 +686,7 @@ def _get_backend_uncached(
|
||||
if platform is not None and not isinstance(platform, str):
|
||||
return platform
|
||||
|
||||
platform = (platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name
|
||||
or None)
|
||||
platform = (platform or _XLA_BACKEND.value or _PLATFORM_NAME.value or None)
|
||||
|
||||
bs = backends()
|
||||
if platform is not None:
|
||||
|
@ -507,7 +507,7 @@ import warnings
|
||||
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax import config
|
||||
from jax._src import config
|
||||
from jax import custom_derivatives
|
||||
from jax._src import dtypes
|
||||
from jax import lax
|
||||
@ -531,17 +531,46 @@ from jax._src.lib.mlir.dialects import hlo
|
||||
import numpy as np
|
||||
|
||||
|
||||
FLAGS = config.FLAGS
|
||||
_HOST_CALLBACK_INLINE = config.DEFINE_bool(
|
||||
'jax_host_callback_inline',
|
||||
config.bool_env('JAX_HOST_CALLBACK_INLINE', False),
|
||||
help='Inline the host_callback, if not in a staged context.'
|
||||
)
|
||||
_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.DEFINE_integer(
|
||||
'jax_host_callback_max_queue_byte_size',
|
||||
config.int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)),
|
||||
help=('The size in bytes of the buffer used to hold outfeeds from each '
|
||||
'device. When this capacity is reached consuming outfeeds from the '
|
||||
'device is paused, thus potentially pausing the device computation, '
|
||||
'until the Python callback consume more outfeeds.'),
|
||||
lower_bound=int(16 * 1e6)
|
||||
)
|
||||
_HOST_CALLBACK_OUTFEED = config.DEFINE_bool(
|
||||
'jax_host_callback_outfeed',
|
||||
config.bool_env('JAX_HOST_CALLBACK_OUTFEED', False),
|
||||
help=(
|
||||
'Use outfeed implementation for host_callback, even on CPU and GPU. '
|
||||
'If false, use the CustomCall implementation. '
|
||||
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
|
||||
)
|
||||
)
|
||||
_HOST_CALLBACK_AD_TRANSFORMS = config.DEFINE_bool(
|
||||
'jax_host_callback_ad_transforms',
|
||||
config.bool_env('JAX_HOST_CALLBACK_AD_TRANSFORMS', False),
|
||||
help=(
|
||||
'Enable support for jvp/vjp for the host_callback primitives. Default is '
|
||||
'False, which means that host_callback operates only on primals. '
|
||||
'The flag exists only temporarily, for backward compatibility.'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _inline_host_callback() -> bool:
|
||||
return FLAGS.jax_host_callback_inline
|
||||
|
||||
|
||||
def _use_outfeed(platform: str) -> bool:
|
||||
return (platform in ("tpu", "gpu", "cuda", "rocm") or FLAGS.jax_host_callback_outfeed)
|
||||
return (platform in ("tpu", "gpu", "cuda", "rocm") or
|
||||
_HOST_CALLBACK_OUTFEED.value)
|
||||
|
||||
|
||||
def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend):
|
||||
@ -620,7 +649,7 @@ def id_tap(tap_func,
|
||||
"pre-apply keyword arguments, either by using a closure or by passing "
|
||||
"``functools.partial(tap_func, **kwargs)``.")
|
||||
raise TypeError(msg)
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
if _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||
warnings.warn('The flag jax_host_callback_ad_transforms is for temporary '
|
||||
'backwards compatibility mode. This flag, and the behavior '
|
||||
'it enabled will be removed soon.',
|
||||
@ -642,7 +671,7 @@ def id_tap(tap_func,
|
||||
if result is not None:
|
||||
# Return the results, but add a dependency on the call, to ensure it
|
||||
# is kept in the graph.
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
if _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||
call_flat_results, _ = tree_util.tree_flatten(call_res)
|
||||
if call_flat_results:
|
||||
call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0])
|
||||
@ -782,7 +811,7 @@ def _call(callback_func: Callable,
|
||||
identity=False):
|
||||
# Lazy initialization
|
||||
_initialize_outfeed_receiver(
|
||||
max_callback_queue_size_bytes=FLAGS.jax_host_callback_max_queue_byte_size)
|
||||
max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value)
|
||||
api.check_callable(callback_func)
|
||||
flat_args, arg_treedef = tree_util.tree_flatten(arg)
|
||||
for arg in flat_args:
|
||||
@ -909,7 +938,7 @@ xla.register_translation(id_tap_dep_p,
|
||||
id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a)
|
||||
|
||||
def _id_tap_dep_jvp_rule(primals, tangents):
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
if _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||
assert False
|
||||
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
|
||||
return (id_tap_dep_p.bind(primals[0], primals[1]),
|
||||
@ -918,7 +947,7 @@ def _id_tap_dep_jvp_rule(primals, tangents):
|
||||
ad.primitive_jvps[id_tap_dep_p] = _id_tap_dep_jvp_rule
|
||||
|
||||
def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap):
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
if _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||
assert False
|
||||
if ad.is_undefined_primal(arg_res):
|
||||
ct_res = _instantiate_zeros(cts, arg_res)
|
||||
@ -934,7 +963,7 @@ ad.primitive_transposes[id_tap_dep_p] = _id_tap_dep_transpose_rule
|
||||
|
||||
|
||||
def _id_tap_dep_batching_rule(batched_args, batch_dims):
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
if _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||
assert False
|
||||
arg_res, arg_tap = batched_args
|
||||
return id_tap_dep_p.bind(arg_res, arg_tap), batch_dims[0]
|
||||
@ -1013,7 +1042,7 @@ outside_call_p.def_abstract_eval(_outside_call_abstract_eval)
|
||||
|
||||
def _outside_call_impl(*args, **params):
|
||||
assert "has_token" not in params
|
||||
if _inline_host_callback():
|
||||
if _HOST_CALLBACK_INLINE.value:
|
||||
device_index = params["device_index"]
|
||||
device = xb.devices()[device_index]
|
||||
results = _outside_call_run_callback(args, device, send_infeed=False, **params)
|
||||
@ -1400,7 +1429,7 @@ def _outside_call_jvp_rule(primals, tangents, **params):
|
||||
assert "has_token" not in params
|
||||
if not params["identity"]:
|
||||
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
|
||||
if FLAGS.jax_host_callback_ad_transforms:
|
||||
if _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
|
||||
|
||||
arg_treedef = params["arg_treedef"]
|
||||
@ -1425,7 +1454,7 @@ ad.primitive_jvps[outside_call_p] = _outside_call_jvp_rule
|
||||
|
||||
def _outside_call_partial_eval_rule(trace, *args, **params):
|
||||
# partial eval is used after jvp and before transpose.
|
||||
if not FLAGS.jax_host_callback_ad_transforms:
|
||||
if not _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||
# TODO: just remote the partial eval rule
|
||||
return trace.default_process_primitive(outside_call_p, args, params)
|
||||
transforms = params.get("transforms", ())
|
||||
@ -1492,7 +1521,7 @@ def _outside_call_transpose_rule(cts, *args, **params):
|
||||
*cts_instantiated,
|
||||
**_add_transform(params, "transpose"))
|
||||
|
||||
if not FLAGS.jax_host_callback_ad_transforms:
|
||||
if not _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||
assert False
|
||||
|
||||
assert len(args) % 2 == 0
|
||||
|
@ -65,7 +65,7 @@ def main(_):
|
||||
tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)
|
||||
keras_model.fit(train_ds, epochs=FLAGS.num_epochs, validation_data=test_ds)
|
||||
|
||||
if FLAGS.show_images:
|
||||
if saved_model_main.SHOW_IMAGES.value:
|
||||
mnist_lib.plot_images(
|
||||
test_ds,
|
||||
1,
|
||||
|
@ -38,8 +38,8 @@ import optax
|
||||
import tensorflow as tf # type: ignore
|
||||
import tensorflow_datasets as tfds # type: ignore
|
||||
|
||||
flags.DEFINE_boolean("mock_data", False, "Use fake data, for testing.")
|
||||
FLAGS = flags.FLAGS
|
||||
_MOCK_DATA = flags.DEFINE_boolean("mock_data", False,
|
||||
"Use fake data, for testing.")
|
||||
|
||||
#### Model parameters
|
||||
|
||||
@ -64,7 +64,7 @@ def load_mnist(split: tfds.Split, batch_size: int):
|
||||
an iterator with pairs (images, labels). The images have shape
|
||||
(B, 28, 28, 1) and the labels have shape (B, 10), where B is the batch_size.
|
||||
"""
|
||||
if FLAGS.mock_data:
|
||||
if _MOCK_DATA.value:
|
||||
with tfds.testing.mock_data(num_examples=batch_size):
|
||||
try:
|
||||
ds = tfds.load("mnist", split=split)
|
||||
|
@ -37,46 +37,49 @@ import numpy as np
|
||||
import tensorflow as tf # type: ignore
|
||||
import tensorflow_datasets as tfds # type: ignore
|
||||
|
||||
flags.DEFINE_enum("model", "mnist_flax", ["mnist_flax", "mnist_pure_jax"],
|
||||
"Which model to use.")
|
||||
flags.DEFINE_boolean("model_classifier_layer", True,
|
||||
_MODEL = flags.DEFINE_enum(
|
||||
"model", "mnist_flax", ["mnist_flax", "mnist_pure_jax"],
|
||||
"Which model to use.")
|
||||
_MODEL_CLASSIFIER_LAYER = flags.DEFINE_boolean("model_classifier_layer", True,
|
||||
("The model should include the classifier layer, or just "
|
||||
"the last layer of logits. Set this to False when you "
|
||||
"want to reuse the classifier-less model in a larger "
|
||||
"model. See keras_reuse_main.py and README.md."))
|
||||
flags.DEFINE_string("model_path", "/tmp/jax2tf/saved_models",
|
||||
_MODEL_PATH = flags.DEFINE_string("model_path", "/tmp/jax2tf/saved_models",
|
||||
"Path under which to save the SavedModel.")
|
||||
flags.DEFINE_integer("model_version", 1,
|
||||
_MODEL_VERSION = flags.DEFINE_integer("model_version", 1,
|
||||
("The version number for the SavedModel. Needed for "
|
||||
"serving, larger versions will take precedence"),
|
||||
lower_bound=1)
|
||||
flags.DEFINE_integer("serving_batch_size", 1,
|
||||
_SERVING_BATCH_SIZE = flags.DEFINE_integer("serving_batch_size", 1,
|
||||
"For what batch size to prepare the serving signature. "
|
||||
"Use -1 for converting and saving with batch polymorphism.")
|
||||
flags.register_validator(
|
||||
"serving_batch_size",
|
||||
lambda serving_batch_size: serving_batch_size > 0 or serving_batch_size == -1,
|
||||
message="--serving_batch_size must be either -1 or a positive integer.")
|
||||
lambda serving_batch_size: serving_batch_size > 0
|
||||
or serving_batch_size == -1,
|
||||
message="--serving_batch_size must be either -1 or a positive integer.",
|
||||
)
|
||||
|
||||
flags.DEFINE_integer("num_epochs", 3, "For how many epochs to train.",
|
||||
lower_bound=1)
|
||||
flags.DEFINE_boolean(
|
||||
_NUM_EPOCHS = flags.DEFINE_integer("num_epochs", 3,
|
||||
"For how many epochs to train.",
|
||||
lower_bound=1)
|
||||
_GENERATE_MODEL = flags.DEFINE_boolean(
|
||||
"generate_model", True,
|
||||
"Train and save a new model. Otherwise, use an existing SavedModel.")
|
||||
flags.DEFINE_boolean(
|
||||
_COMPILE_MODEL = flags.DEFINE_boolean(
|
||||
"compile_model", True,
|
||||
"Enable TensorFlow jit_compiler for the SavedModel. This is "
|
||||
"necessary if you want to use the model for TensorFlow serving.")
|
||||
flags.DEFINE_boolean("show_model", True, "Show details of saved SavedModel.")
|
||||
flags.DEFINE_boolean(
|
||||
_SHOW_MODEL = flags.DEFINE_boolean("show_model", True,
|
||||
"Show details of saved SavedModel.")
|
||||
SHOW_IMAGES = flags.DEFINE_boolean(
|
||||
"show_images", False,
|
||||
"Plot some sample images with labels and inference results.")
|
||||
flags.DEFINE_boolean(
|
||||
_TEST_SAVEDMODEL = flags.DEFINE_boolean(
|
||||
"test_savedmodel", True,
|
||||
"Test TensorFlow inference using the SavedModel w.r.t. the JAX model.")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def train_and_save():
|
||||
logging.info("Loading the MNIST TensorFlow dataset")
|
||||
@ -85,22 +88,22 @@ def train_and_save():
|
||||
test_ds = mnist_lib.load_mnist(
|
||||
tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)
|
||||
|
||||
if FLAGS.show_images:
|
||||
if SHOW_IMAGES.value:
|
||||
mnist_lib.plot_images(train_ds, 1, 5, "Training images", inference_fn=None)
|
||||
|
||||
the_model_class = pick_model_class()
|
||||
model_dir = savedmodel_dir(with_version=True)
|
||||
|
||||
if FLAGS.generate_model:
|
||||
if _GENERATE_MODEL.value:
|
||||
model_descr = model_description()
|
||||
logging.info("Generating model for %s", model_descr)
|
||||
(predict_fn, predict_params) = the_model_class.train(
|
||||
train_ds,
|
||||
test_ds,
|
||||
FLAGS.num_epochs,
|
||||
with_classifier=FLAGS.model_classifier_layer)
|
||||
num_epochs=_NUM_EPOCHS.value,
|
||||
with_classifier=_MODEL_CLASSIFIER_LAYER.value)
|
||||
|
||||
if FLAGS.serving_batch_size == -1:
|
||||
if _SERVING_BATCH_SIZE.value == -1:
|
||||
# Batch-polymorphic SavedModel
|
||||
input_signatures = [
|
||||
tf.TensorSpec((None,) + mnist_lib.input_shape, tf.float32),
|
||||
@ -109,7 +112,7 @@ def train_and_save():
|
||||
else:
|
||||
input_signatures = [
|
||||
# The first one will be the serving signature
|
||||
tf.TensorSpec((FLAGS.serving_batch_size,) + mnist_lib.input_shape,
|
||||
tf.TensorSpec((_SERVING_BATCH_SIZE.value,) + mnist_lib.input_shape,
|
||||
tf.float32),
|
||||
tf.TensorSpec((mnist_lib.train_batch_size,) + mnist_lib.input_shape,
|
||||
tf.float32),
|
||||
@ -126,15 +129,15 @@ def train_and_save():
|
||||
with_gradient=True,
|
||||
input_signatures=input_signatures,
|
||||
polymorphic_shapes=polymorphic_shapes,
|
||||
compile_model=FLAGS.compile_model)
|
||||
compile_model=_COMPILE_MODEL.value)
|
||||
|
||||
if FLAGS.test_savedmodel:
|
||||
if _TEST_SAVEDMODEL.value:
|
||||
tf_accelerator, tolerances = tf_accelerator_and_tolerances()
|
||||
with tf.device(tf_accelerator):
|
||||
logging.info("Testing savedmodel")
|
||||
pure_restored_model = tf.saved_model.load(model_dir)
|
||||
|
||||
if FLAGS.show_images and FLAGS.model_classifier_layer:
|
||||
if SHOW_IMAGES.value and _MODEL_CLASSIFIER_LAYER.value:
|
||||
mnist_lib.plot_images(
|
||||
test_ds,
|
||||
1,
|
||||
@ -149,7 +152,7 @@ def train_and_save():
|
||||
pure_restored_model(tf.convert_to_tensor(test_input)),
|
||||
predict_fn(predict_params, test_input), **tolerances)
|
||||
|
||||
if FLAGS.show_model:
|
||||
if _SHOW_MODEL.value:
|
||||
def print_model(model_dir: str):
|
||||
cmd = f"saved_model_cli show --all --dir {model_dir}"
|
||||
print(cmd)
|
||||
@ -160,18 +163,18 @@ def train_and_save():
|
||||
|
||||
def pick_model_class():
|
||||
"""Picks one of PureJaxMNIST or FlaxMNIST."""
|
||||
if FLAGS.model == "mnist_pure_jax":
|
||||
if _MODEL.value == "mnist_pure_jax":
|
||||
return mnist_lib.PureJaxMNIST
|
||||
elif FLAGS.model == "mnist_flax":
|
||||
elif _MODEL.value == "mnist_flax":
|
||||
return mnist_lib.FlaxMNIST
|
||||
else:
|
||||
raise ValueError(f"Unrecognized model: {FLAGS.model}")
|
||||
raise ValueError(f"Unrecognized model: {_MODEL.value}")
|
||||
|
||||
|
||||
def model_description() -> str:
|
||||
"""A short description of the picked model."""
|
||||
res = pick_model_class().name
|
||||
if not FLAGS.model_classifier_layer:
|
||||
if not _MODEL_CLASSIFIER_LAYER.value:
|
||||
res += " (features_only)"
|
||||
return res
|
||||
|
||||
@ -179,11 +182,11 @@ def model_description() -> str:
|
||||
def savedmodel_dir(with_version: bool = True) -> str:
|
||||
"""The directory where we save the SavedModel."""
|
||||
model_dir = os.path.join(
|
||||
FLAGS.model_path,
|
||||
FLAGS.model + ('' if FLAGS.model_classifier_layer else '_features')
|
||||
_MODEL_PATH.value,
|
||||
_MODEL.value + ('' if _MODEL_CLASSIFIER_LAYER.value else '_features')
|
||||
)
|
||||
if with_version:
|
||||
model_dir = os.path.join(model_dir, str(FLAGS.model_version))
|
||||
model_dir = os.path.join(model_dir, str(_MODEL_VERSION.value))
|
||||
return model_dir
|
||||
|
||||
|
||||
|
@ -32,31 +32,32 @@ from tensorflow_serving.apis import predict_pb2 # type: ignore[import]
|
||||
from tensorflow_serving.apis import prediction_service_pb2_grpc
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
_USE_GRPC = flags.DEFINE_boolean(
|
||||
"use_grpc", True,
|
||||
"Use the gRPC API (default), or the HTTP REST API.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
_MODEL_SPEC_NAME = flags.DEFINE_string(
|
||||
"model_spec_name", "",
|
||||
"The name you used to export your model to model server (e.g., mnist_flax).")
|
||||
|
||||
flags.DEFINE_string(
|
||||
_PREDICTION_SERVICE_ADDR = flags.DEFINE_string(
|
||||
"prediction_service_addr",
|
||||
"localhost:8500",
|
||||
"Stubby endpoint for the prediction service. If you serve your model "
|
||||
"locally using TensorFlow model server, then you can use \"localhost:8500\""
|
||||
"for the gRPC server and \"localhost:8501\" for the HTTP REST server.")
|
||||
|
||||
flags.DEFINE_integer("serving_batch_size", 1,
|
||||
"Batch size for the serving request. Must match the "
|
||||
"batch size at which the model was saved. Must divide "
|
||||
"--count_images",
|
||||
lower_bound=1)
|
||||
flags.DEFINE_integer("count_images", 16,
|
||||
"How many images to test.",
|
||||
lower_bound=1)
|
||||
_SERVING_BATCH_SIZE = flags.DEFINE_integer(
|
||||
"serving_batch_size",
|
||||
1,
|
||||
"Batch size for the serving request. Must match the "
|
||||
"batch size at which the model was saved. Must divide "
|
||||
"--count_images",
|
||||
lower_bound=1,
|
||||
)
|
||||
_COUNT_IMAGES = flags.DEFINE_integer(
|
||||
"count_images", 16, "How many images to test.", lower_bound=1
|
||||
)
|
||||
|
||||
|
||||
def serving_call_mnist(images):
|
||||
@ -69,12 +70,12 @@ def serving_call_mnist(images):
|
||||
Returns:
|
||||
A numpy.ndarray of shape [B, 10] with the one-hot inference response.
|
||||
"""
|
||||
if FLAGS.use_grpc:
|
||||
channel = grpc.insecure_channel(FLAGS.prediction_service_addr)
|
||||
if _USE_GRPC.value:
|
||||
channel = grpc.insecure_channel(_PREDICTION_SERVICE_ADDR.value)
|
||||
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
|
||||
|
||||
request = predict_pb2.PredictRequest()
|
||||
request.model_spec.name = FLAGS.model_spec_name
|
||||
request.model_spec.name = _MODEL_SPEC_NAME.value
|
||||
request.model_spec.signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||
# You can see the name of the input ("inputs") in the SavedModel dump.
|
||||
request.inputs["inputs"].CopyFrom(
|
||||
@ -90,7 +91,7 @@ def serving_call_mnist(images):
|
||||
images_json = json.dumps(images.tolist())
|
||||
# You can see the name of the input ("inputs") in the SavedModel dump.
|
||||
data = f'{{"inputs": {images_json}}}'
|
||||
predict_url = f"http://{FLAGS.prediction_service_addr}/v1/models/{FLAGS.model_spec_name}:predict"
|
||||
predict_url = f"http://{_PREDICTION_SERVICE_ADDR.value}/v1/models/{_MODEL_SPEC_NAME.value}:predict"
|
||||
response = requests.post(predict_url, data=data)
|
||||
if response.status_code != 200:
|
||||
msg = (f"Received error response {response.status_code} from model "
|
||||
@ -101,14 +102,14 @@ def serving_call_mnist(images):
|
||||
|
||||
|
||||
def main(_):
|
||||
if FLAGS.count_images % FLAGS.serving_batch_size != 0:
|
||||
raise ValueError(f"The count_images ({FLAGS.count_images}) must be a "
|
||||
if _COUNT_IMAGES.value % _SERVING_BATCH_SIZE.value != 0:
|
||||
raise ValueError(f"The count_images ({_COUNT_IMAGES.value}) must be a "
|
||||
"multiple of "
|
||||
f"serving_batch_size ({FLAGS.serving_batch_size})")
|
||||
f"serving_batch_size ({_SERVING_BATCH_SIZE.value})")
|
||||
test_ds = mnist_lib.load_mnist(tfds.Split.TEST,
|
||||
batch_size=FLAGS.serving_batch_size)
|
||||
batch_size=_SERVING_BATCH_SIZE.value)
|
||||
images_and_labels = tfds.as_numpy(test_ds.take(
|
||||
FLAGS.count_images // FLAGS.serving_batch_size))
|
||||
_COUNT_IMAGES.value // _SERVING_BATCH_SIZE.value))
|
||||
|
||||
accurate_count = 0
|
||||
for batch_idx, (images, labels) in enumerate(images_and_labels):
|
||||
@ -117,7 +118,7 @@ def main(_):
|
||||
labels_digit = np.argmax(labels, axis=1)
|
||||
accurate_count += np.sum(labels_digit == predictions_digit)
|
||||
running_accuracy = (
|
||||
100. * accurate_count / (1 + batch_idx) / FLAGS.serving_batch_size)
|
||||
100. * accurate_count / (1 + batch_idx) / _SERVING_BATCH_SIZE.value)
|
||||
logging.info(
|
||||
" predicted digits = %s labels %s. Running accuracy %.3f%%",
|
||||
predictions_digit, labels_digit, running_accuracy)
|
||||
|
@ -33,15 +33,17 @@ import tensorflowjs as tfjs
|
||||
import input_pipeline # type: ignore[import]
|
||||
|
||||
|
||||
flags.DEFINE_integer("num_epochs", 5,
|
||||
("Number of epochs to train for."))
|
||||
flags.DEFINE_integer("num_classes", 100, "Number of classification classes.")
|
||||
_NUM_EPOCHS = flags.DEFINE_integer(
|
||||
"num_epochs", 5, "Number of epochs to train for."
|
||||
)
|
||||
_NUM_CLASSES = flags.DEFINE_integer(
|
||||
"num_classes", 100, "Number of classification classes."
|
||||
)
|
||||
|
||||
flags.register_validator("num_classes",
|
||||
lambda value: value >= 1 and value <= 100,
|
||||
message="--num_classes must be in range [1, 100]")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
# The code below is an adaptation for Flax from the work published here:
|
||||
# https://blog.tensorflow.org/2018/07/train-model-in-tfkeras-with-colab-and-run-in-browser-tensorflowjs.html
|
||||
@ -65,7 +67,7 @@ class QuickDraw(nn.Module):
|
||||
x = nn.Dense(features=128)(x)
|
||||
x = nn.relu(x)
|
||||
|
||||
x = nn.Dense(features=FLAGS.num_classes)(x)
|
||||
x = nn.Dense(features=_NUM_CLASSES.value)(x)
|
||||
|
||||
return x
|
||||
|
||||
@ -75,7 +77,7 @@ def apply_model(state, inputs, labels):
|
||||
"""Computes gradients, loss and accuracy for a single batch."""
|
||||
def loss_fn(params):
|
||||
logits = state.apply_fn({'params': params}, inputs)
|
||||
one_hot = jax.nn.one_hot(labels, FLAGS.num_classes)
|
||||
one_hot = jax.nn.one_hot(labels, _NUM_CLASSES.value)
|
||||
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
|
||||
return loss, logits
|
||||
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
||||
@ -113,7 +115,7 @@ def create_train_state(rng):
|
||||
|
||||
|
||||
def train(state, train_ds, test_ds):
|
||||
for epoch in range(1, FLAGS.num_epochs+1):
|
||||
for epoch in range(1, _NUM_EPOCHS.value+1):
|
||||
start_time = time.time()
|
||||
|
||||
state, train_loss, train_accuracy = run_epoch(state, train_ds)
|
||||
@ -136,12 +138,12 @@ def main(argv):
|
||||
|
||||
base_model_path = "/tmp/jax2tf/tf_js_quickdraw"
|
||||
dataset_path = os.path.join(base_model_path, "data")
|
||||
classes = input_pipeline.download_dataset(dataset_path, FLAGS.num_classes)
|
||||
assert len(classes) == FLAGS.num_classes, "Incorrect number of classes"
|
||||
classes = input_pipeline.download_dataset(dataset_path, _NUM_CLASSES.value)
|
||||
assert len(classes) == _NUM_CLASSES.value, "Incorrect number of classes"
|
||||
print(f"Classes are: {classes}")
|
||||
print("Loading dataset into memory...")
|
||||
train_ds, test_ds = input_pipeline.get_datasets(dataset_path, classes)
|
||||
print(f"Starting training for {FLAGS.num_epochs} epochs...")
|
||||
print(f"Starting training for {_NUM_EPOCHS.value} epochs...")
|
||||
|
||||
state = create_train_state(jax.random.PRNGKey(0))
|
||||
state = train(state, train_ds, test_ds)
|
||||
|
@ -24,14 +24,19 @@ import numpy as np
|
||||
import tensorflow as tf # type: ignore[import]
|
||||
import tensorflow_datasets as tfds # type: ignore[import]
|
||||
|
||||
flags.DEFINE_string('tflite_file_path',
|
||||
'/usr/local/google/home/qiuminxu/jax2tf/mnist.tflite',
|
||||
'Path where to save the TensorFlow Lite file.')
|
||||
flags.DEFINE_integer('serving_batch_size', 4,
|
||||
('For what batch size to prepare the serving signature. '))
|
||||
flags.DEFINE_integer('num_epochs', 10, 'For how many epochs to train.')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
_TFLITE_FILE_PATH = flags.DEFINE_string(
|
||||
'tflite_file_path',
|
||||
'/tmp/mnist.tflite',
|
||||
'Path where to save the TensorFlow Lite file.',
|
||||
)
|
||||
_SERVING_BATCH_SIZE = flags.DEFINE_integer(
|
||||
'serving_batch_size',
|
||||
4,
|
||||
'For what batch size to prepare the serving signature. ',
|
||||
)
|
||||
_NUM_EPOCHS = flags.DEFINE_integer(
|
||||
'num_epochs', 10, 'For how many epochs to train.'
|
||||
)
|
||||
|
||||
|
||||
# A helper function to evaluate the TF Lite model using "test" dataset.
|
||||
@ -71,10 +76,11 @@ def main(_):
|
||||
train_ds = mnist_lib.load_mnist(
|
||||
tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size)
|
||||
test_ds = mnist_lib.load_mnist(
|
||||
tfds.Split.TEST, batch_size=FLAGS.serving_batch_size)
|
||||
tfds.Split.TEST, batch_size=_SERVING_BATCH_SIZE)
|
||||
|
||||
(flax_predict,
|
||||
flax_params) = mnist_lib.FlaxMNIST.train(train_ds, test_ds, FLAGS.num_epochs)
|
||||
(flax_predict, flax_params) = mnist_lib.FlaxMNIST.train(
|
||||
train_ds, test_ds, _NUM_EPOCHS.value
|
||||
)
|
||||
|
||||
def predict(image):
|
||||
return flax_predict(flax_params, image)
|
||||
@ -84,7 +90,7 @@ def main(_):
|
||||
jax2tf.convert(predict, enable_xla=False),
|
||||
input_signature=[
|
||||
tf.TensorSpec(
|
||||
shape=[FLAGS.serving_batch_size, 28, 28, 1],
|
||||
shape=[_SERVING_BATCH_SIZE, 28, 28, 1],
|
||||
dtype=tf.float32,
|
||||
name='input')
|
||||
],
|
||||
@ -126,7 +132,7 @@ def main(_):
|
||||
print('Quantized model accuracy = %.4f' % quantized_accuracy)
|
||||
print('Accuracy drop = %.4f' % (float_accuracy - quantized_accuracy))
|
||||
|
||||
f = open(FLAGS.tflite_file_path, 'wb')
|
||||
f = open(_TFLITE_FILE_PATH.value, 'wb')
|
||||
f.write(tflite_quantized_model)
|
||||
f.close()
|
||||
|
||||
|
@ -58,34 +58,32 @@ from jax.experimental.jax2tf.shape_poly import InconclusiveDimensionOperation
|
||||
from jax.experimental.jax2tf.tests.model_harness import ALL_HARNESSES
|
||||
from jax.experimental.jax2tf.tests.converters import ALL_CONVERTERS
|
||||
|
||||
flags.DEFINE_list("converters", [x.name for x in ALL_CONVERTERS],
|
||||
_CONVERTERS = flags.DEFINE_list("converters", [x.name for x in ALL_CONVERTERS],
|
||||
"Which converters to test.")
|
||||
|
||||
flags.DEFINE_list("examples", [],
|
||||
_EXAMPLES = flags.DEFINE_list("examples", [],
|
||||
("List of examples to test, e.g.: 'flax/mnist,flax/seq2seq'. "
|
||||
"If empty, will test all examples."))
|
||||
|
||||
flags.DEFINE_string("example_prefix", "",
|
||||
_EXAMPLE_PREFIX = flags.DEFINE_string("example_prefix", "",
|
||||
("Prefix for filtering tests. For instance 'flax/mnist' "
|
||||
"will test all examples starting with 'flax/mnist' "
|
||||
"(including all polymorphic tests)."))
|
||||
|
||||
flags.DEFINE_bool(
|
||||
_WRITE_MARKDOWN = flags.DEFINE_bool(
|
||||
"write_markdown", True,
|
||||
"If true, write results as Markdown. Otherwise, only output to stdout.")
|
||||
|
||||
flags.DEFINE_bool(
|
||||
_FAIL_ON_ERROR = flags.DEFINE_bool(
|
||||
"fail_on_error", False,
|
||||
("If true, exit with an error when a conversion fails. Useful for "
|
||||
"debugging because it will show the entire stack trace."))
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def _write_markdown(results: dict[str, list[tuple[str, str,]]]) -> None:
|
||||
"""Writes all results to Markdown file."""
|
||||
table_lines = []
|
||||
converters = FLAGS.converters
|
||||
converters = _CONVERTERS.value
|
||||
|
||||
table_lines.append("| Example | " + " ".join([f"{c} |" for c in converters]))
|
||||
table_lines.append("|" + (" --- |" * (len(converters) + 1)))
|
||||
@ -173,7 +171,7 @@ def test_converters():
|
||||
exit()
|
||||
|
||||
def _maybe_reraise(e):
|
||||
if FLAGS.fail_on_error:
|
||||
if _FAIL_ON_ERROR.value:
|
||||
raise e
|
||||
|
||||
def _format(e):
|
||||
@ -183,13 +181,13 @@ def test_converters():
|
||||
return msg
|
||||
|
||||
converters = list(
|
||||
filter(lambda x: x.name in FLAGS.converters, ALL_CONVERTERS))
|
||||
filter(lambda x: x.name in _CONVERTERS.value, ALL_CONVERTERS))
|
||||
_exit_if_empty(converters, "converters")
|
||||
|
||||
harnesses_to_test = {
|
||||
name: fn for name, fn in ALL_HARNESSES.items()
|
||||
if (not FLAGS.examples or name in FLAGS.examples) and
|
||||
(not FLAGS.example_prefix or name.startswith(FLAGS.example_prefix))
|
||||
if (not _EXAMPLES.value or name in _EXAMPLES.value) and
|
||||
(not _EXAMPLE_PREFIX.value or name.startswith(_EXAMPLE_PREFIX.value))
|
||||
}
|
||||
_exit_if_empty(harnesses_to_test, "harness")
|
||||
|
||||
@ -243,7 +241,7 @@ def test_converters():
|
||||
converter_results.append((converter.name, error_msg))
|
||||
results[harness.name] = converter_results
|
||||
|
||||
if FLAGS.write_markdown:
|
||||
if _WRITE_MARKDOWN:
|
||||
_write_markdown(results)
|
||||
else:
|
||||
print("=== NOT writing results to Markdown.")
|
||||
|
@ -140,7 +140,6 @@ def jax_to_ir(name, deps, fn, input_shapes, constants = None, format = "HLO"):
|
||||
from absl import app
|
||||
import jax.tools.jax_to_ir as jax_to_ir
|
||||
|
||||
jax_to_ir.set_up_flags()
|
||||
app.run(jax_to_ir.main)
|
||||
EOF
|
||||
""".format(runner = runner),
|
||||
|
@ -87,7 +87,30 @@ try:
|
||||
except ImportError:
|
||||
tf = None # type: ignore
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
_FN = flags.DEFINE_string(
|
||||
'fn', None, "Fully-qualified name of function that we're going to convert"
|
||||
)
|
||||
_INPUT_SHAPES = flags.DEFINE_string(
|
||||
'input_shapes', None, 'Python dict indicating XLA shapes of params'
|
||||
)
|
||||
_CONSTANTS = flags.DEFINE_string(
|
||||
'constants', '{}', 'Python dict giving constant values for some params'
|
||||
)
|
||||
_EVALED_CONSTANTS = flags.DEFINE_string(
|
||||
'evaled_constants',
|
||||
'{}',
|
||||
'Python dict giving constant values for some params. '
|
||||
'Values in this dict that are of type str are evaluated '
|
||||
'using ast.literal_eval.',
|
||||
)
|
||||
_IR_FORMAT = flags.DEFINE_enum(
|
||||
'ir_format', 'HLO', ('HLO', 'TF'), 'Output format.'
|
||||
)
|
||||
_IR_DEST = flags.DEFINE_string('ir_dest', None, 'File to write IR to')
|
||||
_IR_HUMAN_DEST = flags.DEFINE_string(
|
||||
'ir_human_dest', None, 'File to write human readable debug output'
|
||||
)
|
||||
|
||||
|
||||
def jax_to_ir(fn, input_shapes, *, constants=None, format):
|
||||
@ -163,25 +186,25 @@ def main(argv):
|
||||
if len(argv) != 1:
|
||||
raise app.UsageError('No positional arguments are accepted.')
|
||||
|
||||
if not FLAGS.ir_dest and not FLAGS.ir_human_dest:
|
||||
if not _IR_DEST.value and not _IR_HUMAN_DEST.value:
|
||||
raise app.Error('At least one of --ir_dest and '
|
||||
'--ir_human_dest is required.')
|
||||
|
||||
module_name, fn_name = FLAGS.fn.rsplit('.', 1)
|
||||
module_name, fn_name = _FN.value.rsplit('.', 1)
|
||||
module = importlib.import_module(module_name)
|
||||
fn = getattr(module, fn_name)
|
||||
|
||||
input_shapes = [(name, parse_shape_str(shape_str))
|
||||
for name, shape_str in literal_eval(FLAGS.input_shapes)]
|
||||
for name, shape_str in literal_eval(_INPUT_SHAPES.value)]
|
||||
|
||||
# Parse --constants and --evaled_constants.
|
||||
constants = {}
|
||||
for k, v in literal_eval(FLAGS.constants).items():
|
||||
for k, v in literal_eval(_CONSTANTS.value).items():
|
||||
if isinstance(v, list):
|
||||
v = jnp.asarray(v)
|
||||
constants[k] = v
|
||||
|
||||
for k, v in literal_eval(FLAGS.evaled_constants).items():
|
||||
for k, v in literal_eval(_EVALED_CONSTANTS.value).items():
|
||||
if isinstance(v, str):
|
||||
v = literal_eval(v)
|
||||
if isinstance(v, list):
|
||||
@ -192,14 +215,14 @@ def main(argv):
|
||||
constants[k] = v
|
||||
|
||||
ir, debug_ir = jax_to_ir(fn, input_shapes, constants=constants,
|
||||
format=FLAGS.ir_format)
|
||||
format=_IR_FORMAT.value)
|
||||
|
||||
if FLAGS.ir_dest:
|
||||
with open(FLAGS.ir_dest, 'wb') as f:
|
||||
if _IR_DEST.value:
|
||||
with open(_IR_DEST.value, 'wb') as f:
|
||||
f.write(ir)
|
||||
|
||||
if FLAGS.ir_human_dest:
|
||||
with open(FLAGS.ir_human_dest, 'w') as f:
|
||||
if _IR_HUMAN_DEST.value:
|
||||
with open(_IR_HUMAN_DEST.value, 'w') as f:
|
||||
f.write(debug_ir)
|
||||
|
||||
|
||||
@ -225,21 +248,6 @@ _SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$")
|
||||
|
||||
|
||||
def set_up_flags():
|
||||
flags.DEFINE_string(
|
||||
'fn', None,
|
||||
"Fully-qualified name of function that we're going to convert")
|
||||
flags.DEFINE_string('input_shapes', None,
|
||||
'Python dict indicating XLA shapes of params')
|
||||
flags.DEFINE_string('constants', '{}',
|
||||
'Python dict giving constant values for some params')
|
||||
flags.DEFINE_string('evaled_constants', '{}',
|
||||
'Python dict giving constant values for some params. '
|
||||
'Values in this dict that are of type str are evaluated '
|
||||
'using ast.literal_eval.')
|
||||
flags.DEFINE_enum('ir_format', 'HLO', ('HLO', 'TF'), 'Output format.')
|
||||
flags.DEFINE_string('ir_dest', None, 'File to write IR to')
|
||||
flags.DEFINE_string('ir_human_dest', None,
|
||||
'File to write human readable debug output')
|
||||
flags.mark_flag_as_required('fn')
|
||||
flags.mark_flag_as_required('input_shapes')
|
||||
|
||||
|
@ -22,7 +22,6 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import core
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.config import flags
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.experimental.serialize_executable import (
|
||||
serialize, deserialize_and_load)
|
||||
@ -73,7 +72,7 @@ class JaxAotTest(jtu.JaxTestCase):
|
||||
except NotImplementedError:
|
||||
raise unittest.SkipTest('PJRT Topology not supported')
|
||||
|
||||
if flags.FLAGS.jax_test_with_persistent_compilation_cache:
|
||||
if jtu.TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
|
||||
raise unittest.SkipTest('Compilation caching not yet supported.')
|
||||
|
||||
@jax.jit
|
||||
|
Loading…
x
Reference in New Issue
Block a user