SSAGES  0.8.3
Software Suite for Advanced General Ensemble Simulations
StringMethod.cpp
1 
23 #include "ElasticBand.h"
24 #include "FiniteTempString.h"
25 #include "StringMethod.h"
26 #include "Swarm.h"
27 #include "CVs/CVManager.h"
28 #include "Validator/ObjectRequirement.h"
29 #include "Drivers/DriverException.h"
30 #include "Snapshot.h"
31 #include "spline.h"
32 #include "schema.h"
33 
34 using namespace Json;
35 
36 namespace SSAGES
37 {
38  void StringMethod::PrintString(const CVList& CV)
39  {
40  if(comm_.rank() == 0)
41  {
42  //Write node, iteration, centers of the string and current CV value to output file
43  stringout_.precision(8);
44  stringout_ << mpiid_ << " " << iteration_ << " ";
45 
46  for(size_t i = 0; i < centers_.size(); i++)
47  stringout_ << worldstring_[mpiid_][i] << " " << CV[i]->GetValue() << " ";
48 
49  stringout_ << std::endl;
50  }
51  }
52 
53  void StringMethod::GatherNeighbors(std::vector<double> *lcv0, std::vector<double> *ucv0)
54  {
55  MPI_Status status;
56 
57  if(comm_.rank() == 0)
58  {
59  MPI_Sendrecv(&centers_[0], centers_.size(), MPI_DOUBLE, sendneigh_, 1234,
60  &(*lcv0)[0], centers_.size(), MPI_DOUBLE, recneigh_, 1234,
61  world_, &status);
62 
63  MPI_Sendrecv(&centers_[0], centers_.size(), MPI_DOUBLE, recneigh_, 4321,
64  &(*ucv0)[0], centers_.size(), MPI_DOUBLE, sendneigh_, 4321,
65  world_, &status);
66  }
67 
68  MPI_Bcast(&(*lcv0)[0],centers_.size(),MPI_DOUBLE,0,comm_);
69  MPI_Bcast(&(*ucv0)[0],centers_.size(),MPI_DOUBLE,0,comm_);
70  }
71 
72  void StringMethod::StringReparam(double alpha_star)
73  {
74  std::vector<double> alpha_star_vector(numnodes_,0.0);
75 
76  //Reparameterization
77  //Alpha star is the uneven mesh, approximated as linear distance between points
78  if(comm_.rank()==0)
79  alpha_star_vector[mpiid_] = mpiid_ == 0 ? 0 : alpha_star;
80 
81  //Gather each alpha_star into a vector
82  MPI_Allreduce(MPI_IN_PLACE, &alpha_star_vector[0], numnodes_, MPI_DOUBLE, MPI_SUM, world_);
83 
84  for(size_t i = 1; i < alpha_star_vector.size(); i++)
85  alpha_star_vector[i] += alpha_star_vector[i-1];
86 
87  for(size_t i = 1; i < alpha_star_vector.size(); i++)
88  alpha_star_vector[i] /= alpha_star_vector[numnodes_ - 1];
89 
90  tk::spline spl; //Cubic spline interpolation
91 
92  for(size_t i = 0; i < centers_.size(); i++)
93  {
94  std::vector<double> cvs_new(numnodes_, 0.0);
95 
96  if(comm_.rank() == 0)
97  cvs_new[mpiid_] = centers_[i];
98 
99  MPI_Allreduce(MPI_IN_PLACE, &cvs_new[0], numnodes_, MPI_DOUBLE, MPI_SUM, world_);
100 
101  spl.set_points(alpha_star_vector, cvs_new);
102  centers_[i] = spl(mpiid_/(numnodes_ - 1.0));
103  }
104  }
105 
106  void StringMethod::UpdateWorldString(const CVList& cvs)
107  {
108  for(size_t i = 0; i < centers_.size(); i++)
109  {
110  std::vector<double> cvs_new(numnodes_, 0.0);
111 
112  if(comm_.rank() == 0)
113  {
114  cvs_new[mpiid_] = centers_[i];
115  }
116 
117  MPI_Allreduce(MPI_IN_PLACE, &cvs_new[0], numnodes_, MPI_DOUBLE, MPI_SUM, world_);
118 
119  for(int j = 0; j < numnodes_; j++)
120  {
121  worldstring_[j][i] = cvs_new[j];
122  //Represent worldstring in periodic space
123  worldstring_[j][i] = cvs[i]->GetPeriodicValue(worldstring_[j][i]);
124  }
125  }
126  }
127 
128  bool StringMethod::CheckEnd(const CVList& CV)
129  {
130  if(maxiterator_ && iteration_ > maxiterator_)
131  {
132  std::cout << "System has reached max string method iterations (" << maxiterator_ << ") as specified in the input file(s)." << std::endl;
133  std::cout << "Exiting now" << std::endl;
134  PrintString(CV); //Ensure that the system prints out if it's about to exit
135  MPI_Abort(world_, EXIT_FAILURE);
136  }
137 
138  int local_tolvalue = TolCheck();
139 
140  MPI_Allreduce(MPI_IN_PLACE, &local_tolvalue, 1, MPI_INT, MPI_LAND, world_);
141 
142  if(local_tolvalue)
143  {
144  std::cout << "System has converged within tolerance criteria. Exiting now" << std::endl;
145  PrintString(CV); //Ensure that the system prints out if it's about to exit
146  MPI_Abort(world_, EXIT_FAILURE);
147  }
148 
149  return true;
150  }
151 
152  void StringMethod::PreSimulation(Snapshot* snapshot, const CVManager& cvmanager)
153  {
154  char file[1024];
155  mpiid_ = snapshot->GetWalkerID();
156  sprintf(file, "node-%04d.log",mpiid_);
157  stringout_.open(file);
158 
159  auto cvs = cvmanager.GetCVs(cvmask_);
160  SetSendRecvNeighbors();
161  worldstring_.resize(numnodes_);
162  for(auto& w : worldstring_)
163  w.resize(centers_.size());
164 
165  UpdateWorldString(cvs);
166  PrintString(cvs);
167  }
168 
169  void StringMethod::SetSendRecvNeighbors()
170  {
171  std::vector<int> wiids(world_.size(), 0);
172 
173  //Set the neighbors
174  recneigh_ = -1;
175  sendneigh_ = -1;
176 
177  MPI_Allgather(&mpiid_, 1, MPI_INT, &wiids[0], 1, MPI_INT, world_);
178  numnodes_ = int(*std::max_element(wiids.begin(), wiids.end())) + 1;
179 
180  // Ugly for now...
181  for(size_t i = 0; i < wiids.size(); i++)
182  {
183  if(mpiid_ == 0)
184  {
185  sendneigh_ = comm_.size();
186  if(wiids[i] == numnodes_ - 1)
187  {
188  recneigh_ = i;
189  break;
190  }
191  }
192  else if (mpiid_ == numnodes_ - 1)
193  {
194  sendneigh_ = 0;
195  if(wiids[i] == mpiid_ - 1)
196  {
197  recneigh_ = i;
198  break;
199  }
200  }
201  else
202  {
203  if(wiids[i] == mpiid_ + 1)
204  {
205  sendneigh_ = i;
206  break;
207  }
208  if(wiids[i] == mpiid_ - 1 && recneigh_ == -1)
209  recneigh_ = i;
210  }
211  }
212  }
213 
215  StringMethod* StringMethod::Build(const Value& json,
216  const MPI_Comm& world,
217  const MPI_Comm& comm,
218  const std::string& path)
219  {
220  ObjectRequirement validator;
221  Value schema;
222  CharReaderBuilder rbuilder;
223  CharReader* reader = rbuilder.newCharReader();
224 
225  StringMethod* m = nullptr;
226 
227  reader->parse(JsonSchema::StringMethod.c_str(),
228  JsonSchema::StringMethod.c_str() + JsonSchema::StringMethod.size(),
229  &schema, NULL);
230  validator.Parse(schema, path);
231 
232  // Validate inputs.
233  validator.Validate(json, path);
234  if(validator.HasErrors())
235  throw BuildException(validator.GetErrors());
236 
237  unsigned int wid = mxx::comm(world).rank()/mxx::comm(comm).size();
238  std::vector<double> centers;
239  for(auto& s : json["centers"][wid])
240  centers.push_back(s.asDouble());
241 
242  auto maxiterator = json.get("max_iterations", 0).asInt();
243 
244  std::vector<double> ksprings;
245  for(auto& s : json["ksprings"])
246  ksprings.push_back(s.asDouble());
247 
248  auto freq = json.get("frequency", 1).asInt();
249 
250  // Get stringmethod flavor.
251  std::string flavor = json.get("flavor", "none").asString();
252  if(flavor == "ElasticBand")
253  {
254  reader->parse(JsonSchema::ElasticBandMethod.c_str(),
255  JsonSchema::ElasticBandMethod.c_str() + JsonSchema::ElasticBandMethod.size(),
256  &schema, NULL);
257  validator.Parse(schema, path);
258 
259  // Validate inputs.
260  validator.Validate(json, path);
261  if(validator.HasErrors())
262  throw BuildException(validator.GetErrors());
263 
264  auto eqsteps = json.get("equilibration_steps", 20).asInt();
265  auto evsteps = json.get("evolution_steps", 5).asInt();
266  auto stringspring = json.get("kstring", 10.0).asDouble();
267  auto isteps = json.get("block_iterations", 100).asInt();
268  auto tau = json.get("time_step", 0.1).asDouble();
269 
270  m = new ElasticBand(world, comm, centers,
271  maxiterator, isteps,
272  tau, ksprings, eqsteps,
273  evsteps, stringspring, freq);
274 
275  if(json.isMember("tolerance"))
276  {
277  std::vector<double> tol;
278  for(auto& s : json["tolerance"])
279  tol.push_back(s.asDouble());
280 
281  m->SetTolerance(tol);
282  }
283  }
284  else if(flavor == "FTS")
285  {
286  reader->parse(JsonSchema::FTSMethod.c_str(),
287  JsonSchema::FTSMethod.c_str() + JsonSchema::FTSMethod.size(),
288  &schema, NULL);
289  validator.Parse(schema, path);
290 
291  // Validate inputs.
292  validator.Validate(json, path);
293  if(validator.HasErrors())
294  throw BuildException(validator.GetErrors());
295 
296  auto isteps = json.get("block_iterations", 2000).asInt();
297  auto tau = json.get("time_step", 0.1).asDouble();
298  auto kappa = json.get("kappa", 0.1).asDouble();
299  auto springiter = json.get("umbrella_iterations",2000).asDouble();
300  m = new FiniteTempString(world, comm, centers,
301  maxiterator, isteps,
302  tau, ksprings, kappa,
303  springiter, freq);
304 
305  if(json.isMember("tolerance"))
306  {
307  std::vector<double> tol;
308  for(auto& s : json["tolerance"])
309  tol.push_back(s.asDouble());
310 
311  m->SetTolerance(tol);
312  }
313  }
314  else if(flavor == "SWARM")
315  {
316  reader->parse(JsonSchema::SwarmMethod.c_str(),
317  JsonSchema::SwarmMethod.c_str() + JsonSchema::SwarmMethod.size(),
318  &schema, NULL);
319  validator.Parse(schema, path);
320 
321  //Validate input
322  validator.Validate(json, path);
323  if(validator.HasErrors())
324  throw BuildException(validator.GetErrors());
325 
326  auto InitialSteps = json.get("initial_steps", 2500).asInt();
327  auto HarvestLength = json.get("harvest_length", 10).asInt();
328  auto NumberTrajectories = json.get("number_of_trajectories", 250).asInt();
329  auto SwarmLength = json.get("swarm_length", 20).asInt();
330 
331  m = new Swarm(world, comm, centers, maxiterator, ksprings, freq, InitialSteps, HarvestLength, NumberTrajectories, SwarmLength);
332 
333  if(json.isMember("tolerance"))
334  {
335  std::vector<double> tol;
336  for(auto& s : json["tolerance"])
337  tol.push_back(s.asDouble());
338 
339  m->SetTolerance(tol);
340  }
341  }
342 
343  return m;
344  }
345 
346 }
std::vector< CollectiveVariable * > GetCVs(const std::vector< unsigned int > &mask=std::vector< unsigned int >()) const
Get CV iterator.
Definition: CVManager.h:80
bool HasErrors()
Check if errors have occured.
Definition: Requirement.h:86
Collective variable manager.
Definition: CVManager.h:40
Finite Temperature Spring Method.
std::vector< CollectiveVariable * > CVList
List of Collective Variables.
Definition: types.h:51
String base class for FTS, Swarm, and elastic band.
Definition: StringMethod.h:38
Class containing a snapshot of the current simulation in time.
Definition: Snapshot.h:43
Swarm of Trajectories String Method.
Definition: Swarm.h:32
virtual void Parse(Value json, const std::string &path) override
Parse JSON value to generate Requirement(s).
Exception to be thrown when building the Driver fails.
std::vector< std::string > GetErrors()
Get list of error messages.
Definition: Requirement.h:92
void SetTolerance(std::vector< double > tol)
Set the tolerance for quitting method.
Definition: StringMethod.h:202
Requirements on an object.
unsigned GetWalkerID() const
Get walker ID.
Definition: Snapshot.h:193
Multi-walker Elastic Band.
Definition: ElasticBand.h:34
virtual void Validate(const Value &json, const std::string &path) override
Validate JSON value.