Fixed multiprocessing with CUDA. Added entropy importance as a config option.
This commit is contained in:
parent
9d32a9edd1
commit
a99ca66b4f
2 changed files with 9 additions and 3 deletions
|
@ -5,12 +5,13 @@ import torch
|
|||
from torch.distributions import Categorical
|
||||
import rltorch
|
||||
import rltorch.memory as M
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Q-Evolutionary Policy Agent
|
||||
# Maximizes the policy with respect to the Q-Value function.
|
||||
# Since function is non-differentiabile, depends on the Evolutionary Strategy algorithm
|
||||
class QEPAgent:
|
||||
def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None):
|
||||
def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None, entropy_importance = 0):
|
||||
self.policy_net = policy_net
|
||||
assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP)
|
||||
self.policy_net.fitness = self.fitness
|
||||
|
@ -20,6 +21,7 @@ class QEPAgent:
|
|||
self.config = deepcopy(config)
|
||||
self.logger = logger
|
||||
self.policy_skip = 4
|
||||
self.entropy_importance = entropy_importance
|
||||
|
||||
def save(self, file_location):
|
||||
torch.save({
|
||||
|
@ -36,7 +38,8 @@ class QEPAgent:
|
|||
|
||||
def fitness(self, policy_net, value_net, state_batch):
|
||||
batch_size = len(state_batch)
|
||||
action_probabilities = policy_net(state_batch)
|
||||
with torch.no_grad():
|
||||
action_probabilities = policy_net(state_batch)
|
||||
action_size = action_probabilities.shape[1]
|
||||
distributions = list(map(Categorical, action_probabilities))
|
||||
actions = torch.stack([d.sample() for d in distributions])
|
||||
|
@ -53,6 +56,7 @@ class QEPAgent:
|
|||
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
|
||||
# return -obtained_values.mean().item()
|
||||
entropy_importance = 0 # Entropy accounting for 1% of loss seems to work well
|
||||
entropy_importance = next(self.entropy_importance) if isinstance(self.entropy_importance, collections.Iterable) else self.entropy_importance
|
||||
value_importance = 1 - entropy_importance
|
||||
|
||||
# entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
|
||||
|
@ -131,6 +135,7 @@ class QEPAgent:
|
|||
self.policy_skip -= 1
|
||||
return
|
||||
self.policy_skip = 4
|
||||
|
||||
if self.target_value_net is not None:
|
||||
self.policy_net.calc_gradients(self.target_value_net, state_batch)
|
||||
else:
|
||||
|
|
|
@ -26,7 +26,8 @@ class ESNetworkMP(Network):
|
|||
self.fitness = fitness_fn
|
||||
self.sigma = sigma
|
||||
assert self.sigma > 0
|
||||
self.pool = mp.Pool(processes=3) #[TODO] Probably should make number of processes a config variable
|
||||
mp_ctx = mp.get_context("spawn")
|
||||
self.pool = mp_ctx.Pool(processes=2) #[TODO] Probably should make number of processes a config variable
|
||||
|
||||
# We're not going to be calculating gradients in the traditional way
|
||||
# So there's no need to waste computation time keeping track
|
||||
|
|
Loading…
Reference in a new issue