Fixed multiprocessing with CUDA. Added entropy importance as a config option.
This commit is contained in:
parent
9d32a9edd1
commit
a99ca66b4f
2 changed files with 9 additions and 3 deletions
|
@ -5,12 +5,13 @@ import torch
|
||||||
from torch.distributions import Categorical
|
from torch.distributions import Categorical
|
||||||
import rltorch
|
import rltorch
|
||||||
import rltorch.memory as M
|
import rltorch.memory as M
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
# Q-Evolutionary Policy Agent
|
# Q-Evolutionary Policy Agent
|
||||||
# Maximizes the policy with respect to the Q-Value function.
|
# Maximizes the policy with respect to the Q-Value function.
|
||||||
# Since function is non-differentiabile, depends on the Evolutionary Strategy algorithm
|
# Since function is non-differentiabile, depends on the Evolutionary Strategy algorithm
|
||||||
class QEPAgent:
|
class QEPAgent:
|
||||||
def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None):
|
def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None, entropy_importance = 0):
|
||||||
self.policy_net = policy_net
|
self.policy_net = policy_net
|
||||||
assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP)
|
assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP)
|
||||||
self.policy_net.fitness = self.fitness
|
self.policy_net.fitness = self.fitness
|
||||||
|
@ -20,6 +21,7 @@ class QEPAgent:
|
||||||
self.config = deepcopy(config)
|
self.config = deepcopy(config)
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.policy_skip = 4
|
self.policy_skip = 4
|
||||||
|
self.entropy_importance = entropy_importance
|
||||||
|
|
||||||
def save(self, file_location):
|
def save(self, file_location):
|
||||||
torch.save({
|
torch.save({
|
||||||
|
@ -36,7 +38,8 @@ class QEPAgent:
|
||||||
|
|
||||||
def fitness(self, policy_net, value_net, state_batch):
|
def fitness(self, policy_net, value_net, state_batch):
|
||||||
batch_size = len(state_batch)
|
batch_size = len(state_batch)
|
||||||
action_probabilities = policy_net(state_batch)
|
with torch.no_grad():
|
||||||
|
action_probabilities = policy_net(state_batch)
|
||||||
action_size = action_probabilities.shape[1]
|
action_size = action_probabilities.shape[1]
|
||||||
distributions = list(map(Categorical, action_probabilities))
|
distributions = list(map(Categorical, action_probabilities))
|
||||||
actions = torch.stack([d.sample() for d in distributions])
|
actions = torch.stack([d.sample() for d in distributions])
|
||||||
|
@ -53,6 +56,7 @@ class QEPAgent:
|
||||||
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
|
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
|
||||||
# return -obtained_values.mean().item()
|
# return -obtained_values.mean().item()
|
||||||
entropy_importance = 0 # Entropy accounting for 1% of loss seems to work well
|
entropy_importance = 0 # Entropy accounting for 1% of loss seems to work well
|
||||||
|
entropy_importance = next(self.entropy_importance) if isinstance(self.entropy_importance, collections.Iterable) else self.entropy_importance
|
||||||
value_importance = 1 - entropy_importance
|
value_importance = 1 - entropy_importance
|
||||||
|
|
||||||
# entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
|
# entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
|
||||||
|
@ -131,6 +135,7 @@ class QEPAgent:
|
||||||
self.policy_skip -= 1
|
self.policy_skip -= 1
|
||||||
return
|
return
|
||||||
self.policy_skip = 4
|
self.policy_skip = 4
|
||||||
|
|
||||||
if self.target_value_net is not None:
|
if self.target_value_net is not None:
|
||||||
self.policy_net.calc_gradients(self.target_value_net, state_batch)
|
self.policy_net.calc_gradients(self.target_value_net, state_batch)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -26,7 +26,8 @@ class ESNetworkMP(Network):
|
||||||
self.fitness = fitness_fn
|
self.fitness = fitness_fn
|
||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
assert self.sigma > 0
|
assert self.sigma > 0
|
||||||
self.pool = mp.Pool(processes=3) #[TODO] Probably should make number of processes a config variable
|
mp_ctx = mp.get_context("spawn")
|
||||||
|
self.pool = mp_ctx.Pool(processes=2) #[TODO] Probably should make number of processes a config variable
|
||||||
|
|
||||||
# We're not going to be calculating gradients in the traditional way
|
# We're not going to be calculating gradients in the traditional way
|
||||||
# So there's no need to waste computation time keeping track
|
# So there's no need to waste computation time keeping track
|
||||||
|
|
Loading…
Add table
Reference in a new issue