rltorch/rltorch/action_selector/ArgMaxSelector.py
2019-01-31 23:34:32 -05:00

18 lines
No EOL
691 B
Python

from random import randrange
import torch
class ArgMaxSelector:
def __init__(self, model, action_size, device = None):
self.model = model
self.action_size = action_size
self.device = device
def random_act(self):
return randrange(self.action_size)
def best_act(self, state):
with torch.no_grad():
if self.device is not None:
self.device.to(self.device)
action_values = self.model(state).squeeze(0)
action = self.random_act() if (action_values[0] == action_values).all() else action_values.argmax().item()
return action
def act(self, state):
return self.best_act(state)