ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

java – Mahout – 简单的分类问题

2019-05-30 07:48:28  阅读:203  来源: 互联网

标签:java classification mahout


我正在尝试构建一个简单的模型,可以将点分类为2D空间的2个分区:

>我通过指定几个点和它们所属的分区来训练模型.
>我使用模型来预测测试点可能落入的组(分类).

不幸的是,我没有得到预期的答案.我在代码中遗漏了什么,或者我做错了什么?

public class SimpleClassifier {

    public static class Point{
        public int x;
        public int y;

        public Point(int x,int y){
            this.x = x;
            this.y = y;
        }

        @Override
        public boolean equals(Object arg0) {
            Point p = (Point)  arg0;
            return( (this.x == p.x) &&(this.y== p.y));
        }

        @Override
        public String toString() {
            // TODO Auto-generated method stub
            return  this.x + " , " + this.y ; 
        }
    }

    public static void main(String[] args) {

        Map<Point,Integer> points = new HashMap<SimpleClassifier.Point, Integer>();

        points.put(new Point(0,0), 0);
        points.put(new Point(1,1), 0);
        points.put(new Point(1,0), 0);
        points.put(new Point(0,1), 0);
        points.put(new Point(2,2), 0);


        points.put(new Point(8,8), 1);
        points.put(new Point(8,9), 1);
        points.put(new Point(9,8), 1);
        points.put(new Point(9,9), 1);


        OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression();
        learningAlgo =  new OnlineLogisticRegression(2, 2, new L1());
        learningAlgo.learningRate(50);

        //learningAlgo.alpha(1).stepOffset(1000);

        System.out.println("training model  \n" );
        for(Point point : points.keySet()){
            Vector v = getVector(point);
            System.out.println(point  + " belongs to " + points.get(point));
            learningAlgo.train(points.get(point), v);
        }

        learningAlgo.close();


        //now classify real data
        Vector v = new RandomAccessSparseVector(2);
        v.set(0, 0.5);
        v.set(1, 0.5);

        Vector r = learningAlgo.classifyFull(v);
        System.out.println(r);

        System.out.println("ans = " );
        System.out.println("no of categories = " + learningAlgo.numCategories());
        System.out.println("no of features = " + learningAlgo.numFeatures());
        System.out.println("Probability of cluster 0 = " + r.get(0));
        System.out.println("Probability of cluster 1 = " + r.get(1));

    }

    public static Vector getVector(Point point){
        Vector v = new DenseVector(2);
        v.set(0, point.x);
        v.set(1, point.y);

        return v;
    }
}

输出:

ans = 
no of categories = 2
no of features = 2
Probability of cluster 0 = 3.9580985042775296E-4
Probability of cluster 1 = 0.9996041901495722

99%的输出显示集群1的概率更高.为什么?

解决方法:

问题是你没有包含偏见(拦截)术语,它总是1.
您需要在您的点类中添加偏差项(1).

这是许多有经验的人在机器学习中犯下的一个非常基本的错误.在学习理论上投入一些时间可能是个好主意. Andrew Ng’s lectures是一个值得学习的好地方.

要使代码得到预期的输出,需要更改以下内容.

>偏见术语补充说.
>学习参数太高.将其更改为10

现在你将得到0级的P(0)= 0.9999.

这是一个完整的工作示例,可以给出正确的结果:

import java.util.HashMap;
import java.util.Map;

import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;


class Point{
    public int x;
    public int y;

    public Point(int x,int y){
        this.x = x;
        this.y = y;
    }

    @Override
    public boolean equals(Object arg0) {
        Point p = (Point)  arg0;
        return( (this.x == p.x) &&(this.y== p.y));
    }

    @Override
    public String toString() {
        return  this.x + " , " + this.y ; 
    }
}

public class SimpleClassifier {



    public static void main(String[] args) {

            Map<Point,Integer> points = new HashMap<Point, Integer>();

            points.put(new Point(0,0), 0);
            points.put(new Point(1,1), 0);
            points.put(new Point(1,0), 0);
            points.put(new Point(0,1), 0);
            points.put(new Point(2,2), 0);

            points.put(new Point(8,8), 1);
            points.put(new Point(8,9), 1);
            points.put(new Point(9,8), 1);
            points.put(new Point(9,9), 1);


            OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression();
            learningAlgo =  new OnlineLogisticRegression(2, 3, new L1());
            learningAlgo.lambda(0.1);
            learningAlgo.learningRate(10);

            System.out.println("training model  \n" );

            for(Point point : points.keySet()){

                Vector v = getVector(point);
                System.out.println(point  + " belongs to " + points.get(point));
                learningAlgo.train(points.get(point), v);
            }

            learningAlgo.close();

            Vector v = new RandomAccessSparseVector(3);
            v.set(0, 0.5);
            v.set(1, 0.5);
            v.set(2, 1);

            Vector r = learningAlgo.classifyFull(v);
            System.out.println(r);

            System.out.println("ans = " );
            System.out.println("no of categories = " + learningAlgo.numCategories());
            System.out.println("no of features = " + learningAlgo.numFeatures());
            System.out.println("Probability of cluster 0 = " + r.get(0));
            System.out.println("Probability of cluster 1 = " + r.get(1));

    }

    public static Vector getVector(Point point){
        Vector v = new DenseVector(3);
        v.set(0, point.x);
        v.set(1, point.y);
        v.set(2, 1);
        return v;
    }
}

输出:

2 , 2 belongs to 0
1 , 0 belongs to 0
9 , 8 belongs to 1
8 , 8 belongs to 1
0 , 1 belongs to 0
0 , 0 belongs to 0
1 , 1 belongs to 0
9 , 9 belongs to 1
8 , 9 belongs to 1
{0:2.470723149516907E-6,1:0.9999975292768505}
ans = 
no of categories = 2
no of features = 3
Probability of cluster 0 = 2.470723149516907E-6
Probability of cluster 1 = 0.9999975292768505

请注意,我在SimpleClassifier类之外定义了类Point,但这只是为了使代码更具可读性并且不是必需的.

看看改变学习率时会发生什么.阅读有关交叉验证的说明,了解如何选择学习率.

Learning Rate => Probability of cluster 0
0.001 => 0.4991116089
0.01 => 0.492481585
0.1 => 0.469961472
1 => 0.5327745322
10 => 0.9745740393
100 => 0
1000 => 0

选择学习率:

>通过以固定学习率α开始,通过缓慢地让学习率α减小到零来进行随机梯度下降是常见的.
算法运行时,也可以确保参数收敛到
全局最小值而不仅仅是在最小值附近振荡.
>在这种情况下,当我们使用常数α时,您可以进行初始选择,运行梯度下降和观察成本函数,并相应地调整学习率.它被解释为here

标签:java,classification,mahout
来源: https://codeday.me/bug/20190530/1182708.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有