From cdfd3ab6b998f6fb4eb34230da09ec8fa2d304d4 Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Thu, 14 Mar 2019 00:53:51 -0400 Subject: [PATCH] Playing around with QEP --- rltorch/agents/QEPAgent.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/rltorch/agents/QEPAgent.py b/rltorch/agents/QEPAgent.py index 86f8091..d636cd2 100644 --- a/rltorch/agents/QEPAgent.py +++ b/rltorch/agents/QEPAgent.py @@ -1,5 +1,6 @@ from copy import deepcopy import collections +import numpy as np import torch from torch.distributions import Categorical import rltorch @@ -18,21 +19,27 @@ class QEPAgent: self.memory = memory self.config = deepcopy(config) self.logger = logger - self.policy_skip = 10 + self.policy_skip = 4 def fitness(self, policy_net, value_net, state_batch): + batch_size = len(state_batch) action_probabilities = policy_net(state_batch) + action_size = action_probabilities.shape[1] distributions = list(map(Categorical, action_probabilities)) actions = torch.stack([d.sample() for d in distributions]) with torch.no_grad(): state_values = value_net(state_batch) obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1) - # return -obtained_values.mean().item() + entropy_importance = 0.01 # Entropy accounting for 1% of loss seems to work well - entropy_loss = (action_probabilities * torch.log(action_probabilities)).sum(1) - return (entropy_importance * entropy_loss - (1 - entropy_importance) * obtained_values).mean().item() + value_importance = 1 - entropy_importance + + # entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory + entropy_loss = (action_probabilities - torch.tensor(1 / action_size).repeat(len(state_batch), action_size)).abs().sum(1) + + return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item() def learn(self, logger = None): @@ -75,7 +82,7 @@ class QEPAgent: best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device) best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1) - expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) + expected_values = (reward_batch.float() + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) if (isinstance(self.memory, M.PrioritizedReplayMemory)): value_loss = (torch.as_tensor(importance_weights, device = self.value_net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean() @@ -104,7 +111,7 @@ class QEPAgent: if self.policy_skip > 0: self.policy_skip -= 1 return - self.policy_skip = 10 + self.policy_skip = 4 if self.target_value_net is not None: self.policy_net.calc_gradients(self.target_value_net, state_batch) else: