ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

TransE的程序实现——学习,查阅,注释

2020-09-11 23:31:36  阅读:266  来源: 互联网

标签:tmp 程序实现 double entity int fb relation TransE 查阅


代码来源:https://github.com/thunlp/KB2E/blob/master/TransE/

很久没碰C++了,为了实验成功,还是要仔仔细细抠一抠代码才行。

第一部分:头文件的引入以及常量定义

 1 #include<iostream>  /*输入输出流*/
 2 #include<cstring>   /*C字符串操作函数*/
 3 #include<cstdio>    /*标准输入输出的C++形式*/
 4 #include<map>   /*定义了一种关联容器,数据结构*/
 5 #include<vector>  /*顺序容器,常用于表示向量*/
 6 #include<string>  /*字符串操作函数*/
 7 #include<ctime>  /*日期时间结构体*/
 8 #include<cmath>  /*数学操作*/
 9 #include<cstdlib>  /*提供一些函数和符号常量,如第二部分出现的RAND_MAX*/
10 using namespace std; //添加命名空间
11 
12 #define pi 3.1415926535897932384626433832795
13 
14 bool L1_flag=1;  /*标识数为1时,该变量为L1范数;否则,则为L2范数*/

第二部分:简单计算函数的定义

 1 //normal distribution  正态分布
 2 double rand(double min, double max)  /*返回[min,max)之间的随机数*/
 3 {
 4     return min+(max-min)*rand()/(RAND_MAX+1.0);
 5 }
 6 double normal(double x, double miu,double sigma)  /*返回一个均值为miu标准差为sigma的正态分布函数在x处的函数值*/
 7 {
 8     return 1.0/sqrt(2*pi)/sigma*exp(-1*(x-miu)*(x-miu)/(2*sigma*sigma));
 9 }
10 double randn(double miu,double sigma, double min ,double max)  /*通过产生随机数的方式,返回满足某条件的自变量x*/
11 {
12     double x,y,dScope;
13     do{
14         x=rand(min,max);  /*调用自定义函数rand(min,max)*/
15         y=normal(x,miu,sigma);  /*调用自定义函数normal(x,miu,sigma)*/
16         dScope=rand(0.0,normal(miu,miu,sigma));  /*当x值取miu时,正态函数达到最大值,该变量在0与正态函数最大值取随机数*/
17     }while(dScope>y);  /*x处的函数值小于该随机数时,循环停止*/
18     return x;
19 }
20 
21 double sqr(double x)  /*平方函数*/
22 {
23     return x*x;
24 }
25 
26 double vec_len(vector<double> &a)  /*返回向量a的模,即L2范数*/
27 {
28     double res=0;
29     for (int i=0; i<a.size(); i++)  /*遍历向量长度*/
30         res+=a[i]*a[i];  
31     res = sqrt(res);
32     return res;
33 }

rand()返回一个0到最大随机数RAND_MAX(确定)的任意整数,RAND_MAX至少为32767。

第三部分:变量的定义

1 string version;
2 char buf[100000],buf1[100000];
3 int relation_num,entity_num;  /*定义关系数量,实体数量*/
4 map<string,int> relation2id,entity2id;  /*关系和实体使用关联容器,字符串作关键词,整型作关键字的值*/
5 map<int,string> id2entity,id2relation;  /*与上相反*/
6 
7 map<int,map<int,int> > left_entity,right_entity; 
8 map<int,double> left_num,right_num;

第7行,left_entity表示在此relation下头实体对应的尾实体的个数,3个int分别表示relation_id,headentity_id,个数。right_entity表示在此relation下尾实体对应的头实体的个数,3个int分别表示relation_id,tailentity_id,个数。主要用于计算采样概率p。

第8行,leftnum表示平均每个头实体对应多少个尾实体。rightnum表示平均每个尾实体对应多少头实体。int仍然表示relation_id。

第四部分:训练类的定义

  1 class Train{
  2 
  3 public:
  4     map<pair<int,int>, map<int,int> > ok;
  5     void add(int x,int y,int z)
  6     {
  7         fb_h.push_back(x);
  8         fb_r.push_back(z);
  9         fb_l.push_back(y);
 10         ok[make_pair(x,z)][y]=1;
 11     }
 12     void run(int n_in,double rate_in,double margin_in,int method_in)
 13     {
 14         n = n_in;
 15         rate = rate_in;
 16         margin = margin_in;
 17         method = method_in;
 18         relation_vec.resize(relation_num);
 19         for (int i=0; i<relation_vec.size(); i++)
 20             relation_vec[i].resize(n);
 21         entity_vec.resize(entity_num);
 22         for (int i=0; i<entity_vec.size(); i++)
 23             entity_vec[i].resize(n);
 24         relation_tmp.resize(relation_num);
 25         for (int i=0; i<relation_tmp.size(); i++)
 26             relation_tmp[i].resize(n);
 27         entity_tmp.resize(entity_num);
 28         for (int i=0; i<entity_tmp.size(); i++)
 29             entity_tmp[i].resize(n);
 30         for (int i=0; i<relation_num; i++)
 31         {
 32             for (int ii=0; ii<n; ii++)
 33                 relation_vec[i][ii] = randn(0,1.0/n,-6/sqrt(n),6/sqrt(n));
 34         }
 35         for (int i=0; i<entity_num; i++)
 36         {
 37             for (int ii=0; ii<n; ii++)
 38                 entity_vec[i][ii] = randn(0,1.0/n,-6/sqrt(n),6/sqrt(n));
 39             norm(entity_vec[i]);
 40         }
 41 
 42 
 43         bfgs();
 44     }
 45 
 46 private:
 47     int n,method;
 48     double res;//loss function value
 49     double count,count1;//loss function gradient
 50     double rate,margin;
 51     double belta;
 52     vector<int> fb_h,fb_l,fb_r;
 53     vector<vector<int> > feature;
 54     vector<vector<double> > relation_vec,entity_vec;
 55     vector<vector<double> > relation_tmp,entity_tmp;
 56     double norm(vector<double> &a)
 57     {
 58         double x = vec_len(a);
 59         if (x>1)
 60         for (int ii=0; ii<a.size(); ii++)
 61                 a[ii]/=x;
 62         return 0;
 63     }
 64     int rand_max(int x)
 65     {
 66         int res = (rand()*rand())%x;
 67         while (res<0)
 68             res+=x;
 69         return res;
 70     }
 71 
 72     void bfgs()
 73     {
 74         res=0;
 75         int nbatches=100;
 76         int nepoch = 1000;
 77         int batchsize = fb_h.size()/nbatches;
 78             for (int epoch=0; epoch<nepoch; epoch++)
 79             {
 80 
 81                 res=0;
 82                  for (int batch = 0; batch<nbatches; batch++)
 83                  {
 84                      relation_tmp=relation_vec;
 85                     entity_tmp = entity_vec;
 86                      for (int k=0; k<batchsize; k++)
 87                      {
 88                         int i=rand_max(fb_h.size());
 89                         int j=rand_max(entity_num);
 90                         double pr = 1000*right_num[fb_r[i]]/(right_num[fb_r[i]]+left_num[fb_r[i]]);
 91                         if (method ==0)
 92                             pr = 500;
 93                         if (rand()%1000<pr)
 94                         {
 95                             while (ok[make_pair(fb_h[i],fb_r[i])].count(j)>0)
 96                                 j=rand_max(entity_num);
 97                             train_kb(fb_h[i],fb_l[i],fb_r[i],fb_h[i],j,fb_r[i]);
 98                         }
 99                         else
100                         {
101                             while (ok[make_pair(j,fb_r[i])].count(fb_l[i])>0)
102                                 j=rand_max(entity_num);
103                             train_kb(fb_h[i],fb_l[i],fb_r[i],j,fb_l[i],fb_r[i]);
104                         }
105                         norm(relation_tmp[fb_r[i]]);
106                         norm(entity_tmp[fb_h[i]]);
107                         norm(entity_tmp[fb_l[i]]);
108                         norm(entity_tmp[j]);
109                      }
110                     relation_vec = relation_tmp;
111                     entity_vec = entity_tmp;
112                  }
113                 cout<<"epoch:"<<epoch<<' '<<res<<endl;
114                 FILE* f2 = fopen(("relation2vec."+version).c_str(),"w");
115                 FILE* f3 = fopen(("entity2vec."+version).c_str(),"w");
116                 for (int i=0; i<relation_num; i++)
117                 {
118                     for (int ii=0; ii<n; ii++)
119                         fprintf(f2,"%.6lf\t",relation_vec[i][ii]);
120                     fprintf(f2,"\n");
121                 }
122                 for (int i=0; i<entity_num; i++)
123                 {
124                     for (int ii=0; ii<n; ii++)
125                         fprintf(f3,"%.6lf\t",entity_vec[i][ii]);
126                     fprintf(f3,"\n");
127                 }
128                 fclose(f2);
129                 fclose(f3);
130             }
131     }
132     double res1;
133     double calc_sum(int e1,int e2,int rel)
134     {
135         double sum=0;
136         if (L1_flag)
137             for (int ii=0; ii<n; ii++)
138                 sum+=fabs(entity_vec[e2][ii]-entity_vec[e1][ii]-relation_vec[rel][ii]);
139         else
140             for (int ii=0; ii<n; ii++)
141                 sum+=sqr(entity_vec[e2][ii]-entity_vec[e1][ii]-relation_vec[rel][ii]);
142         return sum;
143     }
144     void gradient(int e1_a,int e2_a,int rel_a,int e1_b,int e2_b,int rel_b)
145     {
146         for (int ii=0; ii<n; ii++)
147         {
148 
149             double x = 2*(entity_vec[e2_a][ii]-entity_vec[e1_a][ii]-relation_vec[rel_a][ii]);
150             if (L1_flag)
151                 if (x>0)
152                     x=1;
153                 else
154                     x=-1;
155             relation_tmp[rel_a][ii]-=-1*rate*x;
156             entity_tmp[e1_a][ii]-=-1*rate*x;
157             entity_tmp[e2_a][ii]+=-1*rate*x;
158             x = 2*(entity_vec[e2_b][ii]-entity_vec[e1_b][ii]-relation_vec[rel_b][ii]);
159             if (L1_flag)
160                 if (x>0)
161                     x=1;
162                 else
163                     x=-1;
164             relation_tmp[rel_b][ii]-=rate*x;
165             entity_tmp[e1_b][ii]-=rate*x;
166             entity_tmp[e2_b][ii]+=rate*x;
167         }
168     }
169     void train_kb(int e1_a,int e2_a,int rel_a,int e1_b,int e2_b,int rel_b)
170     {
171         double sum1 = calc_sum(e1_a,e2_a,rel_a);
172         double sum2 = calc_sum(e1_b,e2_b,rel_b);
173         if (sum1+margin>sum2)
174         {
175             res+=margin+sum1-sum2;
176             gradient( e1_a, e2_a, rel_a, e1_b, e2_b, rel_b);
177         }
178     }
179 };

