Add File
This commit is contained in:
139
src/main/java/org/dromara/easyai/randomForest/RandomForest.java
Normal file
139
src/main/java/org/dromara/easyai/randomForest/RandomForest.java
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
package org.dromara.easyai.randomForest;
|
||||||
|
|
||||||
|
import org.dromara.easyai.tools.ArithUtil;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author lidapeng
|
||||||
|
* @description 随机森林
|
||||||
|
* @date 3:50 下午 2020/2/22
|
||||||
|
*/
|
||||||
|
public class RandomForest {
|
||||||
|
private Random random = new Random();
|
||||||
|
private Tree[] forest;
|
||||||
|
private float trustTh = 0.1F;//信任阈值
|
||||||
|
private float trustPunishment = 0.1F;//信任惩罚
|
||||||
|
|
||||||
|
public float getTrustPunishment() {
|
||||||
|
return trustPunishment;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setTrustPunishment(float trustPunishment) {
|
||||||
|
this.trustPunishment = trustPunishment;
|
||||||
|
}
|
||||||
|
|
||||||
|
public float getTrustTh() {
|
||||||
|
return trustTh;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setTrustTh(float trustTh) {
|
||||||
|
this.trustTh = trustTh;
|
||||||
|
}
|
||||||
|
|
||||||
|
public RandomForest() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public RandomForest(int treeNub) throws Exception {
|
||||||
|
if (treeNub > 0) {
|
||||||
|
forest = new Tree[treeNub];
|
||||||
|
} else {
|
||||||
|
throw new Exception("Number of trees must be greater than 0");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public RfModel getModel() {//获取模型
|
||||||
|
RfModel rfModel = new RfModel();
|
||||||
|
Map<Integer, Node> nodeMap = new HashMap<>();
|
||||||
|
for (int i = 0; i < forest.length; i++) {
|
||||||
|
Node node = forest[i].getRootNode();
|
||||||
|
nodeMap.put(i, node);
|
||||||
|
}
|
||||||
|
rfModel.setNodeMap(nodeMap);
|
||||||
|
return rfModel;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int forest(Object object) throws Exception {//随机森林识别
|
||||||
|
Map<Integer, Float> map = new HashMap<>();
|
||||||
|
for (int i = 0; i < forest.length; i++) {
|
||||||
|
Tree tree = forest[i];
|
||||||
|
TreeWithTrust treeWithTrust = tree.judge(object);
|
||||||
|
int type = treeWithTrust.getType();
|
||||||
|
//System.out.println(type);
|
||||||
|
float trust = treeWithTrust.getTrust();
|
||||||
|
if (map.containsKey(type)) {
|
||||||
|
map.put(type, map.get(type) + trust);
|
||||||
|
} else {
|
||||||
|
map.put(type, trust);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int type = 0;
|
||||||
|
float nub = 0;
|
||||||
|
for (Map.Entry<Integer, Float> entry : map.entrySet()) {
|
||||||
|
float myNub = entry.getValue();
|
||||||
|
//System.out.println("type==" + entry.getKey() + ",nub==" + myNub);
|
||||||
|
if (myNub > nub) {
|
||||||
|
type = entry.getKey();
|
||||||
|
nub = myNub;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (nub < ArithUtil.mul(forest.length, trustTh)) {
|
||||||
|
type = 0;
|
||||||
|
}
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
//rf初始化
|
||||||
|
public void init(DataTable dataTable) throws Exception {
|
||||||
|
//一棵树属性的数量
|
||||||
|
if (dataTable.getSize() > 4) {
|
||||||
|
int kNub = (int) ((int) (float)Math.log(dataTable.getSize()) / (float)Math.log(2));
|
||||||
|
//int kNub = dataTable.getSize() / 2;
|
||||||
|
// System.out.println("knNub==" + kNub);
|
||||||
|
for (int i = 0; i < forest.length; i++) {
|
||||||
|
Tree tree = new Tree(getRandomData(dataTable, kNub), trustPunishment);
|
||||||
|
forest[i] = tree;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new Exception("Number of feature categories must be greater than 3");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void study() throws Exception {//学习
|
||||||
|
for (int i = 0; i < forest.length; i++) {
|
||||||
|
//System.out.println("开始学习==" + i + ",treeNub==" + forest.length);
|
||||||
|
Tree tree = forest[i];
|
||||||
|
tree.study();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void insert(Object object) {//添加学习参数
|
||||||
|
for (int i = 0; i < forest.length; i++) {
|
||||||
|
Tree tree = forest[i];
|
||||||
|
tree.getDataTable().insert(object);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//从总属性列表中随机挑选属性kNub个属性数量
|
||||||
|
private DataTable getRandomData(DataTable dataTable, int kNub) throws Exception {
|
||||||
|
Set<String> attr = dataTable.getKeyType();
|
||||||
|
Set<String> myName = new HashSet<>();
|
||||||
|
String key = dataTable.getKey();//结果
|
||||||
|
List<String> list = new ArrayList<>();
|
||||||
|
for (String name : attr) {//加载主键
|
||||||
|
if (!name.equals(key)) {
|
||||||
|
list.add(name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < kNub; i++) {
|
||||||
|
int index = random.nextInt(list.size());
|
||||||
|
myName.add(list.get(index));
|
||||||
|
list.remove(index);
|
||||||
|
}
|
||||||
|
myName.add(key);
|
||||||
|
//System.out.println(myName);
|
||||||
|
DataTable data = new DataTable(myName);
|
||||||
|
data.setKey(key);
|
||||||
|
return data;
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user