Playing around with QEP
This commit is contained in:
parent
8683b75ad9
commit
cdfd3ab6b9
1 changed files with 13 additions and 6 deletions
|
@ -1,5 +1,6 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import collections
|
import collections
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.distributions import Categorical
|
from torch.distributions import Categorical
|
||||||
import rltorch
|
import rltorch
|
||||||
|
@ -18,21 +19,27 @@ class QEPAgent:
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
self.config = deepcopy(config)
|
self.config = deepcopy(config)
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.policy_skip = 10
|
self.policy_skip = 4
|
||||||
|
|
||||||
def fitness(self, policy_net, value_net, state_batch):
|
def fitness(self, policy_net, value_net, state_batch):
|
||||||
|
batch_size = len(state_batch)
|
||||||
action_probabilities = policy_net(state_batch)
|
action_probabilities = policy_net(state_batch)
|
||||||
|
action_size = action_probabilities.shape[1]
|
||||||
distributions = list(map(Categorical, action_probabilities))
|
distributions = list(map(Categorical, action_probabilities))
|
||||||
actions = torch.stack([d.sample() for d in distributions])
|
actions = torch.stack([d.sample() for d in distributions])
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
state_values = value_net(state_batch)
|
state_values = value_net(state_batch)
|
||||||
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
|
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
|
||||||
|
|
||||||
# return -obtained_values.mean().item()
|
# return -obtained_values.mean().item()
|
||||||
|
|
||||||
entropy_importance = 0.01 # Entropy accounting for 1% of loss seems to work well
|
entropy_importance = 0.01 # Entropy accounting for 1% of loss seems to work well
|
||||||
entropy_loss = (action_probabilities * torch.log(action_probabilities)).sum(1)
|
value_importance = 1 - entropy_importance
|
||||||
return (entropy_importance * entropy_loss - (1 - entropy_importance) * obtained_values).mean().item()
|
|
||||||
|
# entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
|
||||||
|
entropy_loss = (action_probabilities - torch.tensor(1 / action_size).repeat(len(state_batch), action_size)).abs().sum(1)
|
||||||
|
|
||||||
|
return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item()
|
||||||
|
|
||||||
|
|
||||||
def learn(self, logger = None):
|
def learn(self, logger = None):
|
||||||
|
@ -75,7 +82,7 @@ class QEPAgent:
|
||||||
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device)
|
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device)
|
||||||
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
||||||
|
|
||||||
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
expected_values = (reward_batch.float() + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||||
|
|
||||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||||
value_loss = (torch.as_tensor(importance_weights, device = self.value_net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
|
value_loss = (torch.as_tensor(importance_weights, device = self.value_net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
|
||||||
|
@ -104,7 +111,7 @@ class QEPAgent:
|
||||||
if self.policy_skip > 0:
|
if self.policy_skip > 0:
|
||||||
self.policy_skip -= 1
|
self.policy_skip -= 1
|
||||||
return
|
return
|
||||||
self.policy_skip = 10
|
self.policy_skip = 4
|
||||||
if self.target_value_net is not None:
|
if self.target_value_net is not None:
|
||||||
self.policy_net.calc_gradients(self.target_value_net, state_batch)
|
self.policy_net.calc_gradients(self.target_value_net, state_batch)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in a new issue