ScatterPlot

class hdnnpy.training.extensions.ScatterPlot(dataset, model, comm)[source]

Bases: chainer.training.extension.Extension

Trainer extension to output predictions/labels scatter plots.

Parameters:
  • dataset (HDNNPDataset) – Test dataset to plot a scatter plot. It has to have both input dataset and label dataset.
  • model (HighDimensionalNNP) – HDNNP model to evaluate.
  • comm (CommunicatorBase) – ChainerMN communicator instance.
__call__(trainer)[source]

Execute scatter plot extension.

Perform prediction with the parameters of the model when this extension was executed, using the data set at initialization.
Horizontal axis shows the predicted values and vertical axis shows the true values.
Plot configurations are written in _plot().
Parameters:trainer (Trainer) – Trainer object that invokes this extension.