SSAGES  0.8.3
Software Suite for Advanced General Ensemble Simulations
Umbrella.cpp
1 
21 #include "Umbrella.h"
22 #include "Snapshot.h"
23 #include "CVs/CVManager.h"
24 #include "Validator/ObjectRequirement.h"
25 #include "Drivers/DriverException.h"
26 #include "schema.h"
27 #include <iostream>
28 
29 using namespace Json;
30 
31 namespace SSAGES
32 {
33  void Umbrella::PreSimulation(Snapshot* /* snapshot */, const CVManager& cvmanager)
34  {
35  if(comm_.rank() == 0)
36  {
37  if(append_)
38  umbrella_.open(filename_.c_str(), std::ofstream::out | std::ofstream::app);
39  else
40  {
41  // Write out header.
42  umbrella_.open(filename_.c_str(), std::ofstream::out);
43  umbrella_ << "#";
44  umbrella_ << "Iteration ";
45 
46  auto cvs = cvmanager.GetCVs(cvmask_);
47  for(size_t i = 0; i < cvs.size(); ++i)
48  umbrella_ << "cv_" + std::to_string(i) << " ";
49 
50  for(size_t i = 0; i < cvs.size() - 1; ++i)
51  umbrella_ << "center_" + std::to_string(i) << " ";
52  umbrella_ << "center_" + std::to_string(cvs.size() - 1) << std::endl;
53  }
54  }
55  }
56 
57  void Umbrella::PostIntegration(Snapshot* snapshot, const CVManager& cvmanager)
58  {
59  // Get necessary info.
60  auto cvs = cvmanager.GetCVs(cvmask_);
61  auto& forces = snapshot->GetForces();
62  auto& virial = snapshot->GetVirial();
63 
64  for(size_t i = 0; i < cvs.size(); ++i)
65  {
66  // Get current CV and gradient.
67  auto& cv = cvs[i];
68  auto& grad = cv->GetGradient();
69  auto& boxgrad = cv->GetBoxGradient();
70  // Compute dV/dCV.
71  auto center = GetCurrentCenter(snapshot->GetIteration(), i);
72  auto D = kspring_[i]*cv->GetDifference(center);
73 
74  // Update forces.
75  for(size_t j = 0; j < forces.size(); ++j)
76  forces[j] -= D*grad[j];
77 
78  // Update virial.
79  virial += D*boxgrad;
80  }
81 
82  if(snapshot->GetIteration() % outfreq_ == 0)
83  PrintUmbrella(cvs, snapshot->GetIteration());
84  }
85 
86  void Umbrella::PostSimulation(Snapshot*, const CVManager&)
87  {
88  if(comm_.rank() ==0)
89  umbrella_.close();
90  }
91 
92  void Umbrella::PrintUmbrella(const CVList& cvs, size_t iteration)
93  {
94  if(comm_.rank() ==0)
95  {
96  umbrella_.precision(8);
97  umbrella_ << iteration << " ";
98 
99  // Print out CV values first.
100  for(auto& cv : cvs)
101  umbrella_ << cv->GetValue() << " ";
102 
103  // Print out target (center) of each CV.
104  for(size_t i = 0; i < cvs.size() - 1; ++i)
105  umbrella_ << GetCurrentCenter(iteration, i) << " ";
106  umbrella_ << GetCurrentCenter(iteration, cvs.size() - 1);
107 
108  umbrella_ << std::endl;
109  }
110  }
111 
112  Umbrella* Umbrella::Build(const Json::Value& json,
113  const MPI_Comm& world,
114  const MPI_Comm& comm,
115  const std::string& path)
116  {
117  ObjectRequirement validator;
118  Value schema;
119  CharReaderBuilder rbuilder;
120  CharReader* reader = rbuilder.newCharReader();
121 
122  reader->parse(JsonSchema::UmbrellaMethod.c_str(),
123  JsonSchema::UmbrellaMethod.c_str() + JsonSchema::UmbrellaMethod.size(),
124  &schema, NULL);
125  validator.Parse(schema, path);
126 
127  // Validate inputs.
128  validator.Validate(json, path);
129  if(validator.HasErrors())
130  throw BuildException(validator.GetErrors());
131 
132  //TODO walker id should be obtainable in method as
133  // opposed to calculated like this.
134  unsigned int wid = mxx::comm(world).rank()/mxx::comm(comm).size();
135  bool ismulti = mxx::comm(world).size() > mxx::comm(comm).size();
136  unsigned int wcount = mxx::comm(world).size() / mxx::comm(comm).size();
137 
138  std::vector<std::vector<double>> ksprings;
139  for(auto& s : json["ksprings"])
140  {
141  std::vector<double> kspring;
142  if(s.isArray())
143  for(auto& k : s)
144  kspring.push_back(k.asDouble());
145  else
146  kspring.push_back(s.asDouble());
147 
148  ksprings.push_back(kspring);
149  }
150 
151  std::vector<std::vector<double>> centers0, centers1;
152  if(json.isMember("centers"))
153  {
154  for(auto& s : json["centers"])
155  {
156  std::vector<double> center;
157  if(s.isArray())
158  for(auto& k : s)
159  center.push_back(k.asDouble());
160  else
161  center.push_back(s.asDouble());
162 
163  centers0.push_back(center);
164  }
165  }
166  else if(json.isMember("centers0") && json.isMember("centers1") && json.isMember("timesteps"))
167  {
168  for(auto& s : json["centers0"])
169  {
170  std::vector<double> center;
171  if(s.isArray())
172  for(auto& k : s)
173  center.push_back(k.asDouble());
174  else
175  center.push_back(s.asDouble());
176 
177  centers0.push_back(center);
178  }
179 
180  for(auto& s : json["centers1"])
181  {
182  std::vector<double> center;
183  if(s.isArray())
184  for(auto& k : s)
185  center.push_back(k.asDouble());
186  else
187  center.push_back(s.asDouble());
188 
189  centers1.push_back(center);
190  }
191  }
192  else
193  throw BuildException({"Either \"centers\" or \"timesteps\", \"centers0\" and \"centers1\" must be defined for umbrella."});
194 
195  if(ksprings[0].size() != centers0[0].size())
196  throw BuildException({"Need to define a spring for every center or a center for every spring!"});
197 
198  // If only one set of center/ksprings are specified. Fill it up for multi.
199  if(ismulti)
200  {
201  if(ksprings.size() == 1)
202  for(size_t i = 1; i < wcount; ++i)
203  ksprings.push_back(ksprings[0]);
204  else if(ksprings.size() != wcount)
205  throw std::invalid_argument(path + ": Multi-walker simulations requires that the number of \"ksprings\" match the number of walkers.");
206  if(centers0.size() == 1)
207  for(size_t i = 1; i < wcount; ++i)
208  centers0.push_back(centers0[0]);
209  else if(centers0.size() != wcount)
210  throw std::invalid_argument(path + ": Multi-walker simulations requires that the number of \"centers\"/\"centers0\" match the number of walkers.");
211  if(centers1.size() == 1)
212  for(size_t i = 1; i < wcount; ++i)
213  centers1.push_back(centers1[0]);
214  else if(centers1.size()) // centers1 is optional.
215  throw std::invalid_argument(path + ": Multi-walker simulations requires that the number of \"centers1\" match the number of walkers.");
216  }
217 
218  auto freq = json.get("frequency", 1).asInt();
219 
220  size_t timesteps = 0;
221  if(json.isMember("timesteps"))
222  {
223  if(json["timesteps"].isArray())
224  timesteps = json["timesteps"][wid].asUInt();
225  else
226  timesteps = json["timesteps"].asUInt();
227  }
228 
229  std::string name = "umbrella.dat";
230  if(json["output_file"].isArray())
231  name = json["output_file"][wid].asString();
232  else if(ismulti)
233  throw std::invalid_argument(path + ": Multi-walker simulations require a separate output file for each.");
234  else
235  name = json["output_file"].asString();
236 
237  Umbrella* m = nullptr;
238  if(timesteps == 0)
239  m = new Umbrella(world, comm, ksprings[wid], centers0[wid], name, freq);
240  else
241  m = new Umbrella(world, comm, ksprings[wid], centers0[wid], centers1[wid], timesteps, name, freq);
242 
243  m->SetOutputFrequency(json.get("output_frequency",0).asInt());
244  m->SetAppend(json.get("append", false).asBool());
245 
246  return m;
247  }
248 }
std::vector< CollectiveVariable * > GetCVs(const std::vector< unsigned int > &mask=std::vector< unsigned int >()) const
Get CV iterator.
Definition: CVManager.h:80
bool HasErrors()
Check if errors have occured.
Definition: Requirement.h:86
Collective variable manager.
Definition: CVManager.h:40
std::vector< CollectiveVariable * > CVList
List of Collective Variables.
Definition: types.h:51
void SetOutputFrequency(int outfreq)
Set output frequency.
Definition: Umbrella.h:162
const std::vector< Vector3 > & GetForces() const
Access the per-particle forces.
Definition: Snapshot.h:362
Class containing a snapshot of the current simulation in time.
Definition: Snapshot.h:43
Umbrella sampling method.
Definition: Umbrella.h:35
size_t GetIteration() const
Get the current iteration.
Definition: Snapshot.h:103
virtual void Parse(Value json, const std::string &path) override
Parse JSON value to generate Requirement(s).
Exception to be thrown when building the Driver fails.
std::vector< std::string > GetErrors()
Get list of error messages.
Definition: Requirement.h:92
Requirements on an object.
const Matrix3 & GetVirial() const
Get box virial.
Definition: Snapshot.h:133
void SetAppend(bool append)
Set append mode.
Definition: Umbrella.h:171
virtual void Validate(const Value &json, const std::string &path) override
Validate JSON value.