diff --git a/src/main/java/org/dromara/easyai/bayesian/NativeBayesian.java b/src/main/java/org/dromara/easyai/bayesian/NativeBayesian.java new file mode 100644 index 0000000..2e8e32c --- /dev/null +++ b/src/main/java/org/dromara/easyai/bayesian/NativeBayesian.java @@ -0,0 +1,112 @@ +package org.dromara.easyai.bayesian; + +import org.dromara.easyai.randomForest.DataTable; + +import java.lang.reflect.Method; +import java.util.*; + +public class NativeBayesian { + private DataTable dataTable; + // 存储每个类别的先验概率 + private final Map priorProbabilities; + // 存储每个特征在每个类别下的条件概率 + private final Map>> conditionalProbabilities; + + public NativeBayesian() { + priorProbabilities = new HashMap<>(); + conditionalProbabilities = new HashMap<>(); + } + + public NativeBayesian(DataTable dataTable) { + this(); + this.dataTable = dataTable; + } + + public DataTable getDataTable() { + return dataTable; + } + + public void setDataTable(DataTable dataTable) { + this.dataTable = dataTable; + } + + public int classify(Object object) { + // 存储每个类别的后验概率 + Map posteriorProbabilities = new HashMap<>(); + try { + // 获取所有可能的类别 + List classValues = dataTable.getTable().get(dataTable.getKey()); + Set uniqueClasses = new HashSet<>(classValues); + + // 遍历每个类别 + for (int classValue : uniqueClasses) { + double posteriorProb = priorProbabilities.get(classValue); + // 遍历每个特征 + for (String feature : dataTable.getKeyType()) { + if (!feature.equals(dataTable.getKey())) { + String methodName = "get" + feature.substring(0, 1).toUpperCase() + feature.substring(1); + Method method = object.getClass().getMethod(methodName); + int featureValue = (int) method.invoke(object); + // 计算条件概率 + posteriorProb *= conditionalProbabilities.get(feature).get(classValue).getOrDefault(featureValue, 0.0); + } + } + posteriorProbabilities.put(classValue, posteriorProb); + } + + // 找到后验概率最大的类别 + int maxClass = -1; + double maxProb = -1; + for (Map.Entry entry : posteriorProbabilities.entrySet()) { + if (entry.getValue() > maxProb) { + maxProb = entry.getValue(); + maxClass = entry.getKey(); + } + } + return maxClass; + } catch (Exception e) { + e.printStackTrace(); + return -1; + } + } + + public void study() { + + // 计算先验概率 + List classValues = dataTable.getTable().get(dataTable.getKey()); + Map classCounts = new HashMap<>(); + for (int value : classValues) { + classCounts.put(value, classCounts.getOrDefault(value, 0) + 1); + } + for (Map.Entry entry : classCounts.entrySet()) { + priorProbabilities.put(entry.getKey(), (double) entry.getValue() / dataTable.getLength()); + } + + // 计算条件概率 + for (String feature : dataTable.getKeyType()) { + if (!feature.equals(dataTable.getKey())) { + Map> featureCounts = new HashMap<>(); + List featureValues = dataTable.getTable().get(feature); + for (int i = 0; i < dataTable.getLength(); i++) { + int classValue = classValues.get(i); + int featureValue = featureValues.get(i); + featureCounts.computeIfAbsent(classValue, k -> new HashMap<>()).put(featureValue, featureCounts.get(classValue).getOrDefault(featureValue, 0) + 1); + } + Map> featureProbabilities = new HashMap<>(); + for (Map.Entry> entry : featureCounts.entrySet()) { + int classValue = entry.getKey(); + Map counts = entry.getValue(); + Map probabilities = new HashMap<>(); + for (Map.Entry countEntry : counts.entrySet()) { + int featureValue = countEntry.getKey(); + int count = countEntry.getValue(); + probabilities.put(featureValue, (double) count / classCounts.get(classValue)); + } + featureProbabilities.put(classValue, probabilities); + } + conditionalProbabilities.put(feature, featureProbabilities); + } + } + } +} +