From 9cd3625fd3f269ad422a70eb62ef81b88bb5e5aa Mon Sep 17 00:00:00 2001 From: Brandon Rozek Date: Sun, 3 Feb 2019 00:45:14 -0500 Subject: [PATCH] Made sure everything went to their appropriate devices --- examples/acrobot.py | 6 +++--- examples/pong.py | 6 +++--- rltorch/action_selector/ArgMaxSelector.py | 2 +- rltorch/network/Network.py | 5 ++++- rltorch/network/TargetNetwork.py | 4 +++- 5 files changed, 14 insertions(+), 9 deletions(-) diff --git a/examples/acrobot.py b/examples/acrobot.py index c3c337d..6c4251f 100644 --- a/examples/acrobot.py +++ b/examples/acrobot.py @@ -90,11 +90,11 @@ logwriter = rltorch.log.LogWriter(logger, SummaryWriter()) # Setting up the networks device = torch.device("cuda:0" if torch.cuda.is_available() and not config['disable_cuda'] else "cpu") net = rn.Network(Value(state_size, action_size), - torch.optim.Adam, config, logger = logger, name = "DQN") -target_net = rn.TargetNetwork(net) + torch.optim.Adam, config, device = device, logger = logger, name = "DQN") +target_net = rn.TargetNetwork(net, device = device) # Actor takes a net and uses it to produce actions from given states -actor = ArgMaxSelector(net, action_size) +actor = ArgMaxSelector(net, action_size, device = device) # Memory stores experiences for later training memory = M.ReplayMemory(capacity = config['memory_size']) diff --git a/examples/pong.py b/examples/pong.py index 074b9b1..d49e6c6 100644 --- a/examples/pong.py +++ b/examples/pong.py @@ -107,11 +107,11 @@ logwriter = rltorch.log.LogWriter(logger, SummaryWriter()) # Setting up the networks device = torch.device("cuda:0" if torch.cuda.is_available() and not config['disable_cuda'] else "cpu") net = rn.Network(Value(state_size, action_size), - torch.optim.Adam, config, logger = logger, name = "DQN") -target_net = rn.TargetNetwork(net) + torch.optim.Adam, config, device = device, logger = logger, name = "DQN") +target_net = rn.TargetNetwork(net, device = device) # Actor takes a network and uses it to produce actions from given states -actor = ArgMaxSelector(net, action_size) +actor = ArgMaxSelector(net, action_size, device = device) # Memory stores experiences for later training memory = M.ReplayMemory(capacity = config['memory_size']) diff --git a/rltorch/action_selector/ArgMaxSelector.py b/rltorch/action_selector/ArgMaxSelector.py index 2b7b2a1..3f374f9 100644 --- a/rltorch/action_selector/ArgMaxSelector.py +++ b/rltorch/action_selector/ArgMaxSelector.py @@ -10,7 +10,7 @@ class ArgMaxSelector: def best_act(self, state): with torch.no_grad(): if self.device is not None: - self.device.to(self.device) + state = state.to(self.device) action_values = self.model(state).squeeze(0) action = self.random_act() if (action_values[0] == action_values).all() else action_values.argmax().item() return action diff --git a/rltorch/network/Network.py b/rltorch/network/Network.py index 0fbe73f..06603ac 100644 --- a/rltorch/network/Network.py +++ b/rltorch/network/Network.py @@ -2,11 +2,14 @@ class Network: """ Wrapper around model which provides copy of it instead of trained weights """ - def __init__(self, model, optimizer, config, logger = None, name = ""): + def __init__(self, model, optimizer, config, device = None, logger = None, name = ""): self.model = model self.optimizer = optimizer(model.parameters(), lr = config['learning_rate'], weight_decay = config['weight_decay']) self.logger = logger self.name = name + self.device = device + if self.device is not None: + self.model = self.model.to(device) def __call__(self, *args): return self.model(*args) diff --git a/rltorch/network/TargetNetwork.py b/rltorch/network/TargetNetwork.py index c3d9184..dd80365 100644 --- a/rltorch/network/TargetNetwork.py +++ b/rltorch/network/TargetNetwork.py @@ -4,9 +4,11 @@ class TargetNetwork: """ Wrapper around model which provides copy of it instead of trained weights """ - def __init__(self, network): + def __init__(self, network, device = None): self.model = network.model self.target_model = deepcopy(network.model) + if network.device is not None: + self.target_model = self.target_model.to(network.device) def __call__(self, *args): return self.model(*args)