# coding: utf-8
"""Custom chainer training extensions."""
import chainer
from chainer.training import Extension
import matplotlib.pyplot as plt
import numpy as np
[docs]class ScatterPlot(Extension):
"""Trainer extension to output predictions/labels scatter plots."""
def __init__(self, dataset, model, comm):
"""
Args:
dataset (HDNNPDataset):
Test dataset to plot a scatter plot. It has to have both
input dataset and label dataset.
model (HighDimensionalNNP): HDNNP model to evaluate.
comm (~chainermn.CommunicatorBase):
ChainerMN communicator instance.
"""
self._order = dataset.property.order
self._model = model
self._comm = comm.mpi_comm
self._properties = []
self._coefficients = []
self._units = []
self._inputs = []
self._labels = []
self._predictions = []
self._init_labels(dataset)
[docs] def __call__(self, trainer):
"""Execute scatter plot extension.
| Perform prediction with the parameters of the model when this
extension was executed, using the data set at initialization.
| Horizontal axis shows the predicted values and vertical axis
shows the true values.
| Plot configurations are written in :meth:`_plot`.
Args:
trainer (~chainer.training.Trainer):
Trainer object that invokes this extension.
"""
with chainer.using_config('train', False), \
chainer.using_config('enable_backprop', False):
predictions = self._model.predict(self._inputs, self._order)
for i in range(self._order + 1):
pred_send = predictions[i].data
if self._comm.Get_rank() == 0:
self._comm.Gatherv(pred_send, self._predictions[i], root=0)
self._plot(trainer,
self._coefficients[i] * self._predictions[i],
self._coefficients[i] * self._labels[i],
self._properties[i], self._units[i])
else:
self._comm.Gatherv(pred_send, None, root=0)
plt.close('all')
def _init_labels(self, dataset):
"""Gather label dataset to root process and initialize other
instance variables."""
self._properties = dataset.property.properties
self._coefficients = dataset.property.coefficients
self._units = dataset.property.units
batch = chainer.dataset.concat_examples(dataset)
self._inputs = [batch[f'inputs/{i}'] for i in range(self._order + 1)]
labels = [batch[f'labels/{i}'] for i in range(self._order + 1)]
self._count = np.array(self._comm.gather(len(labels[0]), root=0))
for i in range(self._order + 1):
label_send = labels[i]
if self._comm.Get_rank() == 0:
total_size = sum(self._count)
prediction = np.empty((total_size,) + label_send[0].shape,
dtype=np.float32)
self._predictions.append(prediction)
label = np.empty((total_size,) + label_send[0].shape,
dtype=np.float32)
label_recv = (label, self._count * label_send[0].size)
self._comm.Gatherv(label_send, label_recv, root=0)
self._labels.append(label)
else:
self._comm.Gatherv(label_send, None, root=0)
@staticmethod
def _plot(trainer, prediction, label, property_, unit):
"""Plot and save a scatter plot."""
fig = plt.figure(figsize=(10, 10))
min_ = np.min(label)
max_ = np.max(label)
plt.scatter(prediction, label, c='blue'),
plt.xlabel(f'Prediction ({unit})'),
plt.ylabel(f'Label ({unit})'),
plt.xlim(min_, max_),
plt.ylim(min_, max_),
plt.text(0.5, 0.9,
f'{property_} @ epoch={trainer.updater.epoch}',
ha='center', transform=plt.gcf().transFigure)
fig.savefig(trainer.out/f'{property_}.png')
[docs]def set_log_scale(_, a, __):
"""Change y axis scale as log scale."""
a.set_yscale('log')