Add File
This commit is contained in:
119
src/main/java/org/dromara/easyai/tools/MeanClustering.java
Normal file
119
src/main/java/org/dromara/easyai/tools/MeanClustering.java
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package org.dromara.easyai.tools;
|
||||||
|
|
||||||
|
|
||||||
|
import org.dromara.easyai.entity.RGBNorm;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
|
//K均值聚类
|
||||||
|
public class MeanClustering {
|
||||||
|
protected List<float[]> matrixList = new ArrayList<>();//聚类集合
|
||||||
|
private int length;//向量长度(模型需要返回)
|
||||||
|
protected int speciesQuantity;//种类数量(模型需要返回)
|
||||||
|
private final int maxTimes;//最大迭代次数
|
||||||
|
protected List<RGBNorm> matrices = new ArrayList<>();//均值K模型(模型需要返回)
|
||||||
|
|
||||||
|
public List<RGBNorm> getMatrices() {
|
||||||
|
return matrices;
|
||||||
|
}
|
||||||
|
|
||||||
|
public float[] getResultByNorm() {
|
||||||
|
MeanSort meanSort = new MeanSort();
|
||||||
|
float[] dm = new float[matrices.size() * length];
|
||||||
|
matrices.sort(meanSort);
|
||||||
|
for (int i = 0; i < matrices.size(); i++) {
|
||||||
|
RGBNorm rgbNorm = matrices.get(i);
|
||||||
|
float[] rgb = rgbNorm.getRgb();
|
||||||
|
for (int j = 0; j < rgb.length; j++) {
|
||||||
|
dm[i * rgb.length + j] = rgb[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dm;
|
||||||
|
}
|
||||||
|
|
||||||
|
public MeanClustering(int speciesQuantity, int maxTimes) throws Exception {
|
||||||
|
this.speciesQuantity = speciesQuantity;//聚类的数量
|
||||||
|
this.maxTimes = maxTimes;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setFeature(float[] feature) throws Exception {
|
||||||
|
if (matrixList.isEmpty()) {
|
||||||
|
matrixList.add(feature);
|
||||||
|
length = feature.length;
|
||||||
|
} else {
|
||||||
|
if (length == feature.length) {
|
||||||
|
matrixList.add(feature);
|
||||||
|
} else {
|
||||||
|
throw new Exception("vector length is different");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void averageMatrix() {
|
||||||
|
for (float[] rgb : matrixList) {//遍历当前集合
|
||||||
|
float min = -1;
|
||||||
|
int id = 0;
|
||||||
|
for (int i = 0; i < speciesQuantity; i++) {
|
||||||
|
RGBNorm rgbNorm = matrices.get(i);
|
||||||
|
float dist = rgbNorm.getEDist(rgb);
|
||||||
|
if (min == -1 || dist < min) {
|
||||||
|
min = dist;
|
||||||
|
id = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//进簇
|
||||||
|
RGBNorm rgbNorm = matrices.get(id);
|
||||||
|
rgbNorm.setColor(rgb);
|
||||||
|
}
|
||||||
|
//重新计算均值
|
||||||
|
for (RGBNorm rgbNorm : matrices) {
|
||||||
|
rgbNorm.norm();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isNext() {
|
||||||
|
boolean isNext = false;
|
||||||
|
for (RGBNorm rgbNorm : matrices) {
|
||||||
|
isNext = rgbNorm.compare();
|
||||||
|
if (isNext) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return isNext;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void clear() {
|
||||||
|
for (RGBNorm rgbNorm : matrices) {
|
||||||
|
rgbNorm.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void start() throws Exception {//开始聚类
|
||||||
|
if (matrixList.size() > 1) {
|
||||||
|
Random random = new Random();
|
||||||
|
for (int i = 0; i < speciesQuantity; i++) {//初始化均值向量
|
||||||
|
int index = random.nextInt(matrixList.size());
|
||||||
|
float[] rgb = matrixList.get(index);
|
||||||
|
RGBNorm rgbNorm = new RGBNorm(rgb, length);
|
||||||
|
//要进行深度克隆
|
||||||
|
matrices.add(rgbNorm);
|
||||||
|
}
|
||||||
|
//进行两者的比较
|
||||||
|
boolean isNext;
|
||||||
|
for (int i = 0; i < maxTimes; i++) {
|
||||||
|
//System.out.println("聚类:" + i);
|
||||||
|
averageMatrix();
|
||||||
|
isNext = isNext();
|
||||||
|
if (isNext && i < maxTimes - 1) {
|
||||||
|
clear();
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new Exception("matrixList number less than 2");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user