Some work on multiprocessing evolutionary strategies from last semester
This commit is contained in:
parent
6d3a78cd20
commit
da83f1470c
1 changed files with 13 additions and 5 deletions
|
@ -12,7 +12,7 @@ import rltorch.memory as M
|
|||
class QEPAgent:
|
||||
def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None):
|
||||
self.policy_net = policy_net
|
||||
assert isinstance(self.policy_net, rltorch.network.ESNetwork)
|
||||
assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP)
|
||||
self.policy_net.fitness = self.fitness
|
||||
self.value_net = value_net
|
||||
self.target_value_net = target_value_net
|
||||
|
@ -22,6 +22,7 @@ class QEPAgent:
|
|||
self.policy_skip = 4
|
||||
|
||||
def fitness(self, policy_net, value_net, state_batch):
|
||||
# print("Worker started")
|
||||
batch_size = len(state_batch)
|
||||
action_probabilities = policy_net(state_batch)
|
||||
action_size = action_probabilities.shape[1]
|
||||
|
@ -30,15 +31,21 @@ class QEPAgent:
|
|||
|
||||
with torch.no_grad():
|
||||
state_values = value_net(state_batch)
|
||||
|
||||
# Weird hacky solution where in multiprocess, it sometimes spits out nans
|
||||
# So have it try again
|
||||
while torch.isnan(state_values).any():
|
||||
with torch.no_grad():
|
||||
state_values = value_net(state_batch)
|
||||
|
||||
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
|
||||
# return -obtained_values.mean().item()
|
||||
|
||||
entropy_importance = 0.01 # Entropy accounting for 1% of loss seems to work well
|
||||
entropy_importance = 0 # Entropy accounting for 1% of loss seems to work well
|
||||
value_importance = 1 - entropy_importance
|
||||
|
||||
# entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
|
||||
entropy_loss = (action_probabilities - torch.tensor(1 / action_size).repeat(len(state_batch), action_size)).abs().sum(1)
|
||||
|
||||
# print("END WORKER")
|
||||
return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item()
|
||||
|
||||
|
||||
|
@ -112,10 +119,11 @@ class QEPAgent:
|
|||
self.policy_skip -= 1
|
||||
return
|
||||
self.policy_skip = 4
|
||||
|
||||
if self.target_value_net is not None:
|
||||
self.policy_net.calc_gradients(self.target_value_net, state_batch)
|
||||
else:
|
||||
self.policy_net.calc_gradients(self.value_net, state_batch)
|
||||
# self.policy_net.clamp_gradients()
|
||||
|
||||
self.policy_net.step()
|
||||
|
||||
|
|
Loading…
Reference in a new issue