ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

python – 在sklearn DecisionTreeClassifier中修剪不必要的叶子

2019-10-07 03:59:04  阅读:354  来源: 互联网

标签:decision-tree python scikit-learn pruning


我使用sklearn.tree.DecisionTreeClassifier来构建决策树.使用最佳参数设置,我得到一个有不必要叶子的树(参见下面的示例图片 – 我不需要概率,所以标记为红色的叶节点是不必要的分割)

Tree

是否有任何第三方库用于修剪这些不必要的节点?还是代码片段?我可以写一个,但我无法想象我是第一个有这个问题的人……

要复制的代码:

from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
y = iris.target
mdl = DecisionTreeClassifier(max_leaf_nodes=8)
mdl.fit(X,y)

PS:我尝试了多次关键词搜索,并且很惊讶地发现什么都没有 – 在sklearn中是否真的没有后期修剪?

PPS:响应可能的重复:虽然the suggested question可能在我自己编码修剪算法时帮助我,但它回答了一个不同的问题 – 我想摆脱不改变最终决定的叶子,而另一个问题想要一个最小阈值用于拆分节点.

PPPS:显示的树是一个显示我的问题的例子.我知道创建树的参数设置不是最理想的.我不是要求优化这个特定的树,我需要进行后修剪以摆脱可能有用的叶子,如果一个人需要类概率,但如果一个人只对最有可能的类感兴趣则没有帮助.

解决方法:

使用ncfirth的链接,我能够修改那里的代码,以便它适合我的问题:

from sklearn.tree._tree import TREE_LEAF

def is_leaf(inner_tree, index):
    # Check whether node is leaf node
    return (inner_tree.children_left[index] == TREE_LEAF and 
            inner_tree.children_right[index] == TREE_LEAF)

def prune_index(inner_tree, decisions, index=0):
    # Start pruning from the bottom - if we start from the top, we might miss
    # nodes that become leaves during pruning.
    # Do not use this directly - use prune_duplicate_leaves instead.
    if not is_leaf(inner_tree, inner_tree.children_left[index]):
        prune_index(inner_tree, decisions, inner_tree.children_left[index])
    if not is_leaf(inner_tree, inner_tree.children_right[index]):
        prune_index(inner_tree, decisions, inner_tree.children_right[index])

    # Prune children if both children are leaves now and make the same decision:     
    if (is_leaf(inner_tree, inner_tree.children_left[index]) and
        is_leaf(inner_tree, inner_tree.children_right[index]) and
        (decisions[index] == decisions[inner_tree.children_left[index]]) and 
        (decisions[index] == decisions[inner_tree.children_right[index]])):
        # turn node into a leaf by "unlinking" its children
        inner_tree.children_left[index] = TREE_LEAF
        inner_tree.children_right[index] = TREE_LEAF
        ##print("Pruned {}".format(index))

def prune_duplicate_leaves(mdl):
    # Remove leaves if both 
    decisions = mdl.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node
    prune_index(mdl.tree_, decisions)

在DecisionTreeClassifier clf上使用它:

prune_duplicate_leaves(clf)

编辑:修复了更复杂树木的错误

标签:decision-tree,python,scikit-learn,pruning
来源: https://codeday.me/bug/20191007/1864307.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有