Some work on multiprocessing evolutionary strategies from last semester

This commit is contained in:
Brandon Rozek 2019-09-13 19:53:19 -04:00
parent 6d3a78cd20
commit da83f1470c

View file

@ -12,7 +12,7 @@ import rltorch.memory as M
class QEPAgent: class QEPAgent:
def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None): def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None):
self.policy_net = policy_net self.policy_net = policy_net
assert isinstance(self.policy_net, rltorch.network.ESNetwork) assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP)
self.policy_net.fitness = self.fitness self.policy_net.fitness = self.fitness
self.value_net = value_net self.value_net = value_net
self.target_value_net = target_value_net self.target_value_net = target_value_net
@ -22,6 +22,7 @@ class QEPAgent:
self.policy_skip = 4 self.policy_skip = 4
def fitness(self, policy_net, value_net, state_batch): def fitness(self, policy_net, value_net, state_batch):
# print("Worker started")
batch_size = len(state_batch) batch_size = len(state_batch)
action_probabilities = policy_net(state_batch) action_probabilities = policy_net(state_batch)
action_size = action_probabilities.shape[1] action_size = action_probabilities.shape[1]
@ -30,15 +31,21 @@ class QEPAgent:
with torch.no_grad(): with torch.no_grad():
state_values = value_net(state_batch) state_values = value_net(state_batch)
# Weird hacky solution where in multiprocess, it sometimes spits out nans
# So have it try again
while torch.isnan(state_values).any():
with torch.no_grad():
state_values = value_net(state_batch)
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1) obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
# return -obtained_values.mean().item() # return -obtained_values.mean().item()
entropy_importance = 0 # Entropy accounting for 1% of loss seems to work well
entropy_importance = 0.01 # Entropy accounting for 1% of loss seems to work well
value_importance = 1 - entropy_importance value_importance = 1 - entropy_importance
# entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory # entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
entropy_loss = (action_probabilities - torch.tensor(1 / action_size).repeat(len(state_batch), action_size)).abs().sum(1) entropy_loss = (action_probabilities - torch.tensor(1 / action_size).repeat(len(state_batch), action_size)).abs().sum(1)
# print("END WORKER")
return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item() return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item()
@ -112,10 +119,11 @@ class QEPAgent:
self.policy_skip -= 1 self.policy_skip -= 1
return return
self.policy_skip = 4 self.policy_skip = 4
if self.target_value_net is not None: if self.target_value_net is not None:
self.policy_net.calc_gradients(self.target_value_net, state_batch) self.policy_net.calc_gradients(self.target_value_net, state_batch)
else: else:
self.policy_net.calc_gradients(self.value_net, state_batch) self.policy_net.calc_gradients(self.value_net, state_batch)
# self.policy_net.clamp_gradients()
self.policy_net.step() self.policy_net.step()