Some work on multiprocessing evolutionary strategies from last semester
This commit is contained in:
parent
6d3a78cd20
commit
da83f1470c
1 changed files with 13 additions and 5 deletions
|
@ -12,7 +12,7 @@ import rltorch.memory as M
|
||||||
class QEPAgent:
|
class QEPAgent:
|
||||||
def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None):
|
def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None):
|
||||||
self.policy_net = policy_net
|
self.policy_net = policy_net
|
||||||
assert isinstance(self.policy_net, rltorch.network.ESNetwork)
|
assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP)
|
||||||
self.policy_net.fitness = self.fitness
|
self.policy_net.fitness = self.fitness
|
||||||
self.value_net = value_net
|
self.value_net = value_net
|
||||||
self.target_value_net = target_value_net
|
self.target_value_net = target_value_net
|
||||||
|
@ -22,6 +22,7 @@ class QEPAgent:
|
||||||
self.policy_skip = 4
|
self.policy_skip = 4
|
||||||
|
|
||||||
def fitness(self, policy_net, value_net, state_batch):
|
def fitness(self, policy_net, value_net, state_batch):
|
||||||
|
# print("Worker started")
|
||||||
batch_size = len(state_batch)
|
batch_size = len(state_batch)
|
||||||
action_probabilities = policy_net(state_batch)
|
action_probabilities = policy_net(state_batch)
|
||||||
action_size = action_probabilities.shape[1]
|
action_size = action_probabilities.shape[1]
|
||||||
|
@ -30,15 +31,21 @@ class QEPAgent:
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
state_values = value_net(state_batch)
|
state_values = value_net(state_batch)
|
||||||
|
|
||||||
|
# Weird hacky solution where in multiprocess, it sometimes spits out nans
|
||||||
|
# So have it try again
|
||||||
|
while torch.isnan(state_values).any():
|
||||||
|
with torch.no_grad():
|
||||||
|
state_values = value_net(state_batch)
|
||||||
|
|
||||||
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
|
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
|
||||||
# return -obtained_values.mean().item()
|
# return -obtained_values.mean().item()
|
||||||
|
entropy_importance = 0 # Entropy accounting for 1% of loss seems to work well
|
||||||
entropy_importance = 0.01 # Entropy accounting for 1% of loss seems to work well
|
|
||||||
value_importance = 1 - entropy_importance
|
value_importance = 1 - entropy_importance
|
||||||
|
|
||||||
# entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
|
# entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
|
||||||
entropy_loss = (action_probabilities - torch.tensor(1 / action_size).repeat(len(state_batch), action_size)).abs().sum(1)
|
entropy_loss = (action_probabilities - torch.tensor(1 / action_size).repeat(len(state_batch), action_size)).abs().sum(1)
|
||||||
|
# print("END WORKER")
|
||||||
return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item()
|
return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item()
|
||||||
|
|
||||||
|
|
||||||
|
@ -112,10 +119,11 @@ class QEPAgent:
|
||||||
self.policy_skip -= 1
|
self.policy_skip -= 1
|
||||||
return
|
return
|
||||||
self.policy_skip = 4
|
self.policy_skip = 4
|
||||||
|
|
||||||
if self.target_value_net is not None:
|
if self.target_value_net is not None:
|
||||||
self.policy_net.calc_gradients(self.target_value_net, state_batch)
|
self.policy_net.calc_gradients(self.target_value_net, state_batch)
|
||||||
else:
|
else:
|
||||||
self.policy_net.calc_gradients(self.value_net, state_batch)
|
self.policy_net.calc_gradients(self.value_net, state_batch)
|
||||||
# self.policy_net.clamp_gradients()
|
|
||||||
self.policy_net.step()
|
self.policy_net.step()
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue