ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

DJL-Java开发者动手学深度学习之图片分类

2022-01-08 21:00:27  阅读:217  来源: 互联网

标签:trainer Java IMAGE 开发者 new model DJL Mnist build


之前,我们文章里有讲,通过softmax回归对图片进行分类,具体文章请见《使用Softmax进行分类代码实现》。今天我们通过高级API更简洁地实现多层感知机。

准备数据集

private static RandomAccessDataset getDataset(Dataset.Usage usage)
    throws IOException {
    Mnist mnist =
        Mnist.builder()
        .optUsage(usage)
        .setSampling(32, true)
        .optLimit(64)
        .build();
    mnist.prepare(new ProgressBar());
    return mnist;
}

//训练集
RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN);

//验证集
RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST);
        

模型定义

 Block block = new Mlp(
                        Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
                        Mnist.NUM_CLASSES,
                        new int[] {128, 64});

初始化训练器

private static DefaultTrainingConfig setupTrainingConfig() {
    String outputDir = "build/model";
    SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir);
    listener.setSaveModelCallback(
        trainer -> {
            TrainingResult result = trainer.getTrainingResult();
            Model model = trainer.getModel();
            float accuracy = result.getValidateEvaluation("Accuracy");
            model.setProperty("Accuracy", String.format("%.5f", accuracy));
            model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
        });

    return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
        .addEvaluator(new Accuracy())
        .addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
        .addTrainingListeners(listener);
}

开始训练并保存模型

try (Model model = Model.newInstance("mlp")) {
    model.setBlock(block);
    try (Trainer trainer = model.newTrainer(config)) {
        trainer.setMetrics(new Metrics());
        Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);
        trainer.initialize(inputShape);
        EasyTrain.fit(trainer, 15, trainingSet, validateSet);

        //保存模型
        model.save(Paths.get("build/model"), "mlp"); 
        return trainer.getTrainingResult();
    }
}

总结

在这里,我们经过15个Epoch的训练,最终在build/model目录下,生成我们训练好的模型。后续我们将通使用我们训练的模型进行图片分类预测。关注公众号,解锁后续图片预测部分实现。

标签:trainer,Java,IMAGE,开发者,new,model,DJL,Mnist,build
来源: https://blog.csdn.net/xxcnb/article/details/122386126

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有