fix prng key reuse in differential privacy example (#3646)

fix prng key reuse in differential privacy example
This commit is contained in:
Matthew Johnson 2020-07-02 14:29:17 -07:00 committed by GitHub
parent 166e795d63
commit d10cf0e38f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -64,7 +64,6 @@ Example invocations:
--learning_rate=.25 \
"""
from functools import partial
import itertools
import time
import warnings
@ -75,18 +74,17 @@ from absl import flags
from jax import grad
from jax import jit
from jax import random
from jax import tree_util
from jax import vmap
from jax.experimental import optimizers
from jax.experimental import stax
from jax.lax import stop_gradient
from jax.tree_util import tree_flatten, tree_unflatten
import jax.numpy as jnp
from examples import datasets
import numpy.random as npr
# https://github.com/tensorflow/privacy
from privacy.analysis.rdp_accountant import compute_rdp
from privacy.analysis.rdp_accountant import get_privacy_spent
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
FLAGS = flags.FLAGS
@ -134,33 +132,30 @@ def accuracy(params, batch):
return jnp.mean(predicted_class == target_class)
def clipped_grad(params, l2_norm_clip, single_example_batch):
"""Evaluate gradient for a single-example batch and clip its grad norm."""
grads = grad(loss)(params, single_example_batch)
nonempty_grads, tree_def = tree_flatten(grads)
total_grad_norm = jnp.linalg.norm(
[jnp.linalg.norm(neg.ravel()) for neg in nonempty_grads])
divisor = jnp.max((total_grad_norm / l2_norm_clip, 1.))
normalized_nonempty_grads = [g / divisor for g in nonempty_grads]
return tree_unflatten(tree_def, normalized_nonempty_grads)
def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
batch_size):
"""Return differentially private gradients for params, evaluated on batch."""
def _clipped_grad(params, single_example_batch):
"""Evaluate gradient for a single-example batch and clip its grad norm."""
grads = grad(loss)(params, single_example_batch)
nonempty_grads, tree_def = tree_util.tree_flatten(grads)
total_grad_norm = jnp.linalg.norm(
[jnp.linalg.norm(neg.ravel()) for neg in nonempty_grads])
divisor = stop_gradient(jnp.amax((total_grad_norm / l2_norm_clip, 1.)))
normalized_nonempty_grads = [g / divisor for g in nonempty_grads]
return tree_util.tree_unflatten(tree_def, normalized_nonempty_grads)
px_clipped_grad_fn = vmap(partial(_clipped_grad, params))
std_dev = l2_norm_clip * noise_multiplier
noise_ = lambda n: n + std_dev * random.normal(rng, n.shape)
normalize_ = lambda n: n / float(batch_size)
tree_map = tree_util.tree_map
sum_ = lambda n: jnp.sum(n, 0) # aggregate
aggregated_clipped_grads = tree_map(sum_, px_clipped_grad_fn(batch))
noised_aggregated_clipped_grads = tree_map(noise_, aggregated_clipped_grads)
normalized_noised_aggregated_clipped_grads = (
tree_map(normalize_, noised_aggregated_clipped_grads)
)
return normalized_noised_aggregated_clipped_grads
clipped_grads = vmap(clipped_grad, (None, None, 0))(params, l2_norm_clip, batch)
clipped_grads_flat, grads_treedef = tree_flatten(clipped_grads)
aggregated_clipped_grads = [g.sum(0) for g in clipped_grads_flat]
rngs = random.split(rng, len(aggregated_clipped_grads))
noised_aggregated_clipped_grads = [
g + l2_norm_clip * noise_multiplier * random.normal(r, g.shape)
for r, g in zip(rngs, aggregated_clipped_grads)]
normalized_noised_aggregated_clipped_grads = [
g / batch_size for g in noised_aggregated_clipped_grads]
return tree_unflatten(grads_treedef, normalized_noised_aggregated_clipped_grads)
def shape_as_image(images, labels, dummy_dim=False):
@ -225,7 +220,6 @@ def main(_):
print('\nStarting training...')
for epoch in range(1, FLAGS.epochs + 1):
start_time = time.time()
# pylint: disable=no-value-for-parameter
for _ in range(num_batches):
if FLAGS.dpsgd:
opt_state = \
@ -235,7 +229,6 @@ def main(_):
else:
opt_state = update(
key, next(itercount), opt_state, shape_as_image(*next(batches)))
# pylint: enable=no-value-for-parameter
epoch_time = time.time() - start_time
print('Epoch {} in {:0.2f} sec'.format(epoch, epoch_time))