ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

技术/广告 文章分类器(二)

2022-01-03 14:59:15  阅读:145  来源: 互联网

标签:fasttext self 分类器 train 广告 文章 test path data


文章目录


前言

本文基于上一篇博客技术/广告 文章分类器(一),作出了一些优化,将准确率由84.5%提升至94.4%


一、优化手段

1、增加训练数据

之前的训练数据集,两类数据分别只有500条左右,训练数据太少。
本文所使用数据集为45000余条,增加了90倍,应该完全够用

2、更改分类模型

之前使用多项式朴素贝叶斯,效果一般,由于使用了样本属性独立性的假设,所以如果样本属性有关联时其效果不好。因此,直接使用集成学习,达到了一个较好的效果

3、分词时加入用户词典

一些关键的词,并没有被理想分词出来,与不加入用户词典相比,准确率提高了1%左右

4、去除停用词及特殊符号

在分词之前,去除了表情及一些特殊符号,尝试过在分词之后再去除特殊符号,结果证明在分词之前去除特殊符号,效果更好,去除特殊符号后,准确率提升2%左右

二、TFIDF + AdaBoost

全部代码

class TrainBlogClsTfidfAdaBoost:
    def __init__(self):
        jieba.load_userdict(get_blog_cls_jieba_user_dict_path())

        self.train_data_dir = get_blog_cls_train_data_optimize_dir()
        self.tfidf_path = get_tfidf_path()
        self.model_path = get_adaboost_model_path()

        # self.train_data_dir = get_blog_cls_train_data_dev_dir()
        # self.tfidf_path = get_test_tfidf_path()
        # self.model_path = get_adaboost_test_model_path()

    def load(self):
        if not os.path.exists(self.model_path):
            logger.warning("开始训练,目标模型数据:", self.model_path)
            self.train()

        logger.info("加载模型")
        self.model = joblib.load(self.model_path)
        self.tf_idf = joblib.load(self.tfidf_path)

    def load_data(self):
        '''加载文件内容和标签'''
        files = get_files_path(self.train_data_dir, '.txt')
        contents = []
        labels = []

        for file in files:
            with open(file, 'r') as f:
                data = f.read()
            data = filter_content_for_blog_cls(data)
            data_cut = ' '.join(jieba.cut(data))
            contents.append(data_cut)
            label = file.split('/')[-2]
            labels.append(label)
        X_train, X_test, y_train, y_test = train_test_split(contents,
                                                            labels,
                                                            test_size=0.2,
                                                            random_state=123456)
        return X_train, X_test, y_train, y_test

    def load_stopwords(self):
        path = './data/pro/datasets/stopwords/cn_stopwords.txt'
        with open(path, 'r') as f:
            stopwords = f.read().split('\n')
        return stopwords

    def train(self):
        logger.info('开始训练...')
        stopwords = self.load_stopwords()
        X_train, X_test, y_train, y_test = self.load_data()
        tfidf = TfidfVectorizer(stop_words=stopwords, max_df=0.5)
        train_data = tfidf.fit(X_train)
        train_data = tfidf.transform(X_train)
        test_data = tfidf.transform(X_test)

        joblib.dump(tfidf, self.tfidf_path, compress=1)

        model = AdaBoostClassifier()  # 99%

        model.fit(train_data, y_train)

        predict_test = model.predict(test_data)

        joblib.dump(model, self.model_path, compress=1)

        print("准确率为:", metrics.accuracy_score(predict_test, y_test))

    def predict(self, test_data):
        test_data = filter_content_for_blog_cls(test_data)
        test_data = ' '.join(jieba.cut(test_data))

        test_vec = self.tf_idf.transform([test_data])
        res = self.model.predict(test_vec)
        return res

    def test_acc(self):
        data_path = './data/pro/datasets/blogs/blog_adver_cls/test_dev.csv'
        data = pd.read_csv(data_path)
        data = data.dropna(axis=0)
        test_text = data['content']
        text_list = []
        for text in test_text:
            text = filter_content_for_blog_cls(text)
            text = ' '.join(jieba.cut(text))
            text_list.append(text)
        label = data['label']
        test_data = self.tf_idf.transform(text_list)
        predict_test = self.model.predict(test_data)
        print("在测试集准确率为:", metrics.accuracy_score(predict_test, label))

结果:

在测试集准确率为: 0.9646315789473684

测试数据大概5000条,这个数量级,还是比较有说服力的

三、Fasttext

之前有用过fasttext来做图书分类,见「fasttext文本分类」,在三分类上准确率达到93%,在35个类别上准确率为75.6%,总体效果还不错,于是想到用fasttext来试下,看看效果是否会更好些。

全部代码

import os
import fasttext
import jieba
import logging
import random
from tqdm import tqdm
import pandas as pd
from sklearn import metrics
from common.utils import get_files_path
from common.utils import filter_content_for_blog_cls
from common.path.dataset.blog import get_blog_cls_jieba_user_dict_path

from common.path.dataset.blog import get_blog_cls_train_data_dev_dir, get_fasttext_train_data_path
from common.path.model.blog import get_blog_cls_fasttext_model_path

logger = logging.getLogger(__name__)


class TrainBlogClsFasttext:
    def __init__(self):
        jieba.load_userdict(get_blog_cls_jieba_user_dict_path())
        self.train_data_dev_dir = get_blog_cls_train_data_dev_dir()
        self.train_data_path = get_fasttext_train_data_path()
        self.fasttext_model_path = get_blog_cls_fasttext_model_path()
        self.class_name_mapping = {
            '__label__0': 'technology',
            '__label__1': 'advertisement'
        }

    
    def load(self):
        if not os.path.exists(self.fasttext_model_path):
            logger.info('开始训练模型...')
            self.train_fasttext()
        logger.info("加载模型")
        self.model = fasttext.load_model(self.fasttext_model_path)
        

    def data_process(self):
        data_dir = self.train_data_dev_dir
        files = get_files_path(data_dir, '.txt')
        
        if not os.path.exists(self.train_data_path):
            os.mkdir(self.train_data_path)
        random.shuffle(files)

        fasttext_train_data_path = os.path.join(self.train_data_path, 'train.txt')
        fasttext_test_data_path = os.path.join(self.train_data_path, 'test.txt')
        if os.path.exists(fasttext_train_data_path) and os.path.exists(fasttext_test_data_path):
            return
        lines_train = []
        lines_test = []
        all_data = []
        for file in tqdm(files, desc='正在构建训练数据: '):
            with open(file, 'r') as f:
                data = f.read()
            data = filter_content_for_blog_cls(data)
            data = ' '.join(jieba.cut(data))

            if file.find('technology') != -1:
                label = '__label__{}'.format(0)
            elif file.find('advertisement') != -1:
                label = '__label__{}'.format(1)
            else:
                print("错误的数据:{}".format(file))
            line = data + '\t' + label + '\n'
            all_data.append(line)
        
        lines_train = all_data[:int(len(all_data)*0.8)]
        lines_test = all_data[int(len(all_data)*0.8):]
        with open(fasttext_train_data_path, 'a') as f:
            f.writelines(lines_train)
        with open(fasttext_test_data_path, 'a') as f:
            f.writelines(lines_test)


    def load_stopwords(self):
        path = './data/pro/datasets/stopwords/cn_stopwords.txt'
        with open(path, 'r') as f:
            stopwords = f.read().split('\n')
        return stopwords


    def train_fasttext(self):
        self.data_process()
        data_dir = self.train_data_path
        train_path = os.path.join(data_dir, 'train.txt')
        test_path = os.path.join(data_dir, 'test.txt')

        classifier = fasttext.train_supervised(input=train_path,
                                            label="__label__",
                                            dim=100,
                                            epoch=10,
                                            lr=0.1,
                                            wordNgrams=2,
                                            loss='softmax',
                                            thread=8,
                                            verbose=True)
        classifier.save_model(self.fasttext_model_path)
        result = classifier.test(test_path)
        logger.info('Train Result:'.format(result))
        logger.info('F1 Score: {}'.format(result[1] * result[2] * 2 /
                                    (result[2] + result[1])))
    
    def predict(self, text):

        test_data = filter_content_for_blog_cls(text)
        test_data = ' '.join(jieba.cut(test_data))

        result = self.model.predict(test_data)
        class_name = result[0][0]
        res_label = self.class_name_mapping[class_name]
        return res_label

    def test_acc(self):

        data_path = './data/pro/datasets/blogs/blog_adver_cls/test_dev.csv'
        data = pd.read_csv(data_path)
        data = data.dropna(axis=0)
        test_text = data['content']
        text_list = []
        for text in test_text:
            text = filter_content_for_blog_cls(text)
            text = ' '.join(jieba.cut(text))
            text_list.append(text)
        labels = data['label']
        res_labels = []
        for text in text_list:
            label = self.model.predict(text)
            class_name = label[0][0]
            res_label = self.class_name_mapping[class_name]
            res_labels.append(res_label)
        print("在测试集准确率为:", metrics.accuracy_score(res_labels, labels))

代码没什么难的,主要就是数据处理,这里也是在分词之前去除了特殊符号,这样做效果确实有提升,可以自己尝试下。

直接看效果吧:

[INFO][2022-01-03 14:39:23][fasttext_classifier.py:33 at load]: 开始训练模型...
Read 5M words
Number of words:  261664
Number of labels: 2
Progress: 100.0% words/sec/thread: 1415852 lr:  0.000000 avg.loss:  0.059132 ETA:   0h 0m 0s
[INFO][2022-01-03 14:39:33][fasttext_classifier.py:101 at train_fasttext]: Train Result:
[INFO][2022-01-03 14:39:33][fasttext_classifier.py:102 at train_fasttext]: F1 Score: 0.9638259736027375
[INFO][2022-01-03 14:39:33][fasttext_classifier.py:35 at load]: 加载模型
Warning : `load_model` does not return WordVectorModel or SupervisedModel any more, but a `FastText` object which is very similar.
在测试集准确率为: 0.9661052631578947

在同一份测试数据集上,Fasttext准确率高了0.2%,但模型大小为912M,使用TFIDF + AdaBoost 训练出来的模型加起来也就4.9M。

实际推理速度还未测试过,因此目前使用的是占用内存更小的 TFIDF + AdaBoost。

总结

多观察数据,理解数据特征,对提升模型效果有莫大的帮助。

事实证明:

1、增加用户词典可以提升准确率
2、去除文本中的特殊字符可以提升准确率

相关文章:

1、技术/广告 文章分类器(一)
2、fasttext文本分类

标签:fasttext,self,分类器,train,广告,文章,test,path,data
来源: https://blog.csdn.net/qq_44193969/article/details/122286888

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有