Source code for hdnnpy.training.manager

# coding: utf-8

"""Context manager to take trainer snapshot and decide whether to train
or not."""

from contextlib import AbstractContextManager
import pickle
import signal

import chainer
import numpy as np

from hdnnpy.utils import (MPI, pprint)


[docs]class Manager(AbstractContextManager): """Context manager to take trainer snapshot and decide whether to train or not.""" def __init__(self, tag, trainer, result, is_snapshot=True): """ Args: tag (str): Tag of dataset used for training. trainer (~chainer.training.Trainer): Trainer object to be managed. result (dict): Dictionary object containing total elapsed time and metrics value corresponding to the type of loss function. Even when training is stopped / resumed, it is retained. is_snapshot (bool, optional): Take trainer snapshot if True. """ self._tag = tag self._trainer = trainer self._result = result self._is_snapshot = is_snapshot self._is_allow = True self._trainer_snapshot = trainer.out / 'trainer_snapshot.npz' self._interim_result = trainer.out / 'interim_result.pickle' self._signum = None
[docs] def __enter__(self): """Replace signal handler of SIGINT and SIGTERM.""" self._old_sigint_handler = signal.signal( signal.SIGINT, self._snapshot) self._old_sigterm_handler = signal.signal( signal.SIGTERM, self._snapshot)
[docs] def __exit__(self, type_, value, traceback): """Restore signal handler of SIGINT and SIGTERM, and record the result of training.""" signal.signal(signal.SIGINT, self._old_sigint_handler) signal.signal(signal.SIGTERM, self._old_sigterm_handler) if not self._signum: self._result['training_time'] += self._trainer.elapsed_time observation = { k: v.data.item() if isinstance(v, chainer.Variable) else v.item() if isinstance(v, np.float64) else v for k, v in self._trainer.observation.items()} self._result['observation'].append( {'tag': self._tag, **observation})
@property def allow_to_run(self): """Whether the given trainer can train with the dataset.""" return self._is_allow
[docs] def check_to_resume(self, resume_tag): """Decide whether to train or not. If current tag of dataset is equal to ``resume_tag``, restore the state of trainer from snapshot file. Args: resume_tag (str): Tag of dataset when snapshot was taken last time. """ if self._tag == resume_tag: self._resume() self._is_allow = True elif self._trainer_snapshot.exists(): self._is_allow = False else: self._is_allow = True
def _resume(self): """Restore the state of trainer from snapshot file.""" pprint(f'Resume training loop from dataset tagged "{self._tag}"') chainer.serializers.load_npz(self._trainer_snapshot, self._trainer) interim_result = pickle.loads(self._interim_result.read_bytes()) self._result['training_time'] += interim_result['training_time'] self._result['observation'].extend(interim_result['observation']) # remove snapshot if MPI.rank == 0: self._trainer_snapshot.unlink() self._interim_result.unlink() def _snapshot(self, signum, _): """Take trainer snapshot.""" self._signum = signal.Signals(signum) if self._is_snapshot and MPI.rank == 0: pprint(f'Stop {self._tag} training by signal:' f' {self._signum.name}!\n' f'Take trainer snapshot at epoch:' f' {self._trainer.updater.epoch}') chainer.serializers.save_npz(self._trainer_snapshot, self._trainer) self._interim_result.write_bytes(pickle.dumps(self._result)) # must raise any Exception to stop trainer.run() raise InterruptedError( f'Chainer training loop is interrupted by {self._signum.name}')