diff --git a/rltorch/network/NoisyLinear.py b/rltorch/network/NoisyLinear.py index 066069b..578457b 100644 --- a/rltorch/network/NoisyLinear.py +++ b/rltorch/network/NoisyLinear.py @@ -28,17 +28,13 @@ class NoisyLinear(nn.Linear): def forward(self, x): # Fill s_normal_weight with values from the standard normal distribution - torch.randn(self.s_normal_weight.size(), out = self.s_normal_weight, - dtype = self.s_normal_weight.dtype, layout = self.s_normal_weight.layout, device = self.s_normal_weight.device) - # Multiply by the standard deviation to correct the spread of Gaussian noise + self.s_normal_weight.normal_() weight_noise = self.sigma_weight * self.s_normal_weight.clone().requires_grad_() bias = None if self.bias is not None: # Fill s_normal_bias with values from standard normal - torch.randn(self.s_normal_bias.size(), out = self.s_normal_bias, - dtype = self.s_normal_bias.dtype, layout = self.s_normal_bias.layout, device = self.s_normal_bias.device) - # Add guassian noise to original bias + self.s_normal_bias.normal_() bias = self.bias + self.sigma_bias * self.s_normal_bias.clone().requires_grad_() return F.linear(x, self.weight + weight_noise, bias) \ No newline at end of file