Add File
This commit is contained in:
112
src/main/java/org/dromara/easyai/bayesian/NativeBayesian.java
Normal file
112
src/main/java/org/dromara/easyai/bayesian/NativeBayesian.java
Normal file
@@ -0,0 +1,112 @@
|
||||
package org.dromara.easyai.bayesian;
|
||||
|
||||
import org.dromara.easyai.randomForest.DataTable;
|
||||
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.*;
|
||||
|
||||
public class NativeBayesian {
|
||||
private DataTable dataTable;
|
||||
// 存储每个类别的先验概率
|
||||
private final Map<Integer, Double> priorProbabilities;
|
||||
// 存储每个特征在每个类别下的条件概率
|
||||
private final Map<String, Map<Integer, Map<Integer, Double>>> conditionalProbabilities;
|
||||
|
||||
public NativeBayesian() {
|
||||
priorProbabilities = new HashMap<>();
|
||||
conditionalProbabilities = new HashMap<>();
|
||||
}
|
||||
|
||||
public NativeBayesian(DataTable dataTable) {
|
||||
this();
|
||||
this.dataTable = dataTable;
|
||||
}
|
||||
|
||||
public DataTable getDataTable() {
|
||||
return dataTable;
|
||||
}
|
||||
|
||||
public void setDataTable(DataTable dataTable) {
|
||||
this.dataTable = dataTable;
|
||||
}
|
||||
|
||||
public int classify(Object object) {
|
||||
// 存储每个类别的后验概率
|
||||
Map<Integer, Double> posteriorProbabilities = new HashMap<>();
|
||||
try {
|
||||
// 获取所有可能的类别
|
||||
List<Integer> classValues = dataTable.getTable().get(dataTable.getKey());
|
||||
Set<Integer> uniqueClasses = new HashSet<>(classValues);
|
||||
|
||||
// 遍历每个类别
|
||||
for (int classValue : uniqueClasses) {
|
||||
double posteriorProb = priorProbabilities.get(classValue);
|
||||
// 遍历每个特征
|
||||
for (String feature : dataTable.getKeyType()) {
|
||||
if (!feature.equals(dataTable.getKey())) {
|
||||
String methodName = "get" + feature.substring(0, 1).toUpperCase() + feature.substring(1);
|
||||
Method method = object.getClass().getMethod(methodName);
|
||||
int featureValue = (int) method.invoke(object);
|
||||
// 计算条件概率
|
||||
posteriorProb *= conditionalProbabilities.get(feature).get(classValue).getOrDefault(featureValue, 0.0);
|
||||
}
|
||||
}
|
||||
posteriorProbabilities.put(classValue, posteriorProb);
|
||||
}
|
||||
|
||||
// 找到后验概率最大的类别
|
||||
int maxClass = -1;
|
||||
double maxProb = -1;
|
||||
for (Map.Entry<Integer, Double> entry : posteriorProbabilities.entrySet()) {
|
||||
if (entry.getValue() > maxProb) {
|
||||
maxProb = entry.getValue();
|
||||
maxClass = entry.getKey();
|
||||
}
|
||||
}
|
||||
return maxClass;
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
public void study() {
|
||||
|
||||
// 计算先验概率
|
||||
List<Integer> classValues = dataTable.getTable().get(dataTable.getKey());
|
||||
Map<Integer, Integer> classCounts = new HashMap<>();
|
||||
for (int value : classValues) {
|
||||
classCounts.put(value, classCounts.getOrDefault(value, 0) + 1);
|
||||
}
|
||||
for (Map.Entry<Integer, Integer> entry : classCounts.entrySet()) {
|
||||
priorProbabilities.put(entry.getKey(), (double) entry.getValue() / dataTable.getLength());
|
||||
}
|
||||
|
||||
// 计算条件概率
|
||||
for (String feature : dataTable.getKeyType()) {
|
||||
if (!feature.equals(dataTable.getKey())) {
|
||||
Map<Integer, Map<Integer, Integer>> featureCounts = new HashMap<>();
|
||||
List<Integer> featureValues = dataTable.getTable().get(feature);
|
||||
for (int i = 0; i < dataTable.getLength(); i++) {
|
||||
int classValue = classValues.get(i);
|
||||
int featureValue = featureValues.get(i);
|
||||
featureCounts.computeIfAbsent(classValue, k -> new HashMap<>()).put(featureValue, featureCounts.get(classValue).getOrDefault(featureValue, 0) + 1);
|
||||
}
|
||||
Map<Integer, Map<Integer, Double>> featureProbabilities = new HashMap<>();
|
||||
for (Map.Entry<Integer, Map<Integer, Integer>> entry : featureCounts.entrySet()) {
|
||||
int classValue = entry.getKey();
|
||||
Map<Integer, Integer> counts = entry.getValue();
|
||||
Map<Integer, Double> probabilities = new HashMap<>();
|
||||
for (Map.Entry<Integer, Integer> countEntry : counts.entrySet()) {
|
||||
int featureValue = countEntry.getKey();
|
||||
int count = countEntry.getValue();
|
||||
probabilities.put(featureValue, (double) count / classCounts.get(classValue));
|
||||
}
|
||||
featureProbabilities.put(classValue, probabilities);
|
||||
}
|
||||
conditionalProbabilities.put(feature, featureProbabilities);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user