diff --git a/rltorch/log.py b/rltorch/log.py index d24a08f..ff3335f 100644 --- a/rltorch/log.py +++ b/rltorch/log.py @@ -1,42 +1,14 @@ -from collections import Counter +from collections import Counter, defaultdict +from typing import Dict, List, Any import numpy as np import torch -class Logger: - """ - Keeps track of lists of items seperated by tags. - - Notes - ----- - Logger is a dictionary of lists. - """ - def __init__(self): - self.log = {} - def append(self, tag, value): - if tag not in self.log.keys(): - self.log[tag] = [] - self.log[tag].append(value) - def clear(self): - self.log.clear() - def keys(self): - return self.log.keys() - def __len__(self): - return len(self.log) - def __iter__(self): - return iter(self.log) - def __contains__(self, value): - return value in self.log - def __getitem__(self, index): - return self.log[index] - def __setitem__(self, index, value): - self.log[index] = value - def __reversed__(self): - return reversed(self.log) +Logger: Dict[Any, List[Any]] = defaultdict(list) class LogWriter: """ - Takes a logger and writes it to a writter. - While keeping track of the number of times it + Takes a logger and writes it to a writter. + While keeping track of the number of times it a certain tag. Notes