diff --git a/src/main/java/org/dromara/easyai/tools/Knn.java b/src/main/java/org/dromara/easyai/tools/Knn.java new file mode 100644 index 0000000..acc631f --- /dev/null +++ b/src/main/java/org/dromara/easyai/tools/Knn.java @@ -0,0 +1,128 @@ +package org.dromara.easyai.tools; + +import org.dromara.easyai.matrixTools.Matrix; +import org.dromara.easyai.matrixTools.MatrixOperation; + +import java.util.*; + +public class Knn extends MatrixOperation {//KNN分类器 + private Map> featureMap = new HashMap<>(); + private int length;//向量长度(需要返回) + private final int nub;//选择几个人投票 + + public Knn(int nub) { + this.nub = nub; + } + + public void setFeatureMap(Map> featureMap) { + this.featureMap = featureMap; + } + + public Map> getFeatureMap() { + return featureMap; + } + + public void removeType(int type) { + featureMap.remove(type); + } + + public void revoke(int type, int nub) {//撤销一个类别最新的 + List list = featureMap.get(type); + for (int i = 0; i < nub; i++) { + list.remove(list.size() - 1); + } + } + + public int getNub(int type) {//获取该分类模型的数量 + int nub = 0; + List list = featureMap.get(type); + if (list != null) { + nub = list.size(); + } + return nub; + } + + public void insertMatrix(Matrix vector, int tag) throws Exception { + if (vector.isVector() && vector.isRowVector()) { + if (featureMap.size() == 0) { + List list = new ArrayList<>(); + list.add(vector); + featureMap.put(tag, list); + length = vector.getY(); + } else { + if (length == vector.getY()) { + if (featureMap.containsKey(tag)) { + featureMap.get(tag).add(vector); + } else { + List list = new ArrayList<>(); + list.add(vector); + featureMap.put(tag, list); + } + } else { + throw new Exception("vector length is different"); + } + } + } else { + throw new Exception("this matrix is not vector or rowVector"); + } + } + + private void compare(float[] values, int[] types, float value, int type) { + for (int i = 0; i < values.length; i++) { + float val = values[i]; + if (val < 0) { + values[i] = value; + types[i] = type; + break; + } else { + if (value < val) { + for (int j = values.length - 2; j >= i; j--) { + values[j + 1] = values[j]; + types[j + 1] = types[j]; + } + values[i] = value; + types[i] = type; + break; + } + } + } + } + + public int getType(Matrix vector) throws Exception {//识别分类 + int ty = 0; + float[] dists = new float[nub]; + // System.out.println("测试:" + vector.getString()); + int[] types = new int[nub]; + for (int i = 0; i < nub; i++) { + dists[i] = -1; + } + for (Map.Entry> entry : featureMap.entrySet()) { + int type = entry.getKey(); + List matrices = entry.getValue(); + for (Matrix matrix : matrices) { + float dist = getEDist(matrix, vector); + compare(dists, types, dist, type); + } + } + System.out.println(Arrays.toString(types)); + Map map = new HashMap<>(); + for (int i = 0; i < nub; i++) { + int type = types[i]; + if (map.containsKey(type)) { + map.put(type, map.get(type) + 1); + } else { + map.put(type, 1); + } + } + int max = 0; + for (Map.Entry entry : map.entrySet()) { + int value = entry.getValue(); + int type = entry.getKey(); + if (value > max) { + ty = type; + max = value; + } + } + return ty; + } +}