Source code for espressomd.checkpointing

#
# Copyright (C) 2013-2024 The ESPResSo project
#
# This file is part of ESPResSo.
#
# ESPResSo is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ESPResSo is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
import collections
import inspect
import pickle
import os
import re
import signal
from . import utils
from . import script_interface


[docs] class Checkpoint: """Checkpoint handling (reading and writing). Parameters ---------- checkpoint_id : :obj:`str` A string identifying a specific checkpoint. checkpoint_path : :obj:`str`, optional Path for reading and writing the checkpoint. If not given, the current working directory is used. """ def __init__(self, checkpoint_id=None, checkpoint_path="."): # check if checkpoint_id is valid (only allow a-z A-Z 0-9 _ -) if not isinstance(checkpoint_id, str) or bool( re.compile(r"[^a-zA-Z0-9_\-]").search(checkpoint_id)): raise ValueError("Invalid checkpoint id.") if not isinstance(checkpoint_path, str): raise ValueError("Invalid checkpoint path.") self.checkpoint_objects = [] self.checkpoint_signals = [] frm = inspect.stack()[1] self.calling_module = inspect.getmodule(frm[0]) checkpoint_path = os.path.join(checkpoint_path, checkpoint_id) self.checkpoint_dir = os.path.realpath(checkpoint_path) if not os.path.isdir(self.checkpoint_dir): os.makedirs(self.checkpoint_dir) # update checkpoint counter self.counter = 0 while os.path.isfile(os.path.join( self.checkpoint_dir, f"{self.counter}.checkpoint")): self.counter += 1 # init signals for signum in self.read_signals(): self.register_signal(signum) def __getattr_submodule(self, obj, name, default): """ Generalization of ``getattr()``. ``__getattr_submodule(object, "name1.sub1.sub2", None)`` will return attribute ``sub2`` if available otherwise ``None``. """ names = name.split('.') for i in range(len(names) - 1): obj = getattr(obj, names[i], default) return getattr(obj, names[-1], default) def __setattr_submodule(self, obj, name, value): """ Generalization of ``setattr()``. ``__setattr_submodule(object, "name1.sub1.sub2", value)`` will set attribute ``sub2`` to ``value``. Will raise exception if parent modules do not exist. """ names = name.split('.') for i in range(len(names) - 1): obj = getattr(obj, names[i], None) if obj is None: raise Exception( f"Cannot set attribute of non existing submodules: {name}\n" f"Check the order you registered objects for checkpointing.") setattr(obj, names[-1], value) def __hasattr_submodule(self, obj, name): """ Generalization of ``hasattr()``. ``__hasattr_submodule(object, "name1.sub1.sub2")`` will return ``True`` if submodule ``sub1`` has the attribute ``sub2``. """ names = name.split('.') for i in range(len(names) - 1): obj = getattr(obj, names[i], None) return hasattr(obj, names[-1])
[docs] def register(self, *args): """Register python objects for checkpointing. Parameters ---------- args : list of :obj:`str` Names of python objects to be registered for checkpointing. """ for varname in args: if not isinstance(varname, str): raise ValueError( "The object that should be checkpointed is identified with its name given as a string.") # if not a in dir(self.calling_module): if not self.__hasattr_submodule(self.calling_module, varname): raise KeyError( f"The given object '{varname}' was not found in the current scope.") if varname in self.checkpoint_objects: raise KeyError( f"The given object '{varname}' is already registered for checkpointing.") obj = self.__getattr_submodule(self.calling_module, varname, None) if isinstance( obj, script_interface.ScriptInterfaceHelper) and not obj._so_checkpointable: raise TypeError( f"Objects of type {type(obj)} cannot be checkpointed.") self.checkpoint_objects.append(varname)
[docs] def unregister(self, *args): """Unregister python objects for checkpointing. Parameters ---------- args : list of :obj:`str` Names of python objects to be unregistered for checkpointing. """ for varname in args: if not isinstance(varname, str): raise ValueError( "The object that should be checkpointed is identified with its name given as a string.") if varname not in self.checkpoint_objects: raise KeyError( f"The given object '{varname}' was not registered for checkpointing yet.") self.checkpoint_objects.remove(varname)
[docs] def get_registered_objects(self): """ Returns a list of all object names that are registered for checkpointing. """ return self.checkpoint_objects
[docs] def has_checkpoints(self): """Check for checkpoints. Returns ------- :obj:`bool` ``True`` if any checkpoints exist that match ``checkpoint_id`` and ``checkpoint_path`` otherwise ``False``. """ return self.counter > 0
[docs] def get_last_checkpoint_index(self): """ Returns the last index of the given checkpoint id. Will raise exception if no checkpoints are found. """ if not self.has_checkpoints(): raise Exception( "No checkpoints found. Cannot return index for last checkpoint.") return self.counter - 1
[docs] def save(self, checkpoint_index=None): """ Saves all registered python objects in the given checkpoint directory using pickle. """ # get attributes of registered objects checkpoint_data = collections.OrderedDict() for obj_name in self.checkpoint_objects: checkpoint_data[obj_name] = self.__getattr_submodule( self.calling_module, obj_name, None) if checkpoint_index is None: checkpoint_index = self.counter filename = os.path.join( self.checkpoint_dir, f"{checkpoint_index}.checkpoint") tmpname = filename + ".__tmp__" with open(tmpname, "wb") as checkpoint_file: pickle.dump(checkpoint_data, checkpoint_file, -1) os.rename(tmpname, filename)
[docs] def load(self, checkpoint_index=None): """ Loads the python objects using (c)Pickle and sets them in the calling module. Parameters ---------- checkpoint_index : :obj:`int`, optional If not given, the last ``checkpoint_index`` will be used. """ if checkpoint_index is None: checkpoint_index = self.get_last_checkpoint_index() filename = os.path.join( self.checkpoint_dir, f"{checkpoint_index}.checkpoint") with open(filename, "rb") as f: checkpoint_data = pickle.load(f) for key in checkpoint_data: self.__setattr_submodule( self.calling_module, key, checkpoint_data[key]) self.checkpoint_objects.append(key)
def __signal_handler(self, signum, frame): # pylint: disable=unused-argument """ Will be called when a registered signal was sent. """ self.save() exit(signum)
[docs] def read_signals(self): """ Reads all registered signals from the signal file and returns a list of integers. """ if not os.path.isfile(os.path.join(self.checkpoint_dir, "signals")): return [] with open(os.path.join(self.checkpoint_dir, "signals"), "r") as signal_file: signals = signal_file.readline().strip().split() signals = [int(i) for i in signals] # will raise exception if signal file contains invalid entries return signals
def __write_signal(self, signum=None): """Writes the given signal integer signum to the signal file. """ signum = int(signum) if not utils.is_valid_type(signum, int): raise ValueError("Signal must be an integer number.") signals = self.read_signals() if signum not in signals: signals.append(signum) signals = " ".join(str(i) for i in signals) with open(os.path.join(self.checkpoint_dir, "signals"), "w") as signal_file: signal_file.write(signals)
[docs] def register_signal(self, signum=None): """Register a signal that will trigger the signal handler. Parameters ---------- signum : :obj:`int` Signal to be registered. """ if not utils.is_valid_type(signum, int): raise ValueError("Signal must be an integer number.") if signum in self.checkpoint_signals: raise KeyError( f"The signal {signum} is already registered for checkpointing.") signal.signal(signum, self.__signal_handler) self.checkpoint_signals.append(signum) self.__write_signal(signum)