【动手学深度学习笔记】之过拟合与欠拟合实例
时间:2022-07-23
本文章向大家介绍【动手学深度学习笔记】之过拟合与欠拟合实例,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
点击【拇指笔记】,关注我的公众号。
本篇文章完整代码可以在后台回复"fit"获得
1.多项式函数拟合实验
本节以多项式函数为例,来演示模型复杂度和训练集大小对欠拟合和过拟合的影响。
第一步还是导入需要的库。
%matplotlib inlineimport sysimport torchimport torchvisionimport numpy as npimport matplotlib.pyplot as pltimport torchvision.transforms as transformsfrom torch import nnfrom time import timefrom numpy import argmaxfrom torch.nn import initfrom IPython import display
1.1 生成数据集
首先需要生成一个人工数据集,使用如下的三阶多项式函数来生成该样本的标签。
高阶转低阶全连接层来实现三阶线性神经网络。
poly_features = torch.cat((features,torch.pow(features,2),torch.pow(features,3)),1)#将x,x的平方和x的立方合为一个张量,经过了这一步之后,这个多项式就变成了输入特征为3,输出为1的全连接层。
其中随机噪声服从均值为0,标准差为0.01的正态分布。
ntrain,ntest,true_w,true_b = 100,100,[1.2,-3.4,5.6],5#定义训练集、测试集样本数和权重、偏差参数features = torch.randn((ntrain+ntest,1))#生成随机特征值poly_features = torch.cat((features,torch.pow(features,2),torch.pow(features,3)),1)#将x,x的平方和x的立方合为一个张量,经过了这一步之后,这个多项式就变成了输入特征为3,输出为1的全连接层。labels = true_w[0]*poly_features[:,0]+true_w[1]*poly_features[:,1]+true_w[2]*poly_features[:,2]+true_b#根据特征值计算标签labels = labels+torch.tensor(np.random.normal(0,0.01,size = labels.size()),dtype = torch.float)#为标签添加随机噪声项
1.2 读取数据
每个小批量设置为10,使用TensorDataset转换为张量,使用DataLoader生成迭代器。
batch_size =10dataset = torch.utils.data.TensorDataset(poly_features,labels)train_iter = torch.utils.data.DataLoader(dataset,batch_size,shuffle = True)
1.3 损失函数和优化算法
损失函数与线性拟合一样,也使用平方损失函数。优化算法依然使用小批量随机梯度下降算法。
loss = torch.nn.MSELoss()#损失函数optimizer = torch.optim.SGD(net.parameters(),lr = 0.01)#优化算法
1.4 训练模型
def train(num_epochs,train_features,test_features,train_labels,test_labels): #四个参数:训练特征值集、测试特征值集、训练标签集和测试标签集。 net = torch.nn.Linear(train_features.shape[-1], 1) #定义神经网络的输入、输出,设置为全连接神经网络。 batch_size = min(10, train_labels.shape[0]) dataset = torch.utils.data.TensorDataset(train_features, train_labels) train_iter = torch.utils.data.DataLoader(dataset, batch_size, shuffle=True) #读取数据 optimizer = torch.optim.SGD(net.parameters(), lr) for epoch in range(num_epochs+1): for X,y in train_iter: #从迭代器中读取出特征值、标签 l = loss(net(X),y.view(-1,1)) #损失值换形 optimizer.zero_grad() l.backward() optimizer.step() train_labels = train_labels.view(-1, 1) test_labels = test_labels.view(-1, 1) train_ls.append(loss(net(train_features), train_labels).item()) test_ls.append(loss(net(test_features), test_labels).item()) #记录每一个学习周期的损失值,生成列表。
1.5 图像可视化
本节主要是用多项式来形象的体现出过拟合与欠拟合,因此,我们将数据可视化出来。
因为loss太小,所以需要将loss对数化。
#可视化def draw(train_ls,test_ls): x = range(1, num_epochs + 2) x = np.array(x) train_ls = np.array(train_ls) train_ls = np.log(train_ls) test_ls = np.array(test_ls) test_ls = np.log(test_ls)
l1 = plt.plot(x, train_ls,label = 'train') l2 = plt.plot(x,test_ls,'--',label = 'test')
plt.title('Underfit') plt.xlabel('epochs') plt.ylabel('Loss') plt.legend(loc='upper right')
1.5.1 三阶多项式函数拟合正常
使用正常的三阶线性神经网络。
#正常train(num_epochs,poly_features[:n_train, :], poly_features[n_train:, :], labels[:n_train], labels[n_train:])
1.5.2 三阶多项式函数过拟合
为了达到过拟合的效果,我们使用少量训练数据(少于参数数量)。
#过拟合train(num_epochs,poly_features[0:2, :], poly_features[n_train:, :], labels[0:2],labels[n_train:])
由图可以看出, 在迭代过程中,尽管训练误差较低,但是测试数据集上的误差却很高。这是典型的过拟合现象。
1.5.3 三阶多项式函数欠拟合
为了得到欠拟合的效果,只使用一组特征值,相当于一阶线性方程(模型复杂度降低)。
#欠拟合train(num_epochs,features[:n_train, :], features[n_train:, :], labels[:n_train],labels[n_train:])
该模型的训练误差在迭代早期下降后便很难继续降低。在完成100次迭代后,训练误差依旧很高()。
- 知其所以然之永不遗忘的算法
- ZOOKEEPER集群搭建及测试
- 【Python环境】Scikit-Learn:开源的机器学习Python模块
- 【Python环境】可爱的 Python: 自然语言工具包入门
- 电脑静音工作,又听不到12306的来票音乐,纠结啊 !但春节前工作多任务重,不能安心工作,就动手做个“无声购票弹窗”工具吧!
- .net访问PostgreSQL数据库发生“找不到函数名”的问题追踪
- “领域驱动开发”实例之旅(1)--不一样的开发模式 一、分析业务需求。 二、设计领域对象模型 三、测试领域对象模型 四、设计业务处理类 五、设计Entity和Vi
- Java基础——左移和右移
- 【Python环境】利用 Python、SciKit 和文本分类来实现行为分析
- LJMM平台( Linux +Jexus+MySQL+mono) 上使用MySQL的简单总结
- 判断两个单链表是否相交(有环、无环两种)
- 【数据科学家】SparkR:数据科学家的新利器
- KMP算法浅析
- Bug修复问题
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法