d.sample returns a tensor, so we stack them to not lose the device

This commit is contained in:
Brandon Rozek 2019-02-28 14:30:49 -05:00
parent 714443192d
commit 9740c40527

View file

@ -23,7 +23,7 @@ class QEPAgent:
def fitness(self, policy_net, value_net, state_batch):
action_probabilities = policy_net(state_batch)
distributions = list(map(Categorical, action_probabilities))
actions = torch.tensor([d.sample() for d in distributions])
actions = torch.stack([d.sample() for d in distributions])
with torch.no_grad():
state_values = value_net(state_batch)