This commit is contained in:
2025-09-04 14:09:01 +08:00
parent 7dd1c82a9c
commit e7213a20f9

View File

@@ -0,0 +1,341 @@
package org.dromara.easyai.randomForest;
import java.lang.reflect.Method;
import java.util.*;
/**
* @author lidapeng
* @description
* @date 3:12 下午 2020/2/17
*/
public class Tree {//决策树
private DataTable dataTable;
private Map<String, List<Integer>> table;//总样本
private Node rootNode;//根节点
private List<Integer> endList;//最终结果分类
private final List<Node> lastNodes = new ArrayList<>();//最后一层节点集合
private final Random random = new Random();
private final float trustPunishment;//信任惩罚
public Node getRootNode() {
return rootNode;
}
public DataTable getDataTable() {
return dataTable;
}
public void setRootNode(Node rootNode) {
this.rootNode = rootNode;
}
private static class Gain {
private float gain;//信息增益
private float gainRatio;//信息增益率
}
public Tree(float trustPunishment) {
this.trustPunishment = trustPunishment;
}
public Tree(DataTable dataTable, float trustPunishment) throws Exception {
if (dataTable != null && dataTable.getKey() != null) {
this.trustPunishment = trustPunishment;
this.dataTable = dataTable;
} else {
throw new Exception("dataTable is empty");
}
}
private float log2(float p) {
return (float)Math.log(p) / (float)Math.log(2);
}
private float getEnt(List<Integer> list) {
//记录了每个类别有几个
Map<Integer, Integer> myType = new HashMap<>();
for (int index : list) {
int type = endList.get(index);//最终结果的类别
if (myType.containsKey(type)) {
myType.put(type, myType.get(type) + 1);
} else {
myType.put(type, 1);
}
}
float ent = 0;
//求信息熵
for (Map.Entry<Integer, Integer> entry1 : myType.entrySet()) {
float g = (float) entry1.getValue() / (float) list.size();//每个类别的概率
ent = ent + g * log2(g);
}
return -ent;
}
private float getGain(float ent, float dNub, float gain) {
return gain + ent * dNub;
}
private List<Node> createNode(Node node) {
Set<String> attributes = node.attribute;
List<Integer> fatherList = node.fatherList;
if (!attributes.isEmpty()) {
Map<String, Map<Integer, List<Integer>>> mapAll = new HashMap<>();//主键:可用属性 次主键该属性的属性值集合该值对应的id
float fatherEnt = getEnt(fatherList);
int fatherNub = fatherList.size();//总样本数
//该属性每个离散数据分类的集合
for (int i = 0; i < fatherList.size(); i++) {
int index = fatherList.get(i);//编号
for (String attr : attributes) {
if (!mapAll.containsKey(attr)) {
mapAll.put(attr, new HashMap<>());
}
Map<Integer, List<Integer>> map = mapAll.get(attr);
int attrValue = table.get(attr).get(index);//获取当前属性值
if (!map.containsKey(attrValue)) {
map.put(attrValue, new ArrayList<>());
}
List<Integer> list = map.get(attrValue);
list.add(index);
}
}
Map<String, List<Node>> nodeMap = new HashMap<>();
int i = 0;
float sigmaG = 0;
Map<String, Gain> gainMap = new HashMap<>();
for (Map.Entry<String, Map<Integer, List<Integer>>> mapEntry : mapAll.entrySet()) {
Map<Integer, List<Integer>> map = mapEntry.getValue();//当前属性的 属性值及id
//求信息增益
float gain = 0;//信息增益
float IV = 0;//增益率
List<Node> nodeList = new ArrayList<>();
String name = mapEntry.getKey();//可用属性名称
nodeMap.put(name, nodeList);
for (Map.Entry<Integer, List<Integer>> entry : map.entrySet()) {//遍历当前属性下的所有属性值的集合
Set<String> nowAttribute = removeAttribute(attributes, name);
Node sonNode = new Node();
nodeList.add(sonNode);
sonNode.attribute = nowAttribute;
List<Integer> list = entry.getValue();//该属性值下的数据id集合
sonNode.fatherList = list;
sonNode.typeId = entry.getKey();//该属性值
int myNub = list.size();//该属性值下数据的数量
float ent = getEnt(list);//该属性值 的信息熵
float dNub = (float) myNub / (float) fatherNub;//该属性值在 父级样本中出现的概率
IV = dNub * log2(dNub) + IV;
gain = getGain(ent, dNub, gain);
}
Gain gain1 = new Gain();
gainMap.put(name, gain1);
gain1.gain = fatherEnt - gain;//信息增益
if (IV != 0) {
gain1.gainRatio = gain1.gain / -IV;//增益率
} else {
gain1.gainRatio = 1000000;
}
sigmaG = gain1.gain + sigmaG;
i++;
}
float avgGain = sigmaG / i;
float gainRatio = -2;//最大增益率
String key = null;//可选属性
//System.out.println("平均信息增益==============================" + avgGain);
for (Map.Entry<String, Gain> entry : gainMap.entrySet()) {
Gain gain = entry.getValue();
// System.out.println("主键:" + entry.getKey() + ",平均信息增益:" + avgGain + ",可用属性数量:" + gainMap.size()
// + "该属性信息增益:" + gain.gain + ",该属性增益率:" + gain.gainRatio + ",当前最高增益率:" + gainRatio);
if (gainMap.size() == 1 || ((gain.gain >= avgGain || (float)Math.abs(gain.gain - avgGain) < 0.000001) && (gain.gainRatio >= gainRatio || gainRatio == -2))) {
gainRatio = gain.gainRatio;
key = entry.getKey();
}
}
node.key = key;
List<Node> nodeList = nodeMap.get(key);
for (int j = 0; j < nodeList.size(); j++) {//儿子绑定父亲关系
nodeList.get(j).fatherNode = node;
}
for (int j = 0; j < nodeList.size(); j++) {
Node node1 = nodeList.get(j);
node1.nodeList = createNode(node1);
}
return nodeList;
} else {
//判断类别
node.isEnd = true;//叶子节点
node.type = getType(fatherList);
lastNodes.add(node);//将全部最后一层节点集合
return null;
}
}
private int getType(List<Integer> list) {
Map<Integer, Integer> myType = new HashMap<>();
for (int index : list) {
int type = endList.get(index);//最终结果的类别
if (myType.containsKey(type)) {
myType.put(type, myType.get(type) + 1);
} else {
myType.put(type, 1);
}
}
int type = 0;
int nub = 0;
for (Map.Entry<Integer, Integer> entry : myType.entrySet()) {
int nowNub = entry.getValue();
if (nowNub > nub) {
type = entry.getKey();
nub = nowNub;
}
}
return type;
}
private Set<String> removeAttribute(Set<String> attributes, String name) {
Set<String> attriBute = new HashSet<>();
for (String myName : attributes) {
if (!myName.equals(name)) {
attriBute.add(myName);
}
}
return attriBute;
}
private int getTypeId(Object ob, String name) throws Exception {
Class<?> body = ob.getClass();
String methodName = "get" + name.substring(0, 1).toUpperCase() + name.substring(1);
Method method = body.getMethod(methodName);
return Integer.parseInt(method.invoke(ob).toString());
}
public TreeWithTrust judge(Object ob) throws Exception {//进行类别判断
if (rootNode != null) {
TreeWithTrust treeWithTrust = new TreeWithTrust();
treeWithTrust.setTrust(1.0f);
goTree(ob, rootNode, treeWithTrust, 0);
return treeWithTrust;
} else {
throw new Exception("rootNode is null");
}
}
private void punishment(TreeWithTrust treeWithTrust) {//信任惩罚
//System.out.println("惩罚");
float trust = treeWithTrust.getTrust();//获取当前信任值
trust = trust * trustPunishment;
treeWithTrust.setTrust(trust);
}
private void goTree(Object ob, Node node, TreeWithTrust treeWithTrust, int times) throws Exception {//从树顶向下攀爬
if (!node.isEnd) {
int myType = getTypeId(ob, node.key);//当前类别的ID
if (myType == 0) {//做信任惩罚
punishment(treeWithTrust);
}
List<Node> nodeList = node.nodeList;
boolean isOk = false;
for (Node testNode : nodeList) {
if (testNode.typeId == myType) {
isOk = true;
node = testNode;
break;
}
}
if (!isOk) {//当前类别缺失,未知的属性值
punishment(treeWithTrust);
int index = random.nextInt(nodeList.size());
node = nodeList.get(index);
}
times++;
goTree(ob, node, treeWithTrust, times);
} else {
//当以0作为结束的时候要做严厉的信任惩罚
if (node.typeId == 0) {
int nub = rootNode.attribute.size() - times;
//System.out.println("惩罚次数" + nub);
for (int i = 0; i < nub; i++) {
punishment(treeWithTrust);
}
}
treeWithTrust.setType(node.type);
}
}
public void study() throws Exception {
if (dataTable != null && dataTable.getLength() > 0) {
rootNode = new Node();
table = dataTable.getTable();
endList = dataTable.getTable().get(dataTable.getKey());
Set<String> set = dataTable.getKeyType();
set.remove(dataTable.getKey());
rootNode.attribute = set;//当前可用属性
List<Integer> list = new ArrayList<>();
for (int i = 0; i < endList.size(); i++) {
list.add(i);
}
rootNode.fatherList = list;//当前父级样本
rootNode.nodeList = createNode(rootNode);
//进行后剪枝
for (Node lastNode : lastNodes) {
prune(lastNode.fatherNode);
}
lastNodes.clear();
} else {
throw new Exception("dataTable is null");
}
}
private void prune(Node node) {//执行剪枝
if (node != null && !node.isEnd) {
List<Node> listNode = node.nodeList;//子节点
if (isPrune(node, listNode)) {//剪枝
deduction(node);
prune(node.fatherNode);
}
}
}
private void deduction(Node node) {
node.isEnd = true;
node.nodeList = null;
node.type = getType(node.fatherList);
}
private boolean isPrune(Node father, List<Node> sonNodes) {
boolean isRemove = false;
List<Integer> typeList = new ArrayList<>();
for (int i = 0; i < sonNodes.size(); i++) {
Node node = sonNodes.get(i);
List<Integer> list = node.fatherList;
typeList.add(getType(list));
}
int fatherType = getType(father.fatherList);
int nub = getRightPoint(father.fatherList, fatherType);
//父级该样本正确率
float rightFather = (float) nub / (float) father.fatherList.size();
int rightNub = 0;
int rightAllNub = 0;
for (int i = 0; i < sonNodes.size(); i++) {
Node node = sonNodes.get(i);
List<Integer> list = node.fatherList;
int right = getRightPoint(list, typeList.get(i));
rightNub = rightNub + right;
rightAllNub = rightAllNub + list.size();
}
float rightPoint = (float) rightNub / (float) rightAllNub;//子节点正确率
if (rightPoint <= rightFather) {
isRemove = true;
}
return isRemove;
}
private int getRightPoint(List<Integer> types, int type) {
int nub = 0;
for (int index : types) {
int end = endList.get(index);
if (end == type) {
nub++;
}
}
return nub;
}
}