ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

[Deeplearning4j应用教程09]_基于DL4J的自动编码器

2021-01-22 15:30:37  阅读:164  来源: 互联网

标签:编码器 int import 09 List Deeplearning4j org new


基于DL4J的自动编码器

一、简介

为什么要使用自动编码器? 在实践中,自动编码器通常应用于数据的降噪和降维。 这对于表示学习非常有用,而对于数据压缩则不太有用。
在深度学习中,自动编码器是“尝试”以重建其输入的神经网络。 它可以用作特征提取的一种形式,并且可以将自动编码器堆叠起来以创建“深度”网络。 由自动编码器生成的功能可以输入到其他算法中,以进行分类,聚类和异常检测。
当原始输入数据具有高维且无法轻松绘制时,自动编码器还可用于数据可视化。 通过降维,有时可以将输出压缩到2D或3D空间中,以进行更好的数据探索。
在实际应用当中,异常检测能够用于:网络入侵,欺诈检测,系统监视,传感器网络事件检测(IoT)和异常轨迹感测。

二、自编码器的工作流程

自动编码器包括:
1、编码功能(“编码器”)
2、解码功能(“解码器”)
3、距离函数(“损失函数”)
首先,输入被馈入自动编码器并转换为压缩表示。然后,解码器学习如何从压缩的表示中重建原始输入,在无监督的训练过程中,损失函数有助于纠正解码器产生的错误。 此过程是自动的(因此称为“自动”编码器); 即不需要人工干预。
学习到现在,我们应该已经知道如何使用MultiLayerNetwork和ComputationGraph创建不同的网络配置了,现在,我们将构造一个“堆叠”自动编码器,该编码器对MNIST数字执行异常检测而无需预先训练。而目的是识别异常数字,即不寻常和不典型的数字。从给定数据集的规范中“脱颖而出”的内容,事件或观察结果的识别被广泛称为异常检测。异常检测不需要标注的数据集,并且可以在无监督学习的情况下进行,这很有帮助,因为世界上大多数数据都没有标注。
通常,异常检测使用重构误差来衡量解码器的性能。正常的数据应具有较低的重构误差,而异常值应具有较高的重构误差。

三、基于DL4J的自编码器实现

3.1、导入需要的包

import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.util.*;
import java.util.List;

3.2、堆叠式自动编码器

以下自动编码器使用两个堆叠的密集层进行编码。 MNIST数字转换为长度为784的平面一维数组(MNIST图像为28x28像素,当我们端对端放置它们时等于784)。在网络中,数据的大小变化情况如下:
784→250→10→250→784
代码如下:

//搭建模型
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(12345)
            .weightInit(WeightInit.XAVIER)
            .updater(new AdaGrad(0.05))
            .activation(Activation.RELU)
            .l2(0.0001)
            .list()
            .layer(new DenseLayer.Builder().nIn(784).nOut(250)
                .build())
            .layer(new DenseLayer.Builder().nIn(250).nOut(10)
                .build())
            .layer(new DenseLayer.Builder().nIn(10).nOut(250)
                .build())
            .layer(new OutputLayer.Builder().nIn(250).nOut(784)
                .activation(Activation.LEAKYRELU)
                .lossFunction(LossFunctions.LossFunction.MSE)
                .build())
            .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        //监听器
        net.setListeners(Collections.singletonList(new ScoreIterationListener(10)));

3.3、使用MNIST迭代器

像Deeplearning4j的大多数内置迭代器一样,MNIST迭代器扩展了DataSetIterator类。 该API允许简单地实例化数据集并在后台自动下载数据。
代码如下:

