SSAGES  0.8.3
Software Suite for Advanced General Ensemble Simulations
ANN.h
1 
20 #pragma once
21 
22 #include "Method.h"
23 #include "Grids/Grid.h"
24 #include "nnet/nnet.h"
25 
26 namespace SSAGES
27 {
29 
35  class ANN : public Method
36  {
37  private:
39  Eigen::VectorXi topol_;
40 
43  unsigned int sweep_, nsweep_;
45 
47  unsigned int citers_;
48 
50  nnet::neural_net net_;
51 
54  double pweight_, weight_;
56 
59  double temp_, kbt_;
61 
64 
67 
70 
73  Eigen::MatrixXd hist_, bias_;
75 
78  std::vector<double> lowerb_, upperb_;
80 
83  std::vector<double> lowerk_, upperk_;
85 
87  std::string outfile_;
88 
90  bool overwrite_;
91 
93  void TrainNetwork();
94 
96  void WriteBias();
97 
98  public:
100 
117  ANN(const MPI_Comm& world,
118  const MPI_Comm& comm,
119  const Eigen::VectorXi& topol,
120  Grid<Eigen::VectorXd>* fgrid,
121  Grid<unsigned int>* hgrid,
122  Grid<double>* ugrid,
123  const std::vector<double>& lowerb,
124  const std::vector<double>& upperb,
125  const std::vector<double>& lowerk,
126  const std::vector<double>& upperk,
127  double temperature,
128  double weight,
129  unsigned int nsweep
130  );
131 
133 
137  void PreSimulation(Snapshot* snapshot, const class CVManager& cvmanager) override;
138 
140 
144  void PostIntegration(Snapshot* snapshot, const class CVManager& cvmanager) override;
145 
147 
151  void PostSimulation(Snapshot* snapshot, const class CVManager& cvmanager) override;
152 
154  void SetPrevWeight(double h)
155  {
156  pweight_ = h;
157  }
158 
160  void SetOutput(const std::string& outfile)
161  {
162  outfile_ = outfile;
163  }
164 
166  void SetOutputOverwrite(bool overwrite)
167  {
168  overwrite_ = overwrite;
169  }
170 
172  void SetConvergeIters(unsigned int citers)
173  {
174  citers_ = citers;
175  }
176 
178  void SetMaxIters(unsigned int iters)
179  {
180  auto params = net_.get_train_params();
181  params.max_iter = iters;
182  net_.set_train_params(params);
183  }
184 
186  void SetMinLoss(double loss)
187  {
188  auto params = net_.get_train_params();
189  params.min_loss = loss;
190  net_.set_train_params(params);
191  }
192 
194  static ANN* Build(
195  const Json::Value& json,
196  const MPI_Comm& world,
197  const MPI_Comm& comm,
198  const std::string& path);
199 
200  ~ANN()
201  {
202  delete fgrid_;
203  delete hgrid_;
204  }
205  };
206 }
void SetOutput(const std::string &outfile)
Set name of output file.
Definition: ANN.h:160
Collective variable manager.
Definition: CVManager.h:40
std::vector< double > lowerk_
Definition: ANN.h:83
Artificial Neural Network Method.
Definition: ANN.h:35
ANN(const MPI_Comm &world, const MPI_Comm &comm, const Eigen::VectorXi &topol, Grid< Eigen::VectorXd > *fgrid, Grid< unsigned int > *hgrid, Grid< double > *ugrid, const std::vector< double > &lowerb, const std::vector< double > &upperb, const std::vector< double > &lowerk, const std::vector< double > &upperk, double temperature, double weight, unsigned int nsweep)
Constructor.
Definition: ANN.cpp:34
Class containing a snapshot of the current simulation in time.
Definition: Snapshot.h:43
void SetMinLoss(double loss)
Set minimum loss function value (should be zero for production).
Definition: ANN.h:186
Interface for Method implementations.
Definition: Method.h:43
unsigned int citers_
Number of iterations after which we turn on full weight.
Definition: ANN.h:47
Grid< Eigen::VectorXd > * fgrid_
Force grid.
Definition: ANN.h:63
void PostSimulation(Snapshot *snapshot, const class CVManager &cvmanager) override
Post-simulation hook.
Definition: ANN.cpp:165
unsigned int sweep_
Definition: ANN.h:43
void WriteBias()
Writes out the bias to file.
Definition: ANN.cpp:215
void SetConvergeIters(unsigned int citers)
Set number of iterations after which we turn on full weight.
Definition: ANN.h:172
double temp_
Definition: ANN.h:59
Grid< unsigned int > * hgrid_
Histogram grid.
Definition: ANN.h:66
void TrainNetwork()
Trains the neural network.
Definition: ANN.cpp:169
void PreSimulation(Snapshot *snapshot, const class CVManager &cvmanager) override
Pre-simulation hook.
Definition: ANN.cpp:72
void SetOutputOverwrite(bool overwrite)
Set overwrite flag on output file.
Definition: ANN.h:166
Eigen::VectorXi topol_
Neural network topology.
Definition: ANN.h:39
void SetPrevWeight(double h)
Set previous history weight.
Definition: ANN.h:154
bool overwrite_
Overwrite outputs?
Definition: ANN.h:90
std::string outfile_
Output filename.
Definition: ANN.h:87
static ANN * Build(const Json::Value &json, const MPI_Comm &world, const MPI_Comm &comm, const std::string &path)
Build a derived method from JSON node.
Definition: ANN.cpp:234
void SetMaxIters(unsigned int iters)
Set maximum number of training iterations per sweep.
Definition: ANN.h:178
nnet::neural_net net_
Neural network.
Definition: ANN.h:50
std::vector< double > lowerb_
Definition: ANN.h:78
Eigen::MatrixXd hist_
Definition: ANN.h:73
Grid< double > * ugrid_
Unbiased histogram grid.
Definition: ANN.h:69
double pweight_
Definition: ANN.h:54
void PostIntegration(Snapshot *snapshot, const class CVManager &cvmanager) override
Post-integration hook.
Definition: ANN.cpp:84