d.sample returns a tensor, so we stack them to not lose the device
This commit is contained in:
parent
714443192d
commit
9740c40527
1 changed files with 1 additions and 1 deletions
|
@ -23,7 +23,7 @@ class QEPAgent:
|
|||
def fitness(self, policy_net, value_net, state_batch):
|
||||
action_probabilities = policy_net(state_batch)
|
||||
distributions = list(map(Categorical, action_probabilities))
|
||||
actions = torch.tensor([d.sample() for d in distributions])
|
||||
actions = torch.stack([d.sample() for d in distributions])
|
||||
|
||||
with torch.no_grad():
|
||||
state_values = value_net(state_batch)
|
||||
|
|
Loading…
Reference in a new issue