Add File
This commit is contained in:
341
src/main/java/org/dromara/easyai/randomForest/Tree.java
Normal file
341
src/main/java/org/dromara/easyai/randomForest/Tree.java
Normal file
@@ -0,0 +1,341 @@
|
||||
package org.dromara.easyai.randomForest;
|
||||
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* @author lidapeng
|
||||
* @description
|
||||
* @date 3:12 下午 2020/2/17
|
||||
*/
|
||||
public class Tree {//决策树
|
||||
private DataTable dataTable;
|
||||
private Map<String, List<Integer>> table;//总样本
|
||||
private Node rootNode;//根节点
|
||||
private List<Integer> endList;//最终结果分类
|
||||
private final List<Node> lastNodes = new ArrayList<>();//最后一层节点集合
|
||||
private final Random random = new Random();
|
||||
private final float trustPunishment;//信任惩罚
|
||||
|
||||
public Node getRootNode() {
|
||||
return rootNode;
|
||||
}
|
||||
|
||||
public DataTable getDataTable() {
|
||||
return dataTable;
|
||||
}
|
||||
|
||||
public void setRootNode(Node rootNode) {
|
||||
this.rootNode = rootNode;
|
||||
}
|
||||
|
||||
private static class Gain {
|
||||
private float gain;//信息增益
|
||||
private float gainRatio;//信息增益率
|
||||
}
|
||||
|
||||
public Tree(float trustPunishment) {
|
||||
this.trustPunishment = trustPunishment;
|
||||
}
|
||||
|
||||
public Tree(DataTable dataTable, float trustPunishment) throws Exception {
|
||||
if (dataTable != null && dataTable.getKey() != null) {
|
||||
this.trustPunishment = trustPunishment;
|
||||
this.dataTable = dataTable;
|
||||
} else {
|
||||
throw new Exception("dataTable is empty");
|
||||
}
|
||||
}
|
||||
|
||||
private float log2(float p) {
|
||||
return (float)Math.log(p) / (float)Math.log(2);
|
||||
}
|
||||
|
||||
private float getEnt(List<Integer> list) {
|
||||
//记录了每个类别有几个
|
||||
Map<Integer, Integer> myType = new HashMap<>();
|
||||
for (int index : list) {
|
||||
int type = endList.get(index);//最终结果的类别
|
||||
if (myType.containsKey(type)) {
|
||||
myType.put(type, myType.get(type) + 1);
|
||||
} else {
|
||||
myType.put(type, 1);
|
||||
}
|
||||
}
|
||||
float ent = 0;
|
||||
//求信息熵
|
||||
for (Map.Entry<Integer, Integer> entry1 : myType.entrySet()) {
|
||||
float g = (float) entry1.getValue() / (float) list.size();//每个类别的概率
|
||||
ent = ent + g * log2(g);
|
||||
}
|
||||
return -ent;
|
||||
}
|
||||
|
||||
private float getGain(float ent, float dNub, float gain) {
|
||||
return gain + ent * dNub;
|
||||
}
|
||||
|
||||
private List<Node> createNode(Node node) {
|
||||
Set<String> attributes = node.attribute;
|
||||
List<Integer> fatherList = node.fatherList;
|
||||
if (!attributes.isEmpty()) {
|
||||
Map<String, Map<Integer, List<Integer>>> mapAll = new HashMap<>();//主键:可用属性 次主键该属性的属性值,集合该值对应的id
|
||||
float fatherEnt = getEnt(fatherList);
|
||||
int fatherNub = fatherList.size();//总样本数
|
||||
//该属性每个离散数据分类的集合
|
||||
for (int i = 0; i < fatherList.size(); i++) {
|
||||
int index = fatherList.get(i);//编号
|
||||
for (String attr : attributes) {
|
||||
if (!mapAll.containsKey(attr)) {
|
||||
mapAll.put(attr, new HashMap<>());
|
||||
}
|
||||
Map<Integer, List<Integer>> map = mapAll.get(attr);
|
||||
int attrValue = table.get(attr).get(index);//获取当前属性值
|
||||
if (!map.containsKey(attrValue)) {
|
||||
map.put(attrValue, new ArrayList<>());
|
||||
}
|
||||
List<Integer> list = map.get(attrValue);
|
||||
list.add(index);
|
||||
}
|
||||
}
|
||||
Map<String, List<Node>> nodeMap = new HashMap<>();
|
||||
int i = 0;
|
||||
float sigmaG = 0;
|
||||
Map<String, Gain> gainMap = new HashMap<>();
|
||||
for (Map.Entry<String, Map<Integer, List<Integer>>> mapEntry : mapAll.entrySet()) {
|
||||
Map<Integer, List<Integer>> map = mapEntry.getValue();//当前属性的 属性值及id
|
||||
//求信息增益
|
||||
float gain = 0;//信息增益
|
||||
float IV = 0;//增益率
|
||||
List<Node> nodeList = new ArrayList<>();
|
||||
String name = mapEntry.getKey();//可用属性名称
|
||||
nodeMap.put(name, nodeList);
|
||||
for (Map.Entry<Integer, List<Integer>> entry : map.entrySet()) {//遍历当前属性下的所有属性值的集合
|
||||
Set<String> nowAttribute = removeAttribute(attributes, name);
|
||||
Node sonNode = new Node();
|
||||
nodeList.add(sonNode);
|
||||
sonNode.attribute = nowAttribute;
|
||||
List<Integer> list = entry.getValue();//该属性值下的数据id集合
|
||||
sonNode.fatherList = list;
|
||||
sonNode.typeId = entry.getKey();//该属性值
|
||||
int myNub = list.size();//该属性值下数据的数量
|
||||
float ent = getEnt(list);//该属性值 的信息熵
|
||||
float dNub = (float) myNub / (float) fatherNub;//该属性值在 父级样本中出现的概率
|
||||
IV = dNub * log2(dNub) + IV;
|
||||
gain = getGain(ent, dNub, gain);
|
||||
}
|
||||
Gain gain1 = new Gain();
|
||||
gainMap.put(name, gain1);
|
||||
gain1.gain = fatherEnt - gain;//信息增益
|
||||
if (IV != 0) {
|
||||
gain1.gainRatio = gain1.gain / -IV;//增益率
|
||||
} else {
|
||||
gain1.gainRatio = 1000000;
|
||||
}
|
||||
sigmaG = gain1.gain + sigmaG;
|
||||
i++;
|
||||
}
|
||||
float avgGain = sigmaG / i;
|
||||
float gainRatio = -2;//最大增益率
|
||||
String key = null;//可选属性
|
||||
//System.out.println("平均信息增益==============================" + avgGain);
|
||||
for (Map.Entry<String, Gain> entry : gainMap.entrySet()) {
|
||||
Gain gain = entry.getValue();
|
||||
// System.out.println("主键:" + entry.getKey() + ",平均信息增益:" + avgGain + ",可用属性数量:" + gainMap.size()
|
||||
// + "该属性信息增益:" + gain.gain + ",该属性增益率:" + gain.gainRatio + ",当前最高增益率:" + gainRatio);
|
||||
if (gainMap.size() == 1 || ((gain.gain >= avgGain || (float)Math.abs(gain.gain - avgGain) < 0.000001) && (gain.gainRatio >= gainRatio || gainRatio == -2))) {
|
||||
gainRatio = gain.gainRatio;
|
||||
key = entry.getKey();
|
||||
}
|
||||
}
|
||||
node.key = key;
|
||||
List<Node> nodeList = nodeMap.get(key);
|
||||
for (int j = 0; j < nodeList.size(); j++) {//儿子绑定父亲关系
|
||||
nodeList.get(j).fatherNode = node;
|
||||
}
|
||||
for (int j = 0; j < nodeList.size(); j++) {
|
||||
Node node1 = nodeList.get(j);
|
||||
node1.nodeList = createNode(node1);
|
||||
}
|
||||
return nodeList;
|
||||
} else {
|
||||
//判断类别
|
||||
node.isEnd = true;//叶子节点
|
||||
node.type = getType(fatherList);
|
||||
lastNodes.add(node);//将全部最后一层节点集合
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private int getType(List<Integer> list) {
|
||||
Map<Integer, Integer> myType = new HashMap<>();
|
||||
for (int index : list) {
|
||||
int type = endList.get(index);//最终结果的类别
|
||||
if (myType.containsKey(type)) {
|
||||
myType.put(type, myType.get(type) + 1);
|
||||
} else {
|
||||
myType.put(type, 1);
|
||||
}
|
||||
}
|
||||
int type = 0;
|
||||
int nub = 0;
|
||||
for (Map.Entry<Integer, Integer> entry : myType.entrySet()) {
|
||||
int nowNub = entry.getValue();
|
||||
if (nowNub > nub) {
|
||||
type = entry.getKey();
|
||||
nub = nowNub;
|
||||
}
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
private Set<String> removeAttribute(Set<String> attributes, String name) {
|
||||
Set<String> attriBute = new HashSet<>();
|
||||
for (String myName : attributes) {
|
||||
if (!myName.equals(name)) {
|
||||
attriBute.add(myName);
|
||||
}
|
||||
}
|
||||
return attriBute;
|
||||
}
|
||||
|
||||
private int getTypeId(Object ob, String name) throws Exception {
|
||||
Class<?> body = ob.getClass();
|
||||
String methodName = "get" + name.substring(0, 1).toUpperCase() + name.substring(1);
|
||||
Method method = body.getMethod(methodName);
|
||||
return Integer.parseInt(method.invoke(ob).toString());
|
||||
}
|
||||
|
||||
public TreeWithTrust judge(Object ob) throws Exception {//进行类别判断
|
||||
if (rootNode != null) {
|
||||
TreeWithTrust treeWithTrust = new TreeWithTrust();
|
||||
treeWithTrust.setTrust(1.0f);
|
||||
goTree(ob, rootNode, treeWithTrust, 0);
|
||||
return treeWithTrust;
|
||||
} else {
|
||||
throw new Exception("rootNode is null");
|
||||
}
|
||||
}
|
||||
|
||||
private void punishment(TreeWithTrust treeWithTrust) {//信任惩罚
|
||||
//System.out.println("惩罚");
|
||||
float trust = treeWithTrust.getTrust();//获取当前信任值
|
||||
trust = trust * trustPunishment;
|
||||
treeWithTrust.setTrust(trust);
|
||||
}
|
||||
|
||||
private void goTree(Object ob, Node node, TreeWithTrust treeWithTrust, int times) throws Exception {//从树顶向下攀爬
|
||||
if (!node.isEnd) {
|
||||
int myType = getTypeId(ob, node.key);//当前类别的ID
|
||||
if (myType == 0) {//做信任惩罚
|
||||
punishment(treeWithTrust);
|
||||
}
|
||||
List<Node> nodeList = node.nodeList;
|
||||
boolean isOk = false;
|
||||
for (Node testNode : nodeList) {
|
||||
if (testNode.typeId == myType) {
|
||||
isOk = true;
|
||||
node = testNode;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!isOk) {//当前类别缺失,未知的属性值
|
||||
punishment(treeWithTrust);
|
||||
int index = random.nextInt(nodeList.size());
|
||||
node = nodeList.get(index);
|
||||
}
|
||||
times++;
|
||||
goTree(ob, node, treeWithTrust, times);
|
||||
} else {
|
||||
//当以0作为结束的时候要做严厉的信任惩罚
|
||||
if (node.typeId == 0) {
|
||||
int nub = rootNode.attribute.size() - times;
|
||||
//System.out.println("惩罚次数" + nub);
|
||||
for (int i = 0; i < nub; i++) {
|
||||
punishment(treeWithTrust);
|
||||
}
|
||||
}
|
||||
treeWithTrust.setType(node.type);
|
||||
}
|
||||
}
|
||||
|
||||
public void study() throws Exception {
|
||||
if (dataTable != null && dataTable.getLength() > 0) {
|
||||
rootNode = new Node();
|
||||
table = dataTable.getTable();
|
||||
endList = dataTable.getTable().get(dataTable.getKey());
|
||||
Set<String> set = dataTable.getKeyType();
|
||||
set.remove(dataTable.getKey());
|
||||
rootNode.attribute = set;//当前可用属性
|
||||
List<Integer> list = new ArrayList<>();
|
||||
for (int i = 0; i < endList.size(); i++) {
|
||||
list.add(i);
|
||||
}
|
||||
rootNode.fatherList = list;//当前父级样本
|
||||
rootNode.nodeList = createNode(rootNode);
|
||||
//进行后剪枝
|
||||
for (Node lastNode : lastNodes) {
|
||||
prune(lastNode.fatherNode);
|
||||
}
|
||||
lastNodes.clear();
|
||||
} else {
|
||||
throw new Exception("dataTable is null");
|
||||
}
|
||||
}
|
||||
|
||||
private void prune(Node node) {//执行剪枝
|
||||
if (node != null && !node.isEnd) {
|
||||
List<Node> listNode = node.nodeList;//子节点
|
||||
if (isPrune(node, listNode)) {//剪枝
|
||||
deduction(node);
|
||||
prune(node.fatherNode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void deduction(Node node) {
|
||||
node.isEnd = true;
|
||||
node.nodeList = null;
|
||||
node.type = getType(node.fatherList);
|
||||
}
|
||||
|
||||
private boolean isPrune(Node father, List<Node> sonNodes) {
|
||||
boolean isRemove = false;
|
||||
List<Integer> typeList = new ArrayList<>();
|
||||
for (int i = 0; i < sonNodes.size(); i++) {
|
||||
Node node = sonNodes.get(i);
|
||||
List<Integer> list = node.fatherList;
|
||||
typeList.add(getType(list));
|
||||
}
|
||||
int fatherType = getType(father.fatherList);
|
||||
int nub = getRightPoint(father.fatherList, fatherType);
|
||||
//父级该样本正确率
|
||||
float rightFather = (float) nub / (float) father.fatherList.size();
|
||||
int rightNub = 0;
|
||||
int rightAllNub = 0;
|
||||
for (int i = 0; i < sonNodes.size(); i++) {
|
||||
Node node = sonNodes.get(i);
|
||||
List<Integer> list = node.fatherList;
|
||||
int right = getRightPoint(list, typeList.get(i));
|
||||
rightNub = rightNub + right;
|
||||
rightAllNub = rightAllNub + list.size();
|
||||
}
|
||||
float rightPoint = (float) rightNub / (float) rightAllNub;//子节点正确率
|
||||
if (rightPoint <= rightFather) {
|
||||
isRemove = true;
|
||||
}
|
||||
return isRemove;
|
||||
}
|
||||
|
||||
private int getRightPoint(List<Integer> types, int type) {
|
||||
int nub = 0;
|
||||
for (int index : types) {
|
||||
int end = endList.get(index);
|
||||
if (end == type) {
|
||||
nub++;
|
||||
}
|
||||
}
|
||||
return nub;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user