轻松学Pytorch –使用torchvision实现对象检测
大家好,前面一篇文章介绍了torchvision的模型ResNet50实现图像分类,这里再给大家介绍一下如何使用torchvision自带的对象检测模型Faster-RCNN实现对象检测。Torchvision自带的对象检测模型是基于COCO数据集训练的,最小分辨率支持800, 最大支持1333的输入图像。
Faster-RCNN模型
Faster-RCNN模型的基础网络是ResNet50, ROI生成使用了RPN,加上头部组成。图示如下:
在torchvision框架下可以通过下面的代码直接下载预训练模型,
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()
对模型使用GPU加速支持
# 使用GPU
train_on_gpu = torch.cuda.is_available()
if train_on_gpu:
model.cuda()
推理输出有三个信息分别为:
boxes:表示对象框
scores:表示每个对象得分
labels:表示对于的分类标签
图像检测
使用模型实现图像检测,支持90个类别的对象检测,代码实现如下:
def faster_rcnn_image_detection():
image = cv.imread("D:/images/cars.jpg")
blob = transform(image)
c, h, w = blob.shape
input_x = blob.view(1, c, h, w)
output = model(input_x.cuda())[0]
boxes = output['boxes'].cpu().detach().numpy()
scores = output['scores'].cpu().detach().numpy()
labels = output['labels'].cpu().detach().numpy()
index = 0
for x1, y1, x2, y2 in boxes:
if scores[index] > 0.5:
print(x1, y1, x2, y2)
cv.rectangle(image, (np.int32(x1), np.int32(y1)),
(np.int32(x2), np.int32(y2)), (0, 255, 255), 1, 8, 0)
label_id = labels[index]
label_txt = coco_names[str(label_id)]
cv.putText(image, label_txt, (np.int32(x1), np.int32(y1)), cv.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 255), 1)
index += 1
cv.imshow("Faster-RCNN Detection Demo", image)
cv.waitKey(0)
cv.destroyAllWindows()
运行结果下:
视频实时对象检测
基于OpenCV实现视频文件或者摄像头读取,完成视频的实时对象检测,代码实现如下:
1capture = cv.VideoCapture("D:/images/video/vehicle.ts")
2while True:
3 ret, frame = capture.read()
4 if ret is not True:
5 break
6 blob = transform(frame)
7 c, h, w = blob.shape
8 input_x = blob.view(1, c, h, w)
9 output = model(input_x.cuda())[0]
10 boxes = output['boxes'].cpu().detach().numpy()
11 scores = output['scores'].cpu().detach().numpy()
12 labels = output['labels'].cpu().detach().numpy()
13 index = 0
14 for x1, y1, x2, y2 in boxes:
15 if scores[index] > 0.5:
16 cv.rectangle(frame, (np.int32(x1), np.int32(y1)),
17 (np.int32(x2), np.int32(y2)), (0, 255, 255), 1, 8, 0)
18 label_id = labels[index]
19 label_txt = coco_names[str(label_id)]
20 cv.putText(frame, label_txt, (np.int32(x1), np.int32(y1)), cv.FONT_HERSHEY_PLAIN, 1.0, (0, 0, 255), 1)
21 index += 1
22 wk = cv.waitKey(1)
23 if wk == 27:
24 break
25 cv.imshow("video detection Demo", frame)
运行结果如下:
- PHP页面跳转代码
- 分布式事务 TCC-Transaction 源码解析 —— 调试环境搭建
- 机器学习入门——使用python进行监督学习
- 推荐算法的介绍,第一部分——协同过滤与奇异值分解
- 在ASP中实现UNIX时间戳
- 【学术】厉害了我的哥,国外技术大咖仿造了谷歌的Arts &Culture,找到古代的“你”
- 【技巧】应赛技巧,教你如何在Kaggle比赛中排在前1%
- 熔断器 Hystrix 源码解析 —— 命令执行(一)之正常执行逻辑
- 智能主题检测与无监督机器学习:识别颜色教程
- 如何下载安装Weka机器学习工作平台
- Dubbo 源码解析 —— LoadBalance
- 如何处理机器学习中类的不平衡问题
- 【死磕Java并发】—– Java内存模型之重排序
- Mask R-CNN源代码终于来了,还有它背后的物体检测平台
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- Docsify 安装
- Docsify 初始化文件夹
- ELK 日志系统集成 Skywalking 调用链 ID
- ChartCenter ——为您的K8s之旅保驾护航v
- leetcode链表之删除链表的节点
- iOS打包的那一些事情
- 腾讯云服务器(CentOS 7、Tencent Linux)手动搭建LNMP环境(linux+Nginx+Mariadb+PHP)
- iOS技术面试题及答案
- 虽然现在有可以去码的软件了,可视频是如何自动跟踪打码的?
- 2020-09-12:手撕代码:最小公倍数,复杂度多少?
- Mac App推荐
- 美团面试问ThreadLocal,学妹一口气给他说了四种!
- BFE.dev前端刷题#108. 用队列(Queue)实现栈(Stack)
- Kafka消费过程关键源码解析
- leetcode链表之两个链表的第一个公共节点