Added entropy into QEP (1% of loss)
Made random numbers generated in ESNetwork happen in the same device
This commit is contained in:
parent
76a044ace9
commit
714443192d
2 changed files with 13 additions and 9 deletions
|
@ -21,15 +21,19 @@ class QEPAgent:
|
|||
self.policy_skip = 10
|
||||
|
||||
def fitness(self, policy_net, value_net, state_batch):
|
||||
action_probabilities = policy_net(state_batch)
|
||||
distributions = list(map(Categorical, action_probabilities))
|
||||
actions = torch.tensor([d.sample() for d in distributions])
|
||||
action_probabilities = policy_net(state_batch)
|
||||
distributions = list(map(Categorical, action_probabilities))
|
||||
actions = torch.tensor([d.sample() for d in distributions])
|
||||
|
||||
with torch.no_grad():
|
||||
state_values = value_net(state_batch)
|
||||
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
|
||||
with torch.no_grad():
|
||||
state_values = value_net(state_batch)
|
||||
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
|
||||
|
||||
return -obtained_values.mean().item()
|
||||
# return -obtained_values.mean().item()
|
||||
entropy_importance = 0.01 # Entropy accounting for 1% of loss seems to work well
|
||||
entropy_loss = (action_probabilities * torch.log(action_probabilities)).sum(1)
|
||||
return (entropy_importance * entropy_loss - (1 - entropy_importance) * obtained_values).mean().item()
|
||||
|
||||
|
||||
def learn(self, logger = None):
|
||||
if len(self.memory) < self.config['batch_size']:
|
||||
|
@ -105,6 +109,6 @@ class QEPAgent:
|
|||
self.policy_net.calc_gradients(self.target_value_net, state_batch)
|
||||
else:
|
||||
self.policy_net.calc_gradients(self.value_net, state_batch)
|
||||
self.policy_net.clamp_gradients()
|
||||
# self.policy_net.clamp_gradients()
|
||||
self.policy_net.step()
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ class ESNetwork(Network):
|
|||
white_noise_dict = {}
|
||||
noise_dict = {}
|
||||
for key in model_dict.keys():
|
||||
white_noise_dict[key] = torch.randn(self.population_size, *model_dict[key].shape)
|
||||
white_noise_dict[key] = torch.randn(self.population_size, *model_dict[key].shape, device = self.device)
|
||||
noise_dict[key] = self.sigma * white_noise_dict[key]
|
||||
return white_noise_dict, noise_dict
|
||||
|
||||
|
|
Loading…
Reference in a new issue