From 7aa698c349e84740c846017d7b81665328d557db Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Fri, 13 Sep 2019 19:49:04 -0400 Subject: [PATCH] Added save and load functionality --- rltorch/agents/DQNAgent.py | 11 ++++++++++- rltorch/agents/QEPAgent.py | 18 ++++++++++++++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/rltorch/agents/DQNAgent.py b/rltorch/agents/DQNAgent.py index 3f20b52..6f47c05 100644 --- a/rltorch/agents/DQNAgent.py +++ b/rltorch/agents/DQNAgent.py @@ -4,6 +4,7 @@ import torch import torch.nn.functional as F from copy import deepcopy import numpy as np +from pathlib import Path class DQNAgent: def __init__(self, net , memory, config, target_net = None, logger = None): @@ -12,6 +13,12 @@ class DQNAgent: self.memory = memory self.config = deepcopy(config) self.logger = logger + def save(self, file_location): + torch.save(self.net.model.state_dict(), file_location) + def load(self, file_location): + self.net.model.state_dict(torch.load(file_location)) + self.net.model.to(self.net.device) + self.target_net.sync() def learn(self, logger = None): if len(self.memory) < self.config['batch_size']: @@ -57,8 +64,10 @@ class DQNAgent: # If we're sampling by TD error, multiply loss by a importance weight which helps decrease overfitting if (isinstance(self.memory, M.PrioritizedReplayMemory)): - loss = (torch.as_tensor(importance_weights, device = self.net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean() + # loss = (torch.as_tensor(importance_weights, device = self.net.device) * F.smooth_l1_loss(obtained_values, expected_values, reduction = 'none').squeeze(1)).mean() + loss = (torch.as_tensor(importance_weights, device = self.net.device) * ((obtained_values - expected_values)**2).squeeze(1)).mean() else: + # loss = F.smooth_l1_loss(obtained_values, expected_values) loss = F.mse_loss(obtained_values, expected_values) if self.logger is not None: diff --git a/rltorch/agents/QEPAgent.py b/rltorch/agents/QEPAgent.py index 6040e28..a21d886 100644 --- a/rltorch/agents/QEPAgent.py +++ b/rltorch/agents/QEPAgent.py @@ -3,6 +3,7 @@ import collections import numpy as np import torch from torch.distributions import Categorical +import torch.nn.functional as F import rltorch import rltorch.memory as M @@ -20,6 +21,19 @@ class QEPAgent: self.config = deepcopy(config) self.logger = logger self.policy_skip = 4 + + def save(self, file_location): + torch.save({ + 'policy': self.policy_net.model.state_dict(), + 'value': self.value_net.model.state_dict() + }, file_location) + def load(self, file_location): + checkpoint = torch.load(file_location) + self.value_net.model.state_dict(checkpoint['value']) + self.value_net.model.to(self.value_net.device) + self.policy_net.model.state_dict(checkpoint['policy']) + self.policy_net.model.to(self.policy_net.device) + self.target_net.sync() def fitness(self, policy_net, value_net, state_batch): batch_size = len(state_batch) @@ -37,7 +51,7 @@ class QEPAgent: value_importance = 1 - entropy_importance # entropy_loss = (action_probabilities * torch.log2(action_probabilities)).sum(1) # Standard entropy loss from information theory - entropy_loss = (action_probabilities - torch.tensor(1 / action_size).repeat(len(state_batch), action_size)).abs().sum(1) + entropy_loss = (action_probabilities - torch.tensor(1 / action_size, device = state_batch.device).repeat(len(state_batch), action_size)).abs().sum(1) return (entropy_importance * entropy_loss - value_importance * obtained_values).mean().item() @@ -116,6 +130,6 @@ class QEPAgent: self.policy_net.calc_gradients(self.target_value_net, state_batch) else: self.policy_net.calc_gradients(self.value_net, state_batch) - # self.policy_net.clamp_gradients() + ##### self.policy_net.clamp_gradients() self.policy_net.step()