Small import error fix

This commit is contained in:
Brandon Rozek 2019-11-17 18:39:28 -05:00
parent 8dd9ca617e
commit a44b981e55

View file

@ -16,7 +16,7 @@ from torch.optim import Adam
# Import my custom RL library # Import my custom RL library
import rltorch import rltorch
from rltorch.memory import PrioritizedReplayMemory, ReplayMemory from rltorch.memory import PrioritizedReplayMemory, ReplayMemory, DQfDMemory
from rltorch.action_selector import EpsilonGreedySelector, ArgMaxSelector from rltorch.action_selector import EpsilonGreedySelector, ArgMaxSelector
import rltorch.env as E import rltorch.env as E
import rltorch.network as rn import rltorch.network as rn
@ -152,7 +152,7 @@ net = rn.Network(Value(state_size, action_size),
target_net = rn.TargetNetwork(net, device = device) target_net = rn.TargetNetwork(net, device = device)
# Relevant components from RLTorch # Relevant components from RLTorch
memory = M.DQfDMemory(capacity= config['memory_size'], alpha = config['prioritized_replay_sampling_priority'], max_demo = config['memory_size'] // 2) memory = DQfDMemory(capacity= config['memory_size'], alpha = config['prioritized_replay_sampling_priority'], max_demo = config['memory_size'] // 2)
actor = ArgMaxSelector(net, action_size, device = device) actor = ArgMaxSelector(net, action_size, device = device)
agent = rltorch.agents.DQfDAgent(net, memory, config, target_net = target_net) agent = rltorch.agents.DQfDAgent(net, memory, config, target_net = target_net)