diff --git a/rltorch/agents/QEPAgent.py b/rltorch/agents/QEPAgent.py index 71e312d..de203b0 100644 --- a/rltorch/agents/QEPAgent.py +++ b/rltorch/agents/QEPAgent.py @@ -5,12 +5,13 @@ import torch from torch.distributions import Categorical import rltorch import rltorch.memory as M +import torch.nn.functional as F # Q-Evolutionary Policy Agent # Maximizes the policy with respect to the Q-Value function. # Since function is non-differentiabile, depends on the Evolutionary Strategy algorithm class QEPAgent: - def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None): + def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None, entropy_importance = 0): self.policy_net = policy_net assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP) self.policy_net.fitness = self.fitness @@ -20,6 +21,7 @@ class QEPAgent: self.config = deepcopy(config) self.logger = logger self.policy_skip = 4 + self.entropy_importance = entropy_importance def save(self, file_location): torch.save({ @@ -36,7 +38,8 @@ class QEPAgent: def fitness(self, policy_net, value_net, state_batch): batch_size = len(state_batch) - action_probabilities = policy_net(state_batch) + with torch.no_grad(): + action_probabilities = policy_net(state_batch) action_size = action_probabilities.shape[1] distributions = list(map(Categorical, action_probabilities)) actions = torch.stack([d.sample() for d in distributions]) @@ -53,6 +56,7 @@ class QEPAgent: obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1) # return -obtained_values.mean().item() entropy_importance = 0 # Entropy accounting for 1% of loss seems to work well + entropy_importance = next(self.entropy_importance) if isinstance(self.entropy_importance, collections.Iterable) else self.entropy_importance value_importance = 1 - entropy_importance # entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory @@ -131,6 +135,7 @@ class QEPAgent: self.policy_skip -= 1 return self.policy_skip = 4 + if self.target_value_net is not None: self.policy_net.calc_gradients(self.target_value_net, state_batch) else: diff --git a/rltorch/network/ESNetworkMP.py b/rltorch/network/ESNetworkMP.py index ec954c2..69b0d21 100644 --- a/rltorch/network/ESNetworkMP.py +++ b/rltorch/network/ESNetworkMP.py @@ -26,7 +26,8 @@ class ESNetworkMP(Network): self.fitness = fitness_fn self.sigma = sigma assert self.sigma > 0 - self.pool = mp.Pool(processes=3) #[TODO] Probably should make number of processes a config variable + mp_ctx = mp.get_context("spawn") + self.pool = mp_ctx.Pool(processes=2) #[TODO] Probably should make number of processes a config variable # We're not going to be calculating gradients in the traditional way # So there's no need to waste computation time keeping track