第五部分:类变量的定义和数据集准备

 1 Train train;
 2 void prepare()
 3 {
 4     FILE* f1 = fopen("../data/entity2id.txt","r");
 5     FILE* f2 = fopen("../data/relation2id.txt","r");
 6     int x;
 7     while (fscanf(f1,"%s%d",buf,&x)==2)
 8     {
 9         string st=buf;
10         entity2id[st]=x;
11         id2entity[x]=st;
12         entity_num++;
13     }
14     while (fscanf(f2,"%s%d",buf,&x)==2)
15     {
16         string st=buf;
17         relation2id[st]=x;
18         id2relation[x]=st;
19         relation_num++;
20     }
21     FILE* f_kb = fopen("../data/train.txt","r");
22     while (fscanf(f_kb,"%s",buf)==1)
23     {
24         string s1=buf;
25         fscanf(f_kb,"%s",buf);
26         string s2=buf;
27         fscanf(f_kb,"%s",buf);
28         string s3=buf;
29         if (entity2id.count(s1)==0)
30         {
31             cout<<"miss entity:"<<s1<<endl;
32         }
33         if (entity2id.count(s2)==0)
34         {
35             cout<<"miss entity:"<<s2<<endl;
36         }
37         if (relation2id.count(s3)==0)
38         {
39             relation2id[s3] = relation_num;
40             relation_num++;
41         }
42         left_entity[relation2id[s3]][entity2id[s1]]++;
43         right_entity[relation2id[s3]][entity2id[s2]]++;
44         train.add(entity2id[s1],entity2id[s2],relation2id[s3]);
45     }
46     for (int i=0; i<relation_num; i++)
47     {
48         double sum1=0,sum2=0;
49         for (map<int,int>::iterator it = left_entity[i].begin(); it!=left_entity[i].end(); it++)
50         {
51             sum1++;
52             sum2+=it->second;
53         }
54         left_num[i]=sum2/sum1;
55     }
56     for (int i=0; i<relation_num; i++)
57     {
58         double sum1=0,sum2=0;
59         for (map<int,int>::iterator it = right_entity[i].begin(); it!=right_entity[i].end(); it++)
60         {
61             sum1++;
62             sum2+=it->second;
63         }
64         right_num[i]=sum2/sum1;
65     }
66     cout<<"relation_num="<<relation_num<<endl;
67     cout<<"entity_num="<<entity_num<<endl;
68     fclose(f_kb);
69 }

第六部分:未知还没搞懂功能的函数

 1 int ArgPos(char *str, int argc, char **argv) {
 2   int a;
 3   for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) {
 4     if (a == argc - 1) {
 5       printf("Argument missing for %s\n", str);
 6       exit(1);
 7     }
 8     return a;
 9   }
10   return -1;
11 }

第七部分:主函数流程

 1 int main(int argc,char**argv)
 2 {
 3     srand((unsigned) time(NULL));
 4     int method = 1;
 5     int n = 100;
 6     double rate = 0.001;
 7     double margin = 1;
 8     int i;
 9     if ((i = ArgPos((char *)"-size", argc, argv)) > 0) n = atoi(argv[i + 1]);
10     if ((i = ArgPos((char *)"-margin", argc, argv)) > 0) margin = atoi(argv[i + 1]);
11     if ((i = ArgPos((char *)"-method", argc, argv)) > 0) method = atoi(argv[i + 1]);
12     cout<<"size = "<<n<<endl;
13     cout<<"learing rate = "<<rate<<endl;
14     cout<<"margin = "<<margin<<endl;
15     if (method)
16         version = "bern";
17     else
18         version = "unif";
19     cout<<"method = "<<version<<endl;
20     prepare();
21     train.run(n,rate,margin,method);
22 }

 

left_entity:在此relation下头实体对应的尾实体的个数,3个int分别表示relation_id,headentity_id,个数

标签:tmp,程序实现,double,entity,int,fb,relation,TransE,查阅
来源: https://www.cnblogs.com/real-zz-11/p/real_zz.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有