Implemented Schedulers and Prioritized Replay
This commit is contained in:
		
							parent
							
								
									8c78f47c0c
								
							
						
					
					
						commit
						013d40a4f9
					
				
					 10 changed files with 348 additions and 12 deletions
				
			
		| 
						 | 
				
			
			@ -4,5 +4,6 @@ from . import env
 | 
			
		|||
from . import memory
 | 
			
		||||
from . import network
 | 
			
		||||
from . import mp
 | 
			
		||||
from . import scheduler
 | 
			
		||||
from .seed import *
 | 
			
		||||
from . import log
 | 
			
		||||
| 
						 | 
				
			
			@ -1,7 +1,9 @@
 | 
			
		|||
import collections
 | 
			
		||||
import rltorch.memory as M
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from copy import deepcopy
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
class DQNAgent:
 | 
			
		||||
    def __init__(self, net , memory, config, target_net = None, logger = None):
 | 
			
		||||
| 
						 | 
				
			
			@ -15,6 +17,13 @@ class DQNAgent:
 | 
			
		|||
        if len(self.memory) < self.config['batch_size']:
 | 
			
		||||
            return
 | 
			
		||||
        
 | 
			
		||||
        if (isinstance(self.memory, M.PrioritizedReplayMemory)):
 | 
			
		||||
            weight_importance = self.config['prioritized_replay_weight_importance']
 | 
			
		||||
            # If it's a scheduler then get the next value by calling next, otherwise just use it's value
 | 
			
		||||
            beta = next(weight_importance) if isinstance(weight_importance, collections.Iterable) else weight_importance
 | 
			
		||||
            minibatch = self.memory.sample(self.config['batch_size'], beta = beta)
 | 
			
		||||
            state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, importance_weights, batch_indexes = M.zip_batch(minibatch, priority = True)
 | 
			
		||||
        else:
 | 
			
		||||
            minibatch = self.memory.sample(self.config['batch_size'])
 | 
			
		||||
            state_batch, action_batch, reward_batch, next_state_batch, not_done_batch = M.zip_batch(minibatch)
 | 
			
		||||
        
 | 
			
		||||
| 
						 | 
				
			
			@ -44,6 +53,9 @@ class DQNAgent:
 | 
			
		|||
            
 | 
			
		||||
        expected_values = (reward_batch + (self.config['discount_rate'] * best_next_state_value)).unsqueeze(1)
 | 
			
		||||
 | 
			
		||||
        if (isinstance(self.memory, M.PrioritizedReplayMemory)):
 | 
			
		||||
            loss = (torch.as_tensor(importance_weights) * (obtained_values - expected_values)**2).mean()
 | 
			
		||||
        else:
 | 
			
		||||
            loss = F.mse_loss(obtained_values, expected_values)
 | 
			
		||||
        
 | 
			
		||||
        if self.logger is not None:
 | 
			
		||||
| 
						 | 
				
			
			@ -59,3 +71,9 @@ class DQNAgent:
 | 
			
		|||
                self.target_net.partial_sync(self.config['target_sync_tau'])
 | 
			
		||||
            else:
 | 
			
		||||
                self.target_net.sync()
 | 
			
		||||
 | 
			
		||||
        if (isinstance(self.memory, M.PrioritizedReplayMemory)):
 | 
			
		||||
            td_error = (obtained_values - expected_values).detach().abs()
 | 
			
		||||
            self.memory.update_priorities(batch_indexes, td_error)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										255
									
								
								rltorch/memory/PrioritizedReplayMemory.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										255
									
								
								rltorch/memory/PrioritizedReplayMemory.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,255 @@
 | 
			
		|||
# From OpenAI Baselines https://github.com/openai/baselines/blob/master/baselines/deepq/replay_buffer.py
 | 
			
		||||
 | 
			
		||||
from .ReplayMemory import ReplayMemory
 | 
			
		||||
import operator
 | 
			
		||||
