Small import error fix
This commit is contained in:
parent
8dd9ca617e
commit
a44b981e55
1 changed files with 2 additions and 2 deletions
|
@ -16,7 +16,7 @@ from torch.optim import Adam
|
||||||
|
|
||||||
# Import my custom RL library
|
# Import my custom RL library
|
||||||
import rltorch
|
import rltorch
|
||||||
from rltorch.memory import PrioritizedReplayMemory, ReplayMemory
|
from rltorch.memory import PrioritizedReplayMemory, ReplayMemory, DQfDMemory
|
||||||
from rltorch.action_selector import EpsilonGreedySelector, ArgMaxSelector
|
from rltorch.action_selector import EpsilonGreedySelector, ArgMaxSelector
|
||||||
import rltorch.env as E
|
import rltorch.env as E
|
||||||
import rltorch.network as rn
|
import rltorch.network as rn
|
||||||
|
@ -152,7 +152,7 @@ net = rn.Network(Value(state_size, action_size),
|
||||||
target_net = rn.TargetNetwork(net, device = device)
|
target_net = rn.TargetNetwork(net, device = device)
|
||||||
|
|
||||||
# Relevant components from RLTorch
|
# Relevant components from RLTorch
|
||||||
memory = M.DQfDMemory(capacity= config['memory_size'], alpha = config['prioritized_replay_sampling_priority'], max_demo = config['memory_size'] // 2)
|
memory = DQfDMemory(capacity= config['memory_size'], alpha = config['prioritized_replay_sampling_priority'], max_demo = config['memory_size'] // 2)
|
||||||
actor = ArgMaxSelector(net, action_size, device = device)
|
actor = ArgMaxSelector(net, action_size, device = device)
|
||||||
agent = rltorch.agents.DQfDAgent(net, memory, config, target_net = target_net)
|
agent = rltorch.agents.DQfDAgent(net, memory, config, target_net = target_net)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue