Implemented epsilon as a scheduler
This commit is contained in:
parent
04e54cddc2
commit
b2ab2ee132
1 changed files with 3 additions and 6 deletions
|
@ -1,15 +1,12 @@
|
||||||
from .ArgMaxSelector import ArgMaxSelector
|
from .ArgMaxSelector import ArgMaxSelector
|
||||||
import numpy as np
|
import numpy as np
|
||||||
class EpsilonGreedySelector(ArgMaxSelector):
|
class EpsilonGreedySelector(ArgMaxSelector):
|
||||||
def __init__(self, model, action_size, device = None, epsilon = 0.1, epsilon_decay = 1, epsilon_min = 0.1):
|
def __init__(self, model, action_size, device = None, epsilon = 0.1):
|
||||||
super(EpsilonGreedySelector, self).__init__(model, action_size, device = device)
|
super(EpsilonGreedySelector, self).__init__(model, action_size, device = device)
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
self.epsilon_decay = epsilon_decay
|
|
||||||
self.epsilon_min = epsilon_min
|
|
||||||
# random_act is already implemented in ArgMaxSelector
|
# random_act is already implemented in ArgMaxSelector
|
||||||
# best_act is already implemented in ArgMaxSelector
|
# best_act is already implemented in ArgMaxSelector
|
||||||
def act(self, state):
|
def act(self, state):
|
||||||
action = self.random_act() if np.random.rand() < self.epsilon else self.best_act(state)
|
eps = next(self.epsilon) if isinstance(self.epsilon, collections.Iterable) else self.epsilon
|
||||||
if self.epsilon > self.epsilon_min:
|
action = self.random_act() if np.random.rand() < epsilon else self.best_act(state)
|
||||||
self.epsilon = self.epsilon * self.epsilon_decay
|
|
||||||
return action
|
return action
|
Loading…
Add table
Reference in a new issue