Added entropy into QEP (1% of loss)
Made random numbers generated in ESNetwork happen in the same device
This commit is contained in:
parent
76a044ace9
commit
714443192d
2 changed files with 13 additions and 9 deletions
|
@ -21,15 +21,19 @@ class QEPAgent:
|
||||||
self.policy_skip = 10
|
self.policy_skip = 10
|
||||||
|
|
||||||
def fitness(self, policy_net, value_net, state_batch):
|
def fitness(self, policy_net, value_net, state_batch):
|
||||||
action_probabilities = policy_net(state_batch)
|
action_probabilities = policy_net(state_batch)
|
||||||
distributions = list(map(Categorical, action_probabilities))
|
distributions = list(map(Categorical, action_probabilities))
|
||||||
actions = torch.tensor([d.sample() for d in distributions])
|
actions = torch.tensor([d.sample() for d in distributions])
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
state_values = value_net(state_batch)
|
state_values = value_net(state_batch)
|
||||||
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
|
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
|
||||||
|
|
||||||
return -obtained_values.mean().item()
|
# return -obtained_values.mean().item()
|
||||||
|
entropy_importance = 0.01 # Entropy accounting for 1% of loss seems to work well
|
||||||
|
entropy_loss = (action_probabilities * torch.log(action_probabilities)).sum(1)
|
||||||
|
return (entropy_importance * entropy_loss - (1 - entropy_importance) * obtained_values).mean().item()
|
||||||
|
|
||||||
|
|
||||||
def learn(self, logger = None):
|
def learn(self, logger = None):
|
||||||
if len(self.memory) < self.config['batch_size']:
|
if len(self.memory) < self.config['batch_size']:
|
||||||
|
@ -105,6 +109,6 @@ class QEPAgent:
|
||||||
self.policy_net.calc_gradients(self.target_value_net, state_batch)
|
self.policy_net.calc_gradients(self.target_value_net, state_batch)
|
||||||
else:
|
else:
|
||||||
self.policy_net.calc_gradients(self.value_net, state_batch)
|
self.policy_net.calc_gradients(self.value_net, state_batch)
|
||||||
self.policy_net.clamp_gradients()
|
# self.policy_net.clamp_gradients()
|
||||||
self.policy_net.step()
|
self.policy_net.step()
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ class ESNetwork(Network):
|
||||||
white_noise_dict = {}
|
white_noise_dict = {}
|
||||||
noise_dict = {}
|
noise_dict = {}
|
||||||
for key in model_dict.keys():
|
for key in model_dict.keys():
|
||||||
white_noise_dict[key] = torch.randn(self.population_size, *model_dict[key].shape)
|
white_noise_dict[key] = torch.randn(self.population_size, *model_dict[key].shape, device = self.device)
|
||||||
noise_dict[key] = self.sigma * white_noise_dict[key]
|
noise_dict[key] = self.sigma * white_noise_dict[key]
|
||||||
return white_noise_dict, noise_dict
|
return white_noise_dict, noise_dict
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue