From e7213a20f9a445695384cb8ae7d4e0725b007638 Mon Sep 17 00:00:00 2001 From: inter Date: Thu, 4 Sep 2025 14:09:01 +0800 Subject: [PATCH] Add File --- .../org/dromara/easyai/randomForest/Tree.java | 341 ++++++++++++++++++ 1 file changed, 341 insertions(+) create mode 100644 src/main/java/org/dromara/easyai/randomForest/Tree.java diff --git a/src/main/java/org/dromara/easyai/randomForest/Tree.java b/src/main/java/org/dromara/easyai/randomForest/Tree.java new file mode 100644 index 0000000..94f1c0e --- /dev/null +++ b/src/main/java/org/dromara/easyai/randomForest/Tree.java @@ -0,0 +1,341 @@ +package org.dromara.easyai.randomForest; + + +import java.lang.reflect.Method; +import java.util.*; + +/** + * @author lidapeng + * @description + * @date 3:12 下午 2020/2/17 + */ +public class Tree {//决策树 + private DataTable dataTable; + private Map> table;//总样本 + private Node rootNode;//根节点 + private List endList;//最终结果分类 + private final List lastNodes = new ArrayList<>();//最后一层节点集合 + private final Random random = new Random(); + private final float trustPunishment;//信任惩罚 + + public Node getRootNode() { + return rootNode; + } + + public DataTable getDataTable() { + return dataTable; + } + + public void setRootNode(Node rootNode) { + this.rootNode = rootNode; + } + + private static class Gain { + private float gain;//信息增益 + private float gainRatio;//信息增益率 + } + + public Tree(float trustPunishment) { + this.trustPunishment = trustPunishment; + } + + public Tree(DataTable dataTable, float trustPunishment) throws Exception { + if (dataTable != null && dataTable.getKey() != null) { + this.trustPunishment = trustPunishment; + this.dataTable = dataTable; + } else { + throw new Exception("dataTable is empty"); + } + } + + private float log2(float p) { + return (float)Math.log(p) / (float)Math.log(2); + } + + private float getEnt(List list) { + //记录了每个类别有几个 + Map myType = new HashMap<>(); + for (int index : list) { + int type = endList.get(index);//最终结果的类别 + if (myType.containsKey(type)) { + myType.put(type, myType.get(type) + 1); + } else { + myType.put(type, 1); + } + } + float ent = 0; + //求信息熵 + for (Map.Entry entry1 : myType.entrySet()) { + float g = (float) entry1.getValue() / (float) list.size();//每个类别的概率 + ent = ent + g * log2(g); + } + return -ent; + } + + private float getGain(float ent, float dNub, float gain) { + return gain + ent * dNub; + } + + private List createNode(Node node) { + Set attributes = node.attribute; + List fatherList = node.fatherList; + if (!attributes.isEmpty()) { + Map>> mapAll = new HashMap<>();//主键:可用属性 次主键该属性的属性值,集合该值对应的id + float fatherEnt = getEnt(fatherList); + int fatherNub = fatherList.size();//总样本数 + //该属性每个离散数据分类的集合 + for (int i = 0; i < fatherList.size(); i++) { + int index = fatherList.get(i);//编号 + for (String attr : attributes) { + if (!mapAll.containsKey(attr)) { + mapAll.put(attr, new HashMap<>()); + } + Map> map = mapAll.get(attr); + int attrValue = table.get(attr).get(index);//获取当前属性值 + if (!map.containsKey(attrValue)) { + map.put(attrValue, new ArrayList<>()); + } + List list = map.get(attrValue); + list.add(index); + } + } + Map> nodeMap = new HashMap<>(); + int i = 0; + float sigmaG = 0; + Map gainMap = new HashMap<>(); + for (Map.Entry>> mapEntry : mapAll.entrySet()) { + Map> map = mapEntry.getValue();//当前属性的 属性值及id + //求信息增益 + float gain = 0;//信息增益 + float IV = 0;//增益率 + List nodeList = new ArrayList<>(); + String name = mapEntry.getKey();//可用属性名称 + nodeMap.put(name, nodeList); + for (Map.Entry> entry : map.entrySet()) {//遍历当前属性下的所有属性值的集合 + Set nowAttribute = removeAttribute(attributes, name); + Node sonNode = new Node(); + nodeList.add(sonNode); + sonNode.attribute = nowAttribute; + List list = entry.getValue();//该属性值下的数据id集合 + sonNode.fatherList = list; + sonNode.typeId = entry.getKey();//该属性值 + int myNub = list.size();//该属性值下数据的数量 + float ent = getEnt(list);//该属性值 的信息熵 + float dNub = (float) myNub / (float) fatherNub;//该属性值在 父级样本中出现的概率 + IV = dNub * log2(dNub) + IV; + gain = getGain(ent, dNub, gain); + } + Gain gain1 = new Gain(); + gainMap.put(name, gain1); + gain1.gain = fatherEnt - gain;//信息增益 + if (IV != 0) { + gain1.gainRatio = gain1.gain / -IV;//增益率 + } else { + gain1.gainRatio = 1000000; + } + sigmaG = gain1.gain + sigmaG; + i++; + } + float avgGain = sigmaG / i; + float gainRatio = -2;//最大增益率 + String key = null;//可选属性 + //System.out.println("平均信息增益==============================" + avgGain); + for (Map.Entry entry : gainMap.entrySet()) { + Gain gain = entry.getValue(); +// System.out.println("主键:" + entry.getKey() + ",平均信息增益:" + avgGain + ",可用属性数量:" + gainMap.size() +// + "该属性信息增益:" + gain.gain + ",该属性增益率:" + gain.gainRatio + ",当前最高增益率:" + gainRatio); + if (gainMap.size() == 1 || ((gain.gain >= avgGain || (float)Math.abs(gain.gain - avgGain) < 0.000001) && (gain.gainRatio >= gainRatio || gainRatio == -2))) { + gainRatio = gain.gainRatio; + key = entry.getKey(); + } + } + node.key = key; + List nodeList = nodeMap.get(key); + for (int j = 0; j < nodeList.size(); j++) {//儿子绑定父亲关系 + nodeList.get(j).fatherNode = node; + } + for (int j = 0; j < nodeList.size(); j++) { + Node node1 = nodeList.get(j); + node1.nodeList = createNode(node1); + } + return nodeList; + } else { + //判断类别 + node.isEnd = true;//叶子节点 + node.type = getType(fatherList); + lastNodes.add(node);//将全部最后一层节点集合 + return null; + } + } + + private int getType(List list) { + Map myType = new HashMap<>(); + for (int index : list) { + int type = endList.get(index);//最终结果的类别 + if (myType.containsKey(type)) { + myType.put(type, myType.get(type) + 1); + } else { + myType.put(type, 1); + } + } + int type = 0; + int nub = 0; + for (Map.Entry entry : myType.entrySet()) { + int nowNub = entry.getValue(); + if (nowNub > nub) { + type = entry.getKey(); + nub = nowNub; + } + } + return type; + } + + private Set removeAttribute(Set attributes, String name) { + Set attriBute = new HashSet<>(); + for (String myName : attributes) { + if (!myName.equals(name)) { + attriBute.add(myName); + } + } + return attriBute; + } + + private int getTypeId(Object ob, String name) throws Exception { + Class body = ob.getClass(); + String methodName = "get" + name.substring(0, 1).toUpperCase() + name.substring(1); + Method method = body.getMethod(methodName); + return Integer.parseInt(method.invoke(ob).toString()); + } + + public TreeWithTrust judge(Object ob) throws Exception {//进行类别判断 + if (rootNode != null) { + TreeWithTrust treeWithTrust = new TreeWithTrust(); + treeWithTrust.setTrust(1.0f); + goTree(ob, rootNode, treeWithTrust, 0); + return treeWithTrust; + } else { + throw new Exception("rootNode is null"); + } + } + + private void punishment(TreeWithTrust treeWithTrust) {//信任惩罚 + //System.out.println("惩罚"); + float trust = treeWithTrust.getTrust();//获取当前信任值 + trust = trust * trustPunishment; + treeWithTrust.setTrust(trust); + } + + private void goTree(Object ob, Node node, TreeWithTrust treeWithTrust, int times) throws Exception {//从树顶向下攀爬 + if (!node.isEnd) { + int myType = getTypeId(ob, node.key);//当前类别的ID + if (myType == 0) {//做信任惩罚 + punishment(treeWithTrust); + } + List nodeList = node.nodeList; + boolean isOk = false; + for (Node testNode : nodeList) { + if (testNode.typeId == myType) { + isOk = true; + node = testNode; + break; + } + } + if (!isOk) {//当前类别缺失,未知的属性值 + punishment(treeWithTrust); + int index = random.nextInt(nodeList.size()); + node = nodeList.get(index); + } + times++; + goTree(ob, node, treeWithTrust, times); + } else { + //当以0作为结束的时候要做严厉的信任惩罚 + if (node.typeId == 0) { + int nub = rootNode.attribute.size() - times; + //System.out.println("惩罚次数" + nub); + for (int i = 0; i < nub; i++) { + punishment(treeWithTrust); + } + } + treeWithTrust.setType(node.type); + } + } + + public void study() throws Exception { + if (dataTable != null && dataTable.getLength() > 0) { + rootNode = new Node(); + table = dataTable.getTable(); + endList = dataTable.getTable().get(dataTable.getKey()); + Set set = dataTable.getKeyType(); + set.remove(dataTable.getKey()); + rootNode.attribute = set;//当前可用属性 + List list = new ArrayList<>(); + for (int i = 0; i < endList.size(); i++) { + list.add(i); + } + rootNode.fatherList = list;//当前父级样本 + rootNode.nodeList = createNode(rootNode); + //进行后剪枝 + for (Node lastNode : lastNodes) { + prune(lastNode.fatherNode); + } + lastNodes.clear(); + } else { + throw new Exception("dataTable is null"); + } + } + + private void prune(Node node) {//执行剪枝 + if (node != null && !node.isEnd) { + List listNode = node.nodeList;//子节点 + if (isPrune(node, listNode)) {//剪枝 + deduction(node); + prune(node.fatherNode); + } + } + } + + private void deduction(Node node) { + node.isEnd = true; + node.nodeList = null; + node.type = getType(node.fatherList); + } + + private boolean isPrune(Node father, List sonNodes) { + boolean isRemove = false; + List typeList = new ArrayList<>(); + for (int i = 0; i < sonNodes.size(); i++) { + Node node = sonNodes.get(i); + List list = node.fatherList; + typeList.add(getType(list)); + } + int fatherType = getType(father.fatherList); + int nub = getRightPoint(father.fatherList, fatherType); + //父级该样本正确率 + float rightFather = (float) nub / (float) father.fatherList.size(); + int rightNub = 0; + int rightAllNub = 0; + for (int i = 0; i < sonNodes.size(); i++) { + Node node = sonNodes.get(i); + List list = node.fatherList; + int right = getRightPoint(list, typeList.get(i)); + rightNub = rightNub + right; + rightAllNub = rightAllNub + list.size(); + } + float rightPoint = (float) rightNub / (float) rightAllNub;//子节点正确率 + if (rightPoint <= rightFather) { + isRemove = true; + } + return isRemove; + } + + private int getRightPoint(List types, int type) { + int nub = 0; + for (int index : types) { + int end = endList.get(index); + if (end == type) { + nub++; + } + } + return nub; + } +}