SSAGES  0.8.5
Software Suite for Advanced General Ensemble Simulations
StringMethod.cpp
1 
23 #include "ElasticBand.h"
24 #include "FiniteTempString.h"
25 #include "StringMethod.h"
26 #include "Swarm.h"
27 #include "CVs/CVManager.h"
28 #include "Validator/ObjectRequirement.h"
29 #include "Drivers/DriverException.h"
30 #include "Snapshot.h"
31 #include "spline.h"
32 #include "schema.h"
33 
34 using namespace Json;
35 
36 namespace SSAGES
37 {
38  void StringMethod::PrintString(const CVList& CV)
39  {
40  if(IsMasterRank(comm_))
41  {
42  //Write node, iteration, centers of the string and current CV value to output file
43  stringout_.precision(8);
44  stringout_ << mpiid_ << " " << iteration_ << " ";
45 
46  for(size_t i = 0; i < centers_.size(); i++)
47  stringout_ << worldstring_[mpiid_][i] << " " << CV[i]->GetValue() << " ";
48 
49  stringout_ << std::endl;
50  }
51  }
52 
53  void StringMethod::GatherNeighbors(std::vector<double> *lcv0, std::vector<double> *ucv0)
54  {
55  MPI_Status status;
56 
57  if(IsMasterRank(comm_))
58  {
59  MPI_Sendrecv(&centers_[0], centers_.size(), MPI_DOUBLE, sendneigh_, 1234,
60  &(*lcv0)[0], centers_.size(), MPI_DOUBLE, recneigh_, 1234,
61  world_, &status);
62 
63  MPI_Sendrecv(&centers_[0], centers_.size(), MPI_DOUBLE, recneigh_, 4321,
64  &(*ucv0)[0], centers_.size(), MPI_DOUBLE, sendneigh_, 4321,
65  world_, &status);
66  }
67 
68  MPI_Bcast(&(*lcv0)[0],centers_.size(),MPI_DOUBLE,0,comm_);
69  MPI_Bcast(&(*ucv0)[0],centers_.size(),MPI_DOUBLE,0,comm_);
70  }
71 
72  void StringMethod::StringReparam(double alpha_star)
73  {
74  std::vector<double> alpha_star_vector(numnodes_,0.0);
75 
76  //Reparameterization
77  //Alpha star is the uneven mesh, approximated as linear distance between points
78  if(IsMasterRank(comm_))
79  alpha_star_vector[mpiid_] = mpiid_ == 0 ? 0 : alpha_star;
80 
81  //Gather each alpha_star into a vector
82  MPI_Allreduce(MPI_IN_PLACE, &alpha_star_vector[0], numnodes_, MPI_DOUBLE, MPI_SUM, world_);
83 
84  for(size_t i = 1; i < alpha_star_vector.size(); i++)
85  alpha_star_vector[i] += alpha_star_vector[i-1];
86 
87  for(size_t i = 1; i < alpha_star_vector.size(); i++)
88  alpha_star_vector[i] /= alpha_star_vector[numnodes_ - 1];
89 
90  tk::spline spl; //Cubic spline interpolation
91 
92  for(size_t i = 0; i < centers_.size(); i++)
93  {
94  std::vector<double> cvs_new(numnodes_, 0.0);
95 
96  if(IsMasterRank(comm_))
97  cvs_new[mpiid_] = centers_[i];
98 
99  MPI_Allreduce(MPI_IN_PLACE, &cvs_new[0], numnodes_, MPI_DOUBLE, MPI_SUM, world_);
100 
101  spl.set_points(alpha_star_vector, cvs_new);
102  centers_[i] = spl(mpiid_/(numnodes_ - 1.0));
103  }
104  }
105 
106  void StringMethod::UpdateWorldString(const CVList& cvs)
107  {
108  for(size_t i = 0; i < centers_.size(); i++)
109  {
110  std::vector<double> cvs_new(numnodes_, 0.0);
111 
112  if(IsMasterRank(comm_))
113  {
114  cvs_new[mpiid_] = centers_[i];
115  }
116 
117  MPI_Allreduce(MPI_IN_PLACE, &cvs_new[0], numnodes_, MPI_DOUBLE, MPI_SUM, world_);
118 
119  for(int j = 0; j < numnodes_; j++)
120  {
121  worldstring_[j][i] = cvs_new[j];
122  //Represent worldstring in periodic space
123  worldstring_[j][i] = cvs[i]->GetPeriodicValue(worldstring_[j][i]);
124  }
125  }
126  }
127 
128  bool StringMethod::CheckEnd(const CVList& CV)
129  {
130  if(maxiterator_ && iteration_ > maxiterator_)
131  {
132  std::cout << "System has reached max string method iterations (" << maxiterator_ << ") as specified in the input file(s)." << std::endl;
133  std::cout << "Exiting now" << std::endl;
134  PrintString(CV); //Ensure that the system prints out if it's about to exit
135  MPI_Abort(world_, EXIT_FAILURE);
136  }
137 
138  int local_tolvalue = TolCheck();
139 
140  MPI_Allreduce(MPI_IN_PLACE, &local_tolvalue, 1, MPI_INT, MPI_LAND, world_);
141 
142  if(local_tolvalue)
143  {
144  std::cout << "System has converged within tolerance criteria. Exiting now" << std::endl;
145  PrintString(CV); //Ensure that the system prints out if it's about to exit
146  MPI_Abort(world_, EXIT_FAILURE);
147  }
148 
149  return true;
150  }
151 
152  void StringMethod::PreSimulation(Snapshot* snapshot, const CVManager& cvmanager)
153  {
154  char file[1024];
155  mpiid_ = snapshot->GetWalkerID();
156  sprintf(file, "node-%04d.log",mpiid_);
157  stringout_.open(file);
158 
159  auto cvs = cvmanager.GetCVs(cvmask_);
160  SetSendRecvNeighbors();
161  worldstring_.resize(numnodes_);
162  for(auto& w : worldstring_)
163  w.resize(centers_.size());
164 
165  UpdateWorldString(cvs);
166  PrintString(cvs);
167  }
168 
169  void StringMethod::PostSimulation(Snapshot*, const CVManager&)
170  {
171  stringout_.close();
172  }
173 
174  void StringMethod::SetSendRecvNeighbors()
175  {
176  std::vector<int> wiids(world_.size(), 0);
177 
178  //Set the neighbors
179  recneigh_ = -1;
180  sendneigh_ = -1;
181 
182  MPI_Allgather(&mpiid_, 1, MPI_INT, &wiids[0], 1, MPI_INT, world_);
183  numnodes_ = int(*std::max_element(wiids.begin(), wiids.end())) + 1;
184 
185  // Ugly for now...
186  for(size_t i = 0; i < wiids.size(); i++)
187  {
188  if(mpiid_ == 0)
189  {
190  sendneigh_ = comm_.size();
191  if(wiids[i] == numnodes_ - 1)
192  {
193  recneigh_ = i;
194  break;
195  }
196  }
197  else if (mpiid_ == numnodes_ - 1)
198  {
199  sendneigh_ = 0;
200  if(wiids[i] == mpiid_ - 1)
201  {
202  recneigh_ = i;
203  break;
204  }
205  }
206  else
207  {
208  if(wiids[i] == mpiid_ + 1)
209  {
210  sendneigh_ = i;
211  break;
212  }
213  if(wiids[i] == mpiid_ - 1 && recneigh_ == -1)
214  recneigh_ = i;
215  }
216  }
217  }
218 
220  StringMethod* StringMethod::Build(const Value& json,
221  const MPI_Comm& world,
222  const MPI_Comm& comm,
223  const std::string& path)
224  {
225  ObjectRequirement validator;
226  Value schema;
227  CharReaderBuilder rbuilder;
228  CharReader* reader = rbuilder.newCharReader();
229 
230  StringMethod* m = nullptr;
231 
232  reader->parse(JsonSchema::StringMethod.c_str(),
233  JsonSchema::StringMethod.c_str() + JsonSchema::StringMethod.size(),
234  &schema, nullptr);
235  validator.Parse(schema, path);
236 
237  // Validate inputs.
238  validator.Validate(json, path);
239  if(validator.HasErrors())
240  throw BuildException(validator.GetErrors());
241 
242  unsigned int wid = GetWalkerID(world, comm);
243  std::vector<double> centers;
244  for(auto& s : json["centers"][wid])
245  centers.push_back(s.asDouble());
246 
247  auto maxiterator = json.get("max_iterations", 0).asInt();
248 
249  std::vector<double> ksprings;
250  for(auto& s : json["ksprings"])
251  ksprings.push_back(s.asDouble());
252 
253  auto freq = json.get("frequency", 1).asInt();
254 
255  // Get stringmethod flavor.
256  std::string flavor = json.get("flavor", "none").asString();
257  if(flavor == "ElasticBand")
258  {
259  reader->parse(JsonSchema::ElasticBandMethod.c_str(),
260  JsonSchema::ElasticBandMethod.c_str() + JsonSchema::ElasticBandMethod.size(),
261  &schema, nullptr);
262  validator.Parse(schema, path);
263 
264  // Validate inputs.
265  validator.Validate(json, path);
266  if(validator.HasErrors())
267  throw BuildException(validator.GetErrors());
268 
269  auto eqsteps = json.get("equilibration_steps", 20).asInt();
270  auto evsteps = json.get("evolution_steps", 5).asInt();
271  auto stringspring = json.get("kstring", 10.0).asDouble();
272  auto isteps = json.get("block_iterations", 100).asInt();
273  auto tau = json.get("time_step", 0.1).asDouble();
274 
275  m = new ElasticBand(world, comm, centers,
276  maxiterator, isteps,
277  tau, ksprings, eqsteps,
278  evsteps, stringspring, freq);
279 
280  if(json.isMember("tolerance"))
281  {
282  std::vector<double> tol;
283  for(auto& s : json["tolerance"])
284  tol.push_back(s.asDouble());
285 
286  m->SetTolerance(tol);
287  }
288  }
289  else if(flavor == "FTS")
290  {
291  reader->parse(JsonSchema::FTSMethod.c_str(),
292  JsonSchema::FTSMethod.c_str() + JsonSchema::FTSMethod.size(),
293  &schema, nullptr);
294  validator.Parse(schema, path);
295 
296  // Validate inputs.
297  validator.Validate(json, path);
298  if(validator.HasErrors())
299  throw BuildException(validator.GetErrors());
300 
301  auto isteps = json.get("block_iterations", 2000).asInt();
302  auto tau = json.get("time_step", 0.1).asDouble();
303  auto kappa = json.get("kappa", 0.1).asDouble();
304  auto springiter = json.get("umbrella_iterations",2000).asDouble();
305  m = new FiniteTempString(world, comm, centers,
306  maxiterator, isteps,
307  tau, ksprings, kappa,
308  springiter, freq);
309 
310  if(json.isMember("tolerance"))
311  {
312  std::vector<double> tol;
313  for(auto& s : json["tolerance"])
314  tol.push_back(s.asDouble());
315 
316  m->SetTolerance(tol);
317  }
318  }
319  else if(flavor == "SWARM")
320  {
321  reader->parse(JsonSchema::SwarmMethod.c_str(),
322  JsonSchema::SwarmMethod.c_str() + JsonSchema::SwarmMethod.size(),
323  &schema, nullptr);
324  validator.Parse(schema, path);
325 
326  //Validate input
327  validator.Validate(json, path);
328  if(validator.HasErrors())
329  throw BuildException(validator.GetErrors());
330 
331  auto InitialSteps = json.get("initial_steps", 2500).asInt();
332  auto HarvestLength = json.get("harvest_length", 10).asInt();
333  auto NumberTrajectories = json.get("number_of_trajectories", 250).asInt();
334  auto SwarmLength = json.get("swarm_length", 20).asInt();
335 
336  m = new Swarm(world, comm, centers, maxiterator, ksprings, freq, InitialSteps, HarvestLength, NumberTrajectories, SwarmLength);
337 
338  if(json.isMember("tolerance"))
339  {
340  std::vector<double> tol;
341  for(auto& s : json["tolerance"])
342  tol.push_back(s.asDouble());
343 
344  m->SetTolerance(tol);
345  }
346  }
347 
348  return m;
349  }
350 
351 }
unsigned GetWalkerID() const
Get walker ID.
Definition: Snapshot.h:197
bool HasErrors()
Check if errors have occured.
Definition: Requirement.h:86
Collective variable manager.
Definition: CVManager.h:42
Finite Temperature Spring Method.
std::vector< CollectiveVariable * > CVList
List of Collective Variables.
Definition: types.h:51
String base class for FTS, Swarm, and elastic band.
Definition: StringMethod.h:38
Class containing a snapshot of the current simulation in time.
Definition: Snapshot.h:47
Swarm of Trajectories String Method.
Definition: Swarm.h:32
virtual void Parse(Value json, const std::string &path) override
Parse JSON value to generate Requirement(s).
Exception to be thrown when building the Driver fails.
std::vector< std::string > GetErrors()
Get list of error messages.
Definition: Requirement.h:92
void SetTolerance(std::vector< double > tol)
Set the tolerance for quitting method.
Definition: StringMethod.h:205
Requirements on an object.
Multi-walker Elastic Band.
Definition: ElasticBand.h:34
CVList GetCVs(const std::vector< unsigned int > &mask=std::vector< unsigned int >()) const
Get CV iterator.
Definition: CVManager.h:81
virtual void Validate(const Value &json, const std::string &path) override
Validate JSON value.