Added entropy into QEP (1% of loss)

Made random numbers generated in ESNetwork happen in the same device
This commit is contained in:
Brandon Rozek 2019-02-28 12:17:35 -05:00
parent 76a044ace9
commit 714443192d
2 changed files with 13 additions and 9 deletions

View file

@ -21,15 +21,19 @@ class QEPAgent:
self.policy_skip = 10
def fitness(self, policy_net, value_net, state_batch):
action_probabilities = policy_net(state_batch)
distributions = list(map(Categorical, action_probabilities))
actions = torch.tensor([d.sample() for d in distributions])
action_probabilities = policy_net(state_batch)
distributions = list(map(Categorical, action_probabilities))
actions = torch.tensor([d.sample() for d in distributions])
with torch.no_grad():
state_values = value_net(state_batch)
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
with torch.no_grad():
state_values = value_net(state_batch)
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
return -obtained_values.mean().item()
# return -obtained_values.mean().item()
entropy_importance = 0.01 # Entropy accounting for 1% of loss seems to work well
entropy_loss = (action_probabilities * torch.log(action_probabilities)).sum(1)
return (entropy_importance * entropy_loss - (1 - entropy_importance) * obtained_values).mean().item()
def learn(self, logger = None):
if len(self.memory) < self.config['batch_size']:
@ -105,6 +109,6 @@ class QEPAgent:
self.policy_net.calc_gradients(self.target_value_net, state_batch)
else:
self.policy_net.calc_gradients(self.value_net, state_batch)
self.policy_net.clamp_gradients()
# self.policy_net.clamp_gradients()
self.policy_net.step()

View file

@ -28,7 +28,7 @@ class ESNetwork(Network):
white_noise_dict = {}
noise_dict = {}
for key in model_dict.keys():
white_noise_dict[key] = torch.randn(self.population_size, *model_dict[key].shape)
white_noise_dict[key] = torch.randn(self.population_size, *model_dict[key].shape, device = self.device)
noise_dict[key] = self.sigma * white_noise_dict[key]
return white_noise_dict, noise_dict