mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Migrate from deprecated tensorflow_privacy RDP accountant to differential_privacy.
PiperOrigin-RevId: 454724315
This commit is contained in:
parent
b174b7751b
commit
ef9036abf3
@ -82,9 +82,9 @@ import jax.numpy as jnp
|
||||
from jax.examples import datasets
|
||||
import numpy.random as npr
|
||||
|
||||
# https://github.com/tensorflow/privacy
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
# https://github.com/google/differential-privacy
|
||||
from differential_privacy.python.accounting import dp_event
|
||||
from differential_privacy.python.accounting.rdp import rdp_privacy_accountant
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
@ -168,9 +168,11 @@ def compute_epsilon(steps, num_examples=60000, target_delta=1e-5):
|
||||
warnings.warn('Your delta might be too high.')
|
||||
q = FLAGS.batch_size / float(num_examples)
|
||||
orders = list(jnp.linspace(1.1, 10.9, 99)) + list(range(11, 64))
|
||||
rdp_const = compute_rdp(q, FLAGS.noise_multiplier, steps, orders)
|
||||
eps, _, _ = get_privacy_spent(orders, rdp_const, target_delta=target_delta)
|
||||
return eps
|
||||
accountant = rdp_privacy_accountant.RdpAccountant(orders)
|
||||
accountant.compose(
|
||||
dp_event.PoissonSampledDpEvent(
|
||||
q, dp_event.GaussianDpEvent(FLAGS.noise_multiplier)), steps)
|
||||
return accountant.get_epsilon(target_delta)
|
||||
|
||||
|
||||
def main(_):
|
||||
|
Loading…
x
Reference in New Issue
Block a user