//加载数据,并且进行训练集与测试集的划分:40000训练数据,10000测试数据
        DataSetIterator iter = new MnistDataSetIterator(100,50000,false);

        List<INDArray> featuresTrain = new ArrayList<>();
        List<INDArray> featuresTest = new ArrayList<>();
        List<INDArray> labelsTest = new ArrayList<>();

        Random r = new Random(12345);
        while(iter.hasNext()){
            DataSet ds = iter.next();
            SplitTestAndTrain split = ds.splitTestAndTrain(80, r);  //按照8:2的比例划分数据集 (miniBatch = 100)
            featuresTrain.add(split.getTrain().getFeatures());
            DataSet dsTest = split.getTest();
            featuresTest.add(dsTest.getFeatures());
            INDArray indexes = Nd4j.argMax(dsTest.getLabels(),1); //进行独热编码转换: 表示 -> 索引
            labelsTest.add(indexes);
        }

3.4、无监督训练

现在,我们已经设置了网络配置并与我们的MNIST测试/训练迭代器一起实例化了,训练只需要几行代码。
之前,我们使用setListeners()方法将ScoreIterationListener附加到模型。根据用于运行此代码电脑的浏览器,可以打开调试器/检查器以查看侦听器输出。 由于Deeplearning4j的内部使用SL4J进行日志记录,因此此输出重定向到控制台,并且Zeppelin重定向了该输出。 这有助于减少电脑的混乱情况。
代码如下:

//训练模型
int nEpochs = 3;
for( int epoch=0; epoch<nEpochs; epoch++ ){
     for(INDArray data : featuresTrain){
           net.fit(data,data);
            }
        System.out.println("Epoch " + epoch + " complete");
     }

3.5、评估模型

现在,我们已经对自动编码器进行了训练,那么,我们将根据测试数据来评估模型。每个示例将被单独打分,并且将构成一个映射,该映射将每个数字与(得分,示例)对列表相关联。
最后,我们将计算N个最佳分数和N个最差分数。
代码如下:

//根据测试数据评估模型
//分别对测试集中的每个样本评分
//组成一个映射,将每个数字与(得分,样本)对列表相关联
//然后找到每位数中N个最佳分数和N个最差分数
Map<Integer,List<Pair<Double,INDArray>>> listsByDigit = new HashMap<>();
for( int i=0; i<10; i++ ) listsByDigit.put(i,new ArrayList<>());

    for( int i=0; i<featuresTest.size(); i++ ){
        INDArray testData = featuresTest.get(i);
        INDArray labels = labelsTest.get(i);
        int nRows = testData.rows();
        for( int j=0; j<nRows; j++){
            INDArray example = testData.getRow(j, true);
            int digit = (int)labels.getDouble(j);
            double score = net.score(new DataSet(example,example));
            // 将(得分,样本)对添加到适当的列表
            List digitAllPairs = listsByDigit.get(digit);
            digitAllPairs.add(new ImmutablePair<>(score, example));
        }
    }

    //Sort each list in the map by score
    Comparator<Pair<Double, INDArray>> c = new Comparator<Pair<Double, INDArray>>() {
        @Override
        public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
            return Double.compare(o1.getLeft(),o2.getLeft());
        }
    };

    for(List<Pair<Double, INDArray>> digitAllPairs : listsByDigit.values()){
        Collections.sort(digitAllPairs, c);
    }

    //排序后,为每个数字选择N个最佳分数和N个最差分数(根据重构误差),其中N = 5
    List<INDArray> best = new ArrayList<>(50);
    List<INDArray> worst = new ArrayList<>(50);
    for( int i=0; i<10; i++ ){
        List<Pair<Double,INDArray>> list = listsByDigit.get(i);
        for( int j=0; j<5; j++ ){
            best.add(list.get(j).getRight());
            worst.add(list.get(list.size()-j-1).getRight());
        }
    }

3.6、结果可视化

//默认可视化
if (visualize) {
            //可视化最好和最差的数字
            MNISTVisualizer bestVisualizer = new MNISTVisualizer(2.0, best, "Best (Low Rec. Error)");
            bestVisualizer.visualize();

            MNISTVisualizer worstVisualizer = new MNISTVisualizer(2.0, worst, "Worst (High Rec. Error)");
            worstVisualizer.visualize();
        }
//可视化方法
public static class MNISTVisualizer {
        private double imageScale;
        private List<INDArray> digits;  //数字(作为行向量),每个INDArray一个
        private String title;
        private int gridWidth;