import random
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
class SegmentTree(object):
 | 
			
		||||
    def __init__(self, capacity, operation, neutral_element):
 | 
			
		||||
        """Build a Segment Tree data structure.
 | 
			
		||||
        https://en.wikipedia.org/wiki/Segment_tree
 | 
			
		||||
        Can be used as regular array, but with two
 | 
			
		||||
        important differences:
 | 
			
		||||
            a) setting item's value is slightly slower.
 | 
			
		||||
               It is O(lg capacity) instead of O(1).
 | 
			
		||||
            b) user has access to an efficient ( O(log segment size) )
 | 
			
		||||
               `reduce` operation which reduces `operation` over
 | 
			
		||||
               a contiguous subsequence of items in the array.
 | 
			
		||||
        Paramters
 | 
			
		||||
        ---------
 | 
			
		||||
        capacity: int
 | 
			
		||||
            Total size of the array - must be a power of two.
 | 
			
		||||
        operation: lambda obj, obj -> obj
 | 
			
		||||
            and operation for combining elements (eg. sum, max)
 | 
			
		||||
            must form a mathematical group together with the set of
 | 
			
		||||
            possible values for array elements (i.e. be associative)
 | 
			
		||||
        neutral_element: obj
 | 
			
		||||
            neutral element for the operation above. eg. float('-inf')
 | 
			
		||||
            for max and 0 for sum.
 | 
			
		||||
        """
 | 
			
		||||
        assert capacity > 0 and capacity & (capacity - 1) == 0, "capacity must be positive and a power of 2."
 | 
			
		||||
        self._capacity = capacity
 | 
			
		||||
        self._value = [neutral_element for _ in range(2 * capacity)]
 | 
			
		||||
        self._operation = operation
 | 
			
		||||
 | 
			
		||||
    def _reduce_helper(self, start, end, node, node_start, node_end):
 | 
			
		||||
        if start == node_start and end == node_end:
 | 
			
		||||
            return self._value[node]
 | 
			
		||||
        mid = (node_start + node_end) // 2
 | 
			
		||||
        if end <= mid:
 | 
			
		||||
            return self._reduce_helper(start, end, 2 * node, node_start, mid)
 | 
			
		||||
        else:
 | 
			
		||||
            if mid + 1 <= start:
 | 
			
		||||
                return self._reduce_helper(start, end, 2 * node + 1, mid + 1, node_end)
 | 
			
		||||
            else:
 | 
			
		||||
                return self._operation(
 | 
			
		||||
                    self._reduce_helper(start, mid, 2 * node, node_start, mid),
 | 
			
		||||
                    self._reduce_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end)
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    def reduce(self, start=0, end=None):
 | 
			
		||||
        """Returns result of applying `self.operation`
 | 
			
		||||
        to a contiguous subsequence of the array.
 | 
			
		||||
            self.operation(arr[start], operation(arr[start+1], operation(... arr[end])))
 | 
			
		||||
        Parameters
 | 
			
		||||
        ----------
 | 
			
		||||
        start: int
 | 
			
		||||
            beginning of the subsequence
 | 
			
		||||
        end: int
 | 
			
		||||
            end of the subsequences
 | 
			
		||||
        Returns
 | 
			
		||||
        -------
 | 
			
		||||
        reduced: obj
 | 
			
		||||
            result of reducing self.operation over the specified range of array elements.
 | 
			
		||||
        """
 | 
			
		||||
        if end is None:
 | 
			
		||||
            end = self._capacity
 | 
			
		||||
        if end < 0:
 | 
			
		||||
            end += self._capacity
 | 
			
		||||
        end -= 1
 | 
			
		||||
        return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
 | 
			
		||||
 | 
			
		||||
    def __setitem__(self, idx, val):
 | 
			
		||||
        # index of the leaf
 | 
			
		||||
        idx += self._capacity
 | 
			
		||||
        self._value[idx] = val
 | 
			
		||||
        idx //= 2
 | 
			
		||||
        while idx >= 1:
 | 
			
		||||
            self._value[idx] = self._operation(
 | 
			
		||||
                self._value[2 * idx],
 | 
			
		||||
                self._value[2 * idx + 1]
 | 
			
		||||
            )
 | 
			
		||||
            idx //= 2
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, idx):
 | 
			
		||||
        assert 0 <= idx < self._capacity
 | 
			
		||||
        return self._value[self._capacity + idx]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SumSegmentTree(SegmentTree):
 | 
			
		||||
    def __init__(self, capacity):
 | 
			
		||||
        super(SumSegmentTree, self).__init__(
 | 
			
		||||
            capacity=capacity,
 | 
			
		||||
            operation=operator.add,
 | 
			
		||||
            neutral_element=0.0
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def sum(self, start=0, end=None):
 | 
			
		||||
        """Returns arr[start] + ... + arr[end]"""
 | 
			
		||||
        return super(SumSegmentTree, self).reduce(start, end)
 | 
			
		||||
 | 
			
		||||
    def find_prefixsum_idx(self, prefixsum):
 | 
			
		||||
        """Find the highest index `i` in the array such that
 | 
			
		||||
            sum(arr[0] + arr[1] + ... + arr[i - i]) <= prefixsum
 | 
			
		||||
        if array values are probabilities, this function
 | 
			
		||||
        allows to sample indexes according to the discrete
 | 
			
		||||
        probability efficiently.
 | 
			
		||||
        Parameters
 | 
			
		||||
        ----------
 | 
			
		||||
        perfixsum: float
 | 
			
		||||
            upperbound on the sum of array prefix
 | 
			
		||||
        Returns
 | 
			
		||||
        -------
 | 
			
		||||
        idx: int
 | 
			
		||||
            highest index satisfying the prefixsum constraint
 | 
			
		||||
        """
 | 
			
		||||
        assert 0 <= prefixsum <= self.sum() + 1e-5
 | 
			
		||||
        idx = 1
 | 
			
		||||
        while idx < self._capacity:  # while non-leaf
 | 
			
		||||
            if self._value[2 * idx] > prefixsum:
 | 
			
		||||
                idx = 2 * idx
 | 
			
		||||
            else:
 | 
			
		||||
                prefixsum -= self._value[2 * idx]
 | 
			
		||||
                idx = 2 * idx + 1
 | 
			
		||||
        return idx - self._capacity
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MinSegmentTree(SegmentTree):
 | 
			
		||||
    def __init__(self, capacity):
 | 
			
		||||
        super(MinSegmentTree, self).__init__(
 | 
			
		||||
            capacity=capacity,
 | 
			
		||||
            operation=min,
 | 
			
		||||
            neutral_element=float('inf')
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def min(self, start=0, end=None):
 | 
			
		||||
        """Returns min(arr[start], ...,  arr[end])"""
 | 
			
		||||
        return super(MinSegmentTree, self).reduce(start, end)
 | 
			
		||||
 | 
			
		||||
class PrioritizedReplayMemory(ReplayMemory):
 | 
			
		||||
    def __init__(self, capacity, alpha):
 | 
			
		||||
        """Create Prioritized Replay buffer.
 | 
			
		||||
        Parameters
 | 
			
		||||
        ----------
 | 
			
		||||
        capacity: int
 | 
			
		||||
            Max number of transitions to store in the buffer. When the buffer
 | 
			
		||||
            overflows the old memories are dropped.
 | 
			
		||||
        alpha: float
 | 
			
		||||
            how much prioritization is used
 | 
			
		||||
            (0 - no prioritization, 1 - full prioritization)
 | 
			
		||||
        See Also
 | 
			
		||||
        --------
 | 
			
		||||
        ReplayBuffer.__init__
 | 
			
		||||
        """
 | 
			
		||||
        super(PrioritizedReplayMemory, self).__init__(capacity)
 | 
			
		||||
        assert alpha >= 0
 | 
			
		||||
        self._alpha = alpha
 | 
			
		||||
 | 
			
		||||
        it_capacity = 1
 | 
			
		||||
        while it_capacity < capacity:
 | 
			
		||||
            it_capacity *= 2
 | 
			
		||||
 | 
			
		||||
        self._it_sum = SumSegmentTree(it_capacity)
 | 
			
		||||
        self._it_min = MinSegmentTree(it_capacity)
 | 
			
		||||
        self._max_priority = 1.0
 | 
			
		||||
 | 
			
		||||
    def append(self, *args, **kwargs):
 | 
			
		||||
        """See ReplayBuffer.store_effect"""
 | 
			
		||||
        idx = self.position
 | 
			
		||||
        super().append(*args, **kwargs)
 | 
			
		||||
        self._it_sum[idx] = self._max_priority ** self._alpha
 | 
			
		||||
        self._it_min[idx] = self._max_priority ** self._alpha
 | 
			
		||||
 | 
			
		||||
    def _sample_proportional(self, batch_size):
 | 
			
		||||
        res = []
 | 
			
		||||
        p_total = self._it_sum.sum(0, len(self.memory) - 1)
 | 
			
		||||
        every_range_len = p_total / batch_size
 | 
			
		||||
        for i in range(batch_size):
 | 
			
		||||
            mass = random.random() * every_range_len + i * every_range_len
 | 
			
		||||
            idx = self._it_sum.find_prefixsum_idx(mass)
 | 
			
		||||
            res.append(idx)
 | 
			
		||||
        return res
 | 
			
		||||
 | 
			
		||||
    def sample(self, batch_size, beta):
 | 
			
		||||
        """Sample a batch of experiences.
 | 
			
		||||
        compared to ReplayBuffer.sample
 | 
			
		||||
        it also returns importance weights and idxes
 | 
			
		||||
        of sampled experiences.
 | 
			
		||||
        Parameters
 | 
			
		||||
        ----------
 | 
			
		||||
        batch_size: int
 | 
			
		||||
            How many transitions to sample.
 | 
			
		||||
        beta: float
 | 
			
		||||
            To what degree to use importance weights
 | 
			
		||||
            (0 - no corrections, 1 - full correction)
 | 
			
		||||
        Returns
 | 
			
		||||
        -------
 | 
			
		||||
        obs_batch: np.array
 | 
			
		||||
            batch of observations
 | 
			
		||||
        act_batch: np.array
 | 
			
		||||
            batch of actions executed given obs_batch
 | 
			
		||||
        rew_batch: np.array
 | 
			
		||||
            rewards received as results of executing act_batch
 | 
			
		||||
        next_obs_batch: np.array
 | 
			
		||||
            next set of observations seen after executing act_batch
 | 
			
		||||
        done_mask: np.array
 | 
			
		||||
            done_mask[i] = 1 if executing act_batch[i] resulted in
 | 
			
		||||
            the end of an episode and 0 otherwise.
 | 
			
		||||
        weights: np.array
 | 
			
		||||
            Array of shape (batch_size,) and dtype np.float32
 | 
			
		||||
            denoting importance weight of each sampled transition
 | 
			
		||||
        idxes: np.array
 | 
			
		||||
            Array of shape (batch_size,) and dtype np.int32
 | 
			
		||||
            idexes in buffer of sampled experiences
 | 
			
		||||
        """
 | 
			
		||||
        assert beta > 0
 | 
			
		||||
 | 
			
		||||
        idxes = self._sample_proportional(batch_size)
 | 
			
		||||
 | 
			
		||||
        weights = []
 | 
			
		||||
        p_min = self._it_min.min() / self._it_sum.sum()
 | 
			
		||||
        max_weight = (p_min * len(self.memory)) ** (-beta)
 | 
			
		||||
 | 
			
		||||
        for idx in idxes:
 | 
			
		||||
            p_sample = self._it_sum[idx] / self._it_sum.sum()
 | 
			
		||||
            weight = (p_sample * len(self.memory)) ** (-beta)
 | 
			
		||||
            weights.append(weight / max_weight)
 | 
			
		||||
        weights = np.array(weights)
 | 
			
		||||
        encoded_sample = tuple(zip(*self._encode_sample(idxes)))
 | 
			
		||||
        batch = list(zip(*encoded_sample, weights, idxes))
 | 
			
		||||
        return batch
 | 
			
		||||
 | 
			
		||||
    def update_priorities(self, idxes, priorities):
 | 
			
		||||
        """Update priorities of sampled transitions.
 | 
			
		||||
        sets priority of transition at index idxes[i] in buffer
 | 
			
		||||
        to priorities[i].
 | 
			
		||||
        Parameters
 | 
			
		||||
        ----------
 | 
			
		||||
        idxes: [int]
 | 
			
		||||
            List of idxes of sampled transitions
 | 
			
		||||
        priorities: [float]
 | 
			
		||||
            List of updated priorities corresponding to
 | 
			
		||||
            transitions at the sampled idxes denoted by
 | 
			
		||||
            variable `idxes`.
 | 
			
		||||
        """
 | 
			
		||||
        assert len(idxes) == len(priorities)
 | 
			
		||||
        priorities += np.finfo('float').eps
 | 
			
		||||
        for idx, priority in zip(idxes, priorities):
 | 
			
		||||
            assert priority > 0
 | 
			
		||||
            assert 0 <= idx < len(self.memory)
 | 
			
		||||
            self._it_sum[idx] = priority ** self._alpha
 | 
			
		||||
            self._it_min[idx] = priority ** self._alpha
 | 
			
		||||
 | 
			
		||||
            self._max_priority = max(self._max_priority, priority)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -1,4 +1,4 @@
 | 
			
		|||
from random import sample
 | 
			
		||||
import random
 | 
			
		||||
from collections import namedtuple
 | 
			
		||||
import torch
 | 
			
		||||
Transition = namedtuple('Transition',
 | 
			
		||||
| 
						 | 
				
			
			@ -22,8 +22,22 @@ class ReplayMemory(object):
 | 
			
		|||
        self.memory.clear()
 | 
			
		||||
        self.position = 0
 | 
			
		||||
 | 
			
		||||
    def _encode_sample(self, indexes):
 | 
			
		||||
        states, actions, rewards, next_states, dones = [], [], [], [], []
 | 
			
		||||
        for i in indexes:
 | 
			
		||||
            observation = self.memory[i]
 | 
			
		||||
            state, action, reward, next_state, done = observation
 | 
			
		||||
            states.append(state)
 | 
			
		||||
            actions.append(action)
 | 
			
		||||
            rewards.append(reward)
 | 
			
		||||
            next_states.append(next_state)
 | 
			
		||||
            dones.append(done)
 | 
			
		||||
        batch = list(zip(states, actions, rewards, next_states, dones))
 | 
			
		||||
        return batch
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def sample(self, batch_size):
 | 
			
		||||
        return sample(self.memory, batch_size)
 | 
			
		||||
        return random.sample(self.memory, batch_size)
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return len(self.memory)
 | 
			
		||||
| 
						 | 
				
			
			@ -43,7 +57,10 @@ class ReplayMemory(object):
 | 
			
		|||
    def __reversed__(self):
 | 
			
		||||
        return reversed(self.memory)
 | 
			
		||||
 | 
			
		||||
def zip_batch(minibatch):
 | 
			
		||||
def zip_batch(minibatch, priority = False):
 | 
			
		||||
    if priority:
 | 
			
		||||
        state_batch, action_batch, reward_batch, next_state_batch, done_batch, weights, indexes = zip(*minibatch)
 | 
			
		||||
    else:
 | 
			
		||||
        state_batch, action_batch, reward_batch, next_state_batch, done_batch = zip(*minibatch)
 | 
			
		||||
        
 | 
			
		||||
    state_batch = torch.cat(state_batch)
 | 
			
		||||
| 
						 | 
				
			
			@ -52,4 +69,7 @@ def zip_batch(minibatch):
 | 
			
		|||
    not_done_batch = ~torch.tensor(done_batch)
 | 
			
		||||
    next_state_batch = torch.cat(next_state_batch)[not_done_batch]
 | 
			
		||||
 | 
			
		||||
    if priority:
 | 
			
		||||
        return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch, weights, indexes
 | 
			
		||||
    else:
 | 
			
		||||
        return state_batch, action_batch, reward_batch, next_state_batch, not_done_batch
 | 
			
		||||
| 
						 | 
				
			
			@ -1 +1,2 @@
 | 
			
		|||
from .ReplayMemory import * 
 | 
			
		||||
from .PrioritizedReplayMemory import *
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,7 +4,10 @@ class Network:
 | 
			
		|||
    """
 | 
			
		||||
    def __init__(self, model, optimizer, config, device = None, logger = None, name = ""):
 | 
			
		||||
        self.model = model
 | 
			
		||||
        if 'weight_decay' in config:
 | 
			
		||||
            self.optimizer = optimizer(model.parameters(), lr = config['learning_rate'], weight_decay = config['weight_decay'])
 | 
			
		||||
        else:
 | 
			
		||||
            self.optimizer = optimizer(model.parameters(), lr = config['learning_rate'])
 | 
			
		||||
        self.logger = logger
 | 
			
		||||
        self.name = name
 | 
			
		||||
        self.device = device
 | 
			
		||||
| 
						 | 
				
			
			@ -14,9 +17,10 @@ class Network:
 | 
			
		|||
    def __call__(self, *args):
 | 
			
		||||
        return self.model(*args)
 | 
			
		||||
        
 | 
			
		||||
    def clamp_gradients(self):
 | 
			
		||||
    def clamp_gradients(self, x = 1):
 | 
			
		||||
        assert x > 0
 | 
			
		||||
        for param in self.model.parameters():
 | 
			
		||||
            param.grad.data.clamp_(-1, 1)
 | 
			
		||||
            param.grad.data.clamp_(-x, x)
 | 
			
		||||
    
 | 
			
		||||
    def zero_grad(self):
 | 
			
		||||
        self.model.zero_grad()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										12
									
								
								rltorch/scheduler/ExponentialScheduler.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								rltorch/scheduler/ExponentialScheduler.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,12 @@
 | 
			
		|||
from .Scheduler import Scheduler
 | 
			
		||||
class ExponentialScheduler(Scheduler):
 | 
			
		||||
    def __init__(self, initial_value, end_value, iterations):
 | 
			
		||||
        super(ExponentialScheduler, self).__init__(initial_value, end_value, iterations)
 | 
			
		||||
        self.base = (end_value / initial_value) ** (1.0 / iterations)
 | 
			
		||||
    def __next__(self):
 | 
			
		||||
        if self.current_iteration < self.max_iterations:
 | 
			
		||||
            self.current_iteration += 1
 | 
			
		||||
            return self.initial_value * (self.base ** (self.current_iteration - 1))
 | 
			
		||||
        else:
 | 
			
		||||
            return self.end_value
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										12
									
								
								rltorch/scheduler/LinearScheduler.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								rltorch/scheduler/LinearScheduler.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,12 @@
 | 
			
		|||
from .Scheduler import Scheduler
 | 
			
		||||
class LinearScheduler(Scheduler):
 | 
			
		||||
    def __init__(self, initial_value, end_value, iterations):
 | 
			
		||||
        super(LinearScheduler, self).__init__(initial_value, end_value, iterations)
 | 
			
		||||
        self.slope = (end_value - initial_value) / iterations
 | 
			
		||||
    def __next__(self):
 | 
			
		||||
        if self.current_iteration < self.max_iterations:
 | 
			
		||||
            self.current_iteration += 1
 | 
			
		||||
            return self.slope * (self.current_iteration - 1) + self.initial_value
 | 
			
		||||
        else:
 | 
			
		||||
            return self.end_value
 | 
			
		||||
        
 | 
			
		||||
							
								
								
									
										10
									
								
								rltorch/scheduler/Scheduler.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								rltorch/scheduler/Scheduler.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,10 @@
 | 
			
		|||
class Scheduler():
 | 
			
		||||
    def __init__(self, initial_value, end_value, iterations):
 | 
			
		||||
        self.initial_value = initial_value
 | 
			
		||||
        self.end_value = end_value
 | 
			
		||||
        self.max_iterations = iterations
 | 
			
		||||
        self.current_iteration = 0
 | 
			
		||||
    def __iter__(self):
 | 
			
		||||
        return self
 | 
			
		||||
    def __next__(self):
 | 
			
		||||
        raise NotImplementedError("Scheduler does not have it's function to create a value implemented")
 | 
			
		||||
							
								
								
									
										3
									
								
								rltorch/scheduler/__init__.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								rltorch/scheduler/__init__.py
									
										
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,3 @@
 | 
			
		|||
from .Scheduler import Scheduler
 | 
			
		||||
from .LinearScheduler import LinearScheduler
 | 
			
		||||
from .ExponentialScheduler import ExponentialScheduler
 | 
			
		||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue