Simplified in place guassian code
This commit is contained in:
		
							parent
							
								
									a03abe2bb1
								
							
						
					
					
						commit
						39643f04e1
					
				
					 1 changed files with 2 additions and 6 deletions
				
			
		| 
						 | 
				
			
			@ -28,17 +28,13 @@ class NoisyLinear(nn.Linear):
 | 
			
		|||
  
 | 
			
		||||
  def forward(self, x):
 | 
			
		||||
    # Fill s_normal_weight with values from the standard normal distribution
 | 
			
		||||
    torch.randn(self.s_normal_weight.size(), out = self.s_normal_weight, 
 | 
			
		||||
                dtype = self.s_normal_weight.dtype, layout = self.s_normal_weight.layout, device = self.s_normal_weight.device)
 | 
			
		||||
    # Multiply by the standard deviation to correct the spread of Gaussian noise
 | 
			
		||||
    self.s_normal_weight.normal_()
 | 
			
		||||
    weight_noise = self.sigma_weight * self.s_normal_weight.clone().requires_grad_()
 | 
			
		||||
    
 | 
			
		||||
    bias = None
 | 
			
		||||
    if self.bias is not None:
 | 
			
		||||
      # Fill s_normal_bias with values from standard normal
 | 
			
		||||
      torch.randn(self.s_normal_bias.size(), out = self.s_normal_bias, 
 | 
			
		||||
                  dtype = self.s_normal_bias.dtype, layout = self.s_normal_bias.layout, device = self.s_normal_bias.device)
 | 
			
		||||
      # Add guassian noise to original bias
 | 
			
		||||
      self.s_normal_bias.normal_()
 | 
			
		||||
      bias = self.bias + self.sigma_bias * self.s_normal_bias.clone().requires_grad_()
 | 
			
		||||
    
 | 
			
		||||
    return F.linear(x, self.weight + weight_noise, bias)
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue