SSAGES  0.1
A MetaDynamics Package
Public Member Functions | Static Public Member Functions | Private Member Functions | Private Attributes | List of all members
SSAGES::ANN Class Reference
Inheritance diagram for SSAGES::ANN:
Inheritance graph
[legend]

Public Member Functions

 ANN (const MPI_Comm &world, const MPI_Comm &comm, const Eigen::VectorXi &topol, Grid< Eigen::VectorXd > *fgrid, Grid< uint > *hgrid, Grid< double > *ugrid, const std::vector< double > &lowerb, const std::vector< double > &upperb, const std::vector< double > &lowerk, const std::vector< double > &upperk, double temperature, double weight, uint nsweep)
 
void PreSimulation (Snapshot *, const class CVManager &) override
 Pre-simulation hook. More...
 
void PostIntegration (Snapshot *, const class CVManager &) override
 Post-integration hook. More...
 
void PostSimulation (Snapshot *, const class CVManager &) override
 Post-simulation hook. More...
 
void SetPrevWeight (double h)
 Set previous history weight.
 
void SetOutput (const std::string &outfile)
 Set name of output file.
 
void SetOutputOverwrite (bool overwrite)
 Set overwrite flag on output file.
 
void SetConvergeIters (uint citers)
 
void SetMaxIters (uint iters)
 
void SetMinLoss (double loss)
 
- Public Member Functions inherited from SSAGES::Method
 Method (uint frequency, const MPI_Comm &world, const MPI_Comm &comm)
 Constructor. More...
 
void SetCVMask (const std::vector< uint > &mask)
 Sets the collective variable mask.
 
virtual ~Method ()
 Destructor.
 
- Public Member Functions inherited from SSAGES::EventListener
 EventListener (uint frequency)
 Constructor. More...
 
uint GetFrequency () const
 Get frequency of event listener. More...
 
virtual ~EventListener ()
 Destructor.
 

Static Public Member Functions

static ANNBuild (const Json::Value &json, const MPI_Comm &world, const MPI_Comm &comm, const std::string &path)
 
- Static Public Member Functions inherited from SSAGES::Method
static MethodBuildMethod (const Json::Value &json, const MPI_Comm &world, const MPI_Comm &comm, const std::string &path)
 Build a derived method from JSON node. More...
 

Private Member Functions

void TrainNetwork ()
 Trains the neural network.
 
void WriteBias ()
 Writes out the bias to file.
 

Private Attributes

Eigen::VectorXi topol_
 
uint sweep_
 Number of iterations per sweep.
 
uint nsweep_
 
uint citers_
 Number of iterations after which we turn on full weight.
 
nnet::neural_net net_
 Neural network.
 
double pweight_
 Previous and current histogram weight.
 
double weight_
 
double temp_
 System temperature and energy units.
 
double kbt_
 
Grid< Eigen::VectorXd > * fgrid_
 Force grid.
 
Grid< uint > * hgrid_
 Histogram grid.
 
Grid< double > * ugrid_
 Unbiased histogram grid.
 
Eigen::MatrixXd hist_
 Eigen matrices of grids.
 
Eigen::MatrixXd bias_
 
std::vector< double > lowerb_
 Bounds.
 
std::vector< double > upperb_
 
std::vector< double > lowerk_
 Bound restraints.
 
std::vector< double > upperk_
 
std::string outfile_
 Output filename.
 
bool overwrite_
 Overwrite outputs?
 

Additional Inherited Members

- Protected Attributes inherited from SSAGES::Method
mxx::comm world_
 Global MPI communicator.
 
mxx::comm comm_
 Local MPI communicator.
 
std::vector< uint > cvmask_
 Mask which identifies which CVs to act on.
 

Detailed Description

Definition at line 9 of file ANN.h.

Member Function Documentation

ANN * SSAGES::ANN::Build ( const Json::Value &  json,
const MPI_Comm &  world,
const MPI_Comm &  comm,
const std::string &  path 
)
static

Definition at line 216 of file ANN.cpp.

References SSAGES::Grid< T >::BuildGrid(), Json::Requirement::GetErrors(), Json::Requirement::HasErrors(), Json::ObjectRequirement::Parse(), and Json::ObjectRequirement::Validate().

Referenced by SetOutputOverwrite().

