From 714443192d51a880ff3acccf870451536eedc37c Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Thu, 28 Feb 2019 12:17:35 -0500 Subject: [PATCH] Added entropy into QEP (1% of loss) Made random numbers generated in ESNetwork happen in the same device --- rltorch/agents/QEPAgent.py | 20 ++++++++++++-------- rltorch/network/ESNetwork.py | 2 +- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/rltorch/agents/QEPAgent.py b/rltorch/agents/QEPAgent.py index fe11e36..307db94 100644 --- a/rltorch/agents/QEPAgent.py +++ b/rltorch/agents/QEPAgent.py @@ -21,15 +21,19 @@ class QEPAgent: self.policy_skip = 10 def fitness(self, policy_net, value_net, state_batch): - action_probabilities = policy_net(state_batch) - distributions = list(map(Categorical, action_probabilities)) - actions = torch.tensor([d.sample() for d in distributions]) + action_probabilities = policy_net(state_batch) + distributions = list(map(Categorical, action_probabilities)) + actions = torch.tensor([d.sample() for d in distributions]) - with torch.no_grad(): - state_values = value_net(state_batch) - obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1) + with torch.no_grad(): + state_values = value_net(state_batch) + obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1) - return -obtained_values.mean().item() + # return -obtained_values.mean().item() + entropy_importance = 0.01 # Entropy accounting for 1% of loss seems to work well + entropy_loss = (action_probabilities * torch.log(action_probabilities)).sum(1) + return (entropy_importance * entropy_loss - (1 - entropy_importance) * obtained_values).mean().item() + def learn(self, logger = None): if len(self.memory) < self.config['batch_size']: @@ -105,6 +109,6 @@ class QEPAgent: self.policy_net.calc_gradients(self.target_value_net, state_batch) else: self.policy_net.calc_gradients(self.value_net, state_batch) - self.policy_net.clamp_gradients() + # self.policy_net.clamp_gradients() self.policy_net.step() diff --git a/rltorch/network/ESNetwork.py b/rltorch/network/ESNetwork.py index d360cb1..c5ae4b7 100644 --- a/rltorch/network/ESNetwork.py +++ b/rltorch/network/ESNetwork.py @@ -28,7 +28,7 @@ class ESNetwork(Network): white_noise_dict = {} noise_dict = {} for key in model_dict.keys(): - white_noise_dict[key] = torch.randn(self.population_size, *model_dict[key].shape) + white_noise_dict[key] = torch.randn(self.population_size, *model_dict[key].shape, device = self.device) noise_dict[key] = self.sigma * white_noise_dict[key] return white_noise_dict, noise_dict