        public MNISTVisualizer(double imageScale, List<INDArray> digits, String title ) {
            this(imageScale, digits, title, 5);
        }
        
        public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth ) {
            this.imageScale = imageScale;
            this.digits = digits;
            this.title = title;
            this.gridWidth = gridWidth;
        }

        public void visualize(){
            JFrame frame = new JFrame();
            frame.setTitle(title);
            frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);

            JPanel panel = new JPanel();
            panel.setLayout(new GridLayout(0,gridWidth));

            List<JLabel> list = getComponents();
            for(JLabel image : list){
                panel.add(image);
            }

            frame.add(panel);
            frame.setVisible(true);
            frame.pack();
        }

        private List<JLabel> getComponents(){
            List<JLabel> images = new ArrayList<>();
            for( INDArray arr : digits ){
                BufferedImage bi = new BufferedImage(28,28,BufferedImage.TYPE_BYTE_GRAY);
                for( int i=0; i<784; i++ ){
                    bi.getRaster().setSample(i % 28, i / 28, 0, (int)(255*arr.getDouble(i)));
                }
                ImageIcon orig = new ImageIcon(bi);
                Image imageScaled = orig.getImage().getScaledInstance((int)(imageScale*28),(int)(imageScale*28),Image.SCALE_REPLICATE); 
                ImageIcon scaled = new ImageIcon(imageScaled);
                images.add(new JLabel(scaled));
            }
            return images;
        }
    }

最终结果如下图所示:
在这里插入图片描述
最差的手写数字:
在这里插入图片描述
最好的手写数字:
在这里插入图片描述
完整代码如下:

import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.lossfunctions.LossFunctions;

import javax.swing.*;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.util.*;
import java.util.List;
public class xx {

    public static boolean visualize = true;

    public static void main(String[] args) throws Exception {

        //搭建模型. 784 输入/输出 (MNIST 图片大小为 28x28).
        //784 -> 250 -> 10 -> 250 -> 784
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(12345)
            .weightInit(WeightInit.XAVIER)
            .updater(new AdaGrad(0.05))
            .activation(Activation.RELU)
            .l2(0.0001)
            .list()
            .layer(new DenseLayer.Builder().nIn(784).nOut(250)
                .build())
            .layer(new DenseLayer.Builder().nIn(250).nOut(10)
                .build())
            .layer(new DenseLayer.Builder().nIn(10).nOut(250)
                .build())
            .layer(new OutputLayer.Builder().nIn(250).nOut(784)
                .activation(Activation.LEAKYRELU)
                .lossFunction(LossFunctions.LossFunction.MSE)
                .build())
            .build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.setListeners(Collections.singletonList(new ScoreIterationListener(10)));

       //加载数据,并且进行训练集与测试集的划分:40000训练数据,10000测试数据
        DataSetIterator iter = new MnistDataSetIterator(100,50000,false);

        List<INDArray> featuresTrain = new ArrayList<>();
        List<INDArray> featuresTest = new ArrayList<>();
        List<INDArray> labelsTest = new ArrayList<>();

        Random r = new Random(12345);
        while(iter.hasNext()){
            DataSet ds = iter.next();
            SplitTestAndTrain split = ds.splitTestAndTrain(80, r);  //按照8:2的比例进行划分(from miniBatch = 100)
            featuresTrain.add(split.getTrain().getFeatures());
            DataSet dsTest = split.getTest();
            featuresTest.add(dsTest.getFeatures());
            INDArray indexes = Nd4j.argMax(dsTest.getLabels(),1); //通过独热编码将表示转换为索引
            labelsTest.add(indexes);
        }

        //训练模型
        int nEpochs = 3;
        for( int epoch=0; epoch<nEpochs; epoch++ ){
            for(INDArray data : featuresTrain){
                net.fit(data,data);
            }
            System.out.println("Epoch " + epoch + " complete");
        }

        //根据测试数据评估模型
//分别对测试集中的每个样本评分
//组成一个映射,将每个数字与(得分,样本)对列表相关联
//然后找到每位数中N个最佳分数和N个最差分数
 Map<Integer,List<Pair<Double,INDArray>>> listsByDigit = new HashMap<>();
        for( int i=0; i<10; i++ ) listsByDigit.put(i,new ArrayList<>());

        for( int i=0; i<featuresTest.size(); i++ ){
            INDArray testData = featuresTest.get(i);
            INDArray labels = labelsTest.get(i);
            int nRows = testData.rows();
            for( int j=0; j<nRows; j++){
                INDArray example = testData.getRow(j, true);
                int digit = (int)labels.getDouble(j);
                double score = net.score(new DataSet(example,example));
                // 将(得分,样本)对添加到适当的列表
                List digitAllPairs = listsByDigit.get(digit);
                digitAllPairs.add(new ImmutablePair<>(score, example));
            }
        }

        //按分数映射对每个列表进行排序
        Comparator<Pair<Double, INDArray>> c = new Comparator<Pair<Double, INDArray>>() {
            @Override
            public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
                return Double.compare(o1.getLeft(),o2.getLeft());
            }
        };

