fastText  d00d36476b15
Fast text processing tool/library
model.h
Go to the documentation of this file.
1 
10 #ifndef FASTTEXT_MODEL_H
11 #define FASTTEXT_MODEL_H
12 
13 #include <vector>
14 #include <random>
15 #include <utility>
16 #include <memory>
17 
18 #include "args.h"
19 #include "matrix.h"
20 #include "vector.h"
21 #include "qmatrix.h"
22 #include "real.h"
23 
24 #define SIGMOID_TABLE_SIZE 512
25 #define MAX_SIGMOID 8
26 #define LOG_TABLE_SIZE 512
27 
28 namespace fasttext {
29 
30 struct Node {
31  int32_t parent;
32  int32_t left;
33  int32_t right;
34  int64_t count;
35  bool binary;
36 };
37 
38 class Model {
39  private:
40  std::shared_ptr<Matrix> wi_;
41  std::shared_ptr<Matrix> wo_;
42  std::shared_ptr<QMatrix> qwi_;
43  std::shared_ptr<QMatrix> qwo_;
44  std::shared_ptr<Args> args_;
48  int32_t hsz_;
49  int32_t osz_;
51  int64_t nexamples_;
54  // used for negative sampling:
55  std::vector<int32_t> negatives;
56  size_t negpos;
57  // used for hierarchical softmax:
58  std::vector< std::vector<int32_t> > paths;
59  std::vector< std::vector<bool> > codes;
60  std::vector<Node> tree;
61 
62  static bool comparePairs(const std::pair<real, int32_t>&,
63  const std::pair<real, int32_t>&);
64 
65  int32_t getNegative(int32_t target);
66  void initSigmoid();
67  void initLog();
68 
69  static const int32_t NEGATIVE_TABLE_SIZE = 10000000;
70 
71  public:
72  Model(std::shared_ptr<Matrix>, std::shared_ptr<Matrix>,
73  std::shared_ptr<Args>, int32_t);
74  ~Model();
75 
76  real binaryLogistic(int32_t, bool, real);
77  real negativeSampling(int32_t, real);
78  real hierarchicalSoftmax(int32_t, real);
79  real softmax(int32_t, real);
80 
81  void predict(const std::vector<int32_t>&, int32_t,
82  std::vector<std::pair<real, int32_t>>&,
83  Vector&, Vector&) const;
84  void predict(const std::vector<int32_t>&, int32_t,
85  std::vector<std::pair<real, int32_t>>&);
86  void dfs(int32_t, int32_t, real,
87  std::vector<std::pair<real, int32_t>>&,
88  Vector&) const;
89  void findKBest(int32_t, std::vector<std::pair<real, int32_t>>&,
90  Vector&, Vector&) const;
91  void update(const std::vector<int32_t>&, int32_t, real);
92  void computeHidden(const std::vector<int32_t>&, Vector&) const;
93  void computeOutputSoftmax(Vector&, Vector&) const;
94  void computeOutputSoftmax();
95 
96  void setTargetCounts(const std::vector<int64_t>&);
97  void initTableNegatives(const std::vector<int64_t>&);
98  void buildTree(const std::vector<int64_t>&);
99  real getLoss() const;
100  real sigmoid(real) const;
101  real log(real) const;
102 
103  std::minstd_rand rng;
104  bool quant_;
105  void setQuantizePointer(std::shared_ptr<QMatrix>, std::shared_ptr<QMatrix>, bool);
106 };
107 
108 }
109 
110 #endif
real * t_sigmoid
Definition: model.h:52
real loss_
Definition: model.h:50
int32_t parent
Definition: model.h:31
std::shared_ptr< QMatrix > qwo_
Definition: model.h:43
Definition: model.h:30
std::minstd_rand rng
Definition: model.h:103
int32_t right
Definition: model.h:33
std::vector< Node > tree
Definition: model.h:60
Definition: args.cc:17
Definition: vector.h:23
int64_t count
Definition: model.h:34
real * t_log
Definition: model.h:53
void predict(int argc, char **argv)
Definition: main.cc:138
Definition: model.h:38
std::vector< std::vector< bool > > codes
Definition: model.h:59
int32_t hsz_
Definition: model.h:48
bool quant_
Definition: model.h:104
Vector grad_
Definition: model.h:47
std::shared_ptr< QMatrix > qwi_
Definition: model.h:42
size_t negpos
Definition: model.h:56
std::vector< std::vector< int32_t > > paths
Definition: model.h:58
bool binary
Definition: model.h:35
std::shared_ptr< Matrix > wo_
Definition: model.h:41
std::shared_ptr< Matrix > wi_
Definition: model.h:40
Vector output_
Definition: model.h:46
float real
Definition: real.h:15
int64_t nexamples_
Definition: model.h:51
int32_t left
Definition: model.h:32
std::shared_ptr< Args > args_
Definition: model.h:44
std::vector< int32_t > negatives
Definition: model.h:55
Vector hidden_
Definition: model.h:45
int32_t osz_
Definition: model.h:49