Playing around with QEP

This commit is contained in:
Brandon Rozek 2019-03-14 00:53:51 -04:00
parent 8683b75ad9
commit cdfd3ab6b9

View file

@ -1,5 +1,6 @@
from copy import deepcopy from copy import deepcopy
import collections import collections
import numpy as np
import torch import torch
from torch.distributions import Categorical from torch.distributions import Categorical
import rltorch import rltorch
@ -18,21 +19,27 @@ class QEPAgent:
self.memory = memory self.memory = memory
self.config = deepcopy(config) self.config = deepcopy(config)
self.logger = logger self.logger = logger
self.policy_skip = 10 self.policy_skip = 4
def fitness(self, policy_net, value_net, state_batch): def fitness(self, policy_net, value_net, state_batch):
batch_size = len(state_batch)
action_probabilities = policy_net(state_batch) action_probabilities = policy_net(state_batch)
action_size = action_probabilities.shape[1]
distributions = list(map(Categorical, action_probabilities)) distributions = list(map(Categorical, action_probabilities))
actions = torch.stack([d.sample() for d in distributions]) actions = torch.stack([d.sample() for d in distributions])
with torch.no_grad(): with torch.no_grad():
state_values = value_net(state_batch) state_values = value_net(state_batch)
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1) obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
# return -obtained_values.mean().item() # return -obtained_values.mean().item()
entropy_importance = 0.01 # Entropy accounting for 1% of loss seems to work well entropy_importance = 0.01 # Entropy accounting for 1% of loss seems to work well
entropy_loss = (action_probabilities * torch.log(action_probabilities)).sum(1) value_importance = 1 - entropy_importance
return (entropy_importance * entropy_loss - (1 - entropy_importance) * obtained_values).mean().item()
# entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
entropy_loss = (action_probabilities - torch.tensor(1 / action_size).repeat(len(state_batch), action_size)).abs().sum(1)
return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item()
def learn(self, logger = None): def learn(self, logger = None):
@ -75,7 +82,7 @@ class QEPAgent:
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device) best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device)
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1) best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1) expected_values = (reward_batch.float() + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
if (isinstance(self.memory, M.PrioritizedReplayMemory)): if (isinstance(self.memory, M.PrioritizedReplayMemory)):
value_loss = (torch.as_tensor(importance_weights, device = self.value_net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean() value_loss = (torch.as_tensor(importance_weights, device = self.value_net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
@ -104,7 +111,7 @@ class QEPAgent:
if self.policy_skip > 0: if self.policy_skip > 0:
self.policy_skip -= 1 self.policy_skip -= 1
return return
self.policy_skip = 10 self.policy_skip = 4
if self.target_value_net is not None: if self.target_value_net is not None:
self.policy_net.calc_gradients(self.target_value_net, state_batch) self.policy_net.calc_gradients(self.target_value_net, state_batch)
else: else: