Simplified in place guassian code
This commit is contained in:
parent
a03abe2bb1
commit
39643f04e1
1 changed files with 2 additions and 6 deletions
|
@ -28,17 +28,13 @@ class NoisyLinear(nn.Linear):
|
|||
|
||||
def forward(self, x):
|
||||
# Fill s_normal_weight with values from the standard normal distribution
|
||||
torch.randn(self.s_normal_weight.size(), out = self.s_normal_weight,
|
||||
dtype = self.s_normal_weight.dtype, layout = self.s_normal_weight.layout, device = self.s_normal_weight.device)
|
||||
# Multiply by the standard deviation to correct the spread of Gaussian noise
|
||||
self.s_normal_weight.normal_()
|
||||
weight_noise = self.sigma_weight * self.s_normal_weight.clone().requires_grad_()
|
||||
|
||||
bias = None
|
||||
if self.bias is not None:
|
||||
# Fill s_normal_bias with values from standard normal
|
||||
torch.randn(self.s_normal_bias.size(), out = self.s_normal_bias,
|
||||
dtype = self.s_normal_bias.dtype, layout = self.s_normal_bias.layout, device = self.s_normal_bias.device)
|
||||
# Add guassian noise to original bias
|
||||
self.s_normal_bias.normal_()
|
||||
bias = self.bias + self.sigma_bias * self.s_normal_bias.clone().requires_grad_()
|
||||
|
||||
return F.linear(x, self.weight + weight_noise, bias)
|
Loading…
Reference in a new issue