ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

增量学习/训练

2022-01-25 10:36:04  阅读:231  来源: 互联网

标签:训练 lgb image batch 学习 train 增量 data ds


针对大型数据集,数据过大无法加载到内存,使用增量训练方式

目录

    sklearn

    https://scikit-learn.org/stable/auto_examples/applications/plot_out_of_core_classification.html#sphx-glr-auto-examples-applications-plot-out-of-core-classification-py

    lightgbm

    自定义生成器

    predicts = []
    y_train = ()
    
    for i, (x_batch_text, y_batch) in enumerate(minibatch_iterators):
        x_batch = vectorizer.transform(x_batch_text)
    
        # sgd
        sgd_clf.partial_fit(x_batch, y_batch, classes=all_classes)
        
        # lgb
        try:
            lgb_clf.fit(x_batch, y_batch, init_model = lgb_clf)
        except:
            lgb_clf.fit(x_batch, y_batch)
        
        y_train += y_batch
        predicts = np.hstack([predicts,lgb_clf.predict(x_batch)])
        
        if i % 500 == 0:
            print("iter %s ============== " % i)
            metrics(y_train,predicts)
    

    借助pandas

    import lightgbm as lgb
    
    def increment():
        # 第一步,初始化模型为None,设置模型参数
        gbm=None
        params = {
                'task': 'train',
                'objective': 'multiclass',
                'num_class':"3",
                'boosting_type': 'gbdt',
                'learning_rate': 0.1,
                'num_leaves': 31,
                'tree_learner': 'serial',
                'min_data_in_leaf': 100,
                'metric': ['multi_logloss','multi_error'],
                'max_bin': 255,
                'num_trees': 300
            }
        # 第二步,流式读取数据(每次10万)
        CHUNK_SIZE = 1000000
    
        all_data = pd.read_csv(path, chunksize=CHUNK_SIZE)
    
        i = 0
        for data_chunk in all_data:
            print ('Size of uploaded chunk: %i instances, %i features' % (data_chunk.shape))
    
            # preprocess
            data_chunk = shuffle(data_chunk)
            x_train, y_train = pipeline(data_chunk)
    
            # 创建lgb的数据集
            lgb_train = lgb.Dataset(x_train, y_train)
            lgb_eval = lgb.Dataset(x_test, y_test)
    
            # 第三步:增量训练模型
            # 重点来了,通过 init_model 和 keep_training_booster 两个参数实现增量训练
            gbm = lgb.train(params,
                            lgb_train,
                            num_boost_round=1000,
                            valid_sets=lgb_eval,
                            init_model=gbm,             # 如果gbm不为None,那么就是在上次的基础上接着训练
                            early_stopping_rounds=10,
                            verbose_eval=False,
                            keep_training_booster=True) # 增量训练 
    
            # 输出模型评估分数
            score_train = dict([(s[1], s[2]) for s in gbm.eval_train()])
            score_valid = dict([(s[1], s[2]) for s in gbm.eval_valid()])
            print('当前模型在训练集的得分是:loss=%.4f, erro=%.4f'%(score_train['multi_logloss'], score_train['multi_error']))
            print('当前模型在测试集的得分是:loss=%.4f, erro=%.4f' % (score_valid['multi_logloss'], score_valid['multi_error']))
            i += 1
        return gbm
    gbm = increment()
    

    tensorflow

    加载上次保存的网络,接着训练就好了

    # 定义dataset
    def load_and_preprocess_from_path_label(path, label):
      return load_and_preprocess_image(path), label
    
    def make_dataset(image_paths, image_labels, image_count, BATCH_SIZE=32, AUTOTUNE=tf.data.experimental.AUTOTUNE):
        ds = tf.data.Dataset.from_tensor_slices((image_paths, image_labels))
    
        image_label_ds = ds.map(load_and_preprocess_from_path_label, num_parallel_calls=AUTOTUNE)
        
        # 设置一个和数据集大小一致的 shuffle buffer size(随机缓冲区大小)以保证数据
        # 被充分打乱。
        ds = image_label_ds.shuffle(buffer_size=image_count)
        # ds = ds.repeat()
        ds = ds.batch(BATCH_SIZE)
        # 当模型在训练的时候,`prefetch` 使数据集在后台取得 batch。
        ds = ds.prefetch(buffer_size=AUTOTUNE)
        return ds
    

    references

    https://zhuanlan.zhihu.com/p/41422048

    标签:训练,lgb,image,batch,学习,train,增量,data,ds
    来源: https://www.cnblogs.com/gongyanzh/p/15841929.html

    本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
    2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
    3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
    4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
    5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

    专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

    Copyright (C)ICode9.com, All Rights Reserved.

    ICode9版权所有