mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
fix prng key reuse in differential privacy example (#3646)
fix prng key reuse in differential privacy example
This commit is contained in:
parent
166e795d63
commit
d10cf0e38f
@ -64,7 +64,6 @@ Example invocations:
|
||||
--learning_rate=.25 \
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
import itertools
|
||||
import time
|
||||
import warnings
|
||||
@ -75,18 +74,17 @@ from absl import flags
|
||||
from jax import grad
|
||||
from jax import jit
|
||||
from jax import random
|
||||
from jax import tree_util
|
||||
from jax import vmap
|
||||
from jax.experimental import optimizers
|
||||
from jax.experimental import stax
|
||||
from jax.lax import stop_gradient
|
||||
from jax.tree_util import tree_flatten, tree_unflatten
|
||||
import jax.numpy as jnp
|
||||
from examples import datasets
|
||||
import numpy.random as npr
|
||||
|
||||
# https://github.com/tensorflow/privacy
|
||||
from privacy.analysis.rdp_accountant import compute_rdp
|
||||
from privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
@ -134,33 +132,30 @@ def accuracy(params, batch):
|
||||
return jnp.mean(predicted_class == target_class)
|
||||
|
||||
|
||||
def clipped_grad(params, l2_norm_clip, single_example_batch):
|
||||
"""Evaluate gradient for a single-example batch and clip its grad norm."""
|
||||
grads = grad(loss)(params, single_example_batch)
|
||||
nonempty_grads, tree_def = tree_flatten(grads)
|
||||
total_grad_norm = jnp.linalg.norm(
|
||||
[jnp.linalg.norm(neg.ravel()) for neg in nonempty_grads])
|
||||
divisor = jnp.max((total_grad_norm / l2_norm_clip, 1.))
|
||||
normalized_nonempty_grads = [g / divisor for g in nonempty_grads]
|
||||
return tree_unflatten(tree_def, normalized_nonempty_grads)
|
||||
|
||||
|
||||
def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
|
||||
batch_size):
|
||||
"""Return differentially private gradients for params, evaluated on batch."""
|
||||
|
||||
def _clipped_grad(params, single_example_batch):
|
||||
"""Evaluate gradient for a single-example batch and clip its grad norm."""
|
||||
grads = grad(loss)(params, single_example_batch)
|
||||
|
||||
nonempty_grads, tree_def = tree_util.tree_flatten(grads)
|
||||
total_grad_norm = jnp.linalg.norm(
|
||||
[jnp.linalg.norm(neg.ravel()) for neg in nonempty_grads])
|
||||
divisor = stop_gradient(jnp.amax((total_grad_norm / l2_norm_clip, 1.)))
|
||||
normalized_nonempty_grads = [g / divisor for g in nonempty_grads]
|
||||
return tree_util.tree_unflatten(tree_def, normalized_nonempty_grads)
|
||||
|
||||
px_clipped_grad_fn = vmap(partial(_clipped_grad, params))
|
||||
std_dev = l2_norm_clip * noise_multiplier
|
||||
noise_ = lambda n: n + std_dev * random.normal(rng, n.shape)
|
||||
normalize_ = lambda n: n / float(batch_size)
|
||||
tree_map = tree_util.tree_map
|
||||
sum_ = lambda n: jnp.sum(n, 0) # aggregate
|
||||
aggregated_clipped_grads = tree_map(sum_, px_clipped_grad_fn(batch))
|
||||
noised_aggregated_clipped_grads = tree_map(noise_, aggregated_clipped_grads)
|
||||
normalized_noised_aggregated_clipped_grads = (
|
||||
tree_map(normalize_, noised_aggregated_clipped_grads)
|
||||
)
|
||||
return normalized_noised_aggregated_clipped_grads
|
||||
clipped_grads = vmap(clipped_grad, (None, None, 0))(params, l2_norm_clip, batch)
|
||||
clipped_grads_flat, grads_treedef = tree_flatten(clipped_grads)
|
||||
aggregated_clipped_grads = [g.sum(0) for g in clipped_grads_flat]
|
||||
rngs = random.split(rng, len(aggregated_clipped_grads))
|
||||
noised_aggregated_clipped_grads = [
|
||||
g + l2_norm_clip * noise_multiplier * random.normal(r, g.shape)
|
||||
for r, g in zip(rngs, aggregated_clipped_grads)]
|
||||
normalized_noised_aggregated_clipped_grads = [
|
||||
g / batch_size for g in noised_aggregated_clipped_grads]
|
||||
return tree_unflatten(grads_treedef, normalized_noised_aggregated_clipped_grads)
|
||||
|
||||
|
||||
def shape_as_image(images, labels, dummy_dim=False):
|
||||
@ -225,7 +220,6 @@ def main(_):
|
||||
print('\nStarting training...')
|
||||
for epoch in range(1, FLAGS.epochs + 1):
|
||||
start_time = time.time()
|
||||
# pylint: disable=no-value-for-parameter
|
||||
for _ in range(num_batches):
|
||||
if FLAGS.dpsgd:
|
||||
opt_state = \
|
||||
@ -235,7 +229,6 @@ def main(_):
|
||||
else:
|
||||
opt_state = update(
|
||||
key, next(itercount), opt_state, shape_as_image(*next(batches)))
|
||||
# pylint: enable=no-value-for-parameter
|
||||
epoch_time = time.time() - start_time
|
||||
print('Epoch {} in {:0.2f} sec'.format(epoch, epoch_time))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user