Source code for espressomd.checkpointing

#
# Copyright (C) 2013-2018 The ESPResSo project
#
# This file is part of ESPResSo.
#
# ESPResSo is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ESPResSo is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
from __future__ import print_function, absolute_import

from collections import OrderedDict
import sys
import inspect
import os
import re
import signal
from espressomd.utils import is_valid_type

try:
    import cPickle as pickle
except ImportError:
    import pickle


# Convenient Checkpointing for ESPResSo
[docs]class Checkpoint(object): """Checkpoint handling (reading and writing). Parameters ---------- checkpoint_id : :obj:`str` A string identifying a specific checkpoint. checkpoint_path : :obj:`str`, optional Path for reading and writing the checkpoint. If not given, the CWD is used. """ def __init__(self, checkpoint_id=None, checkpoint_path="."): # check if checkpoint_id is valid (only allow a-z A-Z 0-9 _ -) if not isinstance(checkpoint_id, str) or bool(re.compile(r"[^a-zA-Z0-9_\-]").search(checkpoint_id)): raise ValueError("Invalid checkpoint id.") if not isinstance(checkpoint_path, str): raise ValueError("Invalid checkpoint path.") self.checkpoint_objects = [] self. checkpoint_signals = [] frm = inspect.stack()[1] self.calling_module = inspect.getmodule(frm[0]) checkpoint_path = os.path.join(checkpoint_path, checkpoint_id) self.checkpoint_dir = os.path.realpath(checkpoint_path) if not os.path.isdir(self.checkpoint_dir): os.makedirs(self.checkpoint_dir) # update checkpoint counter self.counter = 0 while os.path.isfile(os.path.join(self.checkpoint_dir, "{}.checkpoint".format(self.counter))): self.counter += 1 # init signals for signum in self.read_signals(): self.register_signal(signum) def __getattr_submodule(self, obj, name, default): """ Generalization of getattr(). __getattr_submodule(object, "name1.sub1.sub2", None) will return attribute sub2 if available otherwise None. """ names = name.split('.') for i in range(len(names) - 1): obj = getattr(obj, names[i], default) return getattr(obj, names[-1], default) def __setattr_submodule(self, obj, name, value): """ Generalization of setattr(). __setattr_submodule(object, "name1.sub1.sub2", value) will set attribute sub2 to value. Will raise exception if parent modules do not exist. """ names = name.split('.') tmp_obj = obj for i in range(len(names) - 1): obj = getattr(obj, names[i], None) if obj is None: raise Exception( "Cannot set attribute of non existing submodules: {}\nCheck the order you registered objects for checkpointing.".format(name)) setattr(obj, names[-1], value) def __hasattr_submodule(self, obj, name): """ Generalization of hasattr(). __hasattr_submodule(object, "name1.sub1.sub2") will return True if submodule sub1 has the attribute sub2. """ names = name.split('.') for i in range(len(names) - 1): obj = getattr(obj, names[i], None) return hasattr(obj, names[-1])
[docs] def register(self, *args): """Register python objects for checkpointing. Parameters ---------- args : list of :obj:`str` Names of python objects to be registered for checkpointing. """ for a in args: if not isinstance(a, str): raise ValueError( "The object that should be checkpointed is identified with its name given as a string.") #if not a in dir(self.calling_module): if not self.__hasattr_submodule(self.calling_module, a): raise KeyError( "The given object '{}' was not found in the current scope.".format(a)) if a in self.checkpoint_objects: raise KeyError( "The given object '{}' is already registered for checkpointing.".format(a)) self.checkpoint_objects.append(a)
[docs] def unregister(self, *args): """Unregister python objects for checkpointing. Parameters ---------- args : list of :obj:`str` Names of python objects to be unregistered for checkpointing. """ for a in args: if not isinstance(a, str) or not a in self.checkpoint_objects: raise KeyError( "The given object '{}' was not registered for checkpointing yet.".format(a)) self.checkpoint_objects.remove(a)
[docs] def get_registered_objects(self): """ Returns a list of all object names that are registered for checkpointing. """ return self.checkpoint_objects
[docs] def has_checkpoints(self): """Check for checkpoints. Returns ------- bool True if any checkpoints exist that match checkpoint_id and checkpoint_path otherwise False. """ return self.counter > 0
[docs] def get_last_checkpoint_index(self): """ Returns the last index of the given checkpoint id. Will raise exception if no checkpoints are found. """ if not self.has_checkpoints(): raise Exception( "No checkpoints found. Cannot return index for last checkpoint.") return self.counter - 1
[docs] def save(self, checkpoint_index=None): """ Saves all registered python objects in the given checkpoint directory using cPickle. """ #get attributes of registered objects checkpoint_data = OrderedDict() for obj_name in self.checkpoint_objects: checkpoint_data[obj_name] = self.__getattr_submodule( self.calling_module, obj_name, None) if checkpoint_index is None: checkpoint_index = self.counter filename = os.path.join( self.checkpoint_dir, "{}.checkpoint".format(checkpoint_index)) tmpname = filename + ".__tmp__" with open(tmpname, "wb") as checkpoint_file: pickle.dump(checkpoint_data, checkpoint_file, -1) os.rename(tmpname, filename)
[docs] def load(self, checkpoint_index=None): """ Loads the python objects using (c)Pickle and sets them in the calling module. Parameters ---------- checkpoint_index : :obj:`int`, optional If not given, the latest checkpoint_index will be used. """ if checkpoint_index is None: checkpoint_index = self.get_last_checkpoint_index() filename = os.path.join( self.checkpoint_dir, "{}.checkpoint".format(checkpoint_index)) with open(filename, "rb") as f: checkpoint_data = pickle.load(f) for key in checkpoint_data: self.__setattr_submodule( self.calling_module, key, checkpoint_data[key]) self.checkpoint_objects.append(key)
def __signal_handler(self, signum, frame): """ Will be called when a registered signal was sent. """ self.save() exit(signum)
[docs] def read_signals(self): """ Reads all registered signals from the signal file and returns a list of integers. """ if not os.path.isfile(os.path.join(self.checkpoint_dir, "signals")): return [] with open(os.path.join(self.checkpoint_dir, "signals"), "r") as signal_file: signals = signal_file.readline().strip().split() signals = [int(i) for i in signals] # will raise exception if signal file contains invalid entries return signals
def __write_signal(self, signum=None): """Writes the given signal integer signum to the signal file. """ signum = int(signum) if not is_valid_type(signum, int): raise ValueError("Signal must be an integer number.") signals = self.read_signals() if not signum in signals: signals.append(signum) signals = " ".join(str(i) for i in signals) with open(os.path.join(self.checkpoint_dir, "signals"), "w") as signal_file: signal_file.write(signals)
[docs] def register_signal(self, signum=None): """Register a signal that will trigger the signal handler. Parameters ---------- signum : :obj:`int` Signal to be registered. """ if not is_valid_type(signum, int): raise ValueError("Signal must be an integer number.") if signum in self.checkpoint_signals: raise KeyError( "The signal {} is already registered for checkpointing.".format(signum)) signal.signal(signum, self.__signal_handler) self.checkpoint_signals.append(signum) self.__write_signal(signum)