d.sample returns a tensor, so we stack them to not lose the device
This commit is contained in:
parent
714443192d
commit
9740c40527
1 changed files with 1 additions and 1 deletions
|
@ -23,7 +23,7 @@ class QEPAgent:
|
||||||
def fitness(self, policy_net, value_net, state_batch):
|
def fitness(self, policy_net, value_net, state_batch):
|
||||||
action_probabilities = policy_net(state_batch)
|
action_probabilities = policy_net(state_batch)
|
||||||
distributions = list(map(Categorical, action_probabilities))
|
distributions = list(map(Categorical, action_probabilities))
|
||||||
actions = torch.tensor([d.sample() for d in distributions])
|
actions = torch.stack([d.sample() for d in distributions])
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
state_values = value_net(state_batch)
|
state_values = value_net(state_batch)
|
||||||
|
|
Loading…
Add table
Reference in a new issue