        for(List<Pair<Double, INDArray>> digitAllPairs : listsByDigit.values()){
            Collections.sort(digitAllPairs, c);
        }

        排序后,为每个数字选择N个最佳分数和N个最差分数(根据重构误差),其中N = 5
        List<INDArray> best = new ArrayList<>(50);
        List<INDArray> worst = new ArrayList<>(50);
        for( int i=0; i<10; i++ ){
            List<Pair<Double,INDArray>> list = listsByDigit.get(i);
            for( int j=0; j<5; j++ ){
                best.add(list.get(j).getRight());
                worst.add(list.get(list.size()-j-1).getRight());
            }
        }

        //默认进行可视化
        if (visualize) {
            //可视化最好与最差的数字
            MNISTVisualizer bestVisualizer = new MNISTVisualizer(2.0, best, "Best (Low Rec. Error)");
            bestVisualizer.visualize();

            MNISTVisualizer worstVisualizer = new MNISTVisualizer(2.0, worst, "Worst (High Rec. Error)");
            worstVisualizer.visualize();
        }
    }

    public static class MNISTVisualizer {
        private double imageScale;
        private List<INDArray> digits;  //Digits (as row vectors), one per INDArray
        private String title;
        private int gridWidth;

        public MNISTVisualizer(double imageScale, List<INDArray> digits, String title ) {
            this(imageScale, digits, title, 5);
        }
        //设定可视化的图片大小、数字、标题与显示图片的网格大小
        public MNISTVisualizer(double imageScale, List<INDArray> digits, String title, int gridWidth ) {
            this.imageScale = imageScale;
            this.digits = digits;
            this.title = title;
            this.gridWidth = gridWidth;
        }

        public void visualize(){
            JFrame frame = new JFrame();
            frame.setTitle(title);
            frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);

            JPanel panel = new JPanel();
            panel.setLayout(new GridLayout(0,gridWidth));

            List<JLabel> list = getComponents();
            for(JLabel image : list){
                panel.add(image);
            }

            frame.add(panel);
            frame.setVisible(true);
            frame.pack();
        }

        private List<JLabel> getComponents(){
            List<JLabel> images = new ArrayList<>();
            for( INDArray arr : digits ){
                BufferedImage bi = new BufferedImage(28,28,BufferedImage.TYPE_BYTE_GRAY);
                for( int i=0; i<784; i++ ){
                    bi.getRaster().setSample(i % 28, i / 28, 0, (int)(255*arr.getDouble(i)));
                }
                ImageIcon orig = new ImageIcon(bi);
                Image imageScaled = orig.getImage().getScaledInstance((int)(imageScale*28),(int)(imageScale*28),Image.SCALE_REPLICATE);
                ImageIcon scaled = new ImageIcon(imageScaled);
                images.add(new JLabel(scaled));
            }
            return images;
        }
    }
}

标签:编码器,int,import,09,List,Deeplearning4j,org,new
来源: https://blog.csdn.net/weixin_33980484/article/details/112985781

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有