Fixed multiprocessing with CUDA. Added entropy importance as a config option.

This commit is contained in:
Brandon Rozek 2019-09-18 07:26:32 -04:00
parent 9d32a9edd1
commit a99ca66b4f
2 changed files with 9 additions and 3 deletions

View file

@ -5,12 +5,13 @@ import torch
from torch.distributions import Categorical
import rltorch
import rltorch.memory as M
import torch.nn.functional as F
# Q-Evolutionary Policy Agent
# Maximizes the policy with respect to the Q-Value function.
# Since function is non-differentiabile, depends on the Evolutionary Strategy algorithm
class QEPAgent:
def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None):
def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None, entropy_importance = 0):
self.policy_net = policy_net
assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP)
self.policy_net.fitness = self.fitness
@ -20,6 +21,7 @@ class QEPAgent:
self.config = deepcopy(config)
self.logger = logger
self.policy_skip = 4
self.entropy_importance = entropy_importance
def save(self, file_location):
torch.save({
@ -36,7 +38,8 @@ class QEPAgent:
def fitness(self, policy_net, value_net, state_batch):
batch_size = len(state_batch)
action_probabilities = policy_net(state_batch)
with torch.no_grad():
action_probabilities = policy_net(state_batch)
action_size = action_probabilities.shape[1]
distributions = list(map(Categorical, action_probabilities))
actions = torch.stack([d.sample() for d in distributions])
@ -53,6 +56,7 @@ class QEPAgent:
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
# return -obtained_values.mean().item()
entropy_importance = 0 # Entropy accounting for 1% of loss seems to work well
entropy_importance = next(self.entropy_importance) if isinstance(self.entropy_importance, collections.Iterable) else self.entropy_importance
value_importance = 1 - entropy_importance
# entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
@ -131,6 +135,7 @@ class QEPAgent:
self.policy_skip -= 1
return
self.policy_skip = 4
if self.target_value_net is not None:
self.policy_net.calc_gradients(self.target_value_net, state_batch)
else:

View file

@ -26,7 +26,8 @@ class ESNetworkMP(Network):
self.fitness = fitness_fn
self.sigma = sigma
assert self.sigma > 0
self.pool = mp.Pool(processes=3) #[TODO] Probably should make number of processes a config variable
mp_ctx = mp.get_context("spawn")
self.pool = mp_ctx.Pool(processes=2) #[TODO] Probably should make number of processes a config variable
# We're not going to be calculating gradients in the traditional way
# So there's no need to waste computation time keeping track