Simplified in place guassian code

This commit is contained in:
Brandon Rozek 2019-02-03 00:31:35 -05:00
parent a03abe2bb1
commit 39643f04e1

View file

@ -28,17 +28,13 @@ class NoisyLinear(nn.Linear):
def forward(self, x):
# Fill s_normal_weight with values from the standard normal distribution
torch.randn(self.s_normal_weight.size(), out = self.s_normal_weight,
dtype = self.s_normal_weight.dtype, layout = self.s_normal_weight.layout, device = self.s_normal_weight.device)
# Multiply by the standard deviation to correct the spread of Gaussian noise
self.s_normal_weight.normal_()
weight_noise = self.sigma_weight * self.s_normal_weight.clone().requires_grad_()
bias = None
if self.bias is not None:
# Fill s_normal_bias with values from standard normal
torch.randn(self.s_normal_bias.size(), out = self.s_normal_bias,
dtype = self.s_normal_bias.dtype, layout = self.s_normal_bias.layout, device = self.s_normal_bias.device)
# Add guassian noise to original bias
self.s_normal_bias.normal_()
bias = self.bias + self.sigma_bias * self.s_normal_bias.clone().requires_grad_()
return F.linear(x, self.weight + weight_noise, bias)