diff --git a/play_env.py b/play_env.py index e47041a..952d663 100644 --- a/play_env.py +++ b/play_env.py @@ -16,7 +16,7 @@ from torch.optim import Adam # Import my custom RL library import rltorch -from rltorch.memory import PrioritizedReplayMemory, ReplayMemory +from rltorch.memory import PrioritizedReplayMemory, ReplayMemory, DQfDMemory from rltorch.action_selector import EpsilonGreedySelector, ArgMaxSelector import rltorch.env as E import rltorch.network as rn @@ -152,7 +152,7 @@ net = rn.Network(Value(state_size, action_size), target_net = rn.TargetNetwork(net, device = device) # Relevant components from RLTorch -memory = M.DQfDMemory(capacity= config['memory_size'], alpha = config['prioritized_replay_sampling_priority'], max_demo = config['memory_size'] // 2) +memory = DQfDMemory(capacity= config['memory_size'], alpha = config['prioritized_replay_sampling_priority'], max_demo = config['memory_size'] // 2) actor = ArgMaxSelector(net, action_size, device = device) agent = rltorch.agents.DQfDAgent(net, memory, config, target_net = target_net)