mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Update flags to use the ABSL typed flag API.
Change flags to use the newer definition style where the flag is read via a typed FlagHolder object returned by the DEFINE_... function. The advantage of doing this is that `flag.value` has a type known to the type checker, rather than reading it as an attr out of a gigantic config dictionary. For jax.config flags, define a typed FlagHolder object that is returned when defining a flag, matching the ABSL API. Move a number of flags into the file that consumes them. There's no reason we're defining every flag in `config.py`. This PR does not change the similar "state" objects in `jax.config`. Changing those is for a future PR. PiperOrigin-RevId: 551604974
This commit is contained in:
parent
f35f226b44
commit
76cda0ae07
@ -27,8 +27,7 @@ from absl import app
|
|||||||
from absl import flags
|
from absl import flags
|
||||||
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
_SET_ENV = flags.DEFINE_multi_string(
|
||||||
flags.DEFINE_multi_string(
|
|
||||||
"set_env", None,
|
"set_env", None,
|
||||||
"Specifies additional environment variables to be injected into the "
|
"Specifies additional environment variables to be injected into the "
|
||||||
"environment (via --set_env=variable=value or --set_env=variable). "
|
"environment (via --set_env=variable=value or --set_env=variable). "
|
||||||
@ -140,8 +139,8 @@ def jax_binary_op(state, **kwargs):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
if FLAGS.set_env:
|
if _SET_ENV.value:
|
||||||
for env_str in FLAGS.set_env:
|
for env_str in _SET_ENV.value:
|
||||||
# Stop matching at the first '=' since we want to capture
|
# Stop matching at the first '=' since we want to capture
|
||||||
# --set_env='FOO=--foo_a=1 --foo_b=2' all as part of FOO.
|
# --set_env='FOO=--foo_a=1 --foo_b=2' all as part of FOO.
|
||||||
env_list = env_str.split('=', 1)
|
env_list = env_str.split('=', 1)
|
||||||
|
@ -83,22 +83,21 @@ import numpy.random as npr
|
|||||||
from dp_accounting import dp_event
|
from dp_accounting import dp_event
|
||||||
from dp_accounting import rdp
|
from dp_accounting import rdp
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
flags.DEFINE_boolean(
|
_DPSGD = flags.DEFINE_boolean(
|
||||||
'dpsgd', True, 'If True, train with DP-SGD. If False, '
|
'dpsgd', True, 'If True, train with DP-SGD. If False, '
|
||||||
'train with vanilla SGD.')
|
'train with vanilla SGD.')
|
||||||
flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
|
_LEARNING_RATE = flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
|
||||||
flags.DEFINE_float('noise_multiplier', 1.1,
|
_NOISE_MULTIPLIER = flags.DEFINE_float('noise_multiplier', 1.1,
|
||||||
'Ratio of the standard deviation to the clipping norm')
|
'Ratio of the standard deviation to the clipping norm')
|
||||||
flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
_L2_NORM_CLIP = flags.DEFINE_float('l2_norm_clip', 1.0, 'Clipping norm')
|
||||||
flags.DEFINE_integer('batch_size', 256, 'Batch size')
|
_BATCH_SIZE = flags.DEFINE_integer('batch_size', 256, 'Batch size')
|
||||||
flags.DEFINE_integer('epochs', 60, 'Number of epochs')
|
_EPOCHS = flags.DEFINE_integer('epochs', 60, 'Number of epochs')
|
||||||
flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG')
|
_SEED = flags.DEFINE_integer('seed', 0, 'Seed for jax PRNG')
|
||||||
flags.DEFINE_integer(
|
_MICROBATCHES = flags.DEFINE_integer(
|
||||||
'microbatches', None, 'Number of microbatches '
|
'microbatches', None, 'Number of microbatches '
|
||||||
'(must evenly divide batch_size)')
|
'(must evenly divide batch_size)')
|
||||||
flags.DEFINE_string('model_dir', None, 'Model directory')
|
_MODEL_DIR = flags.DEFINE_string('model_dir', None, 'Model directory')
|
||||||
|
|
||||||
|
|
||||||
init_random_params, predict = stax.serial(
|
init_random_params, predict = stax.serial(
|
||||||
@ -163,39 +162,39 @@ def shape_as_image(images, labels, dummy_dim=False):
|
|||||||
def compute_epsilon(steps, num_examples=60000, target_delta=1e-5):
|
def compute_epsilon(steps, num_examples=60000, target_delta=1e-5):
|
||||||
if num_examples * target_delta > 1.:
|
if num_examples * target_delta > 1.:
|
||||||
warnings.warn('Your delta might be too high.')
|
warnings.warn('Your delta might be too high.')
|
||||||
q = FLAGS.batch_size / float(num_examples)
|
q = _BATCH_SIZE.value / float(num_examples)
|
||||||
orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64))
|
orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64))
|
||||||
accountant = rdp.rdp_privacy_accountant.RdpAccountant(orders)
|
accountant = rdp.rdp_privacy_accountant.RdpAccountant(orders)
|
||||||
accountant.compose(
|
accountant.compose(
|
||||||
dp_event.PoissonSampledDpEvent(
|
dp_event.PoissonSampledDpEvent(
|
||||||
q, dp_event.GaussianDpEvent(FLAGS.noise_multiplier)), steps)
|
q, dp_event.GaussianDpEvent(_NOISE_MULTIPLIER.value)), steps)
|
||||||
return accountant.get_epsilon(target_delta)
|
return accountant.get_epsilon(target_delta)
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
|
|
||||||
if FLAGS.microbatches:
|
if _MICROBATCHES.value:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
'Microbatches < batch size not currently supported'
|
'Microbatches < batch size not currently supported'
|
||||||
)
|
)
|
||||||
|
|
||||||
train_images, train_labels, test_images, test_labels = datasets.mnist()
|
train_images, train_labels, test_images, test_labels = datasets.mnist()
|
||||||
num_train = train_images.shape[0]
|
num_train = train_images.shape[0]
|
||||||
num_complete_batches, leftover = divmod(num_train, FLAGS.batch_size)
|
num_complete_batches, leftover = divmod(num_train, _BATCH_SIZE.value)
|
||||||
num_batches = num_complete_batches + bool(leftover)
|
num_batches = num_complete_batches + bool(leftover)
|
||||||
key = random.PRNGKey(FLAGS.seed)
|
key = random.PRNGKey(_SEED.value)
|
||||||
|
|
||||||
def data_stream():
|
def data_stream():
|
||||||
rng = npr.RandomState(FLAGS.seed)
|
rng = npr.RandomState(_SEED.value)
|
||||||
while True:
|
while True:
|
||||||
perm = rng.permutation(num_train)
|
perm = rng.permutation(num_train)
|
||||||
for i in range(num_batches):
|
for i in range(num_batches):
|
||||||
batch_idx = perm[i * FLAGS.batch_size:(i + 1) * FLAGS.batch_size]
|
batch_idx = perm[i * _BATCH_SIZE.value:(i + 1) * _BATCH_SIZE.value]
|
||||||
yield train_images[batch_idx], train_labels[batch_idx]
|
yield train_images[batch_idx], train_labels[batch_idx]
|
||||||
|
|
||||||
batches = data_stream()
|
batches = data_stream()
|
||||||
|
|
||||||
opt_init, opt_update, get_params = optimizers.sgd(FLAGS.learning_rate)
|
opt_init, opt_update, get_params = optimizers.sgd(_LEARNING_RATE.value)
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def update(_, i, opt_state, batch):
|
def update(_, i, opt_state, batch):
|
||||||
@ -208,19 +207,19 @@ def main(_):
|
|||||||
rng = random.fold_in(rng, i) # get new key for new random numbers
|
rng = random.fold_in(rng, i) # get new key for new random numbers
|
||||||
return opt_update(
|
return opt_update(
|
||||||
i,
|
i,
|
||||||
private_grad(params, batch, rng, FLAGS.l2_norm_clip,
|
private_grad(params, batch, rng, _L2_NORM_CLIP.value,
|
||||||
FLAGS.noise_multiplier, FLAGS.batch_size), opt_state)
|
_NOISE_MULTIPLIER.value, _BATCH_SIZE.value), opt_state)
|
||||||
|
|
||||||
_, init_params = init_random_params(key, (-1, 28, 28, 1))
|
_, init_params = init_random_params(key, (-1, 28, 28, 1))
|
||||||
opt_state = opt_init(init_params)
|
opt_state = opt_init(init_params)
|
||||||
itercount = itertools.count()
|
itercount = itertools.count()
|
||||||
|
|
||||||
steps_per_epoch = 60000 // FLAGS.batch_size
|
steps_per_epoch = 60000 // _BATCH_SIZE.value
|
||||||
print('\nStarting training...')
|
print('\nStarting training...')
|
||||||
for epoch in range(1, FLAGS.epochs + 1):
|
for epoch in range(1, _EPOCHS.value + 1):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for _ in range(num_batches):
|
for _ in range(num_batches):
|
||||||
if FLAGS.dpsgd:
|
if _DPSGD.value:
|
||||||
opt_state = \
|
opt_state = \
|
||||||
private_update(
|
private_update(
|
||||||
key, next(itercount), opt_state,
|
key, next(itercount), opt_state,
|
||||||
@ -239,7 +238,7 @@ def main(_):
|
|||||||
test_loss, 100 * test_acc))
|
test_loss, 100 * test_acc))
|
||||||
|
|
||||||
# determine privacy loss so far
|
# determine privacy loss so far
|
||||||
if FLAGS.dpsgd:
|
if _DPSGD.value:
|
||||||
delta = 1e-5
|
delta = 1e-5
|
||||||
num_examples = 60000
|
num_examples = 60000
|
||||||
eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta)
|
eps = compute_epsilon(epoch * steps_per_epoch, num_examples, delta)
|
||||||
|
@ -16,7 +16,6 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from absl import app
|
from absl import app
|
||||||
from absl import flags
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from jax import grad
|
from jax import grad
|
||||||
from jax import jit
|
from jax import jit
|
||||||
@ -27,8 +26,6 @@ import jax.random as random
|
|||||||
import jax.scipy as scipy
|
import jax.scipy as scipy
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
|
|
||||||
def main(unused_argv):
|
def main(unused_argv):
|
||||||
|
|
||||||
|
@ -12,6 +12,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Hashable, Iterator
|
from collections.abc import Hashable, Iterator
|
||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
@ -20,7 +22,7 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from typing import Any, Callable, NamedTuple, Optional
|
from typing import Any, Callable, Generic, NamedTuple, Optional, TypeVar
|
||||||
|
|
||||||
from jax._src import lib
|
from jax._src import lib
|
||||||
from jax._src.lib import jax_jit
|
from jax._src.lib import jax_jit
|
||||||
@ -29,6 +31,8 @@ from jax._src.lib import xla_client
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_T = TypeVar('_T')
|
||||||
|
|
||||||
|
|
||||||
def bool_env(varname: str, default: bool) -> bool:
|
def bool_env(varname: str, default: bool) -> bool:
|
||||||
"""Read an environment variable and interpret it as a boolean.
|
"""Read an environment variable and interpret it as a boolean.
|
||||||
@ -64,6 +68,16 @@ UPGRADE_BOOL_HELP = (
|
|||||||
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
|
UPGRADE_BOOL_EXTRA_DESC = " (transient)"
|
||||||
|
|
||||||
|
|
||||||
|
class FlagHolder(Generic[_T]):
|
||||||
|
def __init__(self, flags: NameSpace, name: str):
|
||||||
|
self._flags = flags
|
||||||
|
self._name = name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self) -> _T:
|
||||||
|
return getattr(self._flags, self._name)
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
_HAS_DYNAMIC_ATTRIBUTES = True
|
_HAS_DYNAMIC_ATTRIBUTES = True
|
||||||
|
|
||||||
@ -112,26 +126,31 @@ class Config:
|
|||||||
if name not in self.values:
|
if name not in self.values:
|
||||||
raise AttributeError(f"Unrecognized config option: {name}")
|
raise AttributeError(f"Unrecognized config option: {name}")
|
||||||
|
|
||||||
def DEFINE_bool(self, name, default, *args, **kwargs):
|
def DEFINE_bool(self, name, default, *args, **kwargs) -> FlagHolder[bool]:
|
||||||
update_hook = kwargs.pop("update_hook", None)
|
update_hook = kwargs.pop("update_hook", None)
|
||||||
self.add_option(name, default, bool, args, kwargs, update_hook=update_hook)
|
self.add_option(name, default, bool, args, kwargs, update_hook=update_hook)
|
||||||
|
return FlagHolder(self.FLAGS, name)
|
||||||
|
|
||||||
def DEFINE_integer(self, name, default, *args, **kwargs):
|
def DEFINE_integer(self, name, default, *args, **kwargs) -> FlagHolder[int]:
|
||||||
update_hook = kwargs.pop("update_hook", None)
|
update_hook = kwargs.pop("update_hook", None)
|
||||||
self.add_option(name, default, int, args, kwargs, update_hook=update_hook)
|
self.add_option(name, default, int, args, kwargs, update_hook=update_hook)
|
||||||
|
return FlagHolder(self.FLAGS, name)
|
||||||
|
|
||||||
def DEFINE_float(self, name, default, *args, **kwargs):
|
def DEFINE_float(self, name, default, *args, **kwargs) -> FlagHolder[float]:
|
||||||
update_hook = kwargs.pop("update_hook", None)
|
update_hook = kwargs.pop("update_hook", None)
|
||||||
self.add_option(name, default, float, args, kwargs, update_hook=update_hook)
|
self.add_option(name, default, float, args, kwargs, update_hook=update_hook)
|
||||||
|
return FlagHolder(self.FLAGS, name)
|
||||||
|
|
||||||
def DEFINE_string(self, name, default, *args, **kwargs):
|
def DEFINE_string(self, name, default, *args, **kwargs) -> FlagHolder[str]:
|
||||||
update_hook = kwargs.pop("update_hook", None)
|
update_hook = kwargs.pop("update_hook", None)
|
||||||
self.add_option(name, default, str, args, kwargs, update_hook=update_hook)
|
self.add_option(name, default, str, args, kwargs, update_hook=update_hook)
|
||||||
|
return FlagHolder(self.FLAGS, name)
|
||||||
|
|
||||||
def DEFINE_enum(self, name, default, *args, **kwargs):
|
def DEFINE_enum(self, name, default, *args, **kwargs) -> FlagHolder[str]:
|
||||||
update_hook = kwargs.pop("update_hook", None)
|
update_hook = kwargs.pop("update_hook", None)
|
||||||
self.add_option(name, default, 'enum', args, kwargs,
|
self.add_option(name, default, 'enum', args, kwargs,
|
||||||
update_hook=update_hook)
|
update_hook=update_hook)
|
||||||
|
return FlagHolder(self.FLAGS, name)
|
||||||
|
|
||||||
def config_with_absl(self):
|
def config_with_absl(self):
|
||||||
# Run this before calling `app.run(main)` etc
|
# Run this before calling `app.run(main)` etc
|
||||||
@ -551,6 +570,22 @@ config = Config()
|
|||||||
flags = config
|
flags = config
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
|
def DEFINE_bool(name, default, *args, **kwargs):
|
||||||
|
return flags.DEFINE_bool(name, default, *args, **kwargs)
|
||||||
|
|
||||||
|
def DEFINE_integer(name, default, *args, **kwargs):
|
||||||
|
return flags.DEFINE_integer(name, default, *args, **kwargs)
|
||||||
|
|
||||||
|
def DEFINE_float(name, default, *args, **kwargs):
|
||||||
|
return flags.DEFINE_float(name, default, *args, **kwargs)
|
||||||
|
|
||||||
|
def DEFINE_string(name, default, *args, **kwargs):
|
||||||
|
return flags.DEFINE_string(name, default, *args, **kwargs)
|
||||||
|
|
||||||
|
def DEFINE_enum(name, default, *args, **kwargs):
|
||||||
|
return flags.DEFINE_enum(name, default, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
already_configured_with_absl = False
|
already_configured_with_absl = False
|
||||||
|
|
||||||
|
|
||||||
@ -617,51 +652,6 @@ def update_thread_local_jit_state(**kw):
|
|||||||
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
|
tls.extra_jit_context = _thread_local_state_cache.canonicalize(tmp)
|
||||||
|
|
||||||
|
|
||||||
flags.DEFINE_integer(
|
|
||||||
'jax_tracer_error_num_traceback_frames',
|
|
||||||
int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
|
|
||||||
help='Set the number of stack frames in JAX tracer error messages.'
|
|
||||||
)
|
|
||||||
|
|
||||||
flags.DEFINE_bool(
|
|
||||||
'jax_pprint_use_color',
|
|
||||||
bool_env('JAX_PPRINT_USE_COLOR', True),
|
|
||||||
help='Enable jaxpr pretty-printing with colorful syntax highlighting.'
|
|
||||||
)
|
|
||||||
|
|
||||||
flags.DEFINE_bool(
|
|
||||||
'jax_host_callback_inline',
|
|
||||||
bool_env('JAX_HOST_CALLBACK_INLINE', False),
|
|
||||||
help='Inline the host_callback, if not in a staged context.'
|
|
||||||
)
|
|
||||||
flags.DEFINE_integer(
|
|
||||||
'jax_host_callback_max_queue_byte_size',
|
|
||||||
int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)),
|
|
||||||
help=('The size in bytes of the buffer used to hold outfeeds from each '
|
|
||||||
'device. When this capacity is reached consuming outfeeds from the '
|
|
||||||
'device is paused, thus potentially pausing the device computation, '
|
|
||||||
'until the Python callback consume more outfeeds.'),
|
|
||||||
lower_bound=int(16 * 1e6)
|
|
||||||
)
|
|
||||||
flags.DEFINE_bool(
|
|
||||||
'jax_host_callback_outfeed',
|
|
||||||
bool_env('JAX_HOST_CALLBACK_OUTFEED', False),
|
|
||||||
help=(
|
|
||||||
'Use outfeed implementation for host_callback, even on CPU and GPU. '
|
|
||||||
'If false, use the CustomCall implementation. '
|
|
||||||
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
flags.DEFINE_bool(
|
|
||||||
'jax_host_callback_ad_transforms',
|
|
||||||
bool_env('JAX_HOST_CALLBACK_AD_TRANSFORMS', False),
|
|
||||||
help=(
|
|
||||||
'Enable support for jvp/vjp for the host_callback primitives. Default is '
|
|
||||||
'False, which means that host_callback operates only on primals. '
|
|
||||||
'The flag exists only temporarily, for backward compatibility.'
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO(b/214340779): remove flag when XLA:CPU is improved.
|
# TODO(b/214340779): remove flag when XLA:CPU is improved.
|
||||||
jax2tf_associative_scan_reductions = config.define_bool_state(
|
jax2tf_associative_scan_reductions = config.define_bool_state(
|
||||||
name='jax2tf_associative_scan_reductions',
|
name='jax2tf_associative_scan_reductions',
|
||||||
|
@ -38,7 +38,7 @@ import numpy as np
|
|||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import config as jax_config
|
from jax._src import config as jax_config
|
||||||
from jax._src import effects
|
from jax._src import effects
|
||||||
from jax._src.config import FLAGS, config
|
from jax._src.config import config
|
||||||
from jax._src.errors import (
|
from jax._src.errors import (
|
||||||
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
|
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
|
||||||
TracerIntegerConversionError, UnexpectedTracerError)
|
TracerIntegerConversionError, UnexpectedTracerError)
|
||||||
@ -60,6 +60,13 @@ zip, unsafe_zip = safe_zip, zip
|
|||||||
map, unsafe_map = safe_map, map
|
map, unsafe_map = safe_map, map
|
||||||
|
|
||||||
|
|
||||||
|
_TRACER_ERROR_NUM_TRACEBACK_FRAMES = jax_config.DEFINE_integer(
|
||||||
|
'jax_tracer_error_num_traceback_frames',
|
||||||
|
jax_config.int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
|
||||||
|
help='Set the number of stack frames in JAX tracer error messages.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# -------------------- jaxprs --------------------
|
# -------------------- jaxprs --------------------
|
||||||
|
|
||||||
Effect = effects.Effect
|
Effect = effects.Effect
|
||||||
@ -560,7 +567,7 @@ def raise_as_much_as_possible(tracer) -> Tracer:
|
|||||||
|
|
||||||
|
|
||||||
def escaped_tracer_error(tracer, detail=None):
|
def escaped_tracer_error(tracer, detail=None):
|
||||||
num_frames = FLAGS.jax_tracer_error_num_traceback_frames
|
num_frames = _TRACER_ERROR_NUM_TRACEBACK_FRAMES.value
|
||||||
msg = ('Encountered an unexpected tracer. A function transformed by JAX '
|
msg = ('Encountered an unexpected tracer. A function transformed by JAX '
|
||||||
'had a side effect, allowing for a reference to an intermediate value '
|
'had a side effect, allowing for a reference to an intermediate value '
|
||||||
f'with type {tracer.aval.str_short()} wrapped in a '
|
f'with type {tracer.aval.str_short()} wrapped in a '
|
||||||
|
@ -32,6 +32,7 @@ import warnings
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from jax._src import compilation_cache
|
from jax._src import compilation_cache
|
||||||
|
from jax._src import config as jax_config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax._src import linear_util as lu
|
from jax._src import linear_util as lu
|
||||||
@ -43,7 +44,7 @@ from jax._src import traceback_util
|
|||||||
from jax._src import util
|
from jax._src import util
|
||||||
from jax._src import op_shardings
|
from jax._src import op_shardings
|
||||||
from jax._src import xla_bridge as xb
|
from jax._src import xla_bridge as xb
|
||||||
from jax._src.config import config, flags
|
from jax._src.config import config
|
||||||
from jax._src.interpreters import ad
|
from jax._src.interpreters import ad
|
||||||
from jax._src.interpreters import batching
|
from jax._src.interpreters import batching
|
||||||
from jax._src.interpreters import mlir
|
from jax._src.interpreters import mlir
|
||||||
@ -64,9 +65,7 @@ JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
|
|||||||
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
|
JAXPR_TO_MLIR_MODULE_EVENT = "/jax/core/compile/jaxpr_to_mlir_module_duration"
|
||||||
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
|
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
_DUMP_IR_TO = jax_config.DEFINE_string(
|
||||||
|
|
||||||
flags.DEFINE_string(
|
|
||||||
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
|
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
|
||||||
help="Path to which the IR that is emitted by JAX as input to the "
|
help="Path to which the IR that is emitted by JAX as input to the "
|
||||||
"compiler should be dumped as text files. Optional. If omitted, JAX "
|
"compiler should be dumped as text files. Optional. If omitted, JAX "
|
||||||
@ -472,7 +471,7 @@ def _make_string_safe_for_filename(s: str) -> str:
|
|||||||
def _dump_ir_to_file(name: str, ir: str):
|
def _dump_ir_to_file(name: str, ir: str):
|
||||||
id = next(_ir_dump_counter)
|
id = next(_ir_dump_counter)
|
||||||
name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir"
|
name = f"jax_ir{id}_{_make_string_safe_for_filename(name)}.mlir"
|
||||||
name = path.Path(FLAGS.jax_dump_ir_to) / name
|
name = path.Path(_DUMP_IR_TO.value) / name
|
||||||
name.write_text(ir)
|
name.write_text(ir)
|
||||||
|
|
||||||
|
|
||||||
@ -481,7 +480,7 @@ def compile_or_get_cached(backend, computation: ir.Module, devices: np.ndarray,
|
|||||||
sym_name = computation.operation.attributes['sym_name']
|
sym_name = computation.operation.attributes['sym_name']
|
||||||
module_name = ir.StringAttr(sym_name).value
|
module_name = ir.StringAttr(sym_name).value
|
||||||
|
|
||||||
if FLAGS.jax_dump_ir_to:
|
if _DUMP_IR_TO.value:
|
||||||
_dump_ir_to_file(module_name, mlir.module_to_string(computation))
|
_dump_ir_to_file(module_name, mlir.module_to_string(computation))
|
||||||
|
|
||||||
# Persistent compilation cache only implemented on TPU and GPU.
|
# Persistent compilation cache only implemented on TPU and GPU.
|
||||||
|
@ -28,7 +28,7 @@ import warnings
|
|||||||
import ml_dtypes
|
import ml_dtypes
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from jax._src.config import flags, config
|
from jax._src.config import config
|
||||||
from jax._src.typing import DType, DTypeLike
|
from jax._src.typing import DType, DTypeLike
|
||||||
|
|
||||||
from jax._src import traceback_util
|
from jax._src import traceback_util
|
||||||
@ -43,8 +43,6 @@ else:
|
|||||||
raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; "
|
raise ValueError("JAX requires ml_dtypes version 0.2.0 or newer; "
|
||||||
f"installed version is {ml_dtypes.__version__}.")
|
f"installed version is {ml_dtypes.__version__}.")
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
|
|
||||||
class extended(np.generic):
|
class extended(np.generic):
|
||||||
"""Scalar class for extended dtypes.
|
"""Scalar class for extended dtypes.
|
||||||
|
@ -32,13 +32,20 @@ from functools import partial
|
|||||||
import sys
|
import sys
|
||||||
from typing import NamedTuple, Optional, Union
|
from typing import NamedTuple, Optional, Union
|
||||||
|
|
||||||
from jax._src.config import config
|
from jax._src import config
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import colorama # pytype: disable=import-error
|
import colorama # pytype: disable=import-error
|
||||||
except ImportError:
|
except ImportError:
|
||||||
colorama = None
|
colorama = None
|
||||||
|
|
||||||
|
|
||||||
|
_PPRINT_USE_COLOR = config.DEFINE_bool(
|
||||||
|
'jax_pprint_use_color',
|
||||||
|
config.bool_env('JAX_PPRINT_USE_COLOR', True),
|
||||||
|
help='Enable jaxpr pretty-printing with colorful syntax highlighting.'
|
||||||
|
)
|
||||||
|
|
||||||
def _can_use_color() -> bool:
|
def _can_use_color() -> bool:
|
||||||
try:
|
try:
|
||||||
# Check if we're in IPython or Colab
|
# Check if we're in IPython or Colab
|
||||||
@ -63,7 +70,7 @@ class Doc(abc.ABC):
|
|||||||
def format(self, width: int = 80, use_color: Optional[bool] = None,
|
def format(self, width: int = 80, use_color: Optional[bool] = None,
|
||||||
annotation_prefix=" # ") -> str:
|
annotation_prefix=" # ") -> str:
|
||||||
if use_color is None:
|
if use_color is None:
|
||||||
use_color = CAN_USE_COLOR and config.FLAGS.jax_pprint_use_color
|
use_color = CAN_USE_COLOR and _PPRINT_USE_COLOR.value
|
||||||
return _format(self, width, use_color=use_color,
|
return _format(self, width, use_color=use_color,
|
||||||
annotation_prefix=annotation_prefix)
|
annotation_prefix=annotation_prefix)
|
||||||
|
|
||||||
|
@ -16,9 +16,10 @@ from functools import partial
|
|||||||
import operator
|
import operator
|
||||||
|
|
||||||
from jax._src import api
|
from jax._src import api
|
||||||
|
from jax._src import config as jax_config
|
||||||
from jax._src import dtypes as _dtypes
|
from jax._src import dtypes as _dtypes
|
||||||
from jax._src import xla_bridge
|
from jax._src import xla_bridge
|
||||||
from jax._src.config import config, flags
|
from jax._src.config import config
|
||||||
from jax._src.tree_util import tree_map, tree_reduce
|
from jax._src.tree_util import tree_map, tree_reduce
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -30,7 +31,11 @@ import numpy as np
|
|||||||
__all__ = ['check_grads', 'check_jvp', 'check_vjp']
|
__all__ = ['check_grads', 'check_jvp', 'check_vjp']
|
||||||
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
_TEST_DUT = jax_config.DEFINE_string(
|
||||||
|
'jax_test_dut', '',
|
||||||
|
help=
|
||||||
|
'Describes the device under test in case special consideration is required.'
|
||||||
|
)
|
||||||
|
|
||||||
EPS = 1e-4
|
EPS = 1e-4
|
||||||
|
|
||||||
@ -292,4 +297,4 @@ def check_grads(f, args, order,
|
|||||||
|
|
||||||
|
|
||||||
def device_under_test():
|
def device_under_test():
|
||||||
return getattr(FLAGS, 'jax_test_dut', None) or xla_bridge.get_backend().platform
|
return _TEST_DUT.value or xla_bridge.get_backend().platform
|
||||||
|
@ -41,11 +41,12 @@ from jax._src.interpreters import mlir
|
|||||||
from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten
|
from jax.tree_util import tree_map, tree_all, tree_flatten, tree_unflatten
|
||||||
from jax._src import api
|
from jax._src import api
|
||||||
from jax._src import pjit as pjit_lib
|
from jax._src import pjit as pjit_lib
|
||||||
|
from jax._src import config as jax_config
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import dispatch
|
from jax._src import dispatch
|
||||||
from jax._src import dtypes as _dtypes
|
from jax._src import dtypes as _dtypes
|
||||||
from jax._src.interpreters import pxla
|
from jax._src.interpreters import pxla
|
||||||
from jax._src.config import (flags, bool_env, config,
|
from jax._src.config import (bool_env, config,
|
||||||
raise_persistent_cache_errors,
|
raise_persistent_cache_errors,
|
||||||
persistent_cache_min_compile_time_secs)
|
persistent_cache_min_compile_time_secs)
|
||||||
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
|
from jax._src.numpy.util import promote_dtypes, promote_dtypes_inexact
|
||||||
@ -60,19 +61,12 @@ from jax._src import xla_bridge
|
|||||||
# jax.test_util. Functionality appearing here is for internal use only, and
|
# jax.test_util. Functionality appearing here is for internal use only, and
|
||||||
# may be changed or removed at any time and without any deprecation cycle.
|
# may be changed or removed at any time and without any deprecation cycle.
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
_NUM_GENERATED_CASES = jax_config.DEFINE_integer(
|
||||||
flags.DEFINE_string(
|
|
||||||
'jax_test_dut', '',
|
|
||||||
help=
|
|
||||||
'Describes the device under test in case special consideration is required.'
|
|
||||||
)
|
|
||||||
|
|
||||||
flags.DEFINE_integer(
|
|
||||||
'jax_num_generated_cases',
|
'jax_num_generated_cases',
|
||||||
int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
|
int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
|
||||||
help='Number of generated cases to test')
|
help='Number of generated cases to test')
|
||||||
|
|
||||||
flags.DEFINE_integer(
|
_MAX_CASES_SAMPLING_RETRIES = jax_config.DEFINE_integer(
|
||||||
'max_cases_sampling_retries',
|
'max_cases_sampling_retries',
|
||||||
int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')),
|
int(os.getenv('JAX_MAX_CASES_SAMPLING_RETRIES', '100')),
|
||||||
'Number of times a failed test sample should be retried. '
|
'Number of times a failed test sample should be retried. '
|
||||||
@ -80,24 +74,23 @@ flags.DEFINE_integer(
|
|||||||
'sampling process is terminated.'
|
'sampling process is terminated.'
|
||||||
)
|
)
|
||||||
|
|
||||||
flags.DEFINE_bool(
|
_SKIP_SLOW_TESTS = jax_config.DEFINE_bool(
|
||||||
'jax_skip_slow_tests',
|
'jax_skip_slow_tests',
|
||||||
bool_env('JAX_SKIP_SLOW_TESTS', False),
|
bool_env('JAX_SKIP_SLOW_TESTS', False),
|
||||||
help='Skip tests marked as slow (> 5 sec).'
|
help='Skip tests marked as slow (> 5 sec).'
|
||||||
)
|
)
|
||||||
|
|
||||||
flags.DEFINE_string(
|
_TEST_TARGETS = jax_config.DEFINE_string(
|
||||||
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
|
'test_targets', os.getenv('JAX_TEST_TARGETS', ''),
|
||||||
'Regular expression specifying which tests to run, called via re.search on '
|
'Regular expression specifying which tests to run, called via re.search on '
|
||||||
'the test name. If empty or unspecified, run all tests.'
|
'the test name. If empty or unspecified, run all tests.'
|
||||||
)
|
)
|
||||||
flags.DEFINE_string(
|
_EXCLUDE_TEST_TARGETS = jax_config.DEFINE_string(
|
||||||
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
|
'exclude_test_targets', os.getenv('JAX_EXCLUDE_TEST_TARGETS', ''),
|
||||||
'Regular expression specifying which tests NOT to run, called via re.search '
|
'Regular expression specifying which tests NOT to run, called via re.search '
|
||||||
'on the test name. If empty or unspecified, run all tests.'
|
'on the test name. If empty or unspecified, run all tests.'
|
||||||
)
|
)
|
||||||
|
TEST_WITH_PERSISTENT_COMPILATION_CACHE = jax_config.DEFINE_bool(
|
||||||
flags.DEFINE_bool(
|
|
||||||
'jax_test_with_persistent_compilation_cache',
|
'jax_test_with_persistent_compilation_cache',
|
||||||
bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False),
|
bool_env('JAX_TEST_WITH_PERSISTENT_COMPILATION_CACHE', False),
|
||||||
help='If enabled, the persistent compilation cache will be enabled for all '
|
help='If enabled, the persistent compilation cache will be enabled for all '
|
||||||
@ -731,7 +724,7 @@ def assert_dot_precision(expected_precision, fun, *args):
|
|||||||
|
|
||||||
def cases_from_gens(*gens):
|
def cases_from_gens(*gens):
|
||||||
sizes = [1, 3, 10]
|
sizes = [1, 3, 10]
|
||||||
cases_per_size = int(FLAGS.jax_num_generated_cases / len(sizes)) + 1
|
cases_per_size = int(_NUM_GENERATED_CASES.value / len(sizes)) + 1
|
||||||
for size in sizes:
|
for size in sizes:
|
||||||
for i in range(cases_per_size):
|
for i in range(cases_per_size):
|
||||||
yield (f'_{size}_{i}',) + tuple(gen(size) for gen in gens)
|
yield (f'_{size}_{i}',) + tuple(gen(size) for gen in gens)
|
||||||
@ -744,8 +737,8 @@ def named_cases_from_sampler(gen):
|
|||||||
if not isinstance(x, (list, tuple)):
|
if not isinstance(x, (list, tuple)):
|
||||||
x = list(x)
|
x = list(x)
|
||||||
return [x[rng.randint(len(x))]]
|
return [x[rng.randint(len(x))]]
|
||||||
while (len(seen) < FLAGS.jax_num_generated_cases and
|
while (len(seen) < _NUM_GENERATED_CASES.value and
|
||||||
retries < FLAGS.max_cases_sampling_retries):
|
retries < _MAX_CASES_SAMPLING_RETRIES.value):
|
||||||
retries += 1
|
retries += 1
|
||||||
cases = list(gen(choose_one))
|
cases = list(gen(choose_one))
|
||||||
if not cases:
|
if not cases:
|
||||||
@ -773,7 +766,7 @@ def sample_product_testcases(*args, **kw):
|
|||||||
kw = [(k, list(v)) for k, v in kw.items()]
|
kw = [(k, list(v)) for k, v in kw.items()]
|
||||||
n = math.prod(len(a) for a in args) * math.prod(len(v) for _, v in kw)
|
n = math.prod(len(a) for a in args) * math.prod(len(v) for _, v in kw)
|
||||||
testcases = []
|
testcases = []
|
||||||
for i in _choice(n, min(n, FLAGS.jax_num_generated_cases)):
|
for i in _choice(n, min(n, _NUM_GENERATED_CASES.value)):
|
||||||
testcase = {}
|
testcase = {}
|
||||||
for a in args:
|
for a in args:
|
||||||
testcase.update(a[i % len(a)])
|
testcase.update(a[i % len(a)])
|
||||||
@ -804,12 +797,12 @@ def sample_product(*args, **kw):
|
|||||||
class JaxTestLoader(absltest.TestLoader):
|
class JaxTestLoader(absltest.TestLoader):
|
||||||
def getTestCaseNames(self, testCaseClass):
|
def getTestCaseNames(self, testCaseClass):
|
||||||
names = super().getTestCaseNames(testCaseClass)
|
names = super().getTestCaseNames(testCaseClass)
|
||||||
if FLAGS.test_targets:
|
if _TEST_TARGETS.value:
|
||||||
pattern = re.compile(FLAGS.test_targets)
|
pattern = re.compile(_TEST_TARGETS.value)
|
||||||
names = [name for name in names
|
names = [name for name in names
|
||||||
if pattern.search(f"{testCaseClass.__name__}.{name}")]
|
if pattern.search(f"{testCaseClass.__name__}.{name}")]
|
||||||
if FLAGS.exclude_test_targets:
|
if _EXCLUDE_TEST_TARGETS.value:
|
||||||
pattern = re.compile(FLAGS.exclude_test_targets)
|
pattern = re.compile(_EXCLUDE_TEST_TARGETS.value)
|
||||||
names = [name for name in names
|
names = [name for name in names
|
||||||
if not pattern.search(f"{testCaseClass.__name__}.{name}")]
|
if not pattern.search(f"{testCaseClass.__name__}.{name}")]
|
||||||
return names
|
return names
|
||||||
@ -874,7 +867,7 @@ class JaxTestCase(parameterized.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
if FLAGS.jax_test_with_persistent_compilation_cache:
|
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
|
||||||
cls._compilation_cache_exit_stack = ExitStack()
|
cls._compilation_cache_exit_stack = ExitStack()
|
||||||
stack = cls._compilation_cache_exit_stack
|
stack = cls._compilation_cache_exit_stack
|
||||||
stack.enter_context(raise_persistent_cache_errors(True))
|
stack.enter_context(raise_persistent_cache_errors(True))
|
||||||
@ -887,7 +880,7 @@ class JaxTestCase(parameterized.TestCase):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
if FLAGS.jax_test_with_persistent_compilation_cache:
|
if TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
|
||||||
cls._compilation_cache_exit_stack.close()
|
cls._compilation_cache_exit_stack.close()
|
||||||
|
|
||||||
def rng(self):
|
def rng(self):
|
||||||
|
@ -37,7 +37,8 @@ import numpy as np
|
|||||||
|
|
||||||
from jax._src import lib
|
from jax._src import lib
|
||||||
from jax._src import distributed
|
from jax._src import distributed
|
||||||
from jax._src.config import flags, bool_env, config, int_env
|
from jax._src import config as jax_config
|
||||||
|
from jax._src.config import bool_env, config, int_env
|
||||||
from jax._src.lib import xla_client
|
from jax._src.lib import xla_client
|
||||||
from jax._src.lib import xla_extension_version
|
from jax._src.lib import xla_extension_version
|
||||||
from jax._src import traceback_util
|
from jax._src import traceback_util
|
||||||
@ -59,37 +60,35 @@ traceback_util.register_exclusion(__file__)
|
|||||||
|
|
||||||
XlaBackend = xla_client.Client
|
XlaBackend = xla_client.Client
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(phawkins): Remove jax_xla_backend.
|
# TODO(phawkins): Remove jax_xla_backend.
|
||||||
flags.DEFINE_string(
|
_XLA_BACKEND = jax_config.DEFINE_string(
|
||||||
'jax_xla_backend', '',
|
'jax_xla_backend', '',
|
||||||
'Deprecated, please use --jax_platforms instead.')
|
'Deprecated, please use --jax_platforms instead.')
|
||||||
flags.DEFINE_string(
|
BACKEND_TARGET = jax_config.DEFINE_string(
|
||||||
'jax_backend_target',
|
'jax_backend_target',
|
||||||
os.getenv('JAX_BACKEND_TARGET', '').lower(),
|
os.getenv('JAX_BACKEND_TARGET', '').lower(),
|
||||||
'Either "local" or "rpc:address" to connect to a remote service target.')
|
'Either "local" or "rpc:address" to connect to a remote service target.')
|
||||||
# TODO(skye): warn when this is used once we test out --jax_platforms a bit
|
# TODO(skye): warn when this is used once we test out --jax_platforms a bit
|
||||||
flags.DEFINE_string(
|
_PLATFORM_NAME = jax_config.DEFINE_string(
|
||||||
'jax_platform_name',
|
'jax_platform_name',
|
||||||
os.getenv('JAX_PLATFORM_NAME', '').lower(),
|
os.getenv('JAX_PLATFORM_NAME', '').lower(),
|
||||||
'Deprecated, please use --jax_platforms instead.')
|
'Deprecated, please use --jax_platforms instead.')
|
||||||
flags.DEFINE_bool(
|
_DISABLE_MOST_OPTIMIZATIONS = jax_config.DEFINE_bool(
|
||||||
'jax_disable_most_optimizations',
|
'jax_disable_most_optimizations',
|
||||||
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
|
bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
|
||||||
'Try not to do much optimization work. This can be useful if the cost of '
|
'Try not to do much optimization work. This can be useful if the cost of '
|
||||||
'optimization is greater than that of running a less-optimized program.')
|
'optimization is greater than that of running a less-optimized program.')
|
||||||
flags.DEFINE_integer(
|
_XLA_PROFILE_VERSION = jax_config.DEFINE_integer(
|
||||||
'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0),
|
'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0),
|
||||||
'Optional profile version for XLA compilation. '
|
'Optional profile version for XLA compilation. '
|
||||||
'This is meaningful only when XLA is configured to '
|
'This is meaningful only when XLA is configured to '
|
||||||
'support the remote compilation profile feature.')
|
'support the remote compilation profile feature.')
|
||||||
flags.DEFINE_string(
|
CUDA_VISIBLE_DEVICES = jax_config.DEFINE_string(
|
||||||
'jax_cuda_visible_devices', 'all',
|
'jax_cuda_visible_devices', 'all',
|
||||||
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
|
'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
|
||||||
'comma-separate list of integer device IDs.')
|
'comma-separate list of integer device IDs.')
|
||||||
flags.DEFINE_string(
|
_ROCM_VISIBLE_DEVICES = jax_config.DEFINE_string(
|
||||||
'jax_rocm_visible_devices', 'all',
|
'jax_rocm_visible_devices', 'all',
|
||||||
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
|
'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
|
||||||
'comma-separate list of integer device IDs.')
|
'comma-separate list of integer device IDs.')
|
||||||
@ -171,13 +170,13 @@ def get_compile_options(
|
|||||||
if lib.cuda_path is not None:
|
if lib.cuda_path is not None:
|
||||||
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path
|
debug_options.xla_gpu_cuda_data_dir = lib.cuda_path
|
||||||
|
|
||||||
if FLAGS.jax_disable_most_optimizations:
|
if _DISABLE_MOST_OPTIMIZATIONS.value:
|
||||||
|
|
||||||
debug_options.xla_backend_optimization_level = 0
|
debug_options.xla_backend_optimization_level = 0
|
||||||
debug_options.xla_llvm_disable_expensive_passes = True
|
debug_options.xla_llvm_disable_expensive_passes = True
|
||||||
debug_options.xla_test_all_input_layouts = False
|
debug_options.xla_test_all_input_layouts = False
|
||||||
|
|
||||||
compile_options.profile_version = FLAGS.jax_xla_profile_version
|
compile_options.profile_version = _XLA_PROFILE_VERSION.value
|
||||||
return compile_options
|
return compile_options
|
||||||
|
|
||||||
|
|
||||||
@ -264,9 +263,9 @@ register_backend_factory('cpu',
|
|||||||
|
|
||||||
|
|
||||||
def make_gpu_client(
|
def make_gpu_client(
|
||||||
*, platform_name: str, visible_devices_flag: str
|
*, platform_name: str, visible_devices_flag: jax_config.FlagHolder[str]
|
||||||
) -> xla_client.Client:
|
) -> xla_client.Client:
|
||||||
visible_devices = getattr(FLAGS, visible_devices_flag, "all")
|
visible_devices = visible_devices_flag.value
|
||||||
allowed_devices = None
|
allowed_devices = None
|
||||||
if visible_devices != "all":
|
if visible_devices != "all":
|
||||||
allowed_devices = {int(x) for x in visible_devices.split(",")}
|
allowed_devices = {int(x) for x in visible_devices.split(",")}
|
||||||
@ -292,15 +291,25 @@ def make_gpu_client(
|
|||||||
|
|
||||||
if hasattr(xla_client, "make_gpu_client"):
|
if hasattr(xla_client, "make_gpu_client"):
|
||||||
register_backend_factory(
|
register_backend_factory(
|
||||||
'cuda', partial(make_gpu_client, platform_name='cuda',
|
"cuda",
|
||||||
visible_devices_flag='jax_cuda_visible_devices'),
|
partial(
|
||||||
|
make_gpu_client,
|
||||||
|
platform_name="cuda",
|
||||||
|
visible_devices_flag=CUDA_VISIBLE_DEVICES,
|
||||||
|
),
|
||||||
priority=200,
|
priority=200,
|
||||||
fail_quietly=True)
|
fail_quietly=True,
|
||||||
|
)
|
||||||
register_backend_factory(
|
register_backend_factory(
|
||||||
'rocm', partial(make_gpu_client, platform_name='rocm',
|
"rocm",
|
||||||
visible_devices_flag='jax_rocm_visible_devices'),
|
partial(
|
||||||
|
make_gpu_client,
|
||||||
|
platform_name="rocm",
|
||||||
|
visible_devices_flag=_ROCM_VISIBLE_DEVICES,
|
||||||
|
),
|
||||||
priority=200,
|
priority=200,
|
||||||
fail_quietly=True)
|
fail_quietly=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if hasattr(xla_client, "make_tpu_client"):
|
if hasattr(xla_client, "make_tpu_client"):
|
||||||
@ -623,7 +632,7 @@ def backends() -> dict[str, xla_client.Client]:
|
|||||||
# support anything else there at the moment and warning would be pointless.
|
# support anything else there at the moment and warning would be pointless.
|
||||||
if (py_platform.system() != "Darwin" and
|
if (py_platform.system() != "Darwin" and
|
||||||
_default_backend.platform == "cpu" and
|
_default_backend.platform == "cpu" and
|
||||||
FLAGS.jax_platform_name != 'cpu'):
|
_PLATFORM_NAME.value != 'cpu'):
|
||||||
logger.warning('No GPU/TPU found, falling back to CPU. '
|
logger.warning('No GPU/TPU found, falling back to CPU. '
|
||||||
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
|
'(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
|
||||||
return _backends
|
return _backends
|
||||||
@ -677,8 +686,7 @@ def _get_backend_uncached(
|
|||||||
if platform is not None and not isinstance(platform, str):
|
if platform is not None and not isinstance(platform, str):
|
||||||
return platform
|
return platform
|
||||||
|
|
||||||
platform = (platform or FLAGS.jax_xla_backend or FLAGS.jax_platform_name
|
platform = (platform or _XLA_BACKEND.value or _PLATFORM_NAME.value or None)
|
||||||
or None)
|
|
||||||
|
|
||||||
bs = backends()
|
bs = backends()
|
||||||
if platform is not None:
|
if platform is not None:
|
||||||
|
@ -507,7 +507,7 @@ import warnings
|
|||||||
|
|
||||||
from jax._src import api
|
from jax._src import api
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax import config
|
from jax._src import config
|
||||||
from jax import custom_derivatives
|
from jax import custom_derivatives
|
||||||
from jax._src import dtypes
|
from jax._src import dtypes
|
||||||
from jax import lax
|
from jax import lax
|
||||||
@ -531,17 +531,46 @@ from jax._src.lib.mlir.dialects import hlo
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
FLAGS = config.FLAGS
|
_HOST_CALLBACK_INLINE = config.DEFINE_bool(
|
||||||
|
'jax_host_callback_inline',
|
||||||
|
config.bool_env('JAX_HOST_CALLBACK_INLINE', False),
|
||||||
|
help='Inline the host_callback, if not in a staged context.'
|
||||||
|
)
|
||||||
|
_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE = config.DEFINE_integer(
|
||||||
|
'jax_host_callback_max_queue_byte_size',
|
||||||
|
config.int_env('JAX_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE', int(256 * 1e6)),
|
||||||
|
help=('The size in bytes of the buffer used to hold outfeeds from each '
|
||||||
|
'device. When this capacity is reached consuming outfeeds from the '
|
||||||
|
'device is paused, thus potentially pausing the device computation, '
|
||||||
|
'until the Python callback consume more outfeeds.'),
|
||||||
|
lower_bound=int(16 * 1e6)
|
||||||
|
)
|
||||||
|
_HOST_CALLBACK_OUTFEED = config.DEFINE_bool(
|
||||||
|
'jax_host_callback_outfeed',
|
||||||
|
config.bool_env('JAX_HOST_CALLBACK_OUTFEED', False),
|
||||||
|
help=(
|
||||||
|
'Use outfeed implementation for host_callback, even on CPU and GPU. '
|
||||||
|
'If false, use the CustomCall implementation. '
|
||||||
|
'Has no effect on TPU, since only the outfeed mechanism is implemented.'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_HOST_CALLBACK_AD_TRANSFORMS = config.DEFINE_bool(
|
||||||
|
'jax_host_callback_ad_transforms',
|
||||||
|
config.bool_env('JAX_HOST_CALLBACK_AD_TRANSFORMS', False),
|
||||||
|
help=(
|
||||||
|
'Enable support for jvp/vjp for the host_callback primitives. Default is '
|
||||||
|
'False, which means that host_callback operates only on primals. '
|
||||||
|
'The flag exists only temporarily, for backward compatibility.'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _inline_host_callback() -> bool:
|
|
||||||
return FLAGS.jax_host_callback_inline
|
|
||||||
|
|
||||||
|
|
||||||
def _use_outfeed(platform: str) -> bool:
|
def _use_outfeed(platform: str) -> bool:
|
||||||
return (platform in ("tpu", "gpu", "cuda", "rocm") or FLAGS.jax_host_callback_outfeed)
|
return (platform in ("tpu", "gpu", "cuda", "rocm") or
|
||||||
|
_HOST_CALLBACK_OUTFEED.value)
|
||||||
|
|
||||||
|
|
||||||
def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend):
|
def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend):
|
||||||
@ -620,7 +649,7 @@ def id_tap(tap_func,
|
|||||||
"pre-apply keyword arguments, either by using a closure or by passing "
|
"pre-apply keyword arguments, either by using a closure or by passing "
|
||||||
"``functools.partial(tap_func, **kwargs)``.")
|
"``functools.partial(tap_func, **kwargs)``.")
|
||||||
raise TypeError(msg)
|
raise TypeError(msg)
|
||||||
if FLAGS.jax_host_callback_ad_transforms:
|
if _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||||
warnings.warn('The flag jax_host_callback_ad_transforms is for temporary '
|
warnings.warn('The flag jax_host_callback_ad_transforms is for temporary '
|
||||||
'backwards compatibility mode. This flag, and the behavior '
|
'backwards compatibility mode. This flag, and the behavior '
|
||||||
'it enabled will be removed soon.',
|
'it enabled will be removed soon.',
|
||||||
@ -642,7 +671,7 @@ def id_tap(tap_func,
|
|||||||
if result is not None:
|
if result is not None:
|
||||||
# Return the results, but add a dependency on the call, to ensure it
|
# Return the results, but add a dependency on the call, to ensure it
|
||||||
# is kept in the graph.
|
# is kept in the graph.
|
||||||
if FLAGS.jax_host_callback_ad_transforms:
|
if _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||||
call_flat_results, _ = tree_util.tree_flatten(call_res)
|
call_flat_results, _ = tree_util.tree_flatten(call_res)
|
||||||
if call_flat_results:
|
if call_flat_results:
|
||||||
call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0])
|
call_flat_results = [id_tap_dep_p.bind(r, call_flat_results[0])
|
||||||
@ -782,7 +811,7 @@ def _call(callback_func: Callable,
|
|||||||
identity=False):
|
identity=False):
|
||||||
# Lazy initialization
|
# Lazy initialization
|
||||||
_initialize_outfeed_receiver(
|
_initialize_outfeed_receiver(
|
||||||
max_callback_queue_size_bytes=FLAGS.jax_host_callback_max_queue_byte_size)
|
max_callback_queue_size_bytes=_HOST_CALLBACK_MAX_QUEUE_BYTE_SIZE.value)
|
||||||
api.check_callable(callback_func)
|
api.check_callable(callback_func)
|
||||||
flat_args, arg_treedef = tree_util.tree_flatten(arg)
|
flat_args, arg_treedef = tree_util.tree_flatten(arg)
|
||||||
for arg in flat_args:
|
for arg in flat_args:
|
||||||
@ -909,7 +938,7 @@ xla.register_translation(id_tap_dep_p,
|
|||||||
id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a)
|
id_tap_dep_p.def_abstract_eval(lambda r_a, _: r_a)
|
||||||
|
|
||||||
def _id_tap_dep_jvp_rule(primals, tangents):
|
def _id_tap_dep_jvp_rule(primals, tangents):
|
||||||
if FLAGS.jax_host_callback_ad_transforms:
|
if _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||||
assert False
|
assert False
|
||||||
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
|
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
|
||||||
return (id_tap_dep_p.bind(primals[0], primals[1]),
|
return (id_tap_dep_p.bind(primals[0], primals[1]),
|
||||||
@ -918,7 +947,7 @@ def _id_tap_dep_jvp_rule(primals, tangents):
|
|||||||
ad.primitive_jvps[id_tap_dep_p] = _id_tap_dep_jvp_rule
|
ad.primitive_jvps[id_tap_dep_p] = _id_tap_dep_jvp_rule
|
||||||
|
|
||||||
def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap):
|
def _id_tap_dep_transpose_rule(cts, arg_res, arg_tap):
|
||||||
if FLAGS.jax_host_callback_ad_transforms:
|
if _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||||
assert False
|
assert False
|
||||||
if ad.is_undefined_primal(arg_res):
|
if ad.is_undefined_primal(arg_res):
|
||||||
ct_res = _instantiate_zeros(cts, arg_res)
|
ct_res = _instantiate_zeros(cts, arg_res)
|
||||||
@ -934,7 +963,7 @@ ad.primitive_transposes[id_tap_dep_p] = _id_tap_dep_transpose_rule
|
|||||||
|
|
||||||
|
|
||||||
def _id_tap_dep_batching_rule(batched_args, batch_dims):
|
def _id_tap_dep_batching_rule(batched_args, batch_dims):
|
||||||
if FLAGS.jax_host_callback_ad_transforms:
|
if _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||||
assert False
|
assert False
|
||||||
arg_res, arg_tap = batched_args
|
arg_res, arg_tap = batched_args
|
||||||
return id_tap_dep_p.bind(arg_res, arg_tap), batch_dims[0]
|
return id_tap_dep_p.bind(arg_res, arg_tap), batch_dims[0]
|
||||||
@ -1013,7 +1042,7 @@ outside_call_p.def_abstract_eval(_outside_call_abstract_eval)
|
|||||||
|
|
||||||
def _outside_call_impl(*args, **params):
|
def _outside_call_impl(*args, **params):
|
||||||
assert "has_token" not in params
|
assert "has_token" not in params
|
||||||
if _inline_host_callback():
|
if _HOST_CALLBACK_INLINE.value:
|
||||||
device_index = params["device_index"]
|
device_index = params["device_index"]
|
||||||
device = xb.devices()[device_index]
|
device = xb.devices()[device_index]
|
||||||
results = _outside_call_run_callback(args, device, send_infeed=False, **params)
|
results = _outside_call_run_callback(args, device, send_infeed=False, **params)
|
||||||
@ -1400,7 +1429,7 @@ def _outside_call_jvp_rule(primals, tangents, **params):
|
|||||||
assert "has_token" not in params
|
assert "has_token" not in params
|
||||||
if not params["identity"]:
|
if not params["identity"]:
|
||||||
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
|
raise NotImplementedError("JVP rule is implemented only for id_tap, not for call.")
|
||||||
if FLAGS.jax_host_callback_ad_transforms:
|
if _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||||
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
|
tangents_instantiated = tuple(map(_instantiate_zeros, tangents, primals))
|
||||||
|
|
||||||
arg_treedef = params["arg_treedef"]
|
arg_treedef = params["arg_treedef"]
|
||||||
@ -1425,7 +1454,7 @@ ad.primitive_jvps[outside_call_p] = _outside_call_jvp_rule
|
|||||||
|
|
||||||
def _outside_call_partial_eval_rule(trace, *args, **params):
|
def _outside_call_partial_eval_rule(trace, *args, **params):
|
||||||
# partial eval is used after jvp and before transpose.
|
# partial eval is used after jvp and before transpose.
|
||||||
if not FLAGS.jax_host_callback_ad_transforms:
|
if not _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||||
# TODO: just remote the partial eval rule
|
# TODO: just remote the partial eval rule
|
||||||
return trace.default_process_primitive(outside_call_p, args, params)
|
return trace.default_process_primitive(outside_call_p, args, params)
|
||||||
transforms = params.get("transforms", ())
|
transforms = params.get("transforms", ())
|
||||||
@ -1492,7 +1521,7 @@ def _outside_call_transpose_rule(cts, *args, **params):
|
|||||||
*cts_instantiated,
|
*cts_instantiated,
|
||||||
**_add_transform(params, "transpose"))
|
**_add_transform(params, "transpose"))
|
||||||
|
|
||||||
if not FLAGS.jax_host_callback_ad_transforms:
|
if not _HOST_CALLBACK_AD_TRANSFORMS.value:
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
assert len(args) % 2 == 0
|
assert len(args) % 2 == 0
|
||||||
|
@ -65,7 +65,7 @@ def main(_):
|
|||||||
tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)
|
tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)
|
||||||
keras_model.fit(train_ds, epochs=FLAGS.num_epochs, validation_data=test_ds)
|
keras_model.fit(train_ds, epochs=FLAGS.num_epochs, validation_data=test_ds)
|
||||||
|
|
||||||
if FLAGS.show_images:
|
if saved_model_main.SHOW_IMAGES.value:
|
||||||
mnist_lib.plot_images(
|
mnist_lib.plot_images(
|
||||||
test_ds,
|
test_ds,
|
||||||
1,
|
1,
|
||||||
|
@ -38,8 +38,8 @@ import optax
|
|||||||
import tensorflow as tf # type: ignore
|
import tensorflow as tf # type: ignore
|
||||||
import tensorflow_datasets as tfds # type: ignore
|
import tensorflow_datasets as tfds # type: ignore
|
||||||
|
|
||||||
flags.DEFINE_boolean("mock_data", False, "Use fake data, for testing.")
|
_MOCK_DATA = flags.DEFINE_boolean("mock_data", False,
|
||||||
FLAGS = flags.FLAGS
|
"Use fake data, for testing.")
|
||||||
|
|
||||||
#### Model parameters
|
#### Model parameters
|
||||||
|
|
||||||
@ -64,7 +64,7 @@ def load_mnist(split: tfds.Split, batch_size: int):
|
|||||||
an iterator with pairs (images, labels). The images have shape
|
an iterator with pairs (images, labels). The images have shape
|
||||||
(B, 28, 28, 1) and the labels have shape (B, 10), where B is the batch_size.
|
(B, 28, 28, 1) and the labels have shape (B, 10), where B is the batch_size.
|
||||||
"""
|
"""
|
||||||
if FLAGS.mock_data:
|
if _MOCK_DATA.value:
|
||||||
with tfds.testing.mock_data(num_examples=batch_size):
|
with tfds.testing.mock_data(num_examples=batch_size):
|
||||||
try:
|
try:
|
||||||
ds = tfds.load("mnist", split=split)
|
ds = tfds.load("mnist", split=split)
|
||||||
|
@ -37,46 +37,49 @@ import numpy as np
|
|||||||
import tensorflow as tf # type: ignore
|
import tensorflow as tf # type: ignore
|
||||||
import tensorflow_datasets as tfds # type: ignore
|
import tensorflow_datasets as tfds # type: ignore
|
||||||
|
|
||||||
flags.DEFINE_enum("model", "mnist_flax", ["mnist_flax", "mnist_pure_jax"],
|
_MODEL = flags.DEFINE_enum(
|
||||||
"Which model to use.")
|
"model", "mnist_flax", ["mnist_flax", "mnist_pure_jax"],
|
||||||
flags.DEFINE_boolean("model_classifier_layer", True,
|
"Which model to use.")
|
||||||
|
_MODEL_CLASSIFIER_LAYER = flags.DEFINE_boolean("model_classifier_layer", True,
|
||||||
("The model should include the classifier layer, or just "
|
("The model should include the classifier layer, or just "
|
||||||
"the last layer of logits. Set this to False when you "
|
"the last layer of logits. Set this to False when you "
|
||||||
"want to reuse the classifier-less model in a larger "
|
"want to reuse the classifier-less model in a larger "
|
||||||
"model. See keras_reuse_main.py and README.md."))
|
"model. See keras_reuse_main.py and README.md."))
|
||||||
flags.DEFINE_string("model_path", "/tmp/jax2tf/saved_models",
|
_MODEL_PATH = flags.DEFINE_string("model_path", "/tmp/jax2tf/saved_models",
|
||||||
"Path under which to save the SavedModel.")
|
"Path under which to save the SavedModel.")
|
||||||
flags.DEFINE_integer("model_version", 1,
|
_MODEL_VERSION = flags.DEFINE_integer("model_version", 1,
|
||||||
("The version number for the SavedModel. Needed for "
|
("The version number for the SavedModel. Needed for "
|
||||||
"serving, larger versions will take precedence"),
|
"serving, larger versions will take precedence"),
|
||||||
lower_bound=1)
|
lower_bound=1)
|
||||||
flags.DEFINE_integer("serving_batch_size", 1,
|
_SERVING_BATCH_SIZE = flags.DEFINE_integer("serving_batch_size", 1,
|
||||||
"For what batch size to prepare the serving signature. "
|
"For what batch size to prepare the serving signature. "
|
||||||
"Use -1 for converting and saving with batch polymorphism.")
|
"Use -1 for converting and saving with batch polymorphism.")
|
||||||
flags.register_validator(
|
flags.register_validator(
|
||||||
"serving_batch_size",
|
"serving_batch_size",
|
||||||
lambda serving_batch_size: serving_batch_size > 0 or serving_batch_size == -1,
|
lambda serving_batch_size: serving_batch_size > 0
|
||||||
message="--serving_batch_size must be either -1 or a positive integer.")
|
or serving_batch_size == -1,
|
||||||
|
message="--serving_batch_size must be either -1 or a positive integer.",
|
||||||
|
)
|
||||||
|
|
||||||
flags.DEFINE_integer("num_epochs", 3, "For how many epochs to train.",
|
_NUM_EPOCHS = flags.DEFINE_integer("num_epochs", 3,
|
||||||
lower_bound=1)
|
"For how many epochs to train.",
|
||||||
flags.DEFINE_boolean(
|
lower_bound=1)
|
||||||
|
_GENERATE_MODEL = flags.DEFINE_boolean(
|
||||||
"generate_model", True,
|
"generate_model", True,
|
||||||
"Train and save a new model. Otherwise, use an existing SavedModel.")
|
"Train and save a new model. Otherwise, use an existing SavedModel.")
|
||||||
flags.DEFINE_boolean(
|
_COMPILE_MODEL = flags.DEFINE_boolean(
|
||||||
"compile_model", True,
|
"compile_model", True,
|
||||||
"Enable TensorFlow jit_compiler for the SavedModel. This is "
|
"Enable TensorFlow jit_compiler for the SavedModel. This is "
|
||||||
"necessary if you want to use the model for TensorFlow serving.")
|
"necessary if you want to use the model for TensorFlow serving.")
|
||||||
flags.DEFINE_boolean("show_model", True, "Show details of saved SavedModel.")
|
_SHOW_MODEL = flags.DEFINE_boolean("show_model", True,
|
||||||
flags.DEFINE_boolean(
|
"Show details of saved SavedModel.")
|
||||||
|
SHOW_IMAGES = flags.DEFINE_boolean(
|
||||||
"show_images", False,
|
"show_images", False,
|
||||||
"Plot some sample images with labels and inference results.")
|
"Plot some sample images with labels and inference results.")
|
||||||
flags.DEFINE_boolean(
|
_TEST_SAVEDMODEL = flags.DEFINE_boolean(
|
||||||
"test_savedmodel", True,
|
"test_savedmodel", True,
|
||||||
"Test TensorFlow inference using the SavedModel w.r.t. the JAX model.")
|
"Test TensorFlow inference using the SavedModel w.r.t. the JAX model.")
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
|
|
||||||
def train_and_save():
|
def train_and_save():
|
||||||
logging.info("Loading the MNIST TensorFlow dataset")
|
logging.info("Loading the MNIST TensorFlow dataset")
|
||||||
@ -85,22 +88,22 @@ def train_and_save():
|
|||||||
test_ds = mnist_lib.load_mnist(
|
test_ds = mnist_lib.load_mnist(
|
||||||
tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)
|
tfds.Split.TEST, batch_size=mnist_lib.test_batch_size)
|
||||||
|
|
||||||
if FLAGS.show_images:
|
if SHOW_IMAGES.value:
|
||||||
mnist_lib.plot_images(train_ds, 1, 5, "Training images", inference_fn=None)
|
mnist_lib.plot_images(train_ds, 1, 5, "Training images", inference_fn=None)
|
||||||
|
|
||||||
the_model_class = pick_model_class()
|
the_model_class = pick_model_class()
|
||||||
model_dir = savedmodel_dir(with_version=True)
|
model_dir = savedmodel_dir(with_version=True)
|
||||||
|
|
||||||
if FLAGS.generate_model:
|
if _GENERATE_MODEL.value:
|
||||||
model_descr = model_description()
|
model_descr = model_description()
|
||||||
logging.info("Generating model for %s", model_descr)
|
logging.info("Generating model for %s", model_descr)
|
||||||
(predict_fn, predict_params) = the_model_class.train(
|
(predict_fn, predict_params) = the_model_class.train(
|
||||||
train_ds,
|
train_ds,
|
||||||
test_ds,
|
test_ds,
|
||||||
FLAGS.num_epochs,
|
num_epochs=_NUM_EPOCHS.value,
|
||||||
with_classifier=FLAGS.model_classifier_layer)
|
with_classifier=_MODEL_CLASSIFIER_LAYER.value)
|
||||||
|
|
||||||
if FLAGS.serving_batch_size == -1:
|
if _SERVING_BATCH_SIZE.value == -1:
|
||||||
# Batch-polymorphic SavedModel
|
# Batch-polymorphic SavedModel
|
||||||
input_signatures = [
|
input_signatures = [
|
||||||
tf.TensorSpec((None,) + mnist_lib.input_shape, tf.float32),
|
tf.TensorSpec((None,) + mnist_lib.input_shape, tf.float32),
|
||||||
@ -109,7 +112,7 @@ def train_and_save():
|
|||||||
else:
|
else:
|
||||||
input_signatures = [
|
input_signatures = [
|
||||||
# The first one will be the serving signature
|
# The first one will be the serving signature
|
||||||
tf.TensorSpec((FLAGS.serving_batch_size,) + mnist_lib.input_shape,
|
tf.TensorSpec((_SERVING_BATCH_SIZE.value,) + mnist_lib.input_shape,
|
||||||
tf.float32),
|
tf.float32),
|
||||||
tf.TensorSpec((mnist_lib.train_batch_size,) + mnist_lib.input_shape,
|
tf.TensorSpec((mnist_lib.train_batch_size,) + mnist_lib.input_shape,
|
||||||
tf.float32),
|
tf.float32),
|
||||||
@ -126,15 +129,15 @@ def train_and_save():
|
|||||||
with_gradient=True,
|
with_gradient=True,
|
||||||
input_signatures=input_signatures,
|
input_signatures=input_signatures,
|
||||||
polymorphic_shapes=polymorphic_shapes,
|
polymorphic_shapes=polymorphic_shapes,
|
||||||
compile_model=FLAGS.compile_model)
|
compile_model=_COMPILE_MODEL.value)
|
||||||
|
|
||||||
if FLAGS.test_savedmodel:
|
if _TEST_SAVEDMODEL.value:
|
||||||
tf_accelerator, tolerances = tf_accelerator_and_tolerances()
|
tf_accelerator, tolerances = tf_accelerator_and_tolerances()
|
||||||
with tf.device(tf_accelerator):
|
with tf.device(tf_accelerator):
|
||||||
logging.info("Testing savedmodel")
|
logging.info("Testing savedmodel")
|
||||||
pure_restored_model = tf.saved_model.load(model_dir)
|
pure_restored_model = tf.saved_model.load(model_dir)
|
||||||
|
|
||||||
if FLAGS.show_images and FLAGS.model_classifier_layer:
|
if SHOW_IMAGES.value and _MODEL_CLASSIFIER_LAYER.value:
|
||||||
mnist_lib.plot_images(
|
mnist_lib.plot_images(
|
||||||
test_ds,
|
test_ds,
|
||||||
1,
|
1,
|
||||||
@ -149,7 +152,7 @@ def train_and_save():
|
|||||||
pure_restored_model(tf.convert_to_tensor(test_input)),
|
pure_restored_model(tf.convert_to_tensor(test_input)),
|
||||||
predict_fn(predict_params, test_input), **tolerances)
|
predict_fn(predict_params, test_input), **tolerances)
|
||||||
|
|
||||||
if FLAGS.show_model:
|
if _SHOW_MODEL.value:
|
||||||
def print_model(model_dir: str):
|
def print_model(model_dir: str):
|
||||||
cmd = f"saved_model_cli show --all --dir {model_dir}"
|
cmd = f"saved_model_cli show --all --dir {model_dir}"
|
||||||
print(cmd)
|
print(cmd)
|
||||||
@ -160,18 +163,18 @@ def train_and_save():
|
|||||||
|
|
||||||
def pick_model_class():
|
def pick_model_class():
|
||||||
"""Picks one of PureJaxMNIST or FlaxMNIST."""
|
"""Picks one of PureJaxMNIST or FlaxMNIST."""
|
||||||
if FLAGS.model == "mnist_pure_jax":
|
if _MODEL.value == "mnist_pure_jax":
|
||||||
return mnist_lib.PureJaxMNIST
|
return mnist_lib.PureJaxMNIST
|
||||||
elif FLAGS.model == "mnist_flax":
|
elif _MODEL.value == "mnist_flax":
|
||||||
return mnist_lib.FlaxMNIST
|
return mnist_lib.FlaxMNIST
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognized model: {FLAGS.model}")
|
raise ValueError(f"Unrecognized model: {_MODEL.value}")
|
||||||
|
|
||||||
|
|
||||||
def model_description() -> str:
|
def model_description() -> str:
|
||||||
"""A short description of the picked model."""
|
"""A short description of the picked model."""
|
||||||
res = pick_model_class().name
|
res = pick_model_class().name
|
||||||
if not FLAGS.model_classifier_layer:
|
if not _MODEL_CLASSIFIER_LAYER.value:
|
||||||
res += " (features_only)"
|
res += " (features_only)"
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@ -179,11 +182,11 @@ def model_description() -> str:
|
|||||||
def savedmodel_dir(with_version: bool = True) -> str:
|
def savedmodel_dir(with_version: bool = True) -> str:
|
||||||
"""The directory where we save the SavedModel."""
|
"""The directory where we save the SavedModel."""
|
||||||
model_dir = os.path.join(
|
model_dir = os.path.join(
|
||||||
FLAGS.model_path,
|
_MODEL_PATH.value,
|
||||||
FLAGS.model + ('' if FLAGS.model_classifier_layer else '_features')
|
_MODEL.value + ('' if _MODEL_CLASSIFIER_LAYER.value else '_features')
|
||||||
)
|
)
|
||||||
if with_version:
|
if with_version:
|
||||||
model_dir = os.path.join(model_dir, str(FLAGS.model_version))
|
model_dir = os.path.join(model_dir, str(_MODEL_VERSION.value))
|
||||||
return model_dir
|
return model_dir
|
||||||
|
|
||||||
|
|
||||||
|
@ -32,31 +32,32 @@ from tensorflow_serving.apis import predict_pb2 # type: ignore[import]
|
|||||||
from tensorflow_serving.apis import prediction_service_pb2_grpc
|
from tensorflow_serving.apis import prediction_service_pb2_grpc
|
||||||
|
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
_USE_GRPC = flags.DEFINE_boolean(
|
||||||
|
|
||||||
flags.DEFINE_boolean(
|
|
||||||
"use_grpc", True,
|
"use_grpc", True,
|
||||||
"Use the gRPC API (default), or the HTTP REST API.")
|
"Use the gRPC API (default), or the HTTP REST API.")
|
||||||
|
|
||||||
flags.DEFINE_string(
|
_MODEL_SPEC_NAME = flags.DEFINE_string(
|
||||||
"model_spec_name", "",
|
"model_spec_name", "",
|
||||||
"The name you used to export your model to model server (e.g., mnist_flax).")
|
"The name you used to export your model to model server (e.g., mnist_flax).")
|
||||||
|
|
||||||
flags.DEFINE_string(
|
_PREDICTION_SERVICE_ADDR = flags.DEFINE_string(
|
||||||
"prediction_service_addr",
|
"prediction_service_addr",
|
||||||
"localhost:8500",
|
"localhost:8500",
|
||||||
"Stubby endpoint for the prediction service. If you serve your model "
|
"Stubby endpoint for the prediction service. If you serve your model "
|
||||||
"locally using TensorFlow model server, then you can use \"localhost:8500\""
|
"locally using TensorFlow model server, then you can use \"localhost:8500\""
|
||||||
"for the gRPC server and \"localhost:8501\" for the HTTP REST server.")
|
"for the gRPC server and \"localhost:8501\" for the HTTP REST server.")
|
||||||
|
|
||||||
flags.DEFINE_integer("serving_batch_size", 1,
|
_SERVING_BATCH_SIZE = flags.DEFINE_integer(
|
||||||
"Batch size for the serving request. Must match the "
|
"serving_batch_size",
|
||||||
"batch size at which the model was saved. Must divide "
|
1,
|
||||||
"--count_images",
|
"Batch size for the serving request. Must match the "
|
||||||
lower_bound=1)
|
"batch size at which the model was saved. Must divide "
|
||||||
flags.DEFINE_integer("count_images", 16,
|
"--count_images",
|
||||||
"How many images to test.",
|
lower_bound=1,
|
||||||
lower_bound=1)
|
)
|
||||||
|
_COUNT_IMAGES = flags.DEFINE_integer(
|
||||||
|
"count_images", 16, "How many images to test.", lower_bound=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def serving_call_mnist(images):
|
def serving_call_mnist(images):
|
||||||
@ -69,12 +70,12 @@ def serving_call_mnist(images):
|
|||||||
Returns:
|
Returns:
|
||||||
A numpy.ndarray of shape [B, 10] with the one-hot inference response.
|
A numpy.ndarray of shape [B, 10] with the one-hot inference response.
|
||||||
"""
|
"""
|
||||||
if FLAGS.use_grpc:
|
if _USE_GRPC.value:
|
||||||
channel = grpc.insecure_channel(FLAGS.prediction_service_addr)
|
channel = grpc.insecure_channel(_PREDICTION_SERVICE_ADDR.value)
|
||||||
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
|
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
|
||||||
|
|
||||||
request = predict_pb2.PredictRequest()
|
request = predict_pb2.PredictRequest()
|
||||||
request.model_spec.name = FLAGS.model_spec_name
|
request.model_spec.name = _MODEL_SPEC_NAME.value
|
||||||
request.model_spec.signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
request.model_spec.signature_name = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
|
||||||
# You can see the name of the input ("inputs") in the SavedModel dump.
|
# You can see the name of the input ("inputs") in the SavedModel dump.
|
||||||
request.inputs["inputs"].CopyFrom(
|
request.inputs["inputs"].CopyFrom(
|
||||||
@ -90,7 +91,7 @@ def serving_call_mnist(images):
|
|||||||
images_json = json.dumps(images.tolist())
|
images_json = json.dumps(images.tolist())
|
||||||
# You can see the name of the input ("inputs") in the SavedModel dump.
|
# You can see the name of the input ("inputs") in the SavedModel dump.
|
||||||
data = f'{{"inputs": {images_json}}}'
|
data = f'{{"inputs": {images_json}}}'
|
||||||
predict_url = f"http://{FLAGS.prediction_service_addr}/v1/models/{FLAGS.model_spec_name}:predict"
|
predict_url = f"http://{_PREDICTION_SERVICE_ADDR.value}/v1/models/{_MODEL_SPEC_NAME.value}:predict"
|
||||||
response = requests.post(predict_url, data=data)
|
response = requests.post(predict_url, data=data)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
msg = (f"Received error response {response.status_code} from model "
|
msg = (f"Received error response {response.status_code} from model "
|
||||||
@ -101,14 +102,14 @@ def serving_call_mnist(images):
|
|||||||
|
|
||||||
|
|
||||||
def main(_):
|
def main(_):
|
||||||
if FLAGS.count_images % FLAGS.serving_batch_size != 0:
|
if _COUNT_IMAGES.value % _SERVING_BATCH_SIZE.value != 0:
|
||||||
raise ValueError(f"The count_images ({FLAGS.count_images}) must be a "
|
raise ValueError(f"The count_images ({_COUNT_IMAGES.value}) must be a "
|
||||||
"multiple of "
|
"multiple of "
|
||||||
f"serving_batch_size ({FLAGS.serving_batch_size})")
|
f"serving_batch_size ({_SERVING_BATCH_SIZE.value})")
|
||||||
test_ds = mnist_lib.load_mnist(tfds.Split.TEST,
|
test_ds = mnist_lib.load_mnist(tfds.Split.TEST,
|
||||||
batch_size=FLAGS.serving_batch_size)
|
batch_size=_SERVING_BATCH_SIZE.value)
|
||||||
images_and_labels = tfds.as_numpy(test_ds.take(
|
images_and_labels = tfds.as_numpy(test_ds.take(
|
||||||
FLAGS.count_images // FLAGS.serving_batch_size))
|
_COUNT_IMAGES.value // _SERVING_BATCH_SIZE.value))
|
||||||
|
|
||||||
accurate_count = 0
|
accurate_count = 0
|
||||||
for batch_idx, (images, labels) in enumerate(images_and_labels):
|
for batch_idx, (images, labels) in enumerate(images_and_labels):
|
||||||
@ -117,7 +118,7 @@ def main(_):
|
|||||||
labels_digit = np.argmax(labels, axis=1)
|
labels_digit = np.argmax(labels, axis=1)
|
||||||
accurate_count += np.sum(labels_digit == predictions_digit)
|
accurate_count += np.sum(labels_digit == predictions_digit)
|
||||||
running_accuracy = (
|
running_accuracy = (
|
||||||
100. * accurate_count / (1 + batch_idx) / FLAGS.serving_batch_size)
|
100. * accurate_count / (1 + batch_idx) / _SERVING_BATCH_SIZE.value)
|
||||||
logging.info(
|
logging.info(
|
||||||
" predicted digits = %s labels %s. Running accuracy %.3f%%",
|
" predicted digits = %s labels %s. Running accuracy %.3f%%",
|
||||||
predictions_digit, labels_digit, running_accuracy)
|
predictions_digit, labels_digit, running_accuracy)
|
||||||
|
@ -33,15 +33,17 @@ import tensorflowjs as tfjs
|
|||||||
import input_pipeline # type: ignore[import]
|
import input_pipeline # type: ignore[import]
|
||||||
|
|
||||||
|
|
||||||
flags.DEFINE_integer("num_epochs", 5,
|
_NUM_EPOCHS = flags.DEFINE_integer(
|
||||||
("Number of epochs to train for."))
|
"num_epochs", 5, "Number of epochs to train for."
|
||||||
flags.DEFINE_integer("num_classes", 100, "Number of classification classes.")
|
)
|
||||||
|
_NUM_CLASSES = flags.DEFINE_integer(
|
||||||
|
"num_classes", 100, "Number of classification classes."
|
||||||
|
)
|
||||||
|
|
||||||
flags.register_validator("num_classes",
|
flags.register_validator("num_classes",
|
||||||
lambda value: value >= 1 and value <= 100,
|
lambda value: value >= 1 and value <= 100,
|
||||||
message="--num_classes must be in range [1, 100]")
|
message="--num_classes must be in range [1, 100]")
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
# The code below is an adaptation for Flax from the work published here:
|
# The code below is an adaptation for Flax from the work published here:
|
||||||
# https://blog.tensorflow.org/2018/07/train-model-in-tfkeras-with-colab-and-run-in-browser-tensorflowjs.html
|
# https://blog.tensorflow.org/2018/07/train-model-in-tfkeras-with-colab-and-run-in-browser-tensorflowjs.html
|
||||||
@ -65,7 +67,7 @@ class QuickDraw(nn.Module):
|
|||||||
x = nn.Dense(features=128)(x)
|
x = nn.Dense(features=128)(x)
|
||||||
x = nn.relu(x)
|
x = nn.relu(x)
|
||||||
|
|
||||||
x = nn.Dense(features=FLAGS.num_classes)(x)
|
x = nn.Dense(features=_NUM_CLASSES.value)(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -75,7 +77,7 @@ def apply_model(state, inputs, labels):
|
|||||||
"""Computes gradients, loss and accuracy for a single batch."""
|
"""Computes gradients, loss and accuracy for a single batch."""
|
||||||
def loss_fn(params):
|
def loss_fn(params):
|
||||||
logits = state.apply_fn({'params': params}, inputs)
|
logits = state.apply_fn({'params': params}, inputs)
|
||||||
one_hot = jax.nn.one_hot(labels, FLAGS.num_classes)
|
one_hot = jax.nn.one_hot(labels, _NUM_CLASSES.value)
|
||||||
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
|
loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
|
||||||
return loss, logits
|
return loss, logits
|
||||||
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
|
||||||
@ -113,7 +115,7 @@ def create_train_state(rng):
|
|||||||
|
|
||||||
|
|
||||||
def train(state, train_ds, test_ds):
|
def train(state, train_ds, test_ds):
|
||||||
for epoch in range(1, FLAGS.num_epochs+1):
|
for epoch in range(1, _NUM_EPOCHS.value+1):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
state, train_loss, train_accuracy = run_epoch(state, train_ds)
|
state, train_loss, train_accuracy = run_epoch(state, train_ds)
|
||||||
@ -136,12 +138,12 @@ def main(argv):
|
|||||||
|
|
||||||
base_model_path = "/tmp/jax2tf/tf_js_quickdraw"
|
base_model_path = "/tmp/jax2tf/tf_js_quickdraw"
|
||||||
dataset_path = os.path.join(base_model_path, "data")
|
dataset_path = os.path.join(base_model_path, "data")
|
||||||
classes = input_pipeline.download_dataset(dataset_path, FLAGS.num_classes)
|
classes = input_pipeline.download_dataset(dataset_path, _NUM_CLASSES.value)
|
||||||
assert len(classes) == FLAGS.num_classes, "Incorrect number of classes"
|
assert len(classes) == _NUM_CLASSES.value, "Incorrect number of classes"
|
||||||
print(f"Classes are: {classes}")
|
print(f"Classes are: {classes}")
|
||||||
print("Loading dataset into memory...")
|
print("Loading dataset into memory...")
|
||||||
train_ds, test_ds = input_pipeline.get_datasets(dataset_path, classes)
|
train_ds, test_ds = input_pipeline.get_datasets(dataset_path, classes)
|
||||||
print(f"Starting training for {FLAGS.num_epochs} epochs...")
|
print(f"Starting training for {_NUM_EPOCHS.value} epochs...")
|
||||||
|
|
||||||
state = create_train_state(jax.random.PRNGKey(0))
|
state = create_train_state(jax.random.PRNGKey(0))
|
||||||
state = train(state, train_ds, test_ds)
|
state = train(state, train_ds, test_ds)
|
||||||
|
@ -24,14 +24,19 @@ import numpy as np
|
|||||||
import tensorflow as tf # type: ignore[import]
|
import tensorflow as tf # type: ignore[import]
|
||||||
import tensorflow_datasets as tfds # type: ignore[import]
|
import tensorflow_datasets as tfds # type: ignore[import]
|
||||||
|
|
||||||
flags.DEFINE_string('tflite_file_path',
|
_TFLITE_FILE_PATH = flags.DEFINE_string(
|
||||||
'/usr/local/google/home/qiuminxu/jax2tf/mnist.tflite',
|
'tflite_file_path',
|
||||||
'Path where to save the TensorFlow Lite file.')
|
'/tmp/mnist.tflite',
|
||||||
flags.DEFINE_integer('serving_batch_size', 4,
|
'Path where to save the TensorFlow Lite file.',
|
||||||
('For what batch size to prepare the serving signature. '))
|
)
|
||||||
flags.DEFINE_integer('num_epochs', 10, 'For how many epochs to train.')
|
_SERVING_BATCH_SIZE = flags.DEFINE_integer(
|
||||||
|
'serving_batch_size',
|
||||||
FLAGS = flags.FLAGS
|
4,
|
||||||
|
'For what batch size to prepare the serving signature. ',
|
||||||
|
)
|
||||||
|
_NUM_EPOCHS = flags.DEFINE_integer(
|
||||||
|
'num_epochs', 10, 'For how many epochs to train.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# A helper function to evaluate the TF Lite model using "test" dataset.
|
# A helper function to evaluate the TF Lite model using "test" dataset.
|
||||||
@ -71,10 +76,11 @@ def main(_):
|
|||||||
train_ds = mnist_lib.load_mnist(
|
train_ds = mnist_lib.load_mnist(
|
||||||
tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size)
|
tfds.Split.TRAIN, batch_size=mnist_lib.train_batch_size)
|
||||||
test_ds = mnist_lib.load_mnist(
|
test_ds = mnist_lib.load_mnist(
|
||||||
tfds.Split.TEST, batch_size=FLAGS.serving_batch_size)
|
tfds.Split.TEST, batch_size=_SERVING_BATCH_SIZE)
|
||||||
|
|
||||||
(flax_predict,
|
(flax_predict, flax_params) = mnist_lib.FlaxMNIST.train(
|
||||||
flax_params) = mnist_lib.FlaxMNIST.train(train_ds, test_ds, FLAGS.num_epochs)
|
train_ds, test_ds, _NUM_EPOCHS.value
|
||||||
|
)
|
||||||
|
|
||||||
def predict(image):
|
def predict(image):
|
||||||
return flax_predict(flax_params, image)
|
return flax_predict(flax_params, image)
|
||||||
@ -84,7 +90,7 @@ def main(_):
|
|||||||
jax2tf.convert(predict, enable_xla=False),
|
jax2tf.convert(predict, enable_xla=False),
|
||||||
input_signature=[
|
input_signature=[
|
||||||
tf.TensorSpec(
|
tf.TensorSpec(
|
||||||
shape=[FLAGS.serving_batch_size, 28, 28, 1],
|
shape=[_SERVING_BATCH_SIZE, 28, 28, 1],
|
||||||
dtype=tf.float32,
|
dtype=tf.float32,
|
||||||
name='input')
|
name='input')
|
||||||
],
|
],
|
||||||
@ -126,7 +132,7 @@ def main(_):
|
|||||||
print('Quantized model accuracy = %.4f' % quantized_accuracy)
|
print('Quantized model accuracy = %.4f' % quantized_accuracy)
|
||||||
print('Accuracy drop = %.4f' % (float_accuracy - quantized_accuracy))
|
print('Accuracy drop = %.4f' % (float_accuracy - quantized_accuracy))
|
||||||
|
|
||||||
f = open(FLAGS.tflite_file_path, 'wb')
|
f = open(_TFLITE_FILE_PATH.value, 'wb')
|
||||||
f.write(tflite_quantized_model)
|
f.write(tflite_quantized_model)
|
||||||
f.close()
|
f.close()
|
||||||
|
|
||||||
|
@ -58,34 +58,32 @@ from jax.experimental.jax2tf.shape_poly import InconclusiveDimensionOperation
|
|||||||
from jax.experimental.jax2tf.tests.model_harness import ALL_HARNESSES
|
from jax.experimental.jax2tf.tests.model_harness import ALL_HARNESSES
|
||||||
from jax.experimental.jax2tf.tests.converters import ALL_CONVERTERS
|
from jax.experimental.jax2tf.tests.converters import ALL_CONVERTERS
|
||||||
|
|
||||||
flags.DEFINE_list("converters", [x.name for x in ALL_CONVERTERS],
|
_CONVERTERS = flags.DEFINE_list("converters", [x.name for x in ALL_CONVERTERS],
|
||||||
"Which converters to test.")
|
"Which converters to test.")
|
||||||
|
|
||||||
flags.DEFINE_list("examples", [],
|
_EXAMPLES = flags.DEFINE_list("examples", [],
|
||||||
("List of examples to test, e.g.: 'flax/mnist,flax/seq2seq'. "
|
("List of examples to test, e.g.: 'flax/mnist,flax/seq2seq'. "
|
||||||
"If empty, will test all examples."))
|
"If empty, will test all examples."))
|
||||||
|
|
||||||
flags.DEFINE_string("example_prefix", "",
|
_EXAMPLE_PREFIX = flags.DEFINE_string("example_prefix", "",
|
||||||
("Prefix for filtering tests. For instance 'flax/mnist' "
|
("Prefix for filtering tests. For instance 'flax/mnist' "
|
||||||
"will test all examples starting with 'flax/mnist' "
|
"will test all examples starting with 'flax/mnist' "
|
||||||
"(including all polymorphic tests)."))
|
"(including all polymorphic tests)."))
|
||||||
|
|
||||||
flags.DEFINE_bool(
|
_WRITE_MARKDOWN = flags.DEFINE_bool(
|
||||||
"write_markdown", True,
|
"write_markdown", True,
|
||||||
"If true, write results as Markdown. Otherwise, only output to stdout.")
|
"If true, write results as Markdown. Otherwise, only output to stdout.")
|
||||||
|
|
||||||
flags.DEFINE_bool(
|
_FAIL_ON_ERROR = flags.DEFINE_bool(
|
||||||
"fail_on_error", False,
|
"fail_on_error", False,
|
||||||
("If true, exit with an error when a conversion fails. Useful for "
|
("If true, exit with an error when a conversion fails. Useful for "
|
||||||
"debugging because it will show the entire stack trace."))
|
"debugging because it will show the entire stack trace."))
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
|
||||||
|
|
||||||
def _write_markdown(results: dict[str, list[tuple[str, str,]]]) -> None:
|
def _write_markdown(results: dict[str, list[tuple[str, str,]]]) -> None:
|
||||||
"""Writes all results to Markdown file."""
|
"""Writes all results to Markdown file."""
|
||||||
table_lines = []
|
table_lines = []
|
||||||
converters = FLAGS.converters
|
converters = _CONVERTERS.value
|
||||||
|
|
||||||
table_lines.append("| Example | " + " ".join([f"{c} |" for c in converters]))
|
table_lines.append("| Example | " + " ".join([f"{c} |" for c in converters]))
|
||||||
table_lines.append("|" + (" --- |" * (len(converters) + 1)))
|
table_lines.append("|" + (" --- |" * (len(converters) + 1)))
|
||||||
@ -173,7 +171,7 @@ def test_converters():
|
|||||||
exit()
|
exit()
|
||||||
|
|
||||||
def _maybe_reraise(e):
|
def _maybe_reraise(e):
|
||||||
if FLAGS.fail_on_error:
|
if _FAIL_ON_ERROR.value:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _format(e):
|
def _format(e):
|
||||||
@ -183,13 +181,13 @@ def test_converters():
|
|||||||
return msg
|
return msg
|
||||||
|
|
||||||
converters = list(
|
converters = list(
|
||||||
filter(lambda x: x.name in FLAGS.converters, ALL_CONVERTERS))
|
filter(lambda x: x.name in _CONVERTERS.value, ALL_CONVERTERS))
|
||||||
_exit_if_empty(converters, "converters")
|
_exit_if_empty(converters, "converters")
|
||||||
|
|
||||||
harnesses_to_test = {
|
harnesses_to_test = {
|
||||||
name: fn for name, fn in ALL_HARNESSES.items()
|
name: fn for name, fn in ALL_HARNESSES.items()
|
||||||
if (not FLAGS.examples or name in FLAGS.examples) and
|
if (not _EXAMPLES.value or name in _EXAMPLES.value) and
|
||||||
(not FLAGS.example_prefix or name.startswith(FLAGS.example_prefix))
|
(not _EXAMPLE_PREFIX.value or name.startswith(_EXAMPLE_PREFIX.value))
|
||||||
}
|
}
|
||||||
_exit_if_empty(harnesses_to_test, "harness")
|
_exit_if_empty(harnesses_to_test, "harness")
|
||||||
|
|
||||||
@ -243,7 +241,7 @@ def test_converters():
|
|||||||
converter_results.append((converter.name, error_msg))
|
converter_results.append((converter.name, error_msg))
|
||||||
results[harness.name] = converter_results
|
results[harness.name] = converter_results
|
||||||
|
|
||||||
if FLAGS.write_markdown:
|
if _WRITE_MARKDOWN:
|
||||||
_write_markdown(results)
|
_write_markdown(results)
|
||||||
else:
|
else:
|
||||||
print("=== NOT writing results to Markdown.")
|
print("=== NOT writing results to Markdown.")
|
||||||
|
@ -140,7 +140,6 @@ def jax_to_ir(name, deps, fn, input_shapes, constants = None, format = "HLO"):
|
|||||||
from absl import app
|
from absl import app
|
||||||
import jax.tools.jax_to_ir as jax_to_ir
|
import jax.tools.jax_to_ir as jax_to_ir
|
||||||
|
|
||||||
jax_to_ir.set_up_flags()
|
|
||||||
app.run(jax_to_ir.main)
|
app.run(jax_to_ir.main)
|
||||||
EOF
|
EOF
|
||||||
""".format(runner = runner),
|
""".format(runner = runner),
|
||||||
|
@ -87,7 +87,30 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
tf = None # type: ignore
|
tf = None # type: ignore
|
||||||
|
|
||||||
FLAGS = flags.FLAGS
|
|
||||||
|
_FN = flags.DEFINE_string(
|
||||||
|
'fn', None, "Fully-qualified name of function that we're going to convert"
|
||||||
|
)
|
||||||
|
_INPUT_SHAPES = flags.DEFINE_string(
|
||||||
|
'input_shapes', None, 'Python dict indicating XLA shapes of params'
|
||||||
|
)
|
||||||
|
_CONSTANTS = flags.DEFINE_string(
|
||||||
|
'constants', '{}', 'Python dict giving constant values for some params'
|
||||||
|
)
|
||||||
|
_EVALED_CONSTANTS = flags.DEFINE_string(
|
||||||
|
'evaled_constants',
|
||||||
|
'{}',
|
||||||
|
'Python dict giving constant values for some params. '
|
||||||
|
'Values in this dict that are of type str are evaluated '
|
||||||
|
'using ast.literal_eval.',
|
||||||
|
)
|
||||||
|
_IR_FORMAT = flags.DEFINE_enum(
|
||||||
|
'ir_format', 'HLO', ('HLO', 'TF'), 'Output format.'
|
||||||
|
)
|
||||||
|
_IR_DEST = flags.DEFINE_string('ir_dest', None, 'File to write IR to')
|
||||||
|
_IR_HUMAN_DEST = flags.DEFINE_string(
|
||||||
|
'ir_human_dest', None, 'File to write human readable debug output'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def jax_to_ir(fn, input_shapes, *, constants=None, format):
|
def jax_to_ir(fn, input_shapes, *, constants=None, format):
|
||||||
@ -163,25 +186,25 @@ def main(argv):
|
|||||||
if len(argv) != 1:
|
if len(argv) != 1:
|
||||||
raise app.UsageError('No positional arguments are accepted.')
|
raise app.UsageError('No positional arguments are accepted.')
|
||||||
|
|
||||||
if not FLAGS.ir_dest and not FLAGS.ir_human_dest:
|
if not _IR_DEST.value and not _IR_HUMAN_DEST.value:
|
||||||
raise app.Error('At least one of --ir_dest and '
|
raise app.Error('At least one of --ir_dest and '
|
||||||
'--ir_human_dest is required.')
|
'--ir_human_dest is required.')
|
||||||
|
|
||||||
module_name, fn_name = FLAGS.fn.rsplit('.', 1)
|
module_name, fn_name = _FN.value.rsplit('.', 1)
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
fn = getattr(module, fn_name)
|
fn = getattr(module, fn_name)
|
||||||
|
|
||||||
input_shapes = [(name, parse_shape_str(shape_str))
|
input_shapes = [(name, parse_shape_str(shape_str))
|
||||||
for name, shape_str in literal_eval(FLAGS.input_shapes)]
|
for name, shape_str in literal_eval(_INPUT_SHAPES.value)]
|
||||||
|
|
||||||
# Parse --constants and --evaled_constants.
|
# Parse --constants and --evaled_constants.
|
||||||
constants = {}
|
constants = {}
|
||||||
for k, v in literal_eval(FLAGS.constants).items():
|
for k, v in literal_eval(_CONSTANTS.value).items():
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
v = jnp.asarray(v)
|
v = jnp.asarray(v)
|
||||||
constants[k] = v
|
constants[k] = v
|
||||||
|
|
||||||
for k, v in literal_eval(FLAGS.evaled_constants).items():
|
for k, v in literal_eval(_EVALED_CONSTANTS.value).items():
|
||||||
if isinstance(v, str):
|
if isinstance(v, str):
|
||||||
v = literal_eval(v)
|
v = literal_eval(v)
|
||||||
if isinstance(v, list):
|
if isinstance(v, list):
|
||||||
@ -192,14 +215,14 @@ def main(argv):
|
|||||||
constants[k] = v
|
constants[k] = v
|
||||||
|
|
||||||
ir, debug_ir = jax_to_ir(fn, input_shapes, constants=constants,
|
ir, debug_ir = jax_to_ir(fn, input_shapes, constants=constants,
|
||||||
format=FLAGS.ir_format)
|
format=_IR_FORMAT.value)
|
||||||
|
|
||||||
if FLAGS.ir_dest:
|
if _IR_DEST.value:
|
||||||
with open(FLAGS.ir_dest, 'wb') as f:
|
with open(_IR_DEST.value, 'wb') as f:
|
||||||
f.write(ir)
|
f.write(ir)
|
||||||
|
|
||||||
if FLAGS.ir_human_dest:
|
if _IR_HUMAN_DEST.value:
|
||||||
with open(FLAGS.ir_human_dest, 'w') as f:
|
with open(_IR_HUMAN_DEST.value, 'w') as f:
|
||||||
f.write(debug_ir)
|
f.write(debug_ir)
|
||||||
|
|
||||||
|
|
||||||
@ -225,21 +248,6 @@ _SHAPE_RE = re.compile(f"^({'|'.join(_DT)})\\[\\s*(\\d*[\\s*,\\d+]*)\\s*\\]$")
|
|||||||
|
|
||||||
|
|
||||||
def set_up_flags():
|
def set_up_flags():
|
||||||
flags.DEFINE_string(
|
|
||||||
'fn', None,
|
|
||||||
"Fully-qualified name of function that we're going to convert")
|
|
||||||
flags.DEFINE_string('input_shapes', None,
|
|
||||||
'Python dict indicating XLA shapes of params')
|
|
||||||
flags.DEFINE_string('constants', '{}',
|
|
||||||
'Python dict giving constant values for some params')
|
|
||||||
flags.DEFINE_string('evaled_constants', '{}',
|
|
||||||
'Python dict giving constant values for some params. '
|
|
||||||
'Values in this dict that are of type str are evaluated '
|
|
||||||
'using ast.literal_eval.')
|
|
||||||
flags.DEFINE_enum('ir_format', 'HLO', ('HLO', 'TF'), 'Output format.')
|
|
||||||
flags.DEFINE_string('ir_dest', None, 'File to write IR to')
|
|
||||||
flags.DEFINE_string('ir_human_dest', None,
|
|
||||||
'File to write human readable debug output')
|
|
||||||
flags.mark_flag_as_required('fn')
|
flags.mark_flag_as_required('fn')
|
||||||
flags.mark_flag_as_required('input_shapes')
|
flags.mark_flag_as_required('input_shapes')
|
||||||
|
|
||||||
|
@ -22,7 +22,6 @@ import jax
|
|||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from jax._src import core
|
from jax._src import core
|
||||||
from jax._src import test_util as jtu
|
from jax._src import test_util as jtu
|
||||||
from jax._src.config import flags
|
|
||||||
from jax.experimental.pjit import pjit
|
from jax.experimental.pjit import pjit
|
||||||
from jax.experimental.serialize_executable import (
|
from jax.experimental.serialize_executable import (
|
||||||
serialize, deserialize_and_load)
|
serialize, deserialize_and_load)
|
||||||
@ -73,7 +72,7 @@ class JaxAotTest(jtu.JaxTestCase):
|
|||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
raise unittest.SkipTest('PJRT Topology not supported')
|
raise unittest.SkipTest('PJRT Topology not supported')
|
||||||
|
|
||||||
if flags.FLAGS.jax_test_with_persistent_compilation_cache:
|
if jtu.TEST_WITH_PERSISTENT_COMPILATION_CACHE.value:
|
||||||
raise unittest.SkipTest('Compilation caching not yet supported.')
|
raise unittest.SkipTest('Compilation caching not yet supported.')
|
||||||
|
|
||||||
@jax.jit
|
@jax.jit
|
||||||
|
Loading…
x
Reference in New Issue
Block a user