From da83f1470c96ad81d8debb3ec58776f7dc4c4e2a Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Fri, 13 Sep 2019 19:53:19 -0400 Subject: [PATCH] Some work on multiprocessing evolutionary strategies from last semester --- rltorch/agents/QEPAgent.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/rltorch/agents/QEPAgent.py b/rltorch/agents/QEPAgent.py index 6040e28..aa1dc89 100644 --- a/rltorch/agents/QEPAgent.py +++ b/rltorch/agents/QEPAgent.py @@ -12,7 +12,7 @@ import rltorch.memory as M class QEPAgent: def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None): self.policy_net = policy_net - assert isinstance(self.policy_net, rltorch.network.ESNetwork) + assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP) self.policy_net.fitness = self.fitness self.value_net = value_net self.target_value_net = target_value_net @@ -22,6 +22,7 @@ class QEPAgent: self.policy_skip = 4 def fitness(self, policy_net, value_net, state_batch): + # print("Worker started") batch_size = len(state_batch) action_probabilities = policy_net(state_batch) action_size = action_probabilities.shape[1] @@ -30,15 +31,21 @@ class QEPAgent: with torch.no_grad(): state_values = value_net(state_batch) + + # Weird hacky solution where in multiprocess, it sometimes spits out nans + # So have it try again + while torch.isnan(state_values).any(): + with torch.no_grad(): + state_values = value_net(state_batch) + obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1) # return -obtained_values.mean().item() - - entropy_importance = 0.01 # Entropy accounting for 1% of loss seems to work well + entropy_importance = 0 # Entropy accounting for 1% of loss seems to work well value_importance = 1 - entropy_importance # entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory entropy_loss = (action_probabilities - torch.tensor(1 / action_size).repeat(len(state_batch), action_size)).abs().sum(1) - + # print("END WORKER") return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item() @@ -112,10 +119,11 @@ class QEPAgent: self.policy_skip -= 1 return self.policy_skip = 4 + if self.target_value_net is not None: self.policy_net.calc_gradients(self.target_value_net, state_batch) else: self.policy_net.calc_gradients(self.value_net, state_batch) - # self.policy_net.clamp_gradients() + self.policy_net.step()