variable.h
Go to the documentation of this file.
1 #pragma once
2 
3 #include <functional>
4 #include <iostream>
5 #include <string>
6 #include <vector>
7 
12 class Variable
13 {
14 private:
15  double _value; // The value of the variable.
16  double _gradient; // The gradient of the variable.
17  std::string _op; // The operation associated with the variable.
18  std::string _name; // The name of the variable.
19  Variable *ref = nullptr; // A reference to the real variable.
20  std::vector<Variable> _children; // The components this variable.
21  std::function<void(Variable *)> _backward = [](Variable *) {
22  }; //The backward function associated with the variable.
23 
24 public:
32  explicit Variable(double value = 0,
33  double gradient = 0,
34  std::string op = "",
35  std::string name = "")
36  : _value(value), _gradient(gradient), _op(op), _name(name), ref(this){};
37 
42  Variable(const Variable &other)
43  : _value(other._value), _gradient(other._gradient), _op(other._op),
44  _name(other._name), ref(other.ref), _children(other._children),
45  _backward(other._backward){};
46 
53  Variable &operator=(const Variable &other)
54  {
55  this->_value = other._value;
56  this->_gradient = other._gradient;
57  this->_op = other._op;
58  this->_name = other._name;
59  this->_children = other._children;
60  this->_backward = other._backward;
61  this->ref = other.ref;
62  return *this;
63  }
64 
71  Variable &operator=(Variable &&other) noexcept
72  {
73  this->_value = other._value;
74  this->_gradient = other._gradient;
75  this->_op = other._op;
76  this->_name = other._name;
77  this->_children = other._children;
78  this->_backward = other._backward;
79  other.ref = nullptr;
80  this->ref = this;
81  return *this;
82  }
83 
89  Variable(Variable &&other) noexcept
90  : _value(other._value), _gradient(other._gradient), _op(other._op),
91  _name(other._name), ref(other.ref), _children(other._children),
92  _backward(other._backward)
93  {
94  other.ref = nullptr;
95  this->ref = this;
96  };
97 
102  {
103  set_ref(nullptr);
104  };
105 
110  double value() const
111  {
112  return _value;
113  }
114 
119  const std::string &name() const
120  {
121  return _name;
122  }
123 
128  void set_name(const std::string &name)
129  {
130  _name = name;
131  }
132 
137  void set_value(double value)
138  {
139  _value = value;
140  }
141 
146  const std::string &op() const
147  {
148  return _op;
149  }
150 
155  void set_op(const std::string &op)
156  {
157  _op = op;
158  }
159 
164  void set_backward(std::function<void(Variable *)> backward)
165  {
166  _backward = backward;
167  }
168 
173  double gradient() const
174  {
175  return _gradient;
176  }
177 
182  void set_gradient(double gradient)
183  {
184  _gradient = gradient;
185  }
186 
192  {
193  return ref;
194  }
195 
201  {
202  this->ref = reference;
203  }
204 
209  const std::vector<Variable> &children() const
210  {
211  return _children;
212  }
213 
218  std::vector<Variable> &mutable_children()
219  {
220  return _children;
221  }
222 
227  void set_children(const std::vector<Variable> &children)
228  {
229  _children = children;
230  }
231 
236  void backward()
237  {
238  _backward(this);
239  for (auto &child : _children)
240  {
241  child.backward();
242  }
243  }
244 
250  void update_gradient(double grad)
251  {
252  _gradient += grad;
253  if (this != ref && ref != nullptr)
254  {
255  ref->update_gradient(grad);
256  }
257  }
258 
263  void zero_grad()
264  {
265  _gradient = 0;
266  if (this != ref && ref != nullptr)
267  {
268  ref->zero_grad();
269  }
270  }
271 
277  void gradient_descent(double lr)
278  {
279  _value -= lr * _gradient;
280  if (this != ref && ref != nullptr)
281  {
282  ref->gradient_descent(lr);
283  }
284  }
285 
291  Variable activate(std::string activate_function);
292 
293 public:
300  friend std::ostream &operator<<(std::ostream &os, const Variable &var);
301 
307  Variable operator+(const Variable &other);
308 
314  Variable operator-(const Variable &other);
315 
321  Variable operator*(const Variable &other);
322 
328  Variable operator/(const Variable &other);
329 
334  Variable operator-() const;
335 
340  Variable identity() const;
341 
347  Variable operator+(const double other) const;
348 
354  Variable operator-(const double other) const;
355 
361  Variable operator*(const double other) const;
362 
368  Variable operator/(const double other) const;
369 
376  friend Variable operator+(const double other, const Variable &var);
377 
384  friend Variable operator-(const double other, const Variable &var);
385 
392  friend Variable operator*(const double other, const Variable &var);
393 
400  friend Variable operator/(const double other, const Variable &var);
401 
408  friend Variable dot_product(const std::vector<Variable> &a,
409  const std::vector<Variable> &b);
410 
417  friend Variable dot_product(const std::vector<Variable> &a,
418  const std::vector<double> &b);
419 
425  Variable pow(const double other) const;
426 
431  Variable exp() const;
432 
437  Variable log() const;
438 
443  Variable sin() const;
444 
449  Variable cos() const;
450 
455  Variable tan() const;
456 
461  Variable sinh() const;
462 
467  Variable cosh() const;
468 
473  Variable tanh() const;
474 
479  Variable relu() const;
480 
485  Variable sigmoid() const;
486 };
Definition: variable.h:13
Variable sigmoid() const
Definition: variable.cc:452
Variable log() const
Definition: variable.cc:316
Variable operator-() const
Definition: variable.cc:91
void set_gradient(double gradient)
Definition: variable.h:182
Variable operator+(const Variable &other)
Definition: variable.cc:15
Variable exp() const
Definition: variable.cc:301
const std::vector< Variable > & children() const
Definition: variable.h:209
void set_name(const std::string &name)
Definition: variable.h:128
void update_gradient(double grad)
Definition: variable.h:250
Variable sinh() const
Definition: variable.cc:388
Variable tanh() const
Definition: variable.cc:421
friend std::ostream & operator<<(std::ostream &os, const Variable &var)
Definition: variable.cc:5
double gradient() const
Definition: variable.h:173
Variable * reference() const
Definition: variable.h:191
void set_ref(Variable *reference)
Definition: variable.h:200
void set_children(const std::vector< Variable > &children)
Definition: variable.h:227
friend Variable dot_product(const std::vector< Variable > &a, const std::vector< Variable > &b)
Definition: variable.cc:213
Variable identity() const
Definition: variable.cc:106
void backward()
Definition: variable.h:236
Variable cosh() const
Definition: variable.cc:404
Variable tan() const
Definition: variable.cc:367
void zero_grad()
Definition: variable.h:263
Variable(const Variable &other)
Definition: variable.h:42
std::vector< Variable > & mutable_children()
Definition: variable.h:218
Variable & operator=(const Variable &other)
Definition: variable.h:53
const std::string & op() const
Definition: variable.h:146
const std::string & name() const
Definition: variable.h:119
void set_backward(std::function< void(Variable *)> backward)
Definition: variable.h:164
Variable(Variable &&other) noexcept
Definition: variable.h:89
Variable & operator=(Variable &&other) noexcept
Definition: variable.h:71
Variable relu() const
Definition: variable.cc:436
Variable sin() const
Definition: variable.cc:335
Variable pow(const double other) const
Definition: variable.cc:280
~Variable()
Definition: variable.h:101
void gradient_descent(double lr)
Definition: variable.h:277
void set_value(double value)
Definition: variable.h:137
void set_op(const std::string &op)
Definition: variable.h:155
Variable cos() const
Definition: variable.cc:351
Variable operator*(const Variable &other)
Definition: variable.cc:49
Variable operator/(const Variable &other)
Definition: variable.cc:68
double value() const
Definition: variable.h:110
Variable(double value=0, double gradient=0, std::string op="", std::string name="")
Definition: variable.h:32
Variable activate(std::string activate_function)
Definition: variable.cc:469