221  {
222  ObjectRequirement validator;
223  Value schema;
224  Reader reader;
225 
226  reader.parse(JsonSchema::ANNMethod, schema);
227  validator.Parse(schema, path);
228 
229  // Validate inputs.
230  validator.Validate(json, path);
231  if(validator.HasErrors())
232  throw BuildException(validator.GetErrors());
233 
234  // Grid.
235  auto* fgrid = Grid<VectorXd>::BuildGrid(json.get("grid", Json::Value()));
236  auto* hgrid = Grid<uint>::BuildGrid(json.get("grid", Json::Value()));
237  auto* ugrid = Grid<double>::BuildGrid(json.get("grid", Json::Value()));
238 
239  // Topology.
240  auto nlayers = json["topology"].size() + 2;
241  VectorXi topol(nlayers);
242  topol[0] = fgrid->GetDimension();
243  topol[nlayers-1] = 1;
244  for(int i = 0; i < json["topology"].size(); ++i)
245  topol[i+1] = json["topology"][i].asInt();
246 
247  auto weight = json.get("weight", 1.).asDouble();
248  auto temp = json["temperature"].asDouble();
249  auto nsweep = json["nsweep"].asUInt();
250 
251  // Assume all vectors are the same size.
252  std::vector<double> lowerb, upperb, lowerk, upperk;
253  for(int i = 0; i < json["lower_bound_restraints"].size(); ++i)
254  {
255  lowerk.push_back(json["lower_bound_restraints"][i].asDouble());
256  upperk.push_back(json["upper_bound_restraints"][i].asDouble());
257  lowerb.push_back(json["lower_bounds"][i].asDouble());
258  upperb.push_back(json["upper_bounds"][i].asDouble());
259  }
260 
261  auto* m = new ANN(world, comm, topol, fgrid, hgrid, ugrid, lowerb, upperb, lowerk, upperk, temp, weight, nsweep);
262 
263  // Set optional params.
264  m->SetPrevWeight(json.get("prev_weight", 1).asDouble());
265  m->SetOutput(json.get("output_file", "ann.out").asString());
266  m->SetOutputOverwrite( json.get("overwrite_output", true).asBool());
267  m->SetConvergeIters(json.get("converge_iters", 0).asUInt());
268  m->SetMaxIters(json.get("max_iters", 1000).asUInt());
269  m->SetMinLoss(json.get("min_loss", 0).asDouble());
270 
271  return m;
272  }
bool HasErrors()
Check if errors have occured.
Definition: Requirement.h:86
virtual void Parse(Value json, const std::string &path) override
Parse JSON value to generate Requirement(s).
std::vector< std::string > GetErrors()
Get list of error messages.
Definition: Requirement.h:92
Requirements on an object.
static Grid< T > * BuildGrid(const Json::Value &json)
Set up the grid.
Definition: Grid.h:127
virtual void Validate(const Value &json, const std::string &path) override
Validate JSON value.

Here is the call graph for this function:

Here is the caller graph for this function:

void SSAGES::ANN::PostIntegration ( Snapshot ,
const class CVManager  
)
overridevirtual

Post-integration hook.

Parameters
snapshotCurrent simulation snapshot.
cvmanagerCollective variable manager.

Implements SSAGES::Method.

Definition at line 66 of file ANN.cpp.

References SSAGES::Method::comm_, SSAGES::Method::cvmask_, SSAGES::CVManager::GetCVs(), SSAGES::Snapshot::GetForces(), SSAGES::Snapshot::GetIteration(), SSAGES::Snapshot::GetVirial(), and SSAGES::Method::world_.

67  {
68  if(snapshot->GetIteration() && snapshot->GetIteration() % nsweep_ == 0)
69  {
70  // Switch to full blast.
71  if(citers_ && snapshot->GetIteration() > citers_)
72  pweight_ = 1.0;
73 
74  TrainNetwork();
75  if(world_.rank() == 0)
76  WriteBias();
77  }
78 
79  // Get CV vals.
80  auto cvs = cvmanager.GetCVs(cvmask_);
81  auto n = cvs.size();
82 
83  // Determine if we are in bounds.
84  RowVectorXd vec(n);
85  std::vector<double> val(n);
86  bool inbounds = true;
87  for(size_t i = 0; i < n; ++i)
88  {
89  val[i] = cvs[i]->GetValue();
90  vec[i] = cvs[i]->GetValue();
91  if(val[i] < hgrid_->GetLower(i) || val[i] > hgrid_->GetUpper(i))
92  inbounds = false;
93  }
94 
95  // If in bounds, bias.
96  VectorXd derivatives = VectorXd::Zero(n);
97  if(inbounds)
98  {
99  // Record histogram hit and get gradient.
100  // Only record hits on master processes since we will
101  // reduce later.
102  if(comm_.rank() == 0)
103  hgrid_->at(val) += 1;
104  //derivatives = (*fgrid_)[val];
105  net_.forward_pass(vec);
106  derivatives = net_.get_gradient(0);
107  }
108  else
109  {
110  if(comm_.rank() == 0)
111  {
112  std::cerr << "ANN (" << snapshot->GetIteration() << "): out of bounds ( ";
113  for(auto& v : val)
114  std::cerr << v << " ";
115  std::cerr << ")" << std::endl;
116  }
117  }
118 
119  // Restraints.
120  for(size_t i = 0; i < n; ++i)
121  {
122  auto cval = cvs[i]->GetValue();
123  if(cval < lowerb_[i])
124  derivatives[i] += lowerk_[i]*cvs[i]->GetDifference(cval - lowerb_[i]);
125  else if(cval > upperb_[i])
126  derivatives[i] += upperk_[i]*cvs[i]->GetDifference(cval - upperb_[i]);
127  }
128 
129  // Apply bias to atoms.
130  auto& forces = snapshot->GetForces();
131  auto& virial = snapshot->GetVirial();
132 
133  for(size_t i = 0; i < cvs.size(); ++i)
134  {
135  auto& grad = cvs[i]->GetGradient();
136  auto& boxgrad = cvs[i]->GetBoxGradient();
137 
138  // Update the forces in snapshot by adding in the force bias from each
139  // CV to each atom based on the gradient of the CV.
140  for (size_t j = 0; j < forces.size(); ++j)
141  forces[j] -= derivatives[i]*grad[j];
142 
143  virial += derivatives[i]*boxgrad;
144  }
145  }
std::vector< double > lowerk_
Bound restraints.
Definition: ANN.h:46
uint citers_
Number of iterations after which we turn on full weight.
Definition: ANN.h:19
mxx::comm comm_
Local MPI communicator.
Definition: Method.h:47
mxx::comm world_
Global MPI communicator.
Definition: Method.h:46
void WriteBias()
Writes out the bias to file.
Definition: ANN.cpp:197
Grid< uint > * hgrid_
Histogram grid.
Definition: ANN.h:34
void TrainNetwork()
Trains the neural network.
Definition: ANN.cpp:151
const std::vector< double > GetLower() const
Return the lower edges of the Grid.
Definition: GridBase.h:227
const std::vector< double > GetUpper() const
Return the upper edges of the Grid.
Definition: GridBase.h:258
nnet::neural_net net_
Neural network.
Definition: ANN.h:22
std::vector< uint > cvmask_
Mask which identifies which CVs to act on.
Definition: Method.h:50
std::vector< double > lowerb_
Bounds.
Definition: ANN.h:43
const T & at(const std::vector< int > &indices) const
Access Grid element read-only.
Definition: GridBase.h:536
double pweight_
Previous and current histogram weight.
Definition: ANN.h:25

Here is the call graph for this function:

void SSAGES::ANN::PostSimulation ( Snapshot ,
const class CVManager  
)
overridevirtual

Post-simulation hook.

Parameters
snapshotCurrent simulation snapshot.
cvmanagerCollective variable manager.

Implements SSAGES::Method.

Definition at line 147 of file ANN.cpp.

148  {
149  }
void SSAGES::ANN::PreSimulation ( Snapshot ,
const class CVManager  
)
overridevirtual

Pre-simulation hook.

Parameters
snapshotCurrent simulation snapshot.
cvmanagerCollective variable manager.

Implements SSAGES::Method.

Definition at line 54 of file ANN.cpp.

References SSAGES::Snapshot::GetKb().

55  {
56  auto ndim = hgrid_->GetDimension();
57  kbt_ = snapshot->GetKb()*temp_;
58 
59  // Zero out forces and histogram.
60  VectorXd vec = VectorXd::Zero(ndim);
61  std::fill(hgrid_->begin(), hgrid_->end(), 0);
62  std::fill(ugrid_->begin(), ugrid_->end(), 1.0);
63  std::fill(fgrid_->begin(), fgrid_->end(), vec);
64  }
Grid< Eigen::VectorXd > * fgrid_
Force grid.
Definition: ANN.h:31
iterator begin()
Return iterator at first grid point.
Definition: Grid.h:524
double temp_
System temperature and energy units.
Definition: ANN.h:28
Grid< uint > * hgrid_
Histogram grid.
Definition: ANN.h:34
size_t GetDimension() const
Get the dimension.
Definition: GridBase.h:186
Grid< double > * ugrid_
Unbiased histogram grid.
Definition: ANN.h:37
iterator end()
Return iterator after last valid grid point.
Definition: Grid.h:537

Here is the call graph for this function:


The documentation for this class was generated from the following files: