标签:trainer Java IMAGE 开发者 new model DJL Mnist build
之前,我们文章里有讲,通过softmax回归对图片进行分类,具体文章请见《使用Softmax进行分类代码实现》。今天我们通过高级API更简洁地实现多层感知机。
准备数据集
private static RandomAccessDataset getDataset(Dataset.Usage usage)
throws IOException {
Mnist mnist =
Mnist.builder()
.optUsage(usage)
.setSampling(32, true)
.optLimit(64)
.build();
mnist.prepare(new ProgressBar());
return mnist;
}
//训练集
RandomAccessDataset trainingSet = getDataset(Dataset.Usage.TRAIN);
//验证集
RandomAccessDataset validateSet = getDataset(Dataset.Usage.TEST);
模型定义
Block block = new Mlp(
Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH,
Mnist.NUM_CLASSES,
new int[] {128, 64});
初始化训练器
private static DefaultTrainingConfig setupTrainingConfig() {
String outputDir = "build/model";
SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDir);
listener.setSaveModelCallback(
trainer -> {
TrainingResult result = trainer.getTrainingResult();
Model model = trainer.getModel();
float accuracy = result.getValidateEvaluation("Accuracy");
model.setProperty("Accuracy", String.format("%.5f", accuracy));
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
});
return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.addTrainingListeners(TrainingListener.Defaults.logging(outputDir))
.addTrainingListeners(listener);
}
开始训练并保存模型
try (Model model = Model.newInstance("mlp")) {
model.setBlock(block);
try (Trainer trainer = model.newTrainer(config)) {
trainer.setMetrics(new Metrics());
Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);
trainer.initialize(inputShape);
EasyTrain.fit(trainer, 15, trainingSet, validateSet);
//保存模型
model.save(Paths.get("build/model"), "mlp");
return trainer.getTrainingResult();
}
}
总结
在这里,我们经过15个Epoch的训练,最终在build/model
目录下,生成我们训练好的模型。后续我们将通使用我们训练的模型进行图片分类预测。关注公众号,解锁后续图片预测部分实现。
标签:trainer,Java,IMAGE,开发者,new,model,DJL,Mnist,build 来源: https://blog.csdn.net/xxcnb/article/details/122386126
本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享; 2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关; 3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关; 4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除; 5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。