mlp.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <vector>
4 
5 #include "../layer/layer.h"
6 
11 class MLP
12 {
13 private:
14  size_t _n_in; // The number of input connections to the MLP.
15  std::vector<size_t>
16  _n_outs; // The number of output connections for each layer in the MLP.
17  std::vector<Layer> _layers; // The layers in the MLP.
18  std::vector<std::vector<Variable>>
19  _results; // The output results for each layer in the MLP.
20  std::vector<Variable> _parameters; // All parameters of the MLP.
21 
22 public:
28  MLP(size_t n_in, std::vector<size_t> n_outs) : _n_in(n_in), _n_outs(n_outs)
29  {
30  _layers.reserve(n_outs.size());
31  _results.resize(n_outs.size());
32  size_t n_prev = n_in;
33  size_t num_parameters = 0;
34  for (size_t n_out : n_outs)
35  {
36  num_parameters += n_out * (n_prev + 1);
37  }
38  _parameters.reserve(num_parameters);
39 
40  for (size_t i = 0; i < n_outs.size(); i++)
41  {
42  _layers.emplace_back(n_prev, n_outs[i], "tanh");
43  n_prev = n_outs[i];
44  for (size_t j = 0; j < _layers[i].parameters().size(); j++)
45  {
46  _parameters.push_back(_layers[i].parameters()[j]);
47  }
48  }
49  }
50 
55  const std::vector<Layer> &layers() const
56  {
57  return _layers;
58  }
59 
64  const std::vector<std::vector<Variable>> &results() const
65  {
66  return _results;
67  }
68 
73  const std::vector<Variable> &parameters() const
74  {
75  return _parameters;
76  }
77 
82  std::vector<Variable> &mutable_parameters()
83  {
84  return _parameters;
85  }
86 
92  std::vector<Variable> &forward(const std::vector<double> &inputs);
93 };
Definition: mlp.h:12
const std::vector< std::vector< Variable > > & results() const
Definition: mlp.h:64
std::vector< Variable > & mutable_parameters()
Definition: mlp.h:82
std::vector< Variable > & forward(const std::vector< double > &inputs)
Definition: mlp.cc:4
const std::vector< Layer > & layers() const
Definition: mlp.h:55
const std::vector< Variable > & parameters() const
Definition: mlp.h:73
MLP(size_t n_in, std::vector< size_t > n_outs)
Definition: mlp.h:28