PEP8 Conformance
This commit is contained in:
parent
9b81188a77
commit
8fa4691511
29 changed files with 652 additions and 755 deletions
|
@ -1,7 +1,7 @@
|
||||||
from random import randrange
|
from random import randrange
|
||||||
import torch
|
import torch
|
||||||
class ArgMaxSelector:
|
class ArgMaxSelector:
|
||||||
def __init__(self, model, action_size, device = None):
|
def __init__(self, model, action_size, device=None):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.action_size = action_size
|
self.action_size = action_size
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -12,7 +12,8 @@ class ArgMaxSelector:
|
||||||
if self.device is not None:
|
if self.device is not None:
|
||||||
state = state.to(self.device)
|
state = state.to(self.device)
|
||||||
action_values = self.model(state).squeeze(0)
|
action_values = self.model(state).squeeze(0)
|
||||||
action = self.random_act() if (action_values[0] == action_values).all() else action_values.argmax().item()
|
action = self.random_act() if (action_values[0] == action_values).all() \
|
||||||
|
else action_values.argmax().item()
|
||||||
return action
|
return action
|
||||||
def act(self, state):
|
def act(self, state):
|
||||||
return self.best_act(state)
|
return self.best_act(state)
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
from .ArgMaxSelector import ArgMaxSelector
|
|
||||||
import numpy as np
|
|
||||||
import collections
|
import collections
|
||||||
|
import numpy as np
|
||||||
|
from .ArgMaxSelector import ArgMaxSelector
|
||||||
|
|
||||||
class EpsilonGreedySelector(ArgMaxSelector):
|
class EpsilonGreedySelector(ArgMaxSelector):
|
||||||
def __init__(self, model, action_size, device = None, epsilon = 0.1):
|
def __init__(self, model, action_size, device=None, epsilon=0.1):
|
||||||
super(EpsilonGreedySelector, self).__init__(model, action_size, device = device)
|
super(EpsilonGreedySelector, self).__init__(model, action_size, device=device)
|
||||||
self.epsilon = epsilon
|
self.epsilon = epsilon
|
||||||
# random_act is already implemented in ArgMaxSelector
|
# random_act is already implemented in ArgMaxSelector
|
||||||
# best_act is already implemented in ArgMaxSelector
|
# best_act is already implemented in ArgMaxSelector
|
||||||
def act(self, state):
|
def act(self, state):
|
||||||
eps = next(self.epsilon) if isinstance(self.epsilon, collections.Iterable) else self.epsilon
|
eps = next(self.epsilon) if isinstance(self.epsilon, collections.Iterable) else self.epsilon
|
||||||
action = self.random_act() if np.random.rand() < eps else self.best_act(state)
|
action = self.random_act() if np.random.rand() < eps else self.best_act(state)
|
||||||
return action
|
return action
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
from .ArgMaxSelector import ArgMaxSelector
|
|
||||||
import torch
|
import torch
|
||||||
|
from .ArgMaxSelector import ArgMaxSelector
|
||||||
|
|
||||||
class IdentitySelector(ArgMaxSelector):
|
class IdentitySelector(ArgMaxSelector):
|
||||||
def __init__(self, model, action_size, device = None):
|
def __init__(self, model, action_size, device=None):
|
||||||
super(IdentitySelector, self).__init__(model, action_size, device = device)
|
super(IdentitySelector, self).__init__(model, action_size, device=device)
|
||||||
# random_act is already implemented in ArgMaxSelector
|
# random_act is already implemented in ArgMaxSelector
|
||||||
def best_act(self, state):
|
def best_act(self, state):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
@ -11,4 +12,4 @@ class IdentitySelector(ArgMaxSelector):
|
||||||
action = self.model(state).squeeze(0).item()
|
action = self.model(state).squeeze(0).item()
|
||||||
return action
|
return action
|
||||||
def act(self, state):
|
def act(self, state):
|
||||||
return self.best_act(state)
|
return self.best_act(state)
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
from random import randrange
|
from random import randrange
|
||||||
class RandomSelector():
|
class RandomSelector:
|
||||||
def __init__(self, action_size):
|
def __init__(self, action_size):
|
||||||
self.action_size = action_size
|
self.action_size = action_size
|
||||||
def random_act(self):
|
def random_act(self):
|
||||||
return randrange(action_size)
|
return randrange(self.action_size)
|
||||||
def best_act(self, state):
|
def best_act(self, _):
|
||||||
return self.random_act()
|
return self.random_act()
|
||||||
def act(self, state):
|
def act(self, _):
|
||||||
return self.random_act()
|
return self.random_act()
|
||||||
|
|
|
@ -1,22 +1,19 @@
|
||||||
from random import randrange
|
|
||||||
import torch
|
|
||||||
from torch.distributions import Categorical
|
from torch.distributions import Categorical
|
||||||
import rltorch
|
from .ArgMaxSelector import ArgMaxSelector
|
||||||
from rltorch.action_selector import ArgMaxSelector
|
from ..memory.EpisodeMemory import EpisodeMemory
|
||||||
|
|
||||||
class StochasticSelector(ArgMaxSelector):
|
class StochasticSelector(ArgMaxSelector):
|
||||||
def __init__(self, model, action_size, memory = None, device = None):
|
def __init__(self, model, action_size, memory=None, device=None):
|
||||||
super(StochasticSelector, self).__init__(model, action_size, device = device)
|
super(StochasticSelector, self).__init__(model, action_size, device=device)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.action_size = action_size
|
self.action_size = action_size
|
||||||
self.device = device
|
self.device = device
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
def best_act(self, state, log_prob = True):
|
def best_act(self, state, log_prob=True):
|
||||||
if self.device is not None:
|
if self.device is not None:
|
||||||
state = state.to(self.device)
|
state = state.to(self.device)
|
||||||
action_probabilities = self.model(state)
|
action_probabilities = self.model(state)
|
||||||
distribution = Categorical(action_probabilities)
|
distribution = Categorical(action_probabilities)
|
||||||
action = distribution.sample()
|
action = distribution.sample()
|
||||||
if log_prob and isinstance(self.memory, rltorch.memory.EpisodeMemory):
|
if log_prob and isinstance(self.memory, EpisodeMemory):
|
||||||
self.memory.append_log_probs(distribution.log_prob(action))
|
self.memory.append_log_probs(distribution.log_prob(action))
|
||||||
return action.item()
|
return action.item()
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from .ArgMaxSelector import *
|
from .ArgMaxSelector import ArgMaxSelector
|
||||||
from .EpsilonGreedySelector import *
|
from .EpsilonGreedySelector import EpsilonGreedySelector
|
||||||
from .IdentitySelector import *
|
from .IdentitySelector import IdentitySelector
|
||||||
from .RandomSelector import *
|
from .RandomSelector import RandomSelector
|
||||||
from .StochasticSelector import *
|
from .StochasticSelector import StochasticSelector
|
||||||
|
|
|
@ -2,89 +2,91 @@ from copy import deepcopy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import rltorch
|
|
||||||
import rltorch.memory as M
|
|
||||||
|
|
||||||
class A2CSingleAgent:
|
class A2CSingleAgent:
|
||||||
def __init__(self, policy_net, value_net, memory, config, logger = None):
|
def __init__(self, policy_net, value_net, memory, config, logger=None):
|
||||||
self.policy_net = policy_net
|
self.policy_net = policy_net
|
||||||
self.value_net = value_net
|
self.value_net = value_net
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
self.config = deepcopy(config)
|
self.config = deepcopy(config)
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
def _discount_rewards(self, rewards):
|
def _discount_rewards(self, rewards):
|
||||||
gammas = torch.ones_like(rewards)
|
gammas = torch.ones_like(rewards)
|
||||||
if len(rewards) > 1:
|
if len(rewards) > 1:
|
||||||
gammas[1:] = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - 1), dim = 0)
|
discount_tensor = torch.tensor(self.config['discount_rate'])
|
||||||
return gammas * rewards
|
gammas[1:] = torch.cumprod(
|
||||||
|
discount_tensor.repeat(len(rewards) - 1),
|
||||||
# This function is currently not used since the performance gains hasn't been shown
|
dim=0
|
||||||
# May be due to a faulty implementation, need to investigate more..
|
)
|
||||||
def _generalized_advantage_estimation(self, states, rewards, next_states, not_done):
|
return gammas * rewards
|
||||||
tradeoff = 0.5
|
|
||||||
with torch.no_grad():
|
|
||||||
next_values = torch.zeros_like(rewards)
|
|
||||||
next_values[not_done] = self.value_net(next_states[not_done]).squeeze(1)
|
|
||||||
values = self.value_net(states).squeeze(1)
|
|
||||||
|
|
||||||
generalized_advantages = torch.zeros_like(rewards)
|
# This function is currently not used since the performance gains hasn't been shown
|
||||||
for i in range(len(generalized_advantages)):
|
# May be due to a faulty implementation, need to investigate more..
|
||||||
weights = torch.ones_like(rewards[i:])
|
def _generalized_advantage_estimation(self, states, rewards, next_states, not_done):
|
||||||
if i != len(generalized_advantages) - 1:
|
tradeoff = 0.5
|
||||||
weights[1:] = torch.cumprod(torch.tensor(self.config['discount_rate'] * tradeoff).repeat(len(rewards) - i - 1), dim = 0)
|
with torch.no_grad():
|
||||||
generalized_advantages[i] = (weights * (rewards[i:] + self.config['discount_rate'] * next_values[i:] - values[i:])).sum()
|
next_values = torch.zeros_like(rewards)
|
||||||
|
next_values[not_done] = self.value_net(next_states[not_done]).squeeze(1)
|
||||||
|
values = self.value_net(states).squeeze(1)
|
||||||
|
|
||||||
return generalized_advantages
|
generalized_advantages = torch.zeros_like(rewards)
|
||||||
|
discount_tensor = torch.tensor(self.config['discount_rate']) * tradeoff
|
||||||
|
for i, _ in enumerate(generalized_advantages):
|
||||||
|
weights = torch.ones_like(rewards[i:])
|
||||||
|
if i != len(generalized_advantages) - 1:
|
||||||
|
weights[1:] = torch.cumprod(discount_tensor.repeat(len(rewards) - i - 1), dim=0)
|
||||||
|
generalized_advantages[i] = (weights * (rewards[i:] + self.config['discount_rate'] * next_values[i:] - values[i:])).sum()
|
||||||
|
|
||||||
|
return generalized_advantages
|
||||||
def learn(self):
|
|
||||||
episode_batch = self.memory.recall()
|
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
|
|
||||||
|
|
||||||
# Send batches to the appropriate device
|
def learn(self):
|
||||||
state_batch = torch.cat(state_batch).to(self.value_net.device)
|
episode_batch = self.memory.recall()
|
||||||
reward_batch = torch.tensor(reward_batch).to(self.value_net.device).float()
|
state_batch, _, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
|
||||||
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
|
|
||||||
next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
|
|
||||||
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
|
||||||
|
|
||||||
## Value Loss
|
# Send batches to the appropriate device
|
||||||
# In A2C, the value loss is the difference between the discounted reward and the value from the first state
|
state_batch = torch.cat(state_batch).to(self.value_net.device)
|
||||||
# The value of the first state is supposed to tell us the expected reward from the current policy of the whole episode
|
reward_batch = torch.tensor(reward_batch).to(self.value_net.device).float()
|
||||||
discounted_reward = self._discount_rewards(reward_batch)
|
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
|
||||||
observed_value = discounted_reward.sum()
|
next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
|
||||||
value_loss = F.mse_loss(observed_value, self.value_net(state_batch[0]))
|
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
||||||
self.value_net.zero_grad()
|
|
||||||
value_loss.backward()
|
|
||||||
self.value_net.step()
|
|
||||||
|
|
||||||
## Policy Loss
|
## Value Loss
|
||||||
# Increase probabilities of advantageous states
|
# In A2C, the value loss is the difference between the discounted reward
|
||||||
# and decrease the probabilities of non-advantageous ones
|
# and the value from the first state.
|
||||||
with torch.no_grad():
|
# The value of the first state is supposed to tell us
|
||||||
state_values = self.value_net(state_batch)
|
# the expected reward from the current policy of the whole episode
|
||||||
next_state_values = torch.zeros_like(state_values)
|
discounted_reward = self._discount_rewards(reward_batch)
|
||||||
next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch])
|
observed_value = discounted_reward.sum()
|
||||||
advantages = (reward_batch.unsqueeze(1) + self.config['discount_rate'] * next_state_values) - state_values
|
value_loss = F.mse_loss(observed_value, self.value_net(state_batch[0]))
|
||||||
advantages = advantages.squeeze(1)
|
self.value_net.zero_grad()
|
||||||
|
value_loss.backward()
|
||||||
|
self.value_net.step()
|
||||||
|
|
||||||
# advantages = self._generalized_advantage_estimation(state_batch, reward_batch, next_state_batch, not_done_batch)
|
## Policy Loss
|
||||||
# Scale for more stable learning
|
# Increase probabilities of advantageous states
|
||||||
advantages = advantages / (advantages.std() + np.finfo('float').eps)
|
# and decrease the probabilities of non-advantageous ones
|
||||||
|
with torch.no_grad():
|
||||||
|
state_values = self.value_net(state_batch)
|
||||||
|
next_state_values = torch.zeros_like(state_values)
|
||||||
|
next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch])
|
||||||
|
advantages = (reward_batch.unsqueeze(1) + self.config['discount_rate'] * next_state_values) - state_values
|
||||||
|
advantages = advantages.squeeze(1)
|
||||||
|
|
||||||
policy_loss = (-log_prob_batch * advantages).sum()
|
# advantages = self._generalized_advantage_estimation(state_batch, reward_batch, next_state_batch, not_done_batch)
|
||||||
|
# Scale for more stable learning
|
||||||
|
advantages = advantages / (advantages.std() + np.finfo('float').eps)
|
||||||
|
|
||||||
|
policy_loss = (-log_prob_batch * advantages).sum()
|
||||||
|
|
||||||
if self.logger is not None:
|
if self.logger is not None:
|
||||||
self.logger.append("Loss/Policy", policy_loss.item())
|
self.logger.append("Loss/Policy", policy_loss.item())
|
||||||
self.logger.append("Loss/Value", value_loss.item())
|
self.logger.append("Loss/Value", value_loss.item())
|
||||||
|
|
||||||
|
|
||||||
self.policy_net.zero_grad()
|
self.policy_net.zero_grad()
|
||||||
policy_loss.backward()
|
policy_loss.backward()
|
||||||
self.policy_net.step()
|
self.policy_net.step()
|
||||||
|
|
||||||
# Memory under the old policy is not needed for future training
|
|
||||||
self.memory.clear()
|
|
||||||
|
|
||||||
|
|
||||||
|
# Memory under the old policy is not needed for future training
|
||||||
|
self.memory.clear()
|
||||||
|
|
|
@ -1,13 +1,11 @@
|
||||||
import collections
|
import collections
|
||||||
|
from copy import deepcopy
|
||||||
import rltorch.memory as M
|
import rltorch.memory as M
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from copy import deepcopy
|
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
class DQNAgent:
|
class DQNAgent:
|
||||||
def __init__(self, net , memory, config, target_net = None, logger = None):
|
def __init__(self, net, memory, config, target_net=None, logger=None):
|
||||||
self.net = net
|
self.net = net
|
||||||
self.target_net = target_net
|
self.target_net = target_net
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
|
@ -20,16 +18,16 @@ class DQNAgent:
|
||||||
self.net.model.to(self.net.device)
|
self.net.model.to(self.net.device)
|
||||||
self.target_net.sync()
|
self.target_net.sync()
|
||||||
|
|
||||||
def learn(self, logger = None):
|
def learn(self, logger=None):
|
||||||
if len(self.memory) < self.config['batch_size']:
|
if len(self.memory) < self.config['batch_size']:
|
||||||
return
|
return
|
||||||
|
|
||||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
if isinstance(self.memory, M.PrioritizedReplayMemory):
|
||||||
weight_importance = self.config['prioritized_replay_weight_importance']
|
weight_importance = self.config['prioritized_replay_weight_importance']
|
||||||
# If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
# If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
||||||
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
|
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
|
||||||
minibatch = self.memory.sample(self.config['batch_size'], beta = beta)
|
minibatch = self.memory.sample(self.config['batch_size'], beta=beta)
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
|
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority=True)
|
||||||
else:
|
else:
|
||||||
minibatch = self.memory.sample(self.config['batch_size'])
|
minibatch = self.memory.sample(self.config['batch_size'])
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch)
|
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch)
|
||||||
|
@ -49,7 +47,7 @@ class DQNAgent:
|
||||||
# and the regular net to select the action
|
# and the regular net to select the action
|
||||||
# That way we decouple the value and action selecting processes (DOUBLE DQN)
|
# That way we decouple the value and action selecting processes (DOUBLE DQN)
|
||||||
not_done_size = not_done_batch.sum()
|
not_done_size = not_done_batch.sum()
|
||||||
next_state_values = torch.zeros_like(state_values, device = self.net.device)
|
next_state_values = torch.zeros_like(state_values, device=self.net.device)
|
||||||
if self.target_net is not None:
|
if self.target_net is not None:
|
||||||
next_state_values[not_done_batch] = self.target_net(next_state_batch[not_done_batch])
|
next_state_values[not_done_batch] = self.target_net(next_state_batch[not_done_batch])
|
||||||
next_best_action = self.net(next_state_batch[not_done_batch]).argmax(1)
|
next_best_action = self.net(next_state_batch[not_done_batch]).argmax(1)
|
||||||
|
@ -57,15 +55,15 @@ class DQNAgent:
|
||||||
next_state_values[not_done_batch] = self.net(next_state_batch[not_done_batch])
|
next_state_values[not_done_batch] = self.net(next_state_batch[not_done_batch])
|
||||||
next_best_action = next_state_values[not_done_batch].argmax(1)
|
next_best_action = next_state_values[not_done_batch].argmax(1)
|
||||||
|
|
||||||
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.net.device)
|
best_next_state_value = torch.zeros(self.config['batch_size'], device=self.net.device)
|
||||||
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
||||||
|
|
||||||
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||||
|
|
||||||
# If we're sampling by TD error, multiply loss by a importance weight which helps decrease overfitting
|
# If we're sampling by TD error, multiply loss by a importance weight which helps decrease overfitting
|
||||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
if isinstance(self.memory, M.PrioritizedReplayMemory):
|
||||||
# loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.smooth_l1_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean()
|
# loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.smooth_l1_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean()
|
||||||
loss = (torch.as_tensor(importance_weights, device = self.net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
|
loss = (torch.as_tensor(importance_weights, device=self.net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
|
||||||
else:
|
else:
|
||||||
# loss = F.smooth_l1_loss(obtained_values, expected_values)
|
# loss = F.smooth_l1_loss(obtained_values, expected_values)
|
||||||
loss = F.mse_loss(obtained_values, expected_values)
|
loss = F.mse_loss(obtained_values, expected_values)
|
||||||
|
@ -85,8 +83,6 @@ class DQNAgent:
|
||||||
self.target_net.sync()
|
self.target_net.sync()
|
||||||
|
|
||||||
# If we're sampling by TD error, readjust the weights of the experiences
|
# If we're sampling by TD error, readjust the weights of the experiences
|
||||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
if isinstance(self.memory, M.PrioritizedReplayMemory):
|
||||||
td_error = (obtained_values - expected_values).detach().abs()
|
td_error = (obtained_values - expected_values).detach().abs()
|
||||||
self.memory.update_priorities(batch_indexes, td_error)
|
self.memory.update_priorities(batch_indexes, td_error)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,12 @@
|
||||||
import collections
|
import collections
|
||||||
|
from copy import deepcopy
|
||||||
import rltorch.memory as M
|
import rltorch.memory as M
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from copy import deepcopy
|
|
||||||
import numpy as np
|
|
||||||
from pathlib import Path
|
|
||||||
from rltorch.action_selector import ArgMaxSelector
|
|
||||||
|
|
||||||
class DQfDAgent:
|
class DQfDAgent:
|
||||||
def __init__(self, net, memory, config, target_net = None, logger = None):
|
def __init__(self, net, memory, config, target_net=None, logger=None):
|
||||||
self.net = net
|
self.net = net
|
||||||
self.target_net = target_net
|
self.target_net = target_net
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
|
@ -21,7 +19,7 @@ class DQfDAgent:
|
||||||
self.net.model.to(self.net.device)
|
self.net.model.to(self.net.device)
|
||||||
self.target_net.sync()
|
self.target_net.sync()
|
||||||
|
|
||||||
def learn(self, logger = None):
|
def learn(self, logger=None):
|
||||||
if len(self.memory) < self.config['batch_size']:
|
if len(self.memory) < self.config['batch_size']:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -32,29 +30,19 @@ class DQfDAgent:
|
||||||
batch_size = self.config['batch_size']
|
batch_size = self.config['batch_size']
|
||||||
steps = None
|
steps = None
|
||||||
|
|
||||||
if isinstance(self.memory, M.DQfDMemory):
|
weight_importance = self.config['prioritized_replay_weight_importance']
|
||||||
weight_importance = self.config['prioritized_replay_weight_importance']
|
# If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
||||||
# If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) \
|
||||||
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
|
else weight_importance
|
||||||
|
|
||||||
# Check to see if we are doing N-Step DQN
|
# Check to see if we are doing N-Step DQN
|
||||||
if steps is not None:
|
if steps is not None:
|
||||||
minibatch = self.memory.sample_n_steps(batch_size, steps, beta)
|
minibatch = self.memory.sample_n_steps(batch_size, steps, beta)
|
||||||
else:
|
|
||||||
minibatch = self.memory.sample(batch_size, beta = beta)
|
|
||||||
|
|
||||||
# Process batch
|
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Check to see if we're doing N-Step DQN
|
minibatch = self.memory.sample(batch_size, beta=beta)
|
||||||
if steps is not None:
|
|
||||||
minibatch = self.memory.sample_n_steps(batch_size, steps)
|
|
||||||
else:
|
|
||||||
minibatch = self.memory.sample(batch_size)
|
|
||||||
|
|
||||||
# Process batch
|
# Process batch
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, batch_indexes = M.zip_batch(minibatch, want_indices = True)
|
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority=True)
|
||||||
|
|
||||||
batch_index_tensors = torch.tensor(batch_indexes)
|
batch_index_tensors = torch.tensor(batch_indexes)
|
||||||
demo_mask = batch_index_tensors < self.memory.demo_position
|
demo_mask = batch_index_tensors < self.memory.demo_position
|
||||||
|
@ -75,7 +63,7 @@ class DQfDAgent:
|
||||||
# and the regular net to select the action
|
# and the regular net to select the action
|
||||||
# That way we decouple the value and action selecting processes (DOUBLE DQN)
|
# That way we decouple the value and action selecting processes (DOUBLE DQN)
|
||||||
not_done_size = not_done_batch.sum()
|
not_done_size = not_done_batch.sum()
|
||||||
next_state_values = torch.zeros_like(state_values, device = self.net.device)
|
next_state_values = torch.zeros_like(state_values, device=self.net.device)
|
||||||
if self.target_net is not None:
|
if self.target_net is not None:
|
||||||
next_state_values[not_done_batch] = self.target_net(next_state_batch[not_done_batch])
|
next_state_values[not_done_batch] = self.target_net(next_state_batch[not_done_batch])
|
||||||
next_best_action = self.net(next_state_batch[not_done_batch]).argmax(1)
|
next_best_action = self.net(next_state_batch[not_done_batch]).argmax(1)
|
||||||
|
@ -83,14 +71,14 @@ class DQfDAgent:
|
||||||
next_state_values[not_done_batch] = self.net(next_state_batch[not_done_batch])
|
next_state_values[not_done_batch] = self.net(next_state_batch[not_done_batch])
|
||||||
next_best_action = next_state_values[not_done_batch].argmax(1)
|
next_best_action = next_state_values[not_done_batch].argmax(1)
|
||||||
|
|
||||||
best_next_state_value = torch.zeros(batch_size, device = self.net.device)
|
best_next_state_value = torch.zeros(batch_size, device=self.net.device)
|
||||||
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
||||||
|
|
||||||
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
expected_values = (reward_batch + (batch_size * best_next_state_value)).unsqueeze(1)
|
||||||
|
|
||||||
# N-Step DQN Loss
|
# N-Step DQN Loss
|
||||||
# num_steps capture how many steps actually exist before the end of episode
|
# num_steps capture how many steps actually exist before the end of episode
|
||||||
if steps != None:
|
if steps is not None:
|
||||||
expected_n_step_values = []
|
expected_n_step_values = []
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for i in range(0, len(state_batch), steps):
|
for i in range(0, len(state_batch), steps):
|
||||||
|
@ -127,7 +115,7 @@ class DQfDAgent:
|
||||||
l = torch.ones_like(state_values[demo_mask])
|
l = torch.ones_like(state_values[demo_mask])
|
||||||
expert_actions = action_batch[demo_mask]
|
expert_actions = action_batch[demo_mask]
|
||||||
# l(s, a) is zero for every action the expert doesn't take
|
# l(s, a) is zero for every action the expert doesn't take
|
||||||
for i,a in zip(range(len(l)), expert_actions):
|
for i, _, a in zip(enumerate(l), expert_actions):
|
||||||
l[i].fill_(0.8) # According to paper
|
l[i].fill_(0.8) # According to paper
|
||||||
l[i, a] = 0
|
l[i, a] = 0
|
||||||
if self.target_net is not None:
|
if self.target_net is not None:
|
||||||
|
@ -139,35 +127,26 @@ class DQfDAgent:
|
||||||
# Iterate through hyperparamters
|
# Iterate through hyperparamters
|
||||||
if isinstance(self.config['dqfd_demo_loss_weight'], collections.Iterable):
|
if isinstance(self.config['dqfd_demo_loss_weight'], collections.Iterable):
|
||||||
demo_importance = next(self.config['dqfd_demo_loss_weight'])
|
demo_importance = next(self.config['dqfd_demo_loss_weight'])
|
||||||
else:
|
else:
|
||||||
demo_importance = self.config['dqfd_demo_loss_weight']
|
demo_importance = self.config['dqfd_demo_loss_weight']
|
||||||
if isinstance(self.config['dqfd_td_loss_weight'], collections.Iterable):
|
if isinstance(self.config['dqfd_td_loss_weight'], collections.Iterable):
|
||||||
td_importance = next(self.config['dqfd_td_loss_weight'])
|
td_importance = next(self.config['dqfd_td_loss_weight'])
|
||||||
else:
|
else:
|
||||||
td_importance = self.config['dqfd_td_loss_weight']
|
td_importance = self.config['dqfd_td_loss_weight']
|
||||||
|
|
||||||
|
|
||||||
# Since dqn_loss and demo_loss are different sizes, the reduction has to happen before they are combined
|
# Since dqn_loss and demo_loss are different sizes, the reduction has to happen before they are combined
|
||||||
if isinstance(self.memory, M.DQfDMemory):
|
dqn_loss = (torch.as_tensor(importance_weights, device=self.net.device) * F.mse_loss(obtained_values, expected_values, reduction='none').squeeze(1)).mean()
|
||||||
dqn_loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.mse_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean()
|
|
||||||
else:
|
|
||||||
dqn_loss = F.mse_loss(obtained_values, expected_values)
|
|
||||||
|
|
||||||
if steps != None:
|
if steps is not None:
|
||||||
if isinstance(self.memory, M.DQfDMemory):
|
dqn_n_step_loss = (torch.as_tensor(importance_weights[::steps], device=self.net.device) * F.mse_loss(observed_n_step_values, expected_n_step_values, reduction='none')).mean()
|
||||||
dqn_n_step_loss = (torch.as_tensor(importance_weights[::steps], device = self.net.device) * F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none')).mean()
|
|
||||||
else:
|
|
||||||
dqn_n_step_loss = F.mse_loss(observed_n_step_values, expected_n_step_values, reduction = 'none').mean()
|
|
||||||
else:
|
else:
|
||||||
dqn_n_step_loss = torch.tensor(0, device = self.net.device)
|
dqn_n_step_loss = torch.tensor(0, device=self.net.device)
|
||||||
|
|
||||||
if demo_mask.sum() > 0:
|
if demo_mask.sum() > 0:
|
||||||
if isinstance(self.memory, M.DQfDMemory):
|
demo_loss = (torch.as_tensor(importance_weights, device=self.net.device)[demo_mask] * F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction='none').squeeze(1)).mean()
|
||||||
demo_loss = (torch.as_tensor(importance_weights, device = self.net.device)[demo_mask] * F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1)).mean()
|
|
||||||
else:
|
|
||||||
demo_loss = F.mse_loss((state_values[demo_mask] + l).max(1)[0].unsqueeze(1), expert_value, reduction = 'none').squeeze(1).mean()
|
|
||||||
else:
|
else:
|
||||||
demo_loss = 0.
|
demo_loss = 0
|
||||||
loss = td_importance * dqn_loss + td_importance * dqn_n_step_loss + demo_importance * demo_loss
|
loss = td_importance * dqn_loss + td_importance * dqn_n_step_loss + demo_importance * demo_loss
|
||||||
|
|
||||||
if self.logger is not None:
|
if self.logger is not None:
|
||||||
|
|
|
@ -1,81 +1,72 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.distributions import Categorical
|
from torch.distributions import Categorical
|
||||||
import rltorch
|
import rltorch
|
||||||
import rltorch.memory as M
|
|
||||||
import collections
|
|
||||||
import random
|
|
||||||
|
|
||||||
class PPOAgent:
|
class PPOAgent:
|
||||||
def __init__(self, policy_net, value_net, memory, config, logger = None):
|
def __init__(self, policy_net, value_net, memory, config, logger=None):
|
||||||
self.policy_net = policy_net
|
self.policy_net = policy_net
|
||||||
self.old_policy_net = rltorch.network.TargetNetwork(policy_net)
|
self.old_policy_net = rltorch.network.TargetNetwork(policy_net)
|
||||||
self.value_net = value_net
|
self.value_net = value_net
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
self.config = deepcopy(config)
|
self.config = deepcopy(config)
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
def _discount_rewards(self, rewards):
|
def _discount_rewards(self, rewards):
|
||||||
gammas = torch.ones_like(rewards)
|
gammas = torch.ones_like(rewards)
|
||||||
if len(rewards) > 1:
|
if len(rewards) > 1:
|
||||||
gammas[1:] = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - 1), dim = 0)
|
gammas[1:] = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - 1), dim=0)
|
||||||
return gammas * rewards
|
return gammas * rewards
|
||||||
|
|
||||||
|
def learn(self):
|
||||||
def learn(self):
|
episode_batch = self.memory.recall()
|
||||||
episode_batch = self.memory.recall()
|
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
|
|
||||||
|
|
||||||
# Send batches to the appropriate device
|
# Send batches to the appropriate device
|
||||||
state_batch = torch.cat(state_batch).to(self.value_net.device)
|
state_batch = torch.cat(state_batch).to(self.value_net.device)
|
||||||
action_batch = torch.tensor(action_batch).to(self.value_net.device)
|
action_batch = torch.tensor(action_batch).to(self.value_net.device)
|
||||||
reward_batch = torch.tensor(reward_batch).to(self.value_net.device).float()
|
reward_batch = torch.tensor(reward_batch).to(self.value_net.device).float()
|
||||||
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
|
not_done_batch = ~torch.tensor(done_batch).to(self.value_net.device)
|
||||||
next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
|
next_state_batch = torch.cat(next_state_batch).to(self.value_net.device)
|
||||||
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
log_prob_batch = torch.cat(log_prob_batch).to(self.value_net.device)
|
||||||
|
|
||||||
## Value Loss
|
## Value Loss
|
||||||
# In PPO, the value loss is the difference between the discounted reward and the value from the first state
|
# In PPO, the value loss is the difference between the discounted reward and the value from the first state
|
||||||
# The value of the first state is supposed to tell us the expected reward from the current policy of the whole episode
|
# The value of the first state is supposed to tell us the expected reward from the current policy of the whole episode
|
||||||
value_loss = F.mse_loss(self._discount_rewards(reward_batch).sum(), self.value_net(state_batch[0]))
|
value_loss = F.mse_loss(self._discount_rewards(reward_batch).sum(), self.value_net(state_batch[0]))
|
||||||
self.value_net.zero_grad()
|
self.value_net.zero_grad()
|
||||||
value_loss.backward()
|
value_loss.backward()
|
||||||
self.value_net.step()
|
self.value_net.step()
|
||||||
|
|
||||||
## Policy Loss
|
## Policy Loss
|
||||||
# Increase probabilities of advantageous states
|
# Increase probabilities of advantageous states
|
||||||
# and decrease the probabilities of non-advantageous ones
|
# and decrease the probabilities of non-advantageous ones
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
state_values = self.value_net(state_batch)
|
state_values = self.value_net(state_batch)
|
||||||
next_state_values = torch.zeros_like(state_values)
|
next_state_values = torch.zeros_like(state_values)
|
||||||
next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch])
|
next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch])
|
||||||
advantages = (reward_batch.unsqueeze(1) + self.config['discount_rate'] * next_state_values) - state_values
|
advantages = (reward_batch.unsqueeze(1) + self.config['discount_rate'] * next_state_values) - state_values
|
||||||
advantages = advantages.squeeze(1)
|
advantages = advantages.squeeze(1)
|
||||||
|
|
||||||
action_probabilities = self.old_policy_net(state_batch).detach()
|
action_probabilities = self.old_policy_net(state_batch).detach()
|
||||||
distributions = list(map(Categorical, action_probabilities))
|
distributions = list(map(Categorical, action_probabilities))
|
||||||
old_log_probs = torch.stack(list(map(lambda distribution, action: distribution.log_prob(action), distributions, action_batch)))
|
old_log_probs = torch.stack(list(map(lambda distribution, action: distribution.log_prob(action), distributions, action_batch)))
|
||||||
|
|
||||||
# For PPO we want to stay within a certain ratio of the old policy
|
# For PPO we want to stay within a certain ratio of the old policy
|
||||||
policy_ratio = torch.exp(log_prob_batch - old_log_probs) # Equivalent to (log_prob / old_log_prob)
|
policy_ratio = torch.exp(log_prob_batch - old_log_probs) # Equivalent to (log_prob / old_log_prob)
|
||||||
policy_loss1 = policy_ratio * advantages
|
policy_loss1 = policy_ratio * advantages
|
||||||
policy_loss2 = policy_ratio.clamp(min = 0.8, max = 1.2) * advantages # From original paper
|
policy_loss2 = policy_ratio.clamp(min=0.8, max=1.2) * advantages # From original paper
|
||||||
policy_loss = -torch.min(policy_loss1, policy_loss2).sum()
|
policy_loss = -torch.min(policy_loss1, policy_loss2).sum()
|
||||||
|
|
||||||
if self.logger is not None:
|
if self.logger is not None:
|
||||||
self.logger.append("Loss/Policy", policy_loss.item())
|
self.logger.append("Loss/Policy", policy_loss.item())
|
||||||
self.logger.append("Loss/Value", value_loss.item())
|
self.logger.append("Loss/Value", value_loss.item())
|
||||||
|
|
||||||
|
self.old_policy_net.sync()
|
||||||
self.old_policy_net.sync()
|
self.policy_net.zero_grad()
|
||||||
self.policy_net.zero_grad()
|
policy_loss.backward()
|
||||||
policy_loss.backward()
|
self.policy_net.step()
|
||||||
self.policy_net.step()
|
|
||||||
|
|
||||||
|
|
||||||
# Memory under the old policy is not needed for future training
|
|
||||||
self.memory.clear()
|
|
||||||
|
|
||||||
|
|
||||||
|
# Memory under the old policy is not needed for future training
|
||||||
|
self.memory.clear()
|
||||||
|
|
|
@ -2,16 +2,17 @@ from copy import deepcopy
|
||||||
import collections
|
import collections
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch.distributions import Categorical
|
from torch.distributions import Categorical
|
||||||
import rltorch
|
import rltorch
|
||||||
import rltorch.memory as M
|
import rltorch.memory as M
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
# Q-Evolutionary Policy Agent
|
# Q-Evolutionary Policy Agent
|
||||||
# Maximizes the policy with respect to the Q-Value function.
|
# Maximizes the policy with respect to the Q-Value function.
|
||||||
# Since function is non-differentiabile, depends on the Evolutionary Strategy algorithm
|
# Since function is non-differentiabile, depends on the Evolutionary Strategy algorithm
|
||||||
class QEPAgent:
|
class QEPAgent:
|
||||||
def __init__(self, policy_net, value_net, memory, config, target_value_net = None, logger = None, entropy_importance = 0, policy_skip = 4, after_value_train = None):
|
def __init__(self, policy_net, value_net, memory, config, target_value_net=None, logger=None, entropy_importance=0, policy_skip=4):
|
||||||
self.policy_net = policy_net
|
self.policy_net = policy_net
|
||||||
assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP)
|
assert isinstance(self.policy_net, rltorch.network.ESNetwork) or isinstance(self.policy_net, rltorch.network.ESNetworkMP)
|
||||||
self.policy_net.fitness = self.fitness
|
self.policy_net.fitness = self.fitness
|
||||||
|
@ -22,7 +23,6 @@ class QEPAgent:
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.policy_skip = policy_skip
|
self.policy_skip = policy_skip
|
||||||
self.entropy_importance = entropy_importance
|
self.entropy_importance = entropy_importance
|
||||||
self.after_value_train = after_value_train
|
|
||||||
|
|
||||||
def save(self, file_location):
|
def save(self, file_location):
|
||||||
torch.save({
|
torch.save({
|
||||||
|
@ -42,43 +42,41 @@ class QEPAgent:
|
||||||
batch_size = len(state_batch)
|
batch_size = len(state_batch)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
action_probabilities = policy_net(state_batch)
|
action_probabilities = policy_net(state_batch)
|
||||||
|
|
||||||
action_size = action_probabilities.shape[1]
|
action_size = action_probabilities.shape[1]
|
||||||
distributions = list(map(Categorical, action_probabilities))
|
distributions = list(map(Categorical, action_probabilities))
|
||||||
|
|
||||||
actions = torch.stack([d.sample() for d in distributions])
|
actions = torch.stack([d.sample() for d in distributions])
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
state_values = value_net(state_batch)
|
state_values = value_net(state_batch)
|
||||||
|
|
||||||
# Weird hacky solution where in multiprocess, it sometimes spits out nans
|
# Weird hacky solution where in multiprocess, it sometimes spits out nans
|
||||||
# So have it try again
|
# So have it try again
|
||||||
while torch.isnan(state_values).any():
|
while torch.isnan(state_values).any():
|
||||||
print("NAN DETECTED")
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
state_values = value_net(state_batch)
|
state_values = value_net(state_batch)
|
||||||
|
|
||||||
obtained_values = state_values.gather(1, actions.view(batch_size, 1)).squeeze(1)
|
obtained_values = state_values.gather(1, actions.view(len(state_batch), 1)).squeeze(1)
|
||||||
|
# return -obtained_values.mean().item()
|
||||||
|
entropy_importance = 0 # Entropy accounting for 1% of loss seems to work well
|
||||||
entropy_importance = next(self.entropy_importance) if isinstance(self.entropy_importance, collections.Iterable) else self.entropy_importance
|
entropy_importance = next(self.entropy_importance) if isinstance(self.entropy_importance, collections.Iterable) else self.entropy_importance
|
||||||
value_importance = 1 - entropy_importance
|
value_importance = 1 - entropy_importance
|
||||||
|
|
||||||
# entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
|
# entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory
|
||||||
entropy_loss = (action_probabilities - torch.tensor(1 / action_size, device = state_batch.device).repeat(batch_size, action_size)).abs().sum(1)
|
entropy_loss = (action_probabilities - torch.tensor(1 / action_size, device=state_batch.device).repeat(len(state_batch), action_size)).abs().sum(1)
|
||||||
|
|
||||||
return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item()
|
return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item()
|
||||||
|
|
||||||
|
|
||||||
def learn(self, logger = None):
|
def learn(self, logger=None):
|
||||||
if len(self.memory) < self.config['batch_size']:
|
if len(self.memory) < self.config['batch_size']:
|
||||||
return
|
return
|
||||||
|
|
||||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
if isinstance(self.memory, M.PrioritizedReplayMemory):
|
||||||
weight_importance = self.config['prioritized_replay_weight_importance']
|
weight_importance = self.config['prioritized_replay_weight_importance']
|
||||||
# If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
# If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
||||||
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
|
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
|
||||||
minibatch = self.memory.sample(self.config['batch_size'], beta = beta)
|
minibatch = self.memory.sample(self.config['batch_size'], beta=beta)
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
|
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority=True)
|
||||||
else:
|
else:
|
||||||
minibatch = self.memory.sample(self.config['batch_size'])
|
minibatch = self.memory.sample(self.config['batch_size'])
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch)
|
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch)
|
||||||
|
@ -98,7 +96,7 @@ class QEPAgent:
|
||||||
# and the regular net to select the action
|
# and the regular net to select the action
|
||||||
# That way we decouple the value and action selecting processes (DOUBLE DQN)
|
# That way we decouple the value and action selecting processes (DOUBLE DQN)
|
||||||
not_done_size = not_done_batch.sum()
|
not_done_size = not_done_batch.sum()
|
||||||
next_state_values = torch.zeros_like(state_values, device = self.value_net.device)
|
next_state_values = torch.zeros_like(state_values, device=self.value_net.device)
|
||||||
if self.target_value_net is not None:
|
if self.target_value_net is not None:
|
||||||
next_state_values[not_done_batch] = self.target_value_net(next_state_batch[not_done_batch])
|
next_state_values[not_done_batch] = self.target_value_net(next_state_batch[not_done_batch])
|
||||||
next_best_action = self.value_net(next_state_batch[not_done_batch]).argmax(1)
|
next_best_action = self.value_net(next_state_batch[not_done_batch]).argmax(1)
|
||||||
|
@ -106,13 +104,13 @@ class QEPAgent:
|
||||||
next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch])
|
next_state_values[not_done_batch] = self.value_net(next_state_batch[not_done_batch])
|
||||||
next_best_action = next_state_values[not_done_batch].argmax(1)
|
next_best_action = next_state_values[not_done_batch].argmax(1)
|
||||||
|
|
||||||
best_next_state_value = torch.zeros(self.config['batch_size'], device = self.value_net.device)
|
best_next_state_value = torch.zeros(self.config['batch_size'], device=self.value_net.device)
|
||||||
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
best_next_state_value[not_done_batch] = next_state_values[not_done_batch].gather(1, next_best_action.view((not_done_size, 1))).squeeze(1)
|
||||||
|
|
||||||
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||||
|
|
||||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
if isinstance(self.memory, M.PrioritizedReplayMemory):
|
||||||
value_loss = (torch.as_tensor(importance_weights, device = self.value_net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
|
value_loss = (torch.as_tensor(importance_weights, device=self.value_net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean()
|
||||||
else:
|
else:
|
||||||
value_loss = F.mse_loss(obtained_values, expected_values)
|
value_loss = F.mse_loss(obtained_values, expected_values)
|
||||||
|
|
||||||
|
@ -124,28 +122,26 @@ class QEPAgent:
|
||||||
self.value_net.clamp_gradients()
|
self.value_net.clamp_gradients()
|
||||||
self.value_net.step()
|
self.value_net.step()
|
||||||
|
|
||||||
if callable(self.after_value_train):
|
|
||||||
self.after_value_train()
|
|
||||||
|
|
||||||
if self.target_value_net is not None:
|
if self.target_value_net is not None:
|
||||||
if 'target_sync_tau' in self.config:
|
if 'target_sync_tau' in self.config:
|
||||||
self.target_value_net.partial_sync(self.config['target_sync_tau'])
|
self.target_value_net.partial_sync(self.config['target_sync_tau'])
|
||||||
else:
|
else:
|
||||||
self.target_value_net.sync()
|
self.target_value_net.sync()
|
||||||
|
|
||||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
if isinstance(self.memory, M.PrioritizedReplayMemory):
|
||||||
td_error = (obtained_values - expected_values).detach().abs()
|
td_error = (obtained_values - expected_values).detach().abs()
|
||||||
self.memory.update_priorities(batch_indexes, td_error)
|
self.memory.update_priorities(batch_indexes, td_error)
|
||||||
|
|
||||||
## Policy Training
|
## Policy Training
|
||||||
if self.policy_skip > 0:
|
if self.policy_skip > 0:
|
||||||
self.policy_skip -= 1
|
self.policy_skip -= 1
|
||||||
return
|
return
|
||||||
self.policy_skip = self.config['policy_skip']
|
self.policy_skip = 4
|
||||||
|
|
||||||
if self.target_value_net is not None:
|
if self.target_value_net is not None:
|
||||||
self.policy_net.calc_gradients(self.target_value_net, state_batch)
|
self.policy_net.calc_gradients(self.target_value_net, state_batch)
|
||||||
else:
|
else:
|
||||||
self.policy_net.calc_gradients(self.value_net, state_batch)
|
self.policy_net.calc_gradients(self.value_net, state_batch)
|
||||||
|
|
||||||
self.policy_net.step()
|
self.policy_net.step()
|
||||||
|
|
||||||
|
|
|
@ -1,60 +1,60 @@
|
||||||
import rltorch
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import rltorch
|
||||||
|
|
||||||
class REINFORCEAgent:
|
class REINFORCEAgent:
|
||||||
def __init__(self, net , memory, config, target_net = None, logger = None):
|
def __init__(self, net, memory, config, target_net=None, logger=None):
|
||||||
self.net = net
|
self.net = net
|
||||||
if not isinstance(memory, rltorch.memory.EpisodeMemory):
|
if not isinstance(memory, rltorch.memory.EpisodeMemory):
|
||||||
raise ValueError("Memory must be of instance EpisodeMemory")
|
raise ValueError("Memory must be of instance EpisodeMemory")
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
self.config = deepcopy(config)
|
self.config = deepcopy(config)
|
||||||
self.target_net = target_net
|
self.target_net = target_net
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
# Shaped rewards implements three improvements to REINFORCE
|
# Shaped rewards implements three improvements to REINFORCE
|
||||||
# 1) Discounted rewards, future rewards matter less than current
|
# 1) Discounted rewards, future rewards matter less than current
|
||||||
# 2) Baselines: We use the mean reward to see if the current reward is advantageous or not
|
# 2) Baselines: We use the mean reward to see if the current reward is advantageous or not
|
||||||
# 3) Causality: Your current actions do not affect your past. Only the present and future.
|
# 3) Causality: Your current actions do not affect your past. Only the present and future.
|
||||||
def _shape_rewards(self, rewards):
|
def _shape_rewards(self, rewards):
|
||||||
shaped_rewards = torch.zeros_like(rewards)
|
shaped_rewards = torch.zeros_like(rewards)
|
||||||
baseline = rewards.mean()
|
baseline = rewards.mean()
|
||||||
for i in range(len(rewards)):
|
for i in range(len(rewards)):
|
||||||
gammas = torch.ones_like(rewards[i:])
|
gammas = torch.ones_like(rewards[i:])
|
||||||
if i != len(rewards) - 1:
|
if i != len(rewards) - 1:
|
||||||
gammas[1:] = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - i - 1), dim = 0)
|
gammas[1:] = torch.cumprod(torch.tensor(self.config['discount_rate']).repeat(len(rewards) - i - 1), dim=0)
|
||||||
advantages = rewards[i:] - baseline
|
advantages = rewards[i:] - baseline
|
||||||
shaped_rewards[i] = (gammas * advantages).sum()
|
shaped_rewards[i] = (gammas * advantages).sum()
|
||||||
return shaped_rewards
|
return shaped_rewards
|
||||||
|
|
||||||
def learn(self):
|
def learn(self):
|
||||||
episode_batch = self.memory.recall()
|
episode_batch = self.memory.recall()
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch, log_prob_batch = zip(*episode_batch)
|
_, _, reward_batch, _, _, log_prob_batch = zip(*episode_batch)
|
||||||
|
|
||||||
# Caluclate discounted rewards to place more importance to recent rewards
|
# Caluclate discounted rewards to place more importance to recent rewards
|
||||||
shaped_reward_batch = self._shape_rewards(torch.tensor(reward_batch))
|
shaped_reward_batch = self._shape_rewards(torch.tensor(reward_batch))
|
||||||
|
|
||||||
# Scale discounted rewards to have variance 1 (stabalizes training)
|
# Scale discounted rewards to have variance 1 (stabalizes training)
|
||||||
shaped_reward_batch = shaped_reward_batch / (shaped_reward_batch.std() + np.finfo('float').eps)
|
shaped_reward_batch = shaped_reward_batch / (shaped_reward_batch.std() + np.finfo('float').eps)
|
||||||
|
|
||||||
log_prob_batch = torch.cat(log_prob_batch)
|
log_prob_batch = torch.cat(log_prob_batch)
|
||||||
|
|
||||||
policy_loss = (-log_prob_batch * shaped_reward_batch).sum()
|
policy_loss = (-log_prob_batch * shaped_reward_batch).sum()
|
||||||
|
|
||||||
if self.logger is not None:
|
if self.logger is not None:
|
||||||
self.logger.append("Loss", policy_loss.item())
|
self.logger.append("Loss", policy_loss.item())
|
||||||
|
|
||||||
self.net.zero_grad()
|
self.net.zero_grad()
|
||||||
policy_loss.backward()
|
policy_loss.backward()
|
||||||
self.net.clamp_gradients()
|
self.net.clamp_gradients()
|
||||||
self.net.step()
|
self.net.step()
|
||||||
|
|
||||||
if self.target_net is not None:
|
if self.target_net is not None:
|
||||||
if 'target_sync_tau' in self.config:
|
if 'target_sync_tau' in self.config:
|
||||||
self.target_net.partial_sync(self.config['target_sync_tau'])
|
self.target_net.partial_sync(self.config['target_sync_tau'])
|
||||||
else:
|
else:
|
||||||
self.target_net.sync()
|
self.target_net.sync()
|
||||||
|
|
||||||
# Memory under the old policy is not needed for future training
|
# Memory under the old policy is not needed for future training
|
||||||
self.memory.clear()
|
self.memory.clear()
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from .A2CSingleAgent import *
|
from .A2CSingleAgent import A2CSingleAgent
|
||||||
from .DQNAgent import *
|
from .DQNAgent import DQNAgent
|
||||||
from .DQfDAgent import *
|
from .DQfDAgent import DQfDAgent
|
||||||
from .PPOAgent import *
|
from .PPOAgent import PPOAgent
|
||||||
from .QEPAgent import *
|
from .QEPAgent import QEPAgent
|
||||||
from .REINFORCEAgent import *
|
from .REINFORCEAgent import REINFORCEAgent
|
||||||
|
|
186
rltorch/env/simulate.py
vendored
186
rltorch/env/simulate.py
vendored
|
@ -1,108 +1,108 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import rltorch
|
|
||||||
import time
|
import time
|
||||||
|
import rltorch
|
||||||
|
|
||||||
def simulateEnvEps(env, actor, config, total_episodes = 1, memory = None, logger = None, name = "", render = False):
|
def simulateEnvEps(env, actor, config, total_episodes=1, memory=None, logger=None, name="", render=False):
|
||||||
for episode in range(total_episodes):
|
for episode in range(total_episodes):
|
||||||
state = env.reset()
|
state = env.reset()
|
||||||
done = False
|
done = False
|
||||||
episode_reward = 0
|
episode_reward = 0
|
||||||
while not done:
|
while not done:
|
||||||
action = actor.act(state)
|
action = actor.act(state)
|
||||||
next_state, reward, done, _ = env.step(action)
|
next_state, reward, done, _ = env.step(action)
|
||||||
if render:
|
if render:
|
||||||
env.render()
|
env.render()
|
||||||
time.sleep(0.01)
|
time.sleep(0.01)
|
||||||
|
|
||||||
episode_reward = episode_reward + reward
|
episode_reward = episode_reward + reward
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
memory.append(state, action, reward, next_state, done)
|
memory.append(state, action, reward, next_state, done)
|
||||||
state = next_state
|
state = next_state
|
||||||
|
|
||||||
if episode % config['print_stat_n_eps'] == 0:
|
if episode % config['print_stat_n_eps'] == 0:
|
||||||
print("episode: {}/{}, score: {}"
|
print("episode: {}/{}, score: {}"
|
||||||
.format(episode, total_episodes, episode_reward), flush=True)
|
.format(episode, total_episodes, episode_reward), flush=True)
|
||||||
|
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
logger.append(name + '/EpisodeReward', episode_reward)
|
logger.append(name + '/EpisodeReward', episode_reward)
|
||||||
|
|
||||||
|
|
||||||
class EnvironmentRunSync():
|
class EnvironmentRunSync:
|
||||||
def __init__(self, env, actor, config, memory = None, logwriter = None, name = "", render = False):
|
def __init__(self, env, actor, config, memory=None, logwriter=None, name="", render=False):
|
||||||
self.env = env
|
self.env = env
|
||||||
self.name = name
|
self.name = name
|
||||||
self.actor = actor
|
self.actor = actor
|
||||||
self.config = deepcopy(config)
|
self.config = deepcopy(config)
|
||||||
self.logwriter = logwriter
|
self.logwriter = logwriter
|
||||||
self.memory = memory
|
self.memory = memory
|
||||||
self.episode_num = 1
|
self.episode_num = 1
|
||||||
self.episode_reward = 0
|
self.episode_reward = 0
|
||||||
self.last_state = env.reset()
|
self.last_state = env.reset()
|
||||||
self.render = render
|
self.render = render
|
||||||
|
|
||||||
def run(self, iterations):
|
def run(self, iterations):
|
||||||
state = self.last_state
|
state = self.last_state
|
||||||
logger = rltorch.log.Logger() if self.logwriter is not None else None
|
logger = rltorch.log.Logger() if self.logwriter is not None else None
|
||||||
for _ in range(iterations):
|
for _ in range(iterations):
|
||||||
action = self.actor.act(state)
|
action = self.actor.act(state)
|
||||||
next_state, reward, done, _ = self.env.step(action)
|
next_state, reward, done, _ = self.env.step(action)
|
||||||
if self.render:
|
if self.render:
|
||||||
self.env.render()
|
self.env.render()
|
||||||
|
|
||||||
self.episode_reward += reward
|
self.episode_reward += reward
|
||||||
if self.memory is not None:
|
if self.memory is not None:
|
||||||
self.memory.append(state, action, reward, next_state, done)
|
self.memory.append(state, action, reward, next_state, done)
|
||||||
|
|
||||||
state = next_state
|
state = next_state
|
||||||
|
|
||||||
|
if done:
|
||||||
|
if self.episode_num % self.config['print_stat_n_eps'] == 0:
|
||||||
|
print("episode: {}/{}, score: {}"
|
||||||
|
.format(self.episode_num, self.config['total_training_episodes'], self.episode_reward), flush=True)
|
||||||
|
|
||||||
|
if self.logwriter is not None:
|
||||||
|
logger.append(self.name + '/EpisodeReward', self.episode_reward)
|
||||||
|
self.episode_reward = 0
|
||||||
|
state = self.env.reset()
|
||||||
|
self.episode_num += 1
|
||||||
|
|
||||||
|
if self.logwriter is not None:
|
||||||
|
self.logwriter.write(logger)
|
||||||
|
|
||||||
|
self.last_state = state
|
||||||
|
|
||||||
|
|
||||||
|
class EnvironmentEpisodeSync:
|
||||||
|
def __init__(self, env, actor, config, memory=None, logwriter=None, name=""):
|
||||||
|
self.env = env
|
||||||
|
self.name = name
|
||||||
|
self.actor = actor
|
||||||
|
self.config = deepcopy(config)
|
||||||
|
self.logwriter = logwriter
|
||||||
|
self.memory = memory
|
||||||
|
self.episode_num = 1
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
state = self.env.reset()
|
||||||
|
done = False
|
||||||
|
episodeReward = 0
|
||||||
|
logger = rltorch.log.Logger() if self.logwriter is not None else None
|
||||||
|
while not done:
|
||||||
|
action = self.actor.act(state)
|
||||||
|
next_state, reward, done, _ = self.env.step(action)
|
||||||
|
|
||||||
|
episodeReward += reward
|
||||||
|
if self.memory is not None:
|
||||||
|
self.memory.append(state, action, reward, next_state, done)
|
||||||
|
|
||||||
|
state = next_state
|
||||||
|
|
||||||
if done:
|
|
||||||
if self.episode_num % self.config['print_stat_n_eps'] == 0:
|
if self.episode_num % self.config['print_stat_n_eps'] == 0:
|
||||||
print("episode: {}/{}, score: {}"
|
print("episode: {}/{}, score: {}"
|
||||||
.format(self.episode_num, self.config['total_training_episodes'], self.episode_reward), flush=True)
|
.format(self.episode_num, self.config['total_training_episodes'], episodeReward), flush=True)
|
||||||
|
|
||||||
if self.logwriter is not None:
|
if self.logwriter is not None:
|
||||||
logger.append(self.name + '/EpisodeReward', self.episode_reward)
|
logger.append(self.name + '/EpisodeReward', episodeReward)
|
||||||
self.episode_reward = 0
|
self.logwriter.write(logger)
|
||||||
state = self.env.reset()
|
|
||||||
self.episode_num += 1
|
|
||||||
|
|
||||||
if self.logwriter is not None:
|
|
||||||
self.logwriter.write(logger)
|
|
||||||
|
|
||||||
self.last_state = state
|
self.episode_num += 1
|
||||||
|
|
||||||
|
|
||||||
class EnvironmentEpisodeSync():
|
|
||||||
def __init__(self, env, actor, config, memory = None, logwriter = None, name = ""):
|
|
||||||
self.env = env
|
|
||||||
self.name = name
|
|
||||||
self.actor = actor
|
|
||||||
self.config = deepcopy(config)
|
|
||||||
self.logwriter = logwriter
|
|
||||||
self.memory = memory
|
|
||||||
self.episode_num = 1
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
state = self.env.reset()
|
|
||||||
done = False
|
|
||||||
episodeReward = 0
|
|
||||||
logger = rltorch.log.Logger() if self.logwriter is not None else None
|
|
||||||
while not done:
|
|
||||||
action = self.actor.act(state)
|
|
||||||
next_state, reward, done, _ = self.env.step(action)
|
|
||||||
|
|
||||||
episodeReward += reward
|
|
||||||
if self.memory is not None:
|
|
||||||
self.memory.append(state, action, reward, next_state, done)
|
|
||||||
|
|
||||||
state = next_state
|
|
||||||
|
|
||||||
if self.episode_num % self.config['print_stat_n_eps'] == 0:
|
|
||||||
print("episode: {}/{}, score: {}"
|
|
||||||
.format(self.episode_num, self.config['total_training_episodes'], episodeReward), flush=True)
|
|
||||||
|
|
||||||
if self.logwriter is not None:
|
|
||||||
logger.append(self.name + '/EpisodeReward', episodeReward)
|
|
||||||
self.logwriter.write(logger)
|
|
||||||
|
|
||||||
self.episode_num += 1
|
|
||||||
|
|
204
rltorch/env/wrappers.py
vendored
204
rltorch/env/wrappers.py
vendored
|
@ -1,8 +1,8 @@
|
||||||
|
from collections import deque
|
||||||
import gym
|
import gym
|
||||||
import torch
|
import torch
|
||||||
from gym import spaces
|
from gym import spaces
|
||||||
import cv2
|
import cv2
|
||||||
from collections import deque
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
class EpisodicLifeEnv(gym.Wrapper):
|
class EpisodicLifeEnv(gym.Wrapper):
|
||||||
|
@ -111,129 +111,134 @@ class ClippedRewardsWrapper(gym.RewardWrapper):
|
||||||
|
|
||||||
# Mostly derived from OpenAI baselines
|
# Mostly derived from OpenAI baselines
|
||||||
class FireResetEnv(gym.Wrapper):
|
class FireResetEnv(gym.Wrapper):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
"""Take action on reset for environments that are fixed until firing."""
|
"""Take action on reset for environments that are fixed until firing."""
|
||||||
gym.Wrapper.__init__(self, env)
|
gym.Wrapper.__init__(self, env)
|
||||||
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
||||||
assert len(env.unwrapped.get_action_meanings()) >= 3
|
assert len(env.unwrapped.get_action_meanings()) >= 3
|
||||||
|
|
||||||
def reset(self, **kwargs):
|
def reset(self, **kwargs):
|
||||||
self.env.reset(**kwargs)
|
self.env.reset(**kwargs)
|
||||||
obs, _, done, _ = self.env.step(1)
|
obs, _, done, _ = self.env.step(1)
|
||||||
if done:
|
if done:
|
||||||
self.env.reset(**kwargs)
|
self.env.reset(**kwargs)
|
||||||
obs, _, done, _ = self.env.step(2)
|
obs, _, done, _ = self.env.step(2)
|
||||||
if done:
|
if done:
|
||||||
self.env.reset(**kwargs)
|
self.env.reset(**kwargs)
|
||||||
return obs
|
return obs
|
||||||
|
|
||||||
def step(self, ac):
|
def step(self, ac):
|
||||||
return self.env.step(ac)
|
return self.env.step(ac)
|
||||||
|
|
||||||
class LazyFrames(object):
|
class LazyFrames(object):
|
||||||
def __init__(self, frames):
|
def __init__(self, frames):
|
||||||
"""This object ensures that common frames between the observations are only stored once.
|
"""This object ensures that common frames between the observations are only stored once.
|
||||||
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
|
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
|
||||||
buffers.
|
buffers.
|
||||||
This object should only be converted to numpy array before being passed to the model.
|
This object should only be converted to numpy array before being passed to the model.
|
||||||
You'd not believe how complex the previous solution was."""
|
You'd not believe how complex the previous solution was."""
|
||||||
self._frames = frames
|
self._frames = frames
|
||||||
self._out = None
|
self._out = None
|
||||||
|
|
||||||
def _force(self):
|
def _force(self):
|
||||||
if self._out is None:
|
if self._out is None:
|
||||||
self._out = torch.stack(self._frames)
|
self._out = torch.stack(self._frames)
|
||||||
self._frames = None
|
self._frames = None
|
||||||
return self._out
|
return self._out
|
||||||
|
|
||||||
def __array__(self, dtype=None):
|
def __array__(self, dtype=None):
|
||||||
out = self._force()
|
out = self._force()
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
out = out.astype(dtype)
|
out = out.astype(dtype)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self._force())
|
return len(self._force())
|
||||||
|
|
||||||
def __getitem__(self, i):
|
def __getitem__(self, i):
|
||||||
return self._force()[i]
|
return self._force()[i]
|
||||||
|
|
||||||
class FrameStack(gym.Wrapper):
|
class FrameStack(gym.Wrapper):
|
||||||
def __init__(self, env, k):
|
def __init__(self, env, k):
|
||||||
"""Stack k last frames.
|
"""Stack k last frames.
|
||||||
Returns lazy array, which is much more memory efficient.
|
Returns lazy array, which is much more memory efficient.
|
||||||
See Also
|
See Also
|
||||||
--------
|
--------
|
||||||
baselines.common.atari_wrappers.LazyFrames
|
baselines.common.atari_wrappers.LazyFrames
|
||||||
"""
|
"""
|
||||||
gym.Wrapper.__init__(self, env)
|
gym.Wrapper.__init__(self, env)
|
||||||
self.k = k
|
self.k = k
|
||||||
self.frames = deque([], maxlen=k)
|
self.frames = deque([], maxlen=k)
|
||||||
shp = env.observation_space.shape
|
shp = env.observation_space.shape
|
||||||
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[:-1] + (shp[-1] * k,)), dtype=env.observation_space.dtype)
|
self.observation_space = spaces.Box(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=(shp[:-1] + (shp[-1] * k,)),
|
||||||
|
dtype=env.observation_space.dtype
|
||||||
|
)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
ob = self.env.reset()
|
ob = self.env.reset()
|
||||||
for _ in range(self.k):
|
for _ in range(self.k):
|
||||||
self.frames.append(ob)
|
self.frames.append(ob)
|
||||||
return self._get_ob()
|
return self._get_ob()
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
ob, reward, done, info = self.env.step(action)
|
ob, reward, done, info = self.env.step(action)
|
||||||
self.frames.append(ob)
|
self.frames.append(ob)
|
||||||
return self._get_ob(), reward, done, info
|
return self._get_ob(), reward, done, info
|
||||||
|
|
||||||
def _get_ob(self):
|
def _get_ob(self):
|
||||||
assert len(self.frames) == self.k
|
assert len(self.frames) == self.k
|
||||||
# return LazyFrames(list(self.frames))
|
# return LazyFrames(list(self.frames))
|
||||||
return torch.cat(list(self.frames)).unsqueeze(0)
|
return torch.cat(list(self.frames)).unsqueeze(0)
|
||||||
|
|
||||||
class ProcessFrame(gym.Wrapper):
|
class ProcessFrame(gym.Wrapper):
|
||||||
def __init__(self, env, resize_shape = None, crop_bounds = None, grayscale = False):
|
def __init__(self, env, resize_shape=None, crop_bounds=None, grayscale=False):
|
||||||
gym.Wrapper.__init__(self, env)
|
gym.Wrapper.__init__(self, env)
|
||||||
self.resize_shape = resize_shape
|
self.resize_shape = resize_shape
|
||||||
self.crop_bounds = crop_bounds
|
self.crop_bounds = crop_bounds
|
||||||
self.grayscale = grayscale
|
self.grayscale = grayscale
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
return self._preprocess(self.env.reset())
|
return self._preprocess(self.env.reset())
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
next_state, reward, done, info = self.env.step(action)
|
next_state, reward, done, info = self.env.step(action)
|
||||||
next_state = self._preprocess(next_state)
|
next_state = self._preprocess(next_state)
|
||||||
return next_state, reward, done, info
|
return next_state, reward, done, info
|
||||||
|
|
||||||
def _preprocess(self, frame):
|
def _preprocess(self, frame):
|
||||||
if self.grayscale:
|
if self.grayscale:
|
||||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
||||||
if self.crop_bounds is not None and len(self.crop_bounds) == 4:
|
if self.crop_bounds is not None and len(self.crop_bounds) == 4:
|
||||||
frame = frame[self.crop_bounds[0]:self.crop_bounds[1], self.crop_bounds[2]:self.crop_bounds[3]]
|
frame = frame[
|
||||||
if self.resize_shape is not None and len(self.resize_shape) == 2:
|
self.crop_bounds[0]:self.crop_bounds[1],
|
||||||
frame = cv2.resize(frame, self.resize_shape, interpolation=cv2.INTER_AREA)
|
self.crop_bounds[2]:self.crop_bounds[3]
|
||||||
# Normalize
|
]
|
||||||
frame = frame / 255
|
if self.resize_shape is not None and len(self.resize_shape) == 2:
|
||||||
return frame
|
frame = cv2.resize(frame, self.resize_shape, interpolation=cv2.INTER_AREA)
|
||||||
|
# Normalize
|
||||||
|
frame = frame / 255
|
||||||
|
return frame
|
||||||
|
|
||||||
# Turns observations into torch tensors
|
# Turns observations into torch tensors
|
||||||
# Adds an additional dimension that's suppose to represent the batch dim
|
# Adds an additional dimension that's suppose to represent the batch dim
|
||||||
class TorchWrap(gym.Wrapper):
|
class TorchWrap(gym.Wrapper):
|
||||||
def __init__(self, env):
|
def __init__(self, env):
|
||||||
gym.Wrapper.__init__(self, env)
|
gym.Wrapper.__init__(self, env)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
return self._convert(self.env.reset())
|
return self._convert(self.env.reset())
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
next_state, reward, done, info = self.env.step(action)
|
next_state, reward, done, info = self.env.step(action)
|
||||||
next_state = self._convert(next_state)
|
next_state = self._convert(next_state)
|
||||||
return next_state, reward, done, info
|
return next_state, reward, done, info
|
||||||
|
|
||||||
def _convert(self, frame):
|
|
||||||
frame = torch.from_numpy(frame).unsqueeze(0).float()
|
|
||||||
return frame
|
|
||||||
|
|
||||||
|
|
||||||
|
def _convert(self, frame):
|
||||||
|
frame = torch.from_numpy(frame).unsqueeze(0).float()
|
||||||
|
return frame
|
||||||
|
|
||||||
class ProcessFrame84(gym.ObservationWrapper):
|
class ProcessFrame84(gym.ObservationWrapper):
|
||||||
def __init__(self, env=None):
|
def __init__(self, env=None):
|
||||||
|
@ -256,4 +261,3 @@ class ProcessFrame84(gym.ObservationWrapper):
|
||||||
x_t = resized_screen[18:102, :]
|
x_t = resized_screen[18:102, :]
|
||||||
x_t = np.reshape(x_t, [84, 84])
|
x_t = np.reshape(x_t, [84, 84])
|
||||||
return x_t.astype(np.uint8)
|
return x_t.astype(np.uint8)
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
from .PrioritizedReplayMemory import PrioritizedReplayMemory
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from .PrioritizedReplayMemory import PrioritizedReplayMemory
|
||||||
|
|
||||||
Transition = namedtuple('Transition',
|
Transition = namedtuple('Transition',
|
||||||
('state', 'action', 'reward', 'next_state', 'done'))
|
('state', 'action', 'reward', 'next_state', 'done'))
|
||||||
|
|
||||||
|
|
||||||
class DQfDMemory(PrioritizedReplayMemory):
|
class DQfDMemory(PrioritizedReplayMemory):
|
||||||
def __init__(self, capacity, alpha, max_demo = -1):
|
def __init__(self, capacity, alpha, max_demo=-1):
|
||||||
assert max_demo <= capacity
|
assert max_demo <= capacity
|
||||||
super().__init__(capacity, alpha)
|
super().__init__(capacity, alpha)
|
||||||
self.demo_position = 0
|
self.demo_position = 0
|
||||||
|
@ -47,7 +47,8 @@ class DQfDMemory(PrioritizedReplayMemory):
|
||||||
idxes = self._sample_proportional(sample_size)
|
idxes = self._sample_proportional(sample_size)
|
||||||
step_idxes = []
|
step_idxes = []
|
||||||
for i in idxes:
|
for i in idxes:
|
||||||
# If the interval of experiences fall between demonstration and obtained, move it over to the demonstration half
|
# If the interval of experiences fall between demonstration and obtained,
|
||||||
|
# move it over to the demonstration half
|
||||||
if i < self.demo_position and i + steps > self.demo_position:
|
if i < self.demo_position and i + steps > self.demo_position:
|
||||||
diff = i + steps - self.demo_position
|
diff = i + steps - self.demo_position
|
||||||
step_idxes += range(i - diff, i + steps - diff)
|
step_idxes += range(i - diff, i + steps - diff)
|
||||||
|
|
|
@ -1,6 +1,4 @@
|
||||||
import random
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import torch
|
|
||||||
Transition = namedtuple('Transition',
|
Transition = namedtuple('Transition',
|
||||||
('state', 'action', 'reward', 'next_state', 'done'))
|
('state', 'action', 'reward', 'next_state', 'done'))
|
||||||
|
|
||||||
|
@ -39,7 +37,7 @@ class EpisodeMemory(object):
|
||||||
|
|
||||||
def recall(self):
|
def recall(self):
|
||||||
"""
|
"""
|
||||||
Return a list of the transitions with their
|
Return a list of the transitions with their
|
||||||
associated log-based probabilities.
|
associated log-based probabilities.
|
||||||
"""
|
"""
|
||||||
if len(self.memory) != len(self.log_probs):
|
if len(self.memory) != len(self.log_probs):
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
# From OpenAI Baselines https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py
|
# From OpenAI Baselines https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py
|
||||||
|
|
||||||
from .ReplayMemory import ReplayMemory
|
|
||||||
import operator
|
import operator
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numba import jit
|
from numba import jit
|
||||||
|
from .ReplayMemory import ReplayMemory
|
||||||
|
|
||||||
class SegmentTree(object):
|
class SegmentTree(object):
|
||||||
def __init__(self, capacity, operation, neutral_element):
|
def __init__(self, capacity, operation, neutral_element):
|
||||||
|
@ -34,7 +33,7 @@ class SegmentTree(object):
|
||||||
self._value = [neutral_element for _ in range(2 * capacity)]
|
self._value = [neutral_element for _ in range(2 * capacity)]
|
||||||
self._operation = operation
|
self._operation = operation
|
||||||
|
|
||||||
@jit(forceobj = True)
|
@jit(forceobj=True)
|
||||||
def _reduce_helper(self, start, end, node, node_start, node_end):
|
def _reduce_helper(self, start, end, node, node_start, node_end):
|
||||||
if start == node_start and end == node_end:
|
if start == node_start and end == node_end:
|
||||||
return self._value[node]
|
return self._value[node]
|
||||||
|
@ -50,7 +49,7 @@ class SegmentTree(object):
|
||||||
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
|
self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
|
||||||
)
|
)
|
||||||
|
|
||||||
@jit(forceobj = True)
|
@jit(forceobj=True)
|
||||||
def reduce(self, start=0, end=None):
|
def reduce(self, start=0, end=None):
|
||||||
"""Returns result of applying `self.operation`
|
"""Returns result of applying `self.operation`
|
||||||
to a contiguous subsequence of the array.
|
to a contiguous subsequence of the array.
|
||||||
|
@ -73,7 +72,7 @@ class SegmentTree(object):
|
||||||
end -= 1
|
end -= 1
|
||||||
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
|
return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
|
||||||
|
|
||||||
@jit(forceobj = True)
|
@jit(forceobj=True)
|
||||||
def __setitem__(self, idx, val):
|
def __setitem__(self, idx, val):
|
||||||
# index of the leaf
|
# index of the leaf
|
||||||
idx += self._capacity
|
idx += self._capacity
|
||||||
|
@ -86,7 +85,7 @@ class SegmentTree(object):
|
||||||
)
|
)
|
||||||
idx //= 2
|
idx //= 2
|
||||||
|
|
||||||
@jit(forceobj = True)
|
@jit(forceobj=True)
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
assert 0 <= idx < self._capacity
|
assert 0 <= idx < self._capacity
|
||||||
return self._value[self._capacity + idx]
|
return self._value[self._capacity + idx]
|
||||||
|
@ -100,12 +99,12 @@ class SumSegmentTree(SegmentTree):
|
||||||
neutral_element=0.0
|
neutral_element=0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@jit(forceobj = True)
|
@jit(forceobj=True)
|
||||||
def sum(self, start=0, end=None):
|
def sum(self, start=0, end=None):
|
||||||
"""Returns arr[start] + ... + arr[end]"""
|
"""Returns arr[start] + ... + arr[end]"""
|
||||||
return super(SumSegmentTree, self).reduce(start, end)
|
return super(SumSegmentTree, self).reduce(start, end)
|
||||||
|
|
||||||
@jit(forceobj = True, parallel = True)
|
@jit(forceobj=True, parallel=True)
|
||||||
def find_prefixsum_idx(self, prefixsum):
|
def find_prefixsum_idx(self, prefixsum):
|
||||||
"""Find the highest index `i` in the array such that
|
"""Find the highest index `i` in the array such that
|
||||||
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
|
sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
|
||||||
|
@ -140,7 +139,7 @@ class MinSegmentTree(SegmentTree):
|
||||||
neutral_element=float('inf')
|
neutral_element=float('inf')
|
||||||
)
|
)
|
||||||
|
|
||||||
@jit(forceobj = True)
|
@jit(forceobj=True)
|
||||||
def min(self, start=0, end=None):
|
def min(self, start=0, end=None):
|
||||||
"""Returns min(arr[start], ..., arr[end])"""
|
"""Returns min(arr[start], ..., arr[end])"""
|
||||||
return super(MinSegmentTree, self).reduce(start, end)
|
return super(MinSegmentTree, self).reduce(start, end)
|
||||||
|
@ -185,7 +184,7 @@ class PrioritizedReplayMemory(ReplayMemory):
|
||||||
self._it_sum[idx] = self._max_priority ** self._alpha
|
self._it_sum[idx] = self._max_priority ** self._alpha
|
||||||
self._it_min[idx] = self._max_priority ** self._alpha
|
self._it_min[idx] = self._max_priority ** self._alpha
|
||||||
|
|
||||||
@jit(forceobj = True)
|
@jit(forceobj=True)
|
||||||
def _sample_proportional(self, batch_size):
|
def _sample_proportional(self, batch_size):
|
||||||
res = []
|
res = []
|
||||||
p_total = self._it_sum.sum(0, len(self.memory) - 1)
|
p_total = self._it_sum.sum(0, len(self.memory) - 1)
|
||||||
|
@ -294,7 +293,7 @@ class PrioritizedReplayMemory(ReplayMemory):
|
||||||
batch = list(zip(*encoded_sample, weights, step_idxes))
|
batch = list(zip(*encoded_sample, weights, step_idxes))
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
@jit(forceobj = True)
|
@jit(forceobj=True)
|
||||||
def update_priorities(self, idxes, priorities):
|
def update_priorities(self, idxes, priorities):
|
||||||
"""
|
"""
|
||||||
Update priorities of sampled transitions.
|
Update priorities of sampled transitions.
|
||||||
|
@ -320,4 +319,3 @@ class PrioritizedReplayMemory(ReplayMemory):
|
||||||
self._it_min[idx] = priority ** self._alpha
|
self._it_min[idx] = priority ** self._alpha
|
||||||
|
|
||||||
self._max_priority = max(self._max_priority, priority)
|
self._max_priority = max(self._max_priority, priority)
|
||||||
|
|
||||||
|
|
|
@ -80,7 +80,7 @@ class ReplayMemory(object):
|
||||||
The number of observations after the one selected to sample.
|
The number of observations after the one selected to sample.
|
||||||
"""
|
"""
|
||||||
idxes = random.sample(
|
idxes = random.sample(
|
||||||
range(len(self.memory) - steps),
|
range(len(self.memory) - steps),
|
||||||
batch_size // steps
|
batch_size // steps
|
||||||
)
|
)
|
||||||
step_idxes = []
|
step_idxes = []
|
||||||
|
@ -106,11 +106,9 @@ class ReplayMemory(object):
|
||||||
def __reversed__(self):
|
def __reversed__(self):
|
||||||
return reversed(self.memory)
|
return reversed(self.memory)
|
||||||
|
|
||||||
def zip_batch(minibatch, priority = False, want_indices = False):
|
def zip_batch(minibatch, priority=False):
|
||||||
if priority:
|
if priority:
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch, weights, indexes = zip(*minibatch)
|
state_batch, action_batch, reward_batch, next_state_batch, done_batch, weights, indexes = zip(*minibatch)
|
||||||
elif want_indices:
|
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch, indexes = zip(*minibatch)
|
|
||||||
else:
|
else:
|
||||||
state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*minibatch)
|
state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*minibatch)
|
||||||
|
|
||||||
|
@ -122,7 +120,5 @@ def zip_batch(minibatch, priority = False, want_indices = False):
|
||||||
|
|
||||||
if priority:
|
if priority:
|
||||||
return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, weights, indexes
|
return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, weights, indexes
|
||||||
elif want_indices:
|
|
||||||
return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, indexes
|
|
||||||
else:
|
else:
|
||||||
return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch
|
return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch
|
||||||
|
|
|
@ -5,115 +5,32 @@ from copy import deepcopy
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
class EnvironmentEpisode(mp.Process):
|
class EnvironmentEpisode(mp.Process):
|
||||||
def __init__(self, env, actor, config, logger = None, name = ""):
|
def __init__(self, env, actor, config, logger=None, name=""):
|
||||||
super(EnvironmentEpisode, self).__init__()
|
super(EnvironmentEpisode, self).__init__()
|
||||||
self.env = env
|
self.env = env
|
||||||
self.actor = actor
|
self.actor = actor
|
||||||
self.config = deepcopy(config)
|
self.config = deepcopy(config)
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.name = name
|
self.name = name
|
||||||
self.episode_num = 1
|
self.episode_num = 1
|
||||||
|
|
||||||
def run(self, printstat = False, memory = None):
|
def run(self, printstat=False, memory=None):
|
||||||
state = self.env.reset()
|
state = self.env.reset()
|
||||||
done = False
|
done = False
|
||||||
episode_reward = 0
|
episode_reward = 0
|
||||||
while not done:
|
while not done:
|
||||||
action = self.actor.act(state)
|
action = self.actor.act(state)
|
||||||
next_state, reward, done, _ = self.env.step(action)
|
next_state, reward, done, _ = self.env.step(action)
|
||||||
|
|
||||||
episode_reward = episode_reward + reward
|
episode_reward = episode_reward + reward
|
||||||
if memory is not None:
|
if memory is not None:
|
||||||
memory.put((state, action, reward, next_state, done))
|
memory.put((state, action, reward, next_state, done))
|
||||||
state = next_state
|
state = next_state
|
||||||
|
|
||||||
if printstat:
|
if printstat:
|
||||||
print("episode: {}/{}, score: {}"
|
print("episode: {}/{}, score: {}"
|
||||||
.format(self.episode_num, self.config['total_training_episodes'], episode_reward))
|
.format(self.episode_num, self.config['total_training_episodes'], episode_reward))
|
||||||
if self.logger is not None:
|
if self.logger is not None:
|
||||||
self.logger.append(self.name + '/EpisodeReward', episode_reward)
|
self.logger.append(self.name + '/EpisodeReward', episode_reward)
|
||||||
|
|
||||||
self.episode_num += 1
|
self.episode_num += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# from copy import deepcopy
|
|
||||||
# import torch.multiprocessing as mp
|
|
||||||
# from ctypes import *
|
|
||||||
# import rltorch.log
|
|
||||||
|
|
||||||
# def envepisode(actor, env, episode_num, config, runcondition, memoryqueue = None, logqueue = None, name = ""):
|
|
||||||
# # Wait for signal to start running through the environment
|
|
||||||
# while runcondition.wait():
|
|
||||||
# # Start a logger to log the rewards
|
|
||||||
# logger = rltorch.log.Logger()
|
|
||||||
# state = env.reset()
|
|
||||||
# episode_reward = 0
|
|
||||||
# done = False
|
|
||||||
# while not done:
|
|
||||||
# action = actor.act(state)
|
|
||||||
# next_state, reward, done, _ = env.step(action)
|
|
||||||
|
|
||||||
# episode_reward += reward
|
|
||||||
# if memoryqueue is not None:
|
|
||||||
# memoryqueue.put((state, action, reward, next_state, done))
|
|
||||||
|
|
||||||
# state = next_state
|
|
||||||
|
|
||||||
# if done:
|
|
||||||
# with episode_num.get_lock():
|
|
||||||
# if episode_num.value % config['print_stat_n_eps'] == 0:
|
|
||||||
# print("episode: {}/{}, score: {}"
|
|
||||||
# .format(episode_num.value, config['total_training_episodes'], episode_reward))
|
|
||||||
|
|
||||||
# if logger is not None:
|
|
||||||
# logger.append(name + '/EpisodeReward', episode_reward)
|
|
||||||
# episode_reward = 0
|
|
||||||
# state = env.reset()
|
|
||||||
# with episode_num.get_lock():
|
|
||||||
# episode_num.value += 1
|
|
||||||
|
|
||||||
# logqueue.put(logger)
|
|
||||||
|
|
||||||
# class EnvironmentRun():
|
|
||||||
# def __init__(self, env_func, actor, config, memory = None, name = ""):
|
|
||||||
# self.config = deepcopy(config)
|
|
||||||
# self.memory = memory
|
|
||||||
# self.episode_num = mp.Value(c_uint)
|
|
||||||
# self.runcondition = mp.Event()
|
|
||||||
# # Interestingly enough, there isn't a good reliable way to know how many states an episode will have
|
|
||||||
# # Perhaps we can share a uint to keep track...
|
|
||||||
# self.memory_queue = mp.Queue(maxsize = config['replay_skip'] + 1)
|
|
||||||
# self.logqueue = mp.Queue(maxsize = 1)
|
|
||||||
# with self.episode_num.get_lock():
|
|
||||||
# self.episode_num.value = 1
|
|
||||||
# self.runner = mp.Process(target=envrun,
|
|
||||||
# args=(actor, env_func, self.episode_num, config, self.runcondition),
|
|
||||||
# kwargs = {'iterations': config['replay_skip'] + 1,
|
|
||||||
# 'memoryqueue' : self.memory_queue, 'logqueue' : self.logqueue, 'name' : name})
|
|
||||||
# self.runner.start()
|
|
||||||
|
|
||||||
# def run(self):
|
|
||||||
# self.runcondition.set()
|
|
||||||
|
|
||||||
# def join(self):
|
|
||||||
# self._sync_memory()
|
|
||||||
# if self.logwriter is not None:
|
|
||||||
# self.logwriter.write(self._get_reward_logger())
|
|
||||||
|
|
||||||
# def sync_memory(self):
|
|
||||||
# if self.memory is not None:
|
|
||||||
# for i in range(self.config['replay_skip'] + 1):
|
|
||||||
# self.memory.append(*self.memory_queue.get())
|
|
||||||
|
|
||||||
# def get_reward_logger(self):
|
|
||||||
# return self.logqueue.get()
|
|
||||||
|
|
||||||
# def terminate(self):
|
|
||||||
# self.runner.terminate()
|
|
||||||
|
|
||||||
|
|
|
@ -1,40 +1,40 @@
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from ctypes import c_uint
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from ctypes import *
|
|
||||||
import rltorch.log
|
import rltorch.log
|
||||||
|
|
||||||
def envrun(actor, env, episode_num, config, runcondition, iterations = 1, memoryqueue = None, logqueue = None, name = ""):
|
def envrun(actor, env, episode_num, config, runcondition, iterations=1, memoryqueue=None, logqueue=None, name=""):
|
||||||
state = env.reset()
|
state = env.reset()
|
||||||
episode_reward = 0
|
episode_reward = 0
|
||||||
# Wait for signal to start running through the environment
|
# Wait for signal to start running through the environment
|
||||||
while runcondition.wait():
|
while runcondition.wait():
|
||||||
# Start a logger to log the rewards
|
# Start a logger to log the rewards
|
||||||
logger = rltorch.log.Logger() if logqueue is not None else None
|
logger = rltorch.log.Logger() if logqueue is not None else None
|
||||||
for _ in range(iterations):
|
for _ in range(iterations):
|
||||||
action = actor.act(state)
|
action = actor.act(state)
|
||||||
next_state, reward, done, _ = env.step(action)
|
next_state, reward, done, _ = env.step(action)
|
||||||
|
|
||||||
episode_reward += reward
|
episode_reward += reward
|
||||||
if memoryqueue is not None:
|
if memoryqueue is not None:
|
||||||
memoryqueue.put((state, action, reward, next_state, done))
|
memoryqueue.put((state, action, reward, next_state, done))
|
||||||
|
|
||||||
state = next_state
|
state = next_state
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
with episode_num.get_lock():
|
with episode_num.get_lock():
|
||||||
if episode_num.value % config['print_stat_n_eps'] == 0:
|
if episode_num.value % config['print_stat_n_eps'] == 0:
|
||||||
print("episode: {}/{}, score: {}"
|
print("episode: {}/{}, score: {}"
|
||||||
.format(episode_num.value, config['total_training_episodes'], episode_reward))
|
.format(episode_num.value, config['total_training_episodes'], episode_reward))
|
||||||
|
|
||||||
if logger is not None:
|
if logger is not None:
|
||||||
logger.append(name + '/EpisodeReward', episode_reward)
|
logger.append(name + '/EpisodeReward', episode_reward)
|
||||||
episode_reward = 0
|
episode_reward = 0
|
||||||
state = env.reset()
|
state = env.reset()
|
||||||
with episode_num.get_lock():
|
with episode_num.get_lock():
|
||||||
episode_num.value += 1
|
episode_num.value += 1
|
||||||
|
|
||||||
if logqueue is not None:
|
if logqueue is not None:
|
||||||
logqueue.put(logger)
|
logqueue.put(logger)
|
||||||
|
|
||||||
class EnvironmentRun():
|
class EnvironmentRun():
|
||||||
def __init__(self, env, actor, config, memory = None, logwriter = None, name = ""):
|
def __init__(self, env, actor, config, memory = None, logwriter = None, name = ""):
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
|
from copy import deepcopy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from .Network import Network
|
from .Network import Network
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
# [TODO] Should we torch.no_grad the __call__?
|
# [TODO] Should we torch.no_grad the __call__?
|
||||||
# What if we want to sometimes do gradient descent as well?
|
# What if we want to sometimes do gradient descent as well?
|
||||||
|
@ -11,8 +12,8 @@ class ESNetwork(Network):
|
||||||
|
|
||||||
Notes
|
Notes
|
||||||
-----
|
-----
|
||||||
Derived from the paper
|
Derived from the paper
|
||||||
Evolutionary Strategies
|
Evolutionary Strategies
|
||||||
(https://arxiv.org/abs/1703.03864)
|
(https://arxiv.org/abs/1703.03864)
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
|
@ -38,7 +39,7 @@ class ESNetwork(Network):
|
||||||
name
|
name
|
||||||
For use in logger to differentiate in analysis.
|
For use in logger to differentiate in analysis.
|
||||||
"""
|
"""
|
||||||
def __init__(self, model, optimizer, population_size, fitness_fn, config, sigma = 0.05, device = None, logger = None, name = ""):
|
def __init__(self, model, optimizer, population_size, fitness_fn, config, sigma=0.05, device=None, logger=None, name=""):
|
||||||
super(ESNetwork, self).__init__(model, optimizer, config, device, logger, name)
|
super(ESNetwork, self).__init__(model, optimizer, config, device, logger, name)
|
||||||
self.population_size = population_size
|
self.population_size = population_size
|
||||||
self.fitness = fitness_fn
|
self.fitness = fitness_fn
|
||||||
|
@ -49,7 +50,7 @@ class ESNetwork(Network):
|
||||||
"""
|
"""
|
||||||
Notes
|
Notes
|
||||||
-----
|
-----
|
||||||
Since gradients aren't going to be computed in the
|
Since gradients aren't going to be computed in the
|
||||||
traditional fashion, there is no need to keep
|
traditional fashion, there is no need to keep
|
||||||
track of the computations performed on the
|
track of the computations performed on the
|
||||||
tensors.
|
tensors.
|
||||||
|
@ -64,7 +65,11 @@ class ESNetwork(Network):
|
||||||
white_noise_dict = {}
|
white_noise_dict = {}
|
||||||
noise_dict = {}
|
noise_dict = {}
|
||||||
for key in model_dict.keys():
|
for key in model_dict.keys():
|
||||||
white_noise_dict[key] = torch.randn(self.population_size, *model_dict[key].shape, device = self.device)
|
white_noise_dict[key] = torch.randn(
|
||||||
|
self.population_size,
|
||||||
|
*model_dict[key].shape,
|
||||||
|
device=self.device
|
||||||
|
)
|
||||||
noise_dict[key] = self.sigma * white_noise_dict[key]
|
noise_dict[key] = self.sigma * white_noise_dict[key]
|
||||||
return white_noise_dict, noise_dict
|
return white_noise_dict, noise_dict
|
||||||
|
|
||||||
|
@ -87,7 +92,7 @@ class ESNetwork(Network):
|
||||||
|
|
||||||
This is calculated by evaluating the fitness of multiple
|
This is calculated by evaluating the fitness of multiple
|
||||||
networks according to the fitness function specified in
|
networks according to the fitness function specified in
|
||||||
the class.
|
the class.
|
||||||
"""
|
"""
|
||||||
## Generate Noise
|
## Generate Noise
|
||||||
white_noise_dict, noise_dict = self._generate_noise_dicts()
|
white_noise_dict, noise_dict = self._generate_noise_dicts()
|
||||||
|
@ -96,7 +101,10 @@ class ESNetwork(Network):
|
||||||
candidate_solutions = self._generate_candidate_solutions(noise_dict)
|
candidate_solutions = self._generate_candidate_solutions(noise_dict)
|
||||||
|
|
||||||
## Calculate fitness then mean shift, scale
|
## Calculate fitness then mean shift, scale
|
||||||
fitness_values = torch.tensor([self.fitness(x, *args) for x in candidate_solutions], device = self.device)
|
fitness_values = torch.tensor(
|
||||||
|
[self.fitness(x, *args) for x in candidate_solutions],
|
||||||
|
device=self.device
|
||||||
|
)
|
||||||
if self.logger is not None:
|
if self.logger is not None:
|
||||||
self.logger.append(self.name + "/" + "fitness_value", fitness_values.mean().item())
|
self.logger.append(self.name + "/" + "fitness_value", fitness_values.mean().item())
|
||||||
fitness_values = (fitness_values - fitness_values.mean()) / (fitness_values.std() + np.finfo('float').eps)
|
fitness_values = (fitness_values - fitness_values.mean()) / (fitness_values.std() + np.finfo('float').eps)
|
||||||
|
@ -107,4 +115,4 @@ class ESNetwork(Network):
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
noise_dim_n = len(white_noise_dict[name].shape)
|
noise_dim_n = len(white_noise_dict[name].shape)
|
||||||
dim = np.repeat(1, noise_dim_n - 1).tolist() if noise_dim_n > 0 else []
|
dim = np.repeat(1, noise_dim_n - 1).tolist() if noise_dim_n > 0 else []
|
||||||
param.grad = (white_noise_dict[name] * fitness_values.float().reshape(self.population_size, *dim)).mean(0) / self.sigma
|
param.grad = (white_noise_dict[name] * fitness_values.float().reshape(self.population_size, *dim)).mean(0) / self.sigma
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
|
from copy import deepcopy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from .Network import Network
|
|
||||||
from copy import deepcopy
|
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import functools
|
from .Network import Network
|
||||||
|
|
||||||
class fn_copy:
|
class fn_copy:
|
||||||
def __init__(self, fn, args):
|
def __init__(self, fn, args):
|
||||||
|
@ -20,14 +19,15 @@ class ESNetworkMP(Network):
|
||||||
fitness_fun := model, *args -> fitness_value (float)
|
fitness_fun := model, *args -> fitness_value (float)
|
||||||
We wish to find a model that maximizes the fitness function
|
We wish to find a model that maximizes the fitness function
|
||||||
"""
|
"""
|
||||||
def __init__(self, model, optimizer, population_size, fitness_fn, config, sigma = 0.05, device = None, logger = None, name = ""):
|
def __init__(self, model, optimizer, population_size, fitness_fn, config, sigma=0.05, device=None, logger=None, name=""):
|
||||||
super(ESNetworkMP, self).__init__(model, optimizer, config, device, logger, name)
|
super(ESNetworkMP, self).__init__(model, optimizer, config, device, logger, name)
|
||||||
self.population_size = population_size
|
self.population_size = population_size
|
||||||
self.fitness = fitness_fn
|
self.fitness = fitness_fn
|
||||||
self.sigma = sigma
|
self.sigma = sigma
|
||||||
assert self.sigma > 0
|
assert self.sigma > 0
|
||||||
mp_ctx = mp.get_context("spawn")
|
mp_ctx = mp.get_context("spawn")
|
||||||
self.pool = mp_ctx.Pool(processes=2) #[TODO] Probably should make number of processes a config variable
|
#[TODO] Probably should make number of processes a config variable
|
||||||
|
self.pool = mp_ctx.Pool(processes=2)
|
||||||
|
|
||||||
# We're not going to be calculating gradients in the traditional way
|
# We're not going to be calculating gradients in the traditional way
|
||||||
# So there's no need to waste computation time keeping track
|
# So there's no need to waste computation time keeping track
|
||||||
|
@ -42,7 +42,11 @@ class ESNetworkMP(Network):
|
||||||
white_noise_dict = {}
|
white_noise_dict = {}
|
||||||
noise_dict = {}
|
noise_dict = {}
|
||||||
for key in model_dict.keys():
|
for key in model_dict.keys():
|
||||||
white_noise_dict[key] = torch.randn(self.population_size, *model_dict[key].shape, device = self.device)
|
white_noise_dict[key] = torch.randn(
|
||||||
|
self.population_size,
|
||||||
|
*model_dict[key].shape,
|
||||||
|
device=self.device
|
||||||
|
)
|
||||||
noise_dict[key] = self.sigma * white_noise_dict[key]
|
noise_dict[key] = self.sigma * white_noise_dict[key]
|
||||||
return white_noise_dict, noise_dict
|
return white_noise_dict, noise_dict
|
||||||
|
|
||||||
|
@ -67,7 +71,10 @@ class ESNetworkMP(Network):
|
||||||
candidate_solutions = self._generate_candidate_solutions(noise_dict)
|
candidate_solutions = self._generate_candidate_solutions(noise_dict)
|
||||||
|
|
||||||
## Calculate fitness then mean shift, scale
|
## Calculate fitness then mean shift, scale
|
||||||
fitness_values = torch.tensor(list(self.pool.map(fn_copy(self.fitness, args), candidate_solutions)), device = self.device)
|
fitness_values = torch.tensor(
|
||||||
|
list(self.pool.map(fn_copy(self.fitness, args), candidate_solutions)),
|
||||||
|
device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
if self.logger is not None:
|
if self.logger is not None:
|
||||||
self.logger.append(self.name + "/" + "fitness_value", fitness_values.mean().item())
|
self.logger.append(self.name + "/" + "fitness_value", fitness_values.mean().item())
|
||||||
|
@ -87,4 +94,4 @@ class ESNetworkMP(Network):
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
self_dict = self.__dict__.copy()
|
self_dict = self.__dict__.copy()
|
||||||
del self_dict['pool']
|
del self_dict['pool']
|
||||||
return self_dict
|
return self_dict
|
||||||
|
|
|
@ -17,12 +17,16 @@ class Network:
|
||||||
name
|
name
|
||||||
For use in logger to differentiate in analysis.
|
For use in logger to differentiate in analysis.
|
||||||
"""
|
"""
|
||||||
def __init__(self, model, optimizer, config, device = None, logger = None, name = ""):
|
def __init__(self, model, optimizer, config, device=None, logger=None, name=""):
|
||||||
self.model = model
|
self.model = model
|
||||||
if 'weight_decay' in config:
|
if 'weight_decay' in config:
|
||||||
self.optimizer = optimizer(model.parameters(), lr = config['learning_rate'], weight_decay = config['weight_decay'])
|
self.optimizer = optimizer(
|
||||||
|
model.parameters(),
|
||||||
|
lr=config['learning_rate'],
|
||||||
|
weight_decay=config['weight_decay']
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.optimizer = optimizer(model.parameters(), lr = config['learning_rate'])
|
self.optimizer = optimizer(model.parameters(), lr=config['learning_rate'])
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.name = name
|
self.name = name
|
||||||
self.device = device
|
self.device = device
|
||||||
|
@ -32,7 +36,7 @@ class Network:
|
||||||
def __call__(self, *args):
|
def __call__(self, *args):
|
||||||
return self.model(*args)
|
return self.model(*args)
|
||||||
|
|
||||||
def clamp_gradients(self, x = 1):
|
def clamp_gradients(self, x=1):
|
||||||
"""
|
"""
|
||||||
Forcing gradients to stay within a certain interval
|
Forcing gradients to stay within a certain interval
|
||||||
by setting it to the bound if it goes over it.
|
by setting it to the bound if it goes over it.
|
||||||
|
|
|
@ -1,67 +1,67 @@
|
||||||
|
import math
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
|
||||||
|
|
||||||
# This class utilizes this property of the normal distribution
|
# This class utilizes this property of the normal distribution
|
||||||
# N(mu, sigma) = mu + sigma * N(0, 1)
|
# N(mu, sigma) = mu + sigma * N(0, 1)
|
||||||
class NoisyLinear(nn.Linear):
|
class NoisyLinear(nn.Linear):
|
||||||
"""
|
|
||||||
Draws the parameters of nn.Linear from a normal distribution.
|
|
||||||
The parameters of the normal distribution are registered as
|
|
||||||
learnable parameters in the neural network.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
in_features
|
|
||||||
Size of each input sample.
|
|
||||||
out_features
|
|
||||||
Size of each output sample.
|
|
||||||
sigma_init
|
|
||||||
The starting standard deviation of guassian noise.
|
|
||||||
bias
|
|
||||||
If set to False, the layer will not
|
|
||||||
learn an additive bias.
|
|
||||||
Default: True
|
|
||||||
"""
|
|
||||||
def __init__(self, in_features, out_features, sigma_init = 0.017, bias = True):
|
|
||||||
super(NoisyLinear, self).__init__(in_features, out_features, bias = bias)
|
|
||||||
# One of the parameters the network is going to tune is the
|
|
||||||
# standard deviation of the gaussian noise on the weights
|
|
||||||
self.sigma_weight = nn.Parameter(torch.Tensor(out_features, in_features).fill_(sigma_init))
|
|
||||||
# Reserve space for N(0, 1) of weights in the forward() call
|
|
||||||
self.register_buffer("s_normal_weight", torch.zeros(out_features, in_features))
|
|
||||||
if bias:
|
|
||||||
# If a bias exists, then we manipulate the standard deviation of the
|
|
||||||
# gaussion noise on them as well
|
|
||||||
self.sigma_bias = nn.Parameter(torch.Tensor(out_features).fill_(sigma_init))
|
|
||||||
# Reserve space for N(0, 1) of bias in the foward() call
|
|
||||||
self.register_buffer("s_normal_bias", torch.zeros(out_features))
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
|
||||||
std = math.sqrt(3 / self.in_features)
|
|
||||||
nn.init.uniform_(self.weight, -std, std)
|
|
||||||
nn.init.uniform_(self.bias, -std, std)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
r"""
|
|
||||||
Calculates the output :math:`y` through the following:
|
|
||||||
|
|
||||||
:math:`sigma \sim N(mu_1, std_1)`
|
|
||||||
|
|
||||||
:math:`bias \sim N(mu_2, std_2)`
|
|
||||||
|
|
||||||
:math:`y = sigma \cdot x + bias`
|
|
||||||
"""
|
"""
|
||||||
|
Draws the parameters of nn.Linear from a normal distribution.
|
||||||
|
The parameters of the normal distribution are registered as
|
||||||
|
learnable parameters in the neural network.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
in_features
|
||||||
|
Size of each input sample.
|
||||||
|
out_features
|
||||||
|
Size of each output sample.
|
||||||
|
sigma_init
|
||||||
|
The starting standard deviation of guassian noise.
|
||||||
|
bias
|
||||||
|
If set to False, the layer will not
|
||||||
|
learn an additive bias.
|
||||||
|
Default: True
|
||||||
|
"""
|
||||||
|
def __init__(self, in_features, out_features, sigma_init=0.017, bias=True):
|
||||||
|
super(NoisyLinear, self).__init__(in_features, out_features, bias=bias)
|
||||||
|
# One of the parameters the network is going to tune is the
|
||||||
|
# standard deviation of the gaussian noise on the weights
|
||||||
|
self.sigma_weight = nn.Parameter(torch.Tensor(out_features, in_features).fill_(sigma_init))
|
||||||
|
# Reserve space for N(0, 1) of weights in the forward() call
|
||||||
|
self.register_buffer("s_normal_weight", torch.zeros(out_features, in_features))
|
||||||
|
if bias:
|
||||||
|
# If a bias exists, then we manipulate the standard deviation of the
|
||||||
|
# gaussion noise on them as well
|
||||||
|
self.sigma_bias = nn.Parameter(torch.Tensor(out_features).fill_(sigma_init))
|
||||||
|
# Reserve space for N(0, 1) of bias in the foward() call
|
||||||
|
self.register_buffer("s_normal_bias", torch.zeros(out_features))
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
std = math.sqrt(3 / self.in_features)
|
||||||
|
nn.init.uniform_(self.weight, -std, std)
|
||||||
|
nn.init.uniform_(self.bias, -std, std)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
r"""
|
||||||
|
Calculates the output :math:`y` through the following:
|
||||||
|
|
||||||
|
:math:`sigma \sim N(mu_1, std_1)`
|
||||||
|
|
||||||
|
:math:`bias \sim N(mu_2, std_2)`
|
||||||
|
|
||||||
|
:math:`y = sigma \cdot x + bias`
|
||||||
|
"""
|
||||||
# Fill s_normal_weight with values from the standard normal distribution
|
# Fill s_normal_weight with values from the standard normal distribution
|
||||||
self.s_normal_weight.normal_()
|
self.s_normal_weight.normal_()
|
||||||
weight_noise = self.sigma_weight * self.s_normal_weight.clone().requires_grad_()
|
weight_noise = self.sigma_weight * self.s_normal_weight.clone().requires_grad_()
|
||||||
|
|
||||||
bias = None
|
bias = None
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
# Fill s_normal_bias with values from standard normal
|
# Fill s_normal_bias with values from standard normal
|
||||||
self.s_normal_bias.normal_()
|
self.s_normal_bias.normal_()
|
||||||
bias = self.bias + self.sigma_bias * self.s_normal_bias.clone().requires_grad_()
|
bias = self.bias + self.sigma_bias * self.s_normal_bias.clone().requires_grad_()
|
||||||
|
|
||||||
return F.linear(x, self.weight + weight_noise, bias)
|
return F.linear(x, self.weight + weight_noise, bias)
|
||||||
|
|
|
@ -11,7 +11,7 @@ class TargetNetwork:
|
||||||
device
|
device
|
||||||
The device to put the cloned parameters in.
|
The device to put the cloned parameters in.
|
||||||
"""
|
"""
|
||||||
def __init__(self, network, device = None):
|
def __init__(self, network, device=None):
|
||||||
self.model = network.model
|
self.model = network.model
|
||||||
self.target_model = deepcopy(network.model)
|
self.target_model = deepcopy(network.model)
|
||||||
if device is not None:
|
if device is not None:
|
||||||
|
@ -37,7 +37,8 @@ class TargetNetwork:
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
tau : number
|
tau : number
|
||||||
A number between 0-1 which indicates the proportion of the originator and clone in the new clone.
|
A number between 0-1 which indicates
|
||||||
|
the proportion of the originator and clone in the new clone.
|
||||||
"""
|
"""
|
||||||
assert isinstance(tau, float)
|
assert isinstance(tau, float)
|
||||||
assert 0.0 < tau <= 1.0
|
assert 0.0 < tau <= 1.0
|
||||||
|
@ -45,4 +46,4 @@ class TargetNetwork:
|
||||||
target_state = self.target_model.state_dict()
|
target_state = self.target_model.state_dict()
|
||||||
for grad_index, grad in model_state.items():
|
for grad_index, grad in model_state.items():
|
||||||
target_state[grad_index].copy_((1 - tau) * target_state[grad_index] + tau * grad)
|
target_state[grad_index].copy_((1 - tau) * target_state[grad_index] + tau * grad)
|
||||||
self.target_model.load_state_dict(target_state)
|
self.target_model.load_state_dict(target_state)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from .ESNetwork import *
|
from .ESNetwork import ESNetwork
|
||||||
from .ESNetworkMP import *
|
from .ESNetworkMP import ESNetworkMP
|
||||||
from .Network import *
|
from .Network import Network
|
||||||
from .NoisyLinear import *
|
from .NoisyLinear import NoisyLinear
|
||||||
from .TargetNetwork import *
|
from .TargetNetwork import TargetNetwork
|
|
@ -36,4 +36,3 @@ class ExponentialScheduler(Scheduler):
|
||||||
return self.initial_value * (self.base ** (self.current_iteration - 1))
|
return self.initial_value * (self.base ** (self.current_iteration - 1))
|
||||||
else:
|
else:
|
||||||
return self.end_value
|
return self.end_value
|
||||||
|
|
||||||
|
|
|
@ -7,4 +7,4 @@ class Scheduler():
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
raise NotImplementedError("Scheduler does not have it's function to create a value implemented")
|
raise NotImplementedError("__next__ not implemented in Scheduler")
|
||||||
|
|
Loading…
Reference in a new issue