Skip to content

Commit

Permalink
fix dp-sgd example (#873)
Browse files Browse the repository at this point in the history
see [this issue](#467)
  • Loading branch information
spliew authored Jun 14, 2022
1 parent 056ff1f commit 2b16530
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions examples/dp_cifar10/cifar10_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ def compute_norms(sample_grads):
batch_size = sample_grads[0].shape[0]
norms = [sample_grad.view(batch_size, -1).norm(2, dim=-1) for sample_grad in sample_grads]
norms = torch.stack(norms, dim=0).norm(2, dim=0)
return norms
return norms, batch_size


def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise_multiplier=1.0):
sample_grads = tuple(param.grad_sample for param in model.parameters())

# step 0: compute the norms
sample_norms = compute_norms(sample_grads)
sample_norms, batch_size = compute_norms(sample_grads)

# step 1: compute clipping factors
clip_factor = max_per_sample_grad_norm / (sample_norms + 1e-6)
Expand All @@ -76,7 +76,7 @@ def clip_and_accumulate_and_add_noise(model, max_per_sample_grad_norm=1.0, noise

# step 4: assign the new grads, delete the sample grads
for param, param_grad in zip(model.parameters(), grads):
param.grad = param_grad
param.grad = param_grad/batch_size
del param.grad_sample


Expand Down Expand Up @@ -492,4 +492,4 @@ def parse_args():


if __name__ == "__main__":
main()
main()

0 comments on commit 2b16530

Please sign in to comment.