机器学习之决策树熵&信息增量求解算法实现
时间:2022-05-06
本文章向大家介绍机器学习之决策树熵&信息增量求解算法实现,主要内容包括其使用实例、应用技巧、基本知识点总结和需要注意事项,具有一定的参考价值,需要的朋友可以参考一下。
此文不对理论做相关阐述,仅涉及代码实现:
1.熵计算公式:
P为正例,Q为反例
Entropy(S) = -PLog2(P) - QLog2(Q);
2.信息增量计算:
Gain(S,Sv) = Entropy(S) - (|Sv|/|S|)ΣEntropy(Sv);
举例:
转化数据输入:
5 14
Outlook Sunny Sunny Overcast Rain Rain Rain Overcast Sunny Sunny Rain Sunny Overcast Overcast Rain
Temperature Hot Hot Hot Mild Cool Cool Cool Mild Cool Mild Mild Mild Hot Mild
Humidity High High High High Normal Normal Normal High Normal Normal Normal High Normal High
Wind Weak Strong Weak Weak Weak Strong Strong Weak Weak Weak Strong Strong Weak Strong
PlayTennis No No Yes Yes Yes No Yes No Yes Yes Yes Yes Yes No
Outlook Temperature Humidity Wind PlayTennis
1 package com.qunar.data.tree;
2
3 /**
4 * *********************************************************
5 * <p/>
6 * Author: XiJun.Gong
7 * Date: 2016-09-02 15:28
8 * Version: default 1.0.0
9 * Class description:
10 * <p>统计该类型出现的次数</p>
11 * <p/>
12 * *********************************************************
13 */
14 public class CountMap<T> {
15
16 private T key; //类型
17 private int value; //出现的次数
18
19 public CountMap() {
20 this(null, 0);
21 }
22
23 public CountMap(T key, int value) {
24 this.key = key;
25 this.value = value;
26 }
27
28 public T getKey() {
29 return key;
30 }
31
32 public void setKey(T key) {
33 this.key = key;
34 }
35
36 public int getValue() {
37 return value;
38 }
39
40 public void setValue(int value) {
41 this.value = value;
42 }
43 }
1 package com.qunar.data.tree;
2
3 import com.google.common.collect.ArrayListMultimap;
4 import com.google.common.collect.Maps;
5 import com.google.common.collect.Multimap;
6 import com.google.common.collect.Sets;
7
8 import java.util.*;
9
10 /**
11 * *********************************************************
12 * <p/>
13 * Author: XiJun.Gong
14 * Date: 2016-09-02 14:24
15 * Version: default 1.0.0
16 * Class description:
17 * <p>决策树</p>
18 * <p/>
19 * *********************************************************
20 */
21
22 public class DecisionTree<T, K> {
23
24 private static String positiveExampleType = "Yes";
25 private static String counterExampleType = "No";
26
27
28 public double pLog2(final double p) {
29 if (0 == p) return 0;
30 return p * (Math.log(p) / Math.log(2));
31 }
32
33 /**
34 * 熵计算
35 *
36 * @param positiveExample 正例个数
37 * @param counterExample 反例个数
38 * @return 熵值
39 */
40 public double entropy(final double positiveExample, final double counterExample) {
41
42 double total = positiveExample + counterExample;
43 double positiveP = positiveExample / total;
44 double counterP = counterExample / total;
45 return -1d * (pLog2(positiveP) + pLog2(counterP));
46 }
47
48 /**
49 * @param features 特征列表
50 * @param results 对应结果
51 * @return 将信息整合成新的格式
52 */
53 public Multimap<T, CountMap<K>> merge(final List<T> features, final List<T> results) {
54 //数据转化
55 Multimap<T, CountMap<K>> InfoMap = ArrayListMultimap.create();
56 Iterator result = results.iterator();
57 for (T feature : features) {
58 K res = (K) result.next();
59 boolean tag = false;
60 Collection<CountMap<K>> countMaps = InfoMap.get(feature);
61 for (CountMap countMap : countMaps) {
62 if (countMap.getKey().equals(res)) {
63 /*修改值*/
64 int num = countMap.getValue() + 1;
65 InfoMap.remove(feature, countMap);
66 InfoMap.put(feature, new CountMap<K>(res, num));
67 tag = true;
68 break;
69 }
70 }
71 if (!tag)
72 InfoMap.put(feature, new CountMap<K>(res, 1));
73 }
74
75 return InfoMap;
76 }
77
78 /**
79 * 信息增益
80 *
81 * @param infoMap 因素(Outlook,Temperature,Humidity,Wind)对应的结果
82 * @param dataTable 输入的数据表
83 * @param type 因素中的类型(Outlook{Sunny,Overcast,Rain})
84 * @param entropyS 总的熵值
85 * @param totalSize 总的样本数
86 * @return 信息增益
87 */
88 public double gain(Multimap<T, CountMap<K>> infoMap,
89 Map<K, List<T>> dataTable,
90 final String type,
91 double entropyS,
92 final int totalSize) {
93 //去重
94 Set<T> subTypes = Sets.newHashSet();
95 subTypes.addAll(dataTable.get(type));
96 /*计算*/
97 for (T subType : subTypes) {
98 Collection<CountMap<K>> countMaps = infoMap.get(subType);
99 double subSize = 0;
100 double positiveExample = 0;
101 double counterExample = 0;
102 for (CountMap<K> countMap : countMaps) {
103 subSize += countMap.getValue();
104 if (positiveExampleType.equals(countMap.getKey()))
105 positiveExample = countMap.getValue();
106 else
107 counterExample = countMap.getValue();
108 }
109 entropyS -= (subSize / totalSize) * entropy(positiveExample, counterExample);
110 }
111 return entropyS;
112 }
113
114 /**
115 * 计算
116 *
117 * @param dataTable 数据表
118 * @param types 因素列表{Outlook,Temperature,Humidity,Wind}
119 * @param resultType 结果(PlayTennis)
120 * @return 返回信息增益集合
121 */
122 public Map<String, Double> calculate(Map<K, List<T>> dataTable, List<K> types, K resultType) {
123
124 Map<String, Double> answer = Maps.newHashMap();
125 List<T> results = dataTable.get(resultType);
126 int totalSize = results.size();
127 int positiveExample = 0;
128 int counterExample = 0;
129 double entropyS = 0d;
130 for (T ExampleType : results) {
131 if (positiveExampleType.equals(ExampleType)) {
132 ++positiveExample;
133 continue;
134 }
135 ++counterExample;
136 }
137 /*计算总的熵*/
138 entropyS = entropy(positiveExample, counterExample);
139
140 Multimap<T, CountMap<K>> infoMap;
141 for (K type : types) {
142 infoMap = merge(dataTable.get(type), results);
143 double _gain = gain(infoMap, dataTable, (String) type, entropyS, totalSize);
144 answer.put((String) type, _gain);
145 }
146 return answer;
147 }
148
149 } 1package com.qunar.data.tree;
2
3 import com.google.common.collect.Lists;
4 import com.google.common.collect.Maps;
5
6 import java.util.*;
7
8 /**
9 * *********************************************************
10 * <p/>
11 * Author: XiJun.Gong
12 * Date: 2016-09-02 16:43
13 * Version: default 1.0.0
14 * Class description:
15 * <p/>
16 * *********************************************************
17 */
18 public class Main {
19
20 public static void main(String args[]) {
21
22 Scanner scanner = new Scanner(System.in);
23 while (scanner.hasNext()) {
24 DecisionTree<String, String> dt = new DecisionTree();
25 Map<String, List<String>> dataTable = Maps.newHashMap();
26 /*Map<String, List<String>> dataTable = Maps.newHashMap();*/
27 List<String> types = Lists.newArrayList();
28 String resultType;
29 int factorSize = scanner.nextInt();
30 int demoSize = scanner.nextInt();
31 String type;
32
33 for (int i = 0; i < factorSize; i++) {
34 List<String> demos = Lists.newArrayList();
35 type = scanner.next();
36 for (int j = 0; j < demoSize; j++) {
37 demos.add(scanner.next());
38 }
39 dataTable.put(type, demos);
40 }
41 for (int i = 1; i < factorSize; i++) {
42 types.add(scanner.next());
43 }
44 resultType = scanner.next();
45 Map<String, Double> ans = dt.calculate(dataTable, types, resultType);
46 List<Map.Entry<String, Double>> list = new ArrayList<Map.Entry<String, Double>>(ans.entrySet());
47 Collections.sort(list, new Comparator<Map.Entry<String, Double>>() {
48
49
50 @Override
51 public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {
52 return (o2.getValue() > o1.getValue() ? 1 : -1);
53 }
54 });
55
56 for (Map.Entry<String, Double> iterator : list) {
57 System.out.println(iterator.getKey() + "= " + iterator.getValue());
58 }
59 }
60 }
61
62 }
63 /**
64 *使用举例:*
65 5 14
66 Outlook Sunny Sunny Overcast Rain Rain Rain Overcast Sunny Sunny Rain Sunny Overcast Overcast Rain
67 Temperature Hot Hot Hot Mild Cool Cool Cool Mild Cool Mild Mild Mild Hot Mild
68 Humidity High High High High Normal Normal Normal High Normal Normal Normal High Normal High
69 Wind Weak Strong Weak Weak Weak Strong Strong Weak Weak Weak Strong Strong Weak Strong
70 PlayTennis No No Yes Yes Yes No Yes No Yes Yes Yes Yes Yes No
71 Outlook Temperature Humidity Wind PlayTennis
72 */
结果:
Outlook= 0.2467498197744391
Humidity= 0.15183550136234136
Wind= 0.04812703040826927
Temperature= 0.029222565658954647
- 输入一个数字,然后计算出从1到输入数字的和,要求,如果输入的数字小于1,则重新输入,直到输入正确的数字为止
- Linux基础(day76)
- zabbix设置QQ邮箱告警
- 关于JSON CSRF的一些思考
- linux学习第七十篇:expect脚本同步文件,expect脚本指定host和要同步的文件,构建文件分发系统,批量远程执行命令
- linux学习第六十九篇:分发系统介绍,expect脚本远程登录,expect脚本远程执行命令,expect脚本传递参数
- linux学习第六十八篇:告警系统邮件引擎,运行告警系统
- linux学习第六十七篇:告警系统主脚本,告警系统配置文件,告警系统监控项目
- linux学习第六十六篇:shell中的函数,shell中的数组,告警系统需求分析
- linux学习第六十五篇:for循环,while循环, break跳出循环,continue结束本次循环
- linux学习第六十四篇:Shell脚本中的逻辑判断,文件目录属性判断, if特殊用法,case判断
- linux学习第六十三篇:Shell脚本介绍,Shell脚本结构和执行,date命令用法,Shell脚本中的变量
- 熔断Hystrix使用尝鲜
- 报警系统QuickAlarm之默认报警规则扩展
- JavaScript 教程
- JavaScript 编辑工具
- JavaScript 与HTML
- JavaScript 与Java
- JavaScript 数据结构
- JavaScript 基本数据类型
- JavaScript 特殊数据类型
- JavaScript 运算符
- JavaScript typeof 运算符
- JavaScript 表达式
- JavaScript 类型转换
- JavaScript 基本语法
- JavaScript 注释
- Javascript 基本处理流程
- Javascript 选择结构
- Javascript if 语句
- Javascript if 语句的嵌套
- Javascript switch 语句
- Javascript 循环结构
- Javascript 循环结构实例
- Javascript 跳转语句
- Javascript 控制语句总结
- Javascript 函数介绍
- Javascript 函数的定义
- Javascript 函数调用
- Javascript 几种特殊的函数
- JavaScript 内置函数简介
- Javascript eval() 函数
- Javascript isFinite() 函数
- Javascript isNaN() 函数
- parseInt() 与 parseFloat()
- escape() 与 unescape()
- Javascript 字符串介绍
- Javascript length属性
- javascript 字符串函数
- Javascript 日期对象简介
- Javascript 日期对象用途
- Date 对象属性和方法
- Javascript 数组是什么
- Javascript 创建数组
- Javascript 数组赋值与取值
- Javascript 数组属性和方法
- 数据库char varchar nchar nvarchar,编码Unicode,UTF8,GBK等,Sql语句中文前为什么加N(一次线上数据存储乱码排查)
- [Maven][maven-shade-plugin]告警[WARNING] maven-shade-plugin has detected that some class files are pre
- asp.net core 3.1多种身份验证方案,cookie和jwt混合认证授权
- 只知道java反射,宁知道内省吗?
- JDK1.8新特性(七):默认方法,真香,开动!接口?我要升级!!
- Windows10上安装Linux子系统(WSL2,Ubuntu),配合Windows Terminal使用,还要什么自行车
- [Maven][maven-site-plugin]告警[WARNING] No project URL defined - decoration links will not be relativi
- QListWidget添加删除
- 使用GitHub Actions编译项目并将Jar发布到Maven Central仓库
- 为啥Flutter Hooks没有受到太多关注和青睐?
- 二叉搜索树删除节点 动画演示
- 并发与竞态 (自旋锁)
- [Maven][taglist-maven-plugin]告警[WARNING] Using legacy tag format
- [Maven][l10n-maven-plugin]告警[WARNING] No dictionary file under folder
- Python基础之多文件项目的演练