Small import error fix

This commit is contained in:
Brandon Rozek 2019-11-17 18:39:28 -05:00
parent 8dd9ca617e
commit a44b981e55

View file

@ -16,7 +16,7 @@ from torch.optim import Adam
# Import my custom RL library
import rltorch
from rltorch.memory import PrioritizedReplayMemory, ReplayMemory
from rltorch.memory import PrioritizedReplayMemory, ReplayMemory, DQfDMemory
from rltorch.action_selector import EpsilonGreedySelector, ArgMaxSelector
import rltorch.env as E
import rltorch.network as rn
@ -152,7 +152,7 @@ net = rn.Network(Value(state_size, action_size),
target_net = rn.TargetNetwork(net, device = device)
# Relevant components from RLTorch
memory = M.DQfDMemory(capacity= config['memory_size'], alpha = config['prioritized_replay_sampling_priority'], max_demo = config['memory_size'] // 2)
memory = DQfDMemory(capacity= config['memory_size'], alpha = config['prioritized_replay_sampling_priority'], max_demo = config['memory_size'] // 2)
actor = ArgMaxSelector(net, action_size, device = device)
agent = rltorch.agents.DQfDAgent(net, memory, config, target_net = target_net)