【专知-Deeplearning4j深度学习教程03】使用多层神经网络分类MNIST数据集:图文+代码
【导读】主题链路知识是我们专知的核心功能之一,为用户提供AI领域系统性的知识学习服务,一站式学习人工智能的知识,包含人工智能( 机器学习、自然语言处理、计算机视觉等)、大数据、编程语言、系统架构。使用请访问专知 进行主题搜索查看 - 桌面电脑访问www.zhuanzhi.ai, 手机端访问www.zhuanzhi.ai 或关注微信公众号后台回复" 专知"进入专知,搜索主题查看。继Pytorch教程后,我们推出面向Java程序员的深度学习教程DeepLearning4J。Deeplearning4j的案例和资料很少,官方的doc文件也非常简陋,基本上所有的类和函数的都没有解释。为此,我们推出来自中科院自动化所专知小组博士生Hujun创作的-分布式Java开源深度学习框架Deeplearning4j学习教程,第三篇,使用多层神经网络分类MNIST数据集(手写数字识别)。
- Deeplearning4j开发环境配置
- ND4J(DL4J的矩阵运算库)教程
- 基于DL4J的CNN、AutoEncoder、RNN、Word2Vec等模型的实现
MNIST数据集
MNIST由手写数字图片组成,包含0-9十种数字,常被用作测试机器学习算法性能的基准数据集。MNIST包含了一个有60000张图片的训练集和一个有10000张图片的测试集。深度学习在MNIST上可以达到99.7%的准确率。
Deeplearning4j中直接集成了MNIST数据集,例如可以直接用下面的代码加载训练集和测试集:
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
神经网络结构
本教程使用具有1个隐藏层的MLP作为网络的结构,使用RELU作为隐藏层的激活函数,使用SOFTMAX作为输出层的激活函数。
从图中可以看出,网络具有输入层、隐藏层和输出层一共3层,但在代码编写时,会将该网络看作由2个层组成(2次变换):
- Layer 0: 一个Dense Layer(全连接层),由输入层进行线性变换变为隐藏层,并使用RELU对变换结果进行激活。用公式表达形式为
H = relu(XW_0 + b_0)
,其中:- X: 输入层,是形状为[batch_size, input_dim]的矩阵,矩阵的每行对应一个样本,每列对应一个特征(一个像素)
- H: 隐藏层的输出,是形状为[batch_size, hidden_dim]的矩阵,矩阵的每行对应一个样本隐藏层的输出
- relu: 使用RELU激活函数进行激活
- W_0: 形状为[input_dim, hidden_dim]的矩阵,是全连接层线性变换的参数
- b_0: 形状为[hidden_dim]的矩阵,是全连接层线性变换的参数(偏置)
- Layer 1: 一个Dense Layer(全连接层),由隐藏层进行线性变换为输出层,并使用SOFTMAX对变换结果进行激活。用公式表达形式为:
OUTPUT = softmax(HW_1 + b_1)
,其中:- OUTPUT: 输出层,是形状为[batch_size, output_dim]的矩阵,矩阵的每行对应一个样本,每列对应样本属于某类的概率。例如该例子中第0列表示输入手写数字为1的概率。
- softmax: 使用SOFTMAX激活函数进行激活
- W_1: 形状为[hidden_dim, output_dim]的矩阵,是全连接层线性变换的参数
- b_1: 形状为[output_dim]的矩阵,是全连接层线性变换的参数(偏置)
神经网络的训练过程,即神经网络参数的调整过程。待参数能够很好地预测测试集中样本的类别(label),神经网络就训练成功了。
代码
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;import org.slf4j.LoggerFactory;/**
* 本示例使用Deeplearning4j构建了一个多层感知器(MLP)来进行手写数字(MNIST)的识别
* 该示例中的神经网络只有1个隐藏层
*
* 输入层的维度是numRows*numColumns(图像像素行数*图像像素列数),即每个手写数字图像的像素数量(28*28)
* 隐藏层的大小为1000,使用RELU作为激活函数
* 输出层为SOFTMAX层,用于表示输入图像属于每个分类的概率(概率总和为1)
*
*/public class MLPMnistSingleLayerExample {
private static Logger log = LoggerFactory.getLogger(MLPMnistSingleLayerExample.class);
public static void main(String[] args) throws Exception {
//number of rows and columns in the input pictures
final int numRows = 28;
final int numColumns = 28;
int outputNum = 10; // 手写字符类别的数量
int batchSize = 128; // batch大小,一个batch中的输入使用相同的神经网络参数
int rngSeed = 123; // 设置一个随机种子,使得每次跑程序获得的随机值相同
int numEpochs = 15; // 训练时每扫描一遍数据集算一个Epoch
//Deeplearning4j内置的MNIST数据集
DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);
log.info("Build model....");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed) // 为模型设置随机种子
// 使用随机梯度下降作为优化算法
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.iterations(1)
.learningRate(0.006) // 设置学习速率
.updater(Updater.NESTEROVS)
.regularization(true).l2(1e-4) //设置L2正则系数,设置L2正则可以降低过拟合的程度
.list() //开始构建MLP网络(多层感知器)
.layer(0, new DenseLayer.Builder() //设置第一个Dense层
.nIn(numRows * numColumns) //输入为28*28
.nOut(1000) //输出为1000
.activation(Activation.RELU) //使用RELU激活
.weightInit(WeightInit.XAVIER) //设置初始化方法
.build())
.layer(1, new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) //设置第二个Dense层,OutputLayer也是Dense层
.nIn(1000) //输入为1000
.nOut(outputNum) //输出为10,即手写数字的类别数量
.activation(Activation.SOFTMAX) //使用SOFTMAX激活
.weightInit(WeightInit.XAVIER)
.build())
.pretrain(false).backprop(true) //进行反向传播,不进行预训练
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init(); //每隔1个iteration就输出一次score
model.setListeners(new ScoreIterationListener(1));
log.info("Train model....");
for( int i=0; i<numEpochs; i++ ){
model.fit(mnistTrain);
}
log.info("Evaluate model....");
Evaluation eval = new Evaluation(outputNum); //创建一个评价器
while(mnistTest.hasNext()){
DataSet next = mnistTest.next();
INDArray output = model.output(next.getFeatureMatrix()); //模型的预测结果
eval.eval(next.getLabels(), output); //根据真实的结果和模型的预测结果对模型进行评价
}
log.info(eval.stats());
log.info("****************Example finished********************");
}
}
运行代码,输出如下:
- TensorFlow从0到1丨 第五篇:TensorFlow轻松搞定线性回归
- 【直播】我的基因组59:把我的数据伪装成23andme或wegene的芯片数据
- asp.net web api客户端调用
- 细说WebSocket - Node篇
- TensorFlow从0到1丨 第六篇:解锁梯度下降算法
- .Net多线程编程—误用点分析
- Web开发常见的几个漏洞解决方法
- .Net多线程编程—同步机制
- .Net多线程编程—Parallel LINQ、线程池
- 没有自己的服务器如何学习生物数据分析(下篇)
- .Net多线程编程—并发集合
- .Net多线程编程—任务Task
- 学会WCF之试错法——安全配置报错分析
- 生物信息学技能面试题(第5题)-根据GTF画基因的多个转录本结构
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 3分钟短文:说说Laravel模型关联关系最单纯的“一对一”
- Redis 缓存性能实践及总结
- 如何优雅的在react-hook中进行网络请求
- Git commit emoji 食用指南
- 编译安装 ProtoBuf 扩展
- 一键搭建 KMS 服务
- vuepress-theme-yur 使用教程
- 使用 XDebug + Webgrind 进行 PHP 程序性能分析
- 我给自己组装了一台 ITX 小台式
- PHPStorm 常用插件集合
- 优雅地调试线上代码
- WebStorm 配置 ESLint
- Yur 主题更新日志
- 使用 Forestry 管理基于 GitHub 的图床
- 从零开始搭建 VuePress 静态博客