optimizer.zero_grad()
时间:2022-07-25
本文章向大家介绍optimizer.zero_grad(),主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
交流、咨询,有疑问欢迎添加QQ 2125364717,一起交流、一起发现问题、一起进步啊,哈哈哈哈哈
传统的训练函数,一个batch是这么训练的:
for i,(images,target) in enumerate(train_loader):
# 1. input output
images = images.cuda(non_blocking=True)
target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
outputs = model(images)
loss = criterion(outputs,target)
# 2. backward
optimizer.zero_grad() # reset gradient
loss.backward()
optimizer.step()
- 获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
- optimizer.zero_grad() 清空过往梯度;
- loss.backward() 反向传播,计算当前梯度;
- optimizer.step() 根据梯度更新网络参数
简单的说就是进来一个batch的数据,计算一次梯度,更新一次网络,使用梯度累加是这么写的:
for i,(images,target) in enumerate(train_loader):
# 1. input output
images = images.cuda(non_blocking=True)
target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
outputs = model(images)
loss = criterion(outputs,target)
# 2.1 loss regularization
loss = loss/accumulation_steps
# 2.2 back propagation
loss.backward()
# 3. update parameters of net
if((i+1)%accumulation_steps)==0:
# optimizer the net
optimizer.step() # update parameters of net
optimizer.zero_grad() # reset gradient
- 获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
- loss.backward() 反向传播,计算当前梯度;
- 多次循环步骤1-2,不清空梯度,使梯度累加在已有梯度上;
- 梯度累加了一定次数后,先optimizer.step() 根据累计的梯度更新网络参数,然后optimizer.zero_grad() 清空过往梯度,为下一波梯度累加做准备;
总结来说:梯度累加就是,每次获取1个batch的数据,计算1次梯度,梯度不清空,不断累加,累加一定次数后,根据累加的梯度更新网络参数,然后清空梯度,进行下一次循环。
一定条件下,batchsize越大训练效果越好,梯度累加则实现了batchsize的变相扩大,如果accumulation_steps为8,则batchsize '变相' 扩大了8倍,是我们这种乞丐实验室解决显存受限的一个不错的trick,使用时需要注意,学习率也要适当放大。
更新1:关于BN是否有影响,之前有人是这么说的:
As far as I know, batch norm statistics get updated on each forward pass, so no problem if you don't do .backward() every time.
BN的估算是在forward阶段就已经完成的,并不冲突,只是accumulation_steps=8和真实的batchsize放大八倍相比,效果自然是差一些,毕竟八倍Batchsize的BN估算出来的均值和方差肯定更精准一些。
各位看官老爷,如果觉得对您有用麻烦赏个子,创作不易,0.1元就行了。下面是微信乞讨码:
添加描述
添加描述
- writeup | 应该不是 XSS
- Hbase 源码分析之 Get 流程及rpc原理
- MOCTF WEB 题解
- HBase行锁与MVCC分析
- HBase行锁探索
- Spring Cloud构建微服务架构:分布式服务跟踪(抽样收集)【Dalston版】
- HBase client访问ZooKeeper获取root-region-server DeadLock问题(zookeeper.ClientCnxn Unable to get data of zn
- zookeeper学习系列:四、Paxos算法和zookeeper的关系
- 有了phonegap你还android吗?
- zookeeper学习系列:三、利用zookeeper做选举和锁
- Spring Cloud构建微服务架构:分布式服务跟踪(收集原理)【Dalston版】
- zookeeper学习系列:二、api实践
- Spring Cloud构建微服务架构:分布式服务跟踪(整合logstash)【Dalston版】
- Spring Cloud构建微服务架构:分布式服务跟踪(整合zipkin)【Dalston版】
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Docker Dockerfile 指令详解与实战案例
- flume kafka和sparkstreaming整合
- Docker如何搭建私有registry镜像仓库
- Harbor介绍与企业级私有Docker镜像仓库搭建
- 如何查看docker run启动参数命令
- YAML 语言教程与使用案例
- 计算等压面要素场的基本检验指标
- Kubernetes K8S之SSL证书有效期修改
- Kubernetes K8S之通过yaml文件创建Pod与Pod常用字段详解
- Kubernetes K8S之kubectl命令详解及常用示例
- Kubernetes K8S之Pod 生命周期与init container初始化容器详解
- Kubernetes K8S之Pod生命周期与探针检测
- Kubernetes K8S之Pod 生命周期与postStart、preStop事件
- Kubernetes K8S之资源控制器RC、RS、Deployment详解
- python教程 | 最标准的地图调用方式(国家测绘局提供数据)