Tensorflow学习:使用Tensorflow搭建深层网络分类器
根据官方文档整理而来的,主要是对Iris数据集进行分类。使用tf.contrib.learn.tf.contrib.learn快速搭建一个深层网络分类器,
步骤
- 导入csv数据
- 搭建网络分类器
- 训练网络
- 计算测试集正确率
- 对新样本进行分类
数据
Iris数据集包含150行数据,有三种不同的Iris品种分类。每一行数据给出了四个特征信息和一个分类信息。
现在已经将数据分为训练集和测试集
- A training set of 120 samples
http://download.tensorflow.org/data/iris_training.csv
- A test set of 30 samples
http://download.tensorflow.org/data/iris_test.csv
网络搭建
1. 首先,导入tensorflow 和 numpy
2. 导入数据
# 定义数据地址
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"
# 导入数据
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TRAINING,
target_dtype=np.int,
features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TEST,
target_dtype=np.int,
features_dtype=np.float32)
load_csv_with_header() 有三个参数 - filename, 数据地址 - target_dtype, 目标值的numpy datatype(iris的目标值是0,1,2,所以是np.int) - features_dtype, 特征值的numpy datatype .
3. 搭建网络结构
# 每行数据4个特征,都是real-value的
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
# 3层DNN,3分类问题
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3,
model_dir="/tmp/iris_model")
参数解释 - feature_columns 特征值 - hidden_units=[10, 20, 10]. 3个隐藏层,包含的隐藏神经元依次是10, 20, 10 - n_classes 类别个数 - model_dir 模型保存地址
4. 训练数据
classifier.fit(x=training_set.data, y=training_set.target, steps=2000)
steps 为训练次数
5. 计算准确率
accuracy_score = classifier.evaluate(x=test_set.data, y=test_set.target)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))
运行结果是
6. 对新样本进行预测
# Classify two new flower samples.
new_samples = np.array(
[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = list(classifier.predict(new_samples, as_iterable=True))
print('Predictions: {}'.format(str(y)))
运行结果为:
完整代码
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import numpy as np
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TRAINING,
target_dtype=np.int,
features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
filename=IRIS_TEST,
target_dtype=np.int,
features_dtype=np.float32)
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3,
model_dir="/tmp/iris_model")
classifier.fit(x=training_set.data,
y=training_set.target,
steps=2000)
accuracy_score = classifier.evaluate(x=test_set.data,
y=test_set.target)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))
new_samples = np.array(
[[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = list(classifier.predict(new_samples, as_iterable=True))
print('Predictions: {}'.format(str(y)))
- Flash/Flex学习笔记(39):弹性运动
- 兼容Mono的下一代云环境Web开发框架ASP.NET vNext
- ASP.NET vNext 概述
- 丰富排版页面——为你的wordpress主题添加短代码形式美化框
- 开放式管理基础结构 OMI
- 人类设计了游戏和AI 2017年AI在游戏中打败了人类
- WordPress 代码实现相关文章(列表模式)功能
- 自动刷新页面
- Python语言被列入全国计算机等级考试科目中
- WordPress纯代码高仿 无觅相关文章 图文模式功能
- 各种序列化库的性能数据
- WordPress内置搜索结果只有一篇文章时自动跳转到该文章
- Flash/Flex学习笔记(23):运动学原理
- WordPress重定向作者归档链接到“关于”页面
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 基于Centos7 部署Varnish缓存代理服务器
- Android getActivity()为空的问题解决办法
- Android Notification使用方法总结
- Linux下redis的持久化、主从同步与哨兵详解
- 详解Android(共享元素)转场动画开发实践
- Android自定义View中attrs.xml的实例详解
- Android开发之缓冲dialog对话框创建、使用与封装操作
- Windows 10 下安装 Apache 2.4.41的教程
- Android基础教程数据存储之文件存储
- 关于Android短信验证码的获取的示例
- Linux下制作给ARM开发板使用的文件系统
- Android开发之ProgressBar字体随着进度条的加载而滚动
- Android 动态注册监听网络变化实例详解
- linux 服务器自动备份脚本的方法(mysql、附件备份)
- 详解Android中Glide与CircleImageView加载圆形图片的问题