ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

python – 可以在scikit-learn中修改/修剪学过的树吗?

2019-05-19 10:55:31  阅读:340  来源: 互联网

标签:python machine-learning scikit-learn random-forest decision-tree


可以在sklearn中访问树参数

tree.tree_.children_left
tree.tree_.children_right
tree.tree_.threshold
tree.tree_.feature

等等

但是,尝试写入这些变量会引发一个不可写的异常

有没有办法修改学习树,或绕过AttributeError不可写?

解决方法:

属性都是无法覆盖的int数组.您仍然可以修改这些数组的元素.这不会减轻数据.

children_left : array of int, shape [node_count]
    children_left[i] holds the node id of the left child of node i.
    For leaves, children_left[i] == TREE_LEAF. Otherwise,
    children_left[i] > i. This child handles the case where
    X[:, feature[i]] <= threshold[i].

children_right : array of int, shape [node_count]
    children_right[i] holds the node id of the right child of node i.
    For leaves, children_right[i] == TREE_LEAF. Otherwise,
    children_right[i] > i. This child handles the case where
    X[:, feature[i]] > threshold[i].

feature : array of int, shape [node_count]
    feature[i] holds the feature to split on, for the internal node i.

threshold : array of double, shape [node_count]
    threshold[i] holds the threshold for the internal node i.

要根据节点中的观察数量修剪DecisionTree,我使用此函数.您需要知道TREE_LEAF常量等于-1.

def prune(decisiontree, min_samples_leaf = 1):
    if decisiontree.min_samples_leaf >= min_samples_leaf:
        raise Exception('Tree already more pruned')
    else:
        decisiontree.min_samples_leaf = min_samples_leaf
        tree = decisiontree.tree_
        for i in range(tree.node_count):
            n_samples = tree.n_node_samples[i]
            if n_samples <= min_samples_leaf:
                tree.children_left[i]=-1
                tree.children_right[i]=-1

这是一个在前后生成graphviz输出的示例:

[from sklearn.tree import DecisionTreeRegressor as DTR
from sklearn.datasets import load_diabetes
from sklearn.tree import export_graphviz as export

bunch = load_diabetes()
data = bunch.data
target = bunch.target

dtr = DTR(max_depth = 4)
dtr.fit(data,target)

export(decision_tree=dtr.tree_, out_file='before.dot')
prune(dtr, min_samples_leaf = 100)
export(decision_tree=dtr.tree_, out_file='after.dot')][1]

标签:python,machine-learning,scikit-learn,random-forest,decision-tree
来源: https://codeday.me/bug/20190519/1134631.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有