Simplified in place guassian code
This commit is contained in:
parent
a03abe2bb1
commit
39643f04e1
1 changed files with 2 additions and 6 deletions
|
@ -28,17 +28,13 @@ class NoisyLinear(nn.Linear):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# Fill s_normal_weight with values from the standard normal distribution
|
# Fill s_normal_weight with values from the standard normal distribution
|
||||||
torch.randn(self.s_normal_weight.size(), out = self.s_normal_weight,
|
self.s_normal_weight.normal_()
|
||||||
dtype = self.s_normal_weight.dtype, layout = self.s_normal_weight.layout, device = self.s_normal_weight.device)
|
|
||||||
# Multiply by the standard deviation to correct the spread of Gaussian noise
|
|
||||||
weight_noise = self.sigma_weight * self.s_normal_weight.clone().requires_grad_()
|
weight_noise = self.sigma_weight * self.s_normal_weight.clone().requires_grad_()
|
||||||
|
|
||||||
bias = None
|
bias = None
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
# Fill s_normal_bias with values from standard normal
|
# Fill s_normal_bias with values from standard normal
|
||||||
torch.randn(self.s_normal_bias.size(), out = self.s_normal_bias,
|
self.s_normal_bias.normal_()
|
||||||
dtype = self.s_normal_bias.dtype, layout = self.s_normal_bias.layout, device = self.s_normal_bias.device)
|
|
||||||
# Add guassian noise to original bias
|
|
||||||
bias = self.bias + self.sigma_bias * self.s_normal_bias.clone().requires_grad_()
|
bias = self.bias + self.sigma_bias * self.s_normal_bias.clone().requires_grad_()
|
||||||
|
|
||||||
return F.linear(x, self.weight + weight_noise, bias)
|
return F.linear(x, self.weight + weight_noise, bias)
|
Loading…
Reference in a new issue