基于Pytorch的动态卷积复现
时间:2022-07-24
本文章向大家介绍基于Pytorch的动态卷积复现,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
【GaintPandaCV导语】 最近动态卷积开始有人进行了研究,也有不少的论文发表(动态卷积论文合集https://github.com/kaijieshi7/awesome-dynamic-convolution),但是动态卷积具体的实现代码却很少有文章给出。本文以微软发表在CVPR2020上面的文章为例,详细的讲解了动态卷积实现的难点以及如何动分组卷积巧妙的解决。希望能给大家以启发。
论文的题目为《Dynamic Convolution: Attention over Convolution Kernels》
paper的地址arxiv.org/pdf/1912.0345
简单回顾
这篇文章主要是改进传统卷积,让每层的卷积参数在推理的时候也是随着输入可变的,而不是传统卷积中对任何输入都是固定不变的参数。(由于本文主要说明的是代码如何实现,所以推荐给大家一个讲解论文的连接:Happy:动态滤波器卷积|DynamicConv)
class attention2d(nn.Module):
def __init__(self, in_planes, K,):
super(attention2d, self).__init__()
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(in_planes, K, 1,)
self.fc2 = nn.Conv2d(K, K, 1,)
def forward(self, x):
x = self.avgpool(x)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x).view(x.size(0), -1)
return F.softmax(x, 1)
具体代码如下:
class Dynamic_conv2d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, K=4,):
super(Dynamic_conv2d, self).__init__()
assert in_planes%groups==0
self.in_planes = in_planes
self.out_planes = out_planes
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.bias = bias
self.K = K
self.attention = attention2d(in_planes, K, )
self.weight = nn.Parameter(torch.Tensor(K, out_planes, in_planes//groups, kernel_size, kernel_size), requires_grad=True)
if bias:
self.bias = nn.Parameter(torch.Tensor(K, out_planes))
else:
self.bias = None
def forward(self, x):#将batch视作维度变量,进行组卷积,因为组卷积的权重是不同的,动态卷积的权重也是不同的
softmax_attention = self.attention(x)
batch_size, in_planes, height, width = x.size()
x = x.view(1, -1, height, width)# 变化成一个维度进行组卷积
weight = self.weight.view(self.K, -1)
# 动态卷积的权重的生成, 生成的是batch_size个卷积参数(每个参数不同)
aggregate_weight = torch.mm(softmax_attention, weight).view(-1, self.in_planes, self.kernel_size, self.kernel_size)
if self.bias is not None:
aggregate_bias = torch.mm(softmax_attention, self.bias).view(-1)
output = F.conv2d(x, weight=aggregate_weight, bias=aggregate_bias, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups*batch_size)
else:
output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
dilation=self.dilation, groups=self.groups * batch_size)
output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
return output
完整的代码在github.com/kaijieshi7/D,大家觉得有帮助的话,求点个星星。
纸上得来终觉浅,绝知此事要躬行。试下代码,方能体会其中妙处。
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 内网横向移动:Kerberos认证与(哈希)票据传递攻击
- 诺禾致源linux下数据下载
- 技巧 | OpenCV中如何绘制与填充多边形
- Swift guard
- PyTorch实现TPU版本CNN模型
- 使用NLP检测和对抗AI假新闻
- kallisto --genomebam报错解决(GTF文件的坑)
- linux查找文件
- TCP 协议面试灵魂 12 问,问到你怀疑人生!
- 方差分析简介(结合COVID-19案例)
- mysql计算两个时间字段的时间差
- 学生党学编程,有这个开源项目就够了!
- 【最强ResNet改进系列】Res2Net:一种新的多尺度网络结构,性能提升显著
- Java中的锁以及sychronized实现机制(十)
- Web 指纹识别之路