SSAGES  0.8.3
Software Suite for Advanced General Ensemble Simulations
ANN.cpp
1 
20 #include "ANN.h"
21 #include "schema.h"
22 #include "Snapshot.h"
23 #include "mxx/bcast.hpp"
24 #include "CVs/CVManager.h"
25 #include "Drivers/DriverException.h"
26 #include "Validator/ObjectRequirement.h"
27 
28 using namespace Eigen;
29 using namespace nnet;
30 using namespace Json;
31 
32 namespace SSAGES
33 {
34  ANN::ANN(const MPI_Comm& world,
35  const MPI_Comm& comm,
36  const VectorXi& topol,
37  Grid<VectorXd>* fgrid,
38  Grid<unsigned int>* hgrid,
39  Grid<double>* ugrid,
40  const std::vector<double>& lowerb,
41  const std::vector<double>& upperb,
42  const std::vector<double>& lowerk,
43  const std::vector<double>& upperk,
44  double temperature,
45  double weight,
46  unsigned int nsweep) :
47  Method(1, world, comm), topol_(topol), sweep_(0), nsweep_(nsweep),
48  citers_(0), net_(topol), pweight_(1.), weight_(weight), temp_(temperature),
49  kbt_(0), fgrid_(fgrid), hgrid_(hgrid), ugrid_(ugrid), hist_(), bias_(),
50  lowerb_(lowerb), upperb_(upperb), lowerk_(lowerk), upperk_(upperk),
51  outfile_("ann.out"), overwrite_(true)
52  {
53  // Create histogram grid matrix.
54  hist_.resize(hgrid_->size(), hgrid_->GetDimension());
55 
56  // Fill it up.
57  size_t i = 0;
58  for(auto it = hgrid_->begin(); it != hgrid_->end(); ++it)
59  {
60  auto coord = it.coordinates();
61  for(size_t j = 0; j < coord.size(); ++j)
62  hist_(i, j) = coord[j];
63  ++i;
64  }
65 
66  // Initialize FES vector.
67  bias_.resize(hgrid_->size(), 1);
68  net_.forward_pass(hist_);
69  bias_.array() = net_.get_activation().col(0).array();
70  }
71 
72  void ANN::PreSimulation(Snapshot* snapshot, const CVManager&)
73  {
74  auto ndim = hgrid_->GetDimension();
75  kbt_ = snapshot->GetKb()*temp_;
76 
77  // Zero out forces and histogram.
78  VectorXd vec = VectorXd::Zero(ndim);
79  std::fill(hgrid_->begin(), hgrid_->end(), 0);
80  std::fill(ugrid_->begin(), ugrid_->end(), 1.0);
81  std::fill(fgrid_->begin(), fgrid_->end(), vec);
82  }
83 
84  void ANN::PostIntegration(Snapshot* snapshot, const CVManager& cvmanager)
85  {
86  if(snapshot->GetIteration() && snapshot->GetIteration() % nsweep_ == 0)
87  {
88  // Switch to full blast.
89  if(citers_ && snapshot->GetIteration() > citers_)
90  pweight_ = 1.0;
91 
92  TrainNetwork();
93  if(world_.rank() == 0)
94  WriteBias();
95  }
96 
97  // Get CV vals.
98  auto cvs = cvmanager.GetCVs(cvmask_);
99  auto n = cvs.size();
100 
101  // Determine if we are in bounds.
102  RowVectorXd vec(n);
103  std::vector<double> val(n);
104  bool inbounds = true;
105  for(size_t i = 0; i < n; ++i)
106  {
107  val[i] = cvs[i]->GetValue();
108  vec[i] = cvs[i]->GetValue();
109  if(val[i] < hgrid_->GetLower(i) || val[i] > hgrid_->GetUpper(i))
110  inbounds = false;
111  }
112 
113  // If in bounds, bias.
114  VectorXd derivatives = VectorXd::Zero(n);
115  if(inbounds)
116  {
117  // Record histogram hit and get gradient.
118  // Only record hits on master processes since we will
119  // reduce later.
120  if(comm_.rank() == 0)
121  hgrid_->at(val) += 1;
122  //derivatives = (*fgrid_)[val];
123  net_.forward_pass(vec);
124  derivatives = net_.get_gradient(0);
125  }
126  else
127  {
128  if(comm_.rank() == 0)
129  {
130  std::cerr << "ANN (" << snapshot->GetIteration() << "): out of bounds ( ";
131  for(auto& v : val)
132  std::cerr << v << " ";
133  std::cerr << ")" << std::endl;
134  }
135  }
136 
137  // Restraints.
138  for(size_t i = 0; i < n; ++i)
139  {
140  auto cval = cvs[i]->GetValue();
141  if(cval < lowerb_[i])
142  derivatives[i] += lowerk_[i]*cvs[i]->GetDifference(cval - lowerb_[i]);
143  else if(cval > upperb_[i])
144  derivatives[i] += upperk_[i]*cvs[i]->GetDifference(cval - upperb_[i]);
145  }
146 
147  // Apply bias to atoms.
148  auto& forces = snapshot->GetForces();
149  auto& virial = snapshot->GetVirial();
150 
151  for(size_t i = 0; i < cvs.size(); ++i)
152  {
153  auto& grad = cvs[i]->GetGradient();
154  auto& boxgrad = cvs[i]->GetBoxGradient();
155 
156  // Update the forces in snapshot by adding in the force bias from each
157  // CV to each atom based on the gradient of the CV.
158  for (size_t j = 0; j < forces.size(); ++j)
159  forces[j] -= derivatives[i]*grad[j];
160 
161  virial += derivatives[i]*boxgrad;
162  }
163  }
164 
166  {
167  }
168 
170  {
171  // Increment cycle counter.
172  ++sweep_;
173 
174  // Reduce histogram across procs.
175  mxx::allreduce(hgrid_->data(), hgrid_->size(), std::plus<unsigned int>(), world_);
176 
177  // Synchronize grid in case it's periodic.
178  hgrid_->syncGrid();
179 
180  // Update FES estimator. Synchronize unbiased histogram.
183  uhist.array() = pweight_*uhist.array() + hist.cast<double>()*(1./kbt_*bias_).array().exp()*weight_;
184  ugrid_->syncGrid();
185  hist.setZero();
186 
187  bias_.array() = kbt_*uhist.array().log();
188  bias_.array() -= bias_.minCoeff();
189 
190  // Train network.
191  net_.autoscale(hist_, bias_);
192  if(world_.rank() == 0)
193  {
194  net_.train(hist_, bias_, true);
195  }
196 
197  // Send optimal nnet params to all procs.
198  vector_t wb = net_.get_wb();
199  mxx::bcast(wb.data(), wb.size(), 0, world_);
200  net_.set_wb(wb);
201 
202  // Evaluate and subtract off min value for applied bias.
203  net_.forward_pass(hist_);
204  bias_.array() = net_.get_activation().col(0).array();
205  bias_.array() -= bias_.minCoeff();
206 
207  // Calc new bias force.
208  for(size_t i = 0; i < fgrid_->size(); ++i)
209  {
210  MatrixXd forces = net_.get_gradient(i);
211  fgrid_->data()[i] = forces.row(i).transpose();
212  }
213  }
214 
216  {
217  net_.write("netstate.dat");
218 
219  std::string filename = overwrite_ ? outfile_ : outfile_ + std::to_string(sweep_);
220  std::ofstream file(filename);
221  file.precision(16);
222  net_.forward_pass(hist_);
223  matrix_t y = net_.get_activation();
224  for(int i = 0; i < y.rows(); ++i)
225  {
226  for(int j = 0; j < hist_.cols(); ++j)
227  file << std::fixed << hist_(i,j) << " ";
228  file << std::fixed << ugrid_->data()[i] << " " << std::fixed << y(i) << "\n";
229  }
230 
231  file.close();
232  }
233 
235  const Json::Value& json,
236  const MPI_Comm& world,
237  const MPI_Comm& comm,
238  const std::string& path)
239  {
240  ObjectRequirement validator;
241  Value schema;
242  CharReaderBuilder rbuilder;
243  CharReader* reader = rbuilder.newCharReader();
244 
245  reader->parse(JsonSchema::ANNMethod.c_str(),
246  JsonSchema::ANNMethod.c_str() + JsonSchema::ANNMethod.size(),
247  &schema, NULL);
248  validator.Parse(schema, path);
249 
250  // Validate inputs.
251  validator.Validate(json, path);
252  if(validator.HasErrors())
253  throw BuildException(validator.GetErrors());
254 
255  // Grid.
256  auto* fgrid = Grid<VectorXd>::BuildGrid(json.get("grid", Json::Value()));
257  auto* hgrid = Grid<unsigned int>::BuildGrid(json.get("grid", Json::Value()));
258  auto* ugrid = Grid<double>::BuildGrid(json.get("grid", Json::Value()));
259 
260  // Topology.
261  auto nlayers = json["topology"].size() + 2;
262  VectorXi topol(nlayers);
263  topol[0] = fgrid->GetDimension();
264  topol[nlayers-1] = 1;
265  for(int i = 0; i < static_cast<int>(json["topology"].size()); ++i)
266  topol[i+1] = json["topology"][i].asInt();
267 
268  auto weight = json.get("weight", 1.).asDouble();
269  auto temp = json["temperature"].asDouble();
270  auto nsweep = json["nsweep"].asUInt();
271 
272  // Assume all vectors are the same size.
273  std::vector<double> lowerb, upperb, lowerk, upperk;
274  for(int i = 0; i < static_cast<int>(json["lower_bound_restraints"].size()); ++i)
275  {
276  lowerk.push_back(json["lower_bound_restraints"][i].asDouble());
277  upperk.push_back(json["upper_bound_restraints"][i].asDouble());
278  lowerb.push_back(json["lower_bounds"][i].asDouble());
279  upperb.push_back(json["upper_bounds"][i].asDouble());
280  }
281 
282  auto* m = new ANN(world, comm, topol, fgrid, hgrid, ugrid, lowerb, upperb, lowerk, upperk, temp, weight, nsweep);
283 
284  // Set optional params.
285  m->SetPrevWeight(json.get("prev_weight", 1).asDouble());
286  m->SetOutput(json.get("output_file", "ann.out").asString());
287  m->SetOutputOverwrite( json.get("overwrite_output", true).asBool());
288  m->SetConvergeIters(json.get("converge_iters", 0).asUInt());
289  m->SetMaxIters(json.get("max_iters", 1000).asUInt());
290  m->SetMinLoss(json.get("min_loss", 0).asDouble());
291 
292  return m;
293  }
294 }
std::vector< CollectiveVariable * > GetCVs(const std::vector< unsigned int > &mask=std::vector< unsigned int >()) const
Get CV iterator.
Definition: CVManager.h:80
const std::vector< double > GetLower() const
Return the lower edges of the Grid.
Definition: GridBase.h:226
bool HasErrors()
Check if errors have occured.
Definition: Requirement.h:86
Collective variable manager.
Definition: CVManager.h:40
size_t size() const
Get the size of the internal storage vector.
Definition: GridBase.h:321
const T & at(const std::vector< int > &indices) const
Access Grid element read-only.
Definition: GridBase.h:541
std::vector< double > lowerk_
Definition: ANN.h:83
Artificial Neural Network Method.
Definition: ANN.h:35
const std::vector< Vector3 > & GetForces() const
Access the per-particle forces.
Definition: Snapshot.h:362
ANN(const MPI_Comm &world, const MPI_Comm &comm, const Eigen::VectorXi &topol, Grid< Eigen::VectorXd > *fgrid, Grid< unsigned int > *hgrid, Grid< double > *ugrid, const std::vector< double > &lowerb, const std::vector< double > &upperb, const std::vector< double > &lowerk, const std::vector< double > &upperk, double temperature, double weight, unsigned int nsweep)
Constructor.
Definition: ANN.cpp:34
void syncGrid()
Sync the grid.
Definition: GridBase.h:141
Class containing a snapshot of the current simulation in time.
Definition: Snapshot.h:43
Basic Grid.
Definition: Grid.h:58
mxx::comm comm_
Local MPI communicator.
Definition: Method.h:47
Interface for Method implementations.
Definition: Method.h:43
size_t GetDimension() const
Get the dimension.
Definition: GridBase.h:190
mxx::comm world_
Global MPI communicator.
Definition: Method.h:46
unsigned int citers_
Number of iterations after which we turn on full weight.
Definition: ANN.h:47
Grid< Eigen::VectorXd > * fgrid_
Force grid.
Definition: ANN.h:63
void PostSimulation(Snapshot *snapshot, const class CVManager &cvmanager) override
Post-simulation hook.
Definition: ANN.cpp:165
double GetKb() const
Get system Kb.
Definition: Snapshot.h:163
size_t GetIteration() const
Get the current iteration.
Definition: Snapshot.h:103
virtual void Parse(Value json, const std::string &path) override
Parse JSON value to generate Requirement(s).
unsigned int sweep_
Definition: ANN.h:43
Map for histogram and coefficients.
Definition: Basis.h:41
void WriteBias()
Writes out the bias to file.
Definition: ANN.cpp:215
iterator begin()
Return iterator at first grid point.
Definition: Grid.h:527
Exception to be thrown when building the Driver fails.
double temp_
Definition: ANN.h:59
Grid< unsigned int > * hgrid_
Histogram grid.
Definition: ANN.h:66
std::vector< std::string > GetErrors()
Get list of error messages.
Definition: Requirement.h:92
const std::vector< double > GetUpper() const
Return the upper edges of the Grid.
Definition: GridBase.h:257
void TrainNetwork()
Trains the neural network.
Definition: ANN.cpp:169
void PreSimulation(Snapshot *snapshot, const class CVManager &cvmanager) override
Pre-simulation hook.
Definition: ANN.cpp:72
Requirements on an object.
std::vector< unsigned int > cvmask_
Mask which identifies which CVs to act on.
Definition: Method.h:50
bool overwrite_
Overwrite outputs?
Definition: ANN.h:90
std::string outfile_
Output filename.
Definition: ANN.h:87
static ANN * Build(const Json::Value &json, const MPI_Comm &world, const MPI_Comm &comm, const std::string &path)
Build a derived method from JSON node.
Definition: ANN.cpp:234
nnet::neural_net net_
Neural network.
Definition: ANN.h:50
const Matrix3 & GetVirial() const
Get box virial.
Definition: Snapshot.h:133
static Grid< T > * BuildGrid(const Json::Value &json)
Set up the grid.
Definition: Grid.h:127
std::vector< double > lowerb_
Definition: ANN.h:78
Eigen::MatrixXd hist_
Definition: ANN.h:73
Grid< double > * ugrid_
Unbiased histogram grid.
Definition: ANN.h:69
T * data()
Get pointer to the internal data storage vector.
Definition: GridBase.h:333
double pweight_
Definition: ANN.h:54
void PostIntegration(Snapshot *snapshot, const class CVManager &cvmanager) override
Post-integration hook.
Definition: ANN.cpp:84
virtual void Validate(const Value &json, const std::string &path) override
Validate JSON value.
iterator end()
Return iterator after last valid grid point.
Definition: Grid.h:540