Source code for olympus.utils.storage

from datetime import datetime
import os
import io
import tempfile
import torch

from olympus.utils import info
from olympus.utils.options import options


[docs]class BaseStorage:
[docs] def load(self, *args, **kwargs): pass
[docs] def safe_load(self, *args, **kwargs): pass
def __init__(self, *args, **kwargs): pass
[docs] def set_base(self, *args, **kwargs): pass
[docs] def show_memory_usage(self): return {}
[docs] def garbage_collect_in_memory(self, *args, **kwargs): pass
[docs] def garbage_collect_on_disk(self, *args, **kwargs): pass
[docs] def garbage_collect(self, *args, **kwargs): pass
[docs] def open(self, *args, **kwargs): pass
[docs] def write(self, *args, **kwargs): pass
[docs] def read(self, *args, **kwargs): pass
[docs] def exits(self, *args, **kwargs): pass
[docs] def save(self, *args, **kwargs): pass
[docs]def NoStorage(*args, **kwargs): return BaseStorage(*args, **kwargs)
[docs]class StateStorage(BaseStorage): Kio = 1024 Mio = 1024 * 1024 USE_IN_MEMORY_CACHE = False def __init__(self, folder=options('state.storage', '/tmp'), time_buffer=options('state.storage.time', 5 * 60, type=int)): # typically root/task_name/experiment_name/trial_id self.folder = None self.set_base(folder) self.time_buffer = time_buffer self.last_save = datetime.utcnow() self.cache = dict() self.in_memory = 0 self.on_disk = 0 self.on_disk_files = dict()
[docs] def set_base(self, folder): self.folder = folder os.makedirs(self.folder, exist_ok=True)
[docs] def show_memory_usage(self): return { 'on_disk': self.on_disk / StateStorage.Mio, 'on_disk_file_count': len(self.on_disk_files), 'in_memory': self.in_memory / StateStorage.Mio }
[docs] def garbage_collect_in_memory(self, gc_time): now = datetime.utcnow() old = self.in_memory to_be_deleted = [] for name, (buffer, save_time) in self.cache.items(): if (now - save_time).total_seconds > gc_time: to_be_deleted.append(name) for name in to_be_deleted: self._pop_from_cache(name) new = self.in_memory freed = old - new return freed
[docs] def garbage_collect_on_disk(self, gc_time): now = datetime.utcnow() old = self.on_disk to_be_deleted = [] for path, (size, save_time) in self.on_disk_files.items(): if (now - save_time).total_seconds > gc_time: to_be_deleted.append(path) for path in to_be_deleted: self._pop_from_disk(path) new = self.on_disk freed = old - new return freed
[docs] def garbage_collect(self, gc_time): freed = 0 freed += self.garbage_collect_in_memory(gc_time) freed += self.garbage_collect_on_disk(gc_time) return freed
def _file(self, filename): return f'{self.folder}/{filename}.state'
[docs] def open(self, filename, mode): return open(self._file(filename), mode)
[docs] def write(self, filename, data): return self.open(filename, 'w').write(data)
[docs] def read(self, filename): return self.open(filename, 'r').read()
[docs] def exits(self, filename): return os.path.exists(self._file(filename))
[docs] def save(self, filename, state): from olympus.utils import info path = self._file(filename) dirname = os.path.dirname(path) if dirname: os.makedirs(dirname, exist_ok=True) # Writes the state inside a buffer buffer = io.BytesIO() torch.save(state, buffer) buffer = buffer.getbuffer() # if it has been a while write it to disk elapsed_time = (datetime.utcnow() - self.last_save).total_seconds() if elapsed_time > self.time_buffer: # write it inside a temporary file fd, name = tempfile.mkstemp('state', dir=self.folder) file = os.fdopen(fd, 'wb') file.write(buffer) file.close() # move temporary file to right spot os.rename(name, path) info(f'State was written to {path}') # Remove from cache it is in self._pop_from_cache(filename) self._insert_disk(filename, buffer.nbytes) self.last_save = datetime.utcnow() return True else: info(f'({elapsed_time:.2f} > {self.time_buffer}) skipping checkpoint') self._insert_cache(filename, buffer) return False
def _insert_disk(self, key, size): if key in self.on_disk_files: self._pop_from_disk(key) self.on_disk_files[key] = (size, datetime.utcnow()) self.on_disk += size def _pop_from_disk(self, key): size, _ = self.on_disk_files.pop(key, (None, None)) if size: self.on_disk -= size def _insert_cache(self, key, buffer): if StateStorage.USE_IN_MEMORY_CACHE: if key in self.cache: self._pop_from_cache(key) self.cache[key] = (buffer, datetime.utcnow()) self.in_memory += buffer.getbuffer().nbytes def _pop_from_cache(self, key): buffer, _ = self.cache.pop(key, (None, None)) if buffer: self.in_memory -= buffer.getbuffer().nbytes return buffer
[docs] def load(self, filename, device=None): """ Parameters ---------- filename: str file to load the state from device: torch.device it indicates the location where all tensors should be loaded. """ buffer = self._pop_from_cache(filename) if buffer is None: buffer = self._file(filename) return torch.load(buffer, map_location=device)
[docs] def safe_load(self, name, device): """Handles a few common exceptions for you and returns None if a file is not found""" try: return self.load(name, device=device) except RuntimeError as e: # This error happens when there is a mismatch between save device and current device if 'CPU-only machine' in str(e): raise KeyboardInterrupt('Job got scheduled on bad node.') from e except FileNotFoundError: info(f'State file {name} was not found') return None