Implemented Schedulers and Prioritized Replay
This commit is contained in:
parent
8c78f47c0c
commit
013d40a4f9
10 changed files with 348 additions and 12 deletions
|
@ -1,7 +1,9 @@
|
|||
import collections
|
||||
import rltorch.memory as M
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import numpy as np
|
||||
|
||||
class DQNAgent:
|
||||
def __init__(self, net , memory, config, target_net = None, logger = None):
|
||||
|
@ -14,9 +16,16 @@ class DQNAgent:
|
|||
def learn(self):
|
||||
if len(self.memory) < self.config['batch_size']:
|
||||
return
|
||||
|
||||
minibatch = self.memory.sample(self.config['batch_size'])
|
||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch)
|
||||
|
||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||
weight_importance = self.config['prioritized_replay_weight_importance']
|
||||
# If it's a scheduler then get the next value by calling next, otherwise just use it's value
|
||||
beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
|
||||
minibatch = self.memory.sample(self.config['batch_size'], beta = beta)
|
||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
|
||||
else:
|
||||
minibatch = self.memory.sample(self.config['batch_size'])
|
||||
state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch)
|
||||
|
||||
# Send to their appropriate devices
|
||||
state_batch = state_batch.to(self.net.device)
|
||||
|
@ -44,7 +53,10 @@ class DQNAgent:
|
|||
|
||||
expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
|
||||
|
||||
loss = F.mse_loss(obtained_values, expected_values)
|
||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||
loss = (torch.as_tensor(importance_weights) * (obtained_values - expected_values)**2).mean()
|
||||
else:
|
||||
loss = F.mse_loss(obtained_values, expected_values)
|
||||
|
||||
if self.logger is not None:
|
||||
self.logger.append("Loss", loss.item())
|
||||
|
@ -59,3 +71,9 @@ class DQNAgent:
|
|||
self.target_net.partial_sync(self.config['target_sync_tau'])
|
||||
else:
|
||||
self.target_net.sync()
|
||||
|
||||
if (isinstance(self.memory, M.PrioritizedReplayMemory)):
|
||||
td_error = (obtained_values - expected_values).detach().abs()
|
||||
self.memory.update_priorities(batch_indexes, td_error)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue