From 5c11f4bfe8026ba7ad972c8f0c14601bbf5e20eb Mon Sep 17 00:00:00 2001 From: inter Date: Thu, 4 Sep 2025 14:08:52 +0800 Subject: [PATCH] Add File --- .../dromara/easyai/tools/MeanClustering.java | 119 ++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 src/main/java/org/dromara/easyai/tools/MeanClustering.java diff --git a/src/main/java/org/dromara/easyai/tools/MeanClustering.java b/src/main/java/org/dromara/easyai/tools/MeanClustering.java new file mode 100644 index 0000000..8e86d38 --- /dev/null +++ b/src/main/java/org/dromara/easyai/tools/MeanClustering.java @@ -0,0 +1,119 @@ +package org.dromara.easyai.tools; + + +import org.dromara.easyai.entity.RGBNorm; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +//K均值聚类 +public class MeanClustering { + protected List matrixList = new ArrayList<>();//聚类集合 + private int length;//向量长度(模型需要返回) + protected int speciesQuantity;//种类数量(模型需要返回) + private final int maxTimes;//最大迭代次数 + protected List matrices = new ArrayList<>();//均值K模型(模型需要返回) + + public List getMatrices() { + return matrices; + } + + public float[] getResultByNorm() { + MeanSort meanSort = new MeanSort(); + float[] dm = new float[matrices.size() * length]; + matrices.sort(meanSort); + for (int i = 0; i < matrices.size(); i++) { + RGBNorm rgbNorm = matrices.get(i); + float[] rgb = rgbNorm.getRgb(); + for (int j = 0; j < rgb.length; j++) { + dm[i * rgb.length + j] = rgb[j]; + } + } + return dm; + } + + public MeanClustering(int speciesQuantity, int maxTimes) throws Exception { + this.speciesQuantity = speciesQuantity;//聚类的数量 + this.maxTimes = maxTimes; + } + + public void setFeature(float[] feature) throws Exception { + if (matrixList.isEmpty()) { + matrixList.add(feature); + length = feature.length; + } else { + if (length == feature.length) { + matrixList.add(feature); + } else { + throw new Exception("vector length is different"); + } + } + } + + private void averageMatrix() { + for (float[] rgb : matrixList) {//遍历当前集合 + float min = -1; + int id = 0; + for (int i = 0; i < speciesQuantity; i++) { + RGBNorm rgbNorm = matrices.get(i); + float dist = rgbNorm.getEDist(rgb); + if (min == -1 || dist < min) { + min = dist; + id = i; + } + } + //进簇 + RGBNorm rgbNorm = matrices.get(id); + rgbNorm.setColor(rgb); + } + //重新计算均值 + for (RGBNorm rgbNorm : matrices) { + rgbNorm.norm(); + } + } + + private boolean isNext() { + boolean isNext = false; + for (RGBNorm rgbNorm : matrices) { + isNext = rgbNorm.compare(); + if (isNext) { + break; + } + } + return isNext; + } + + private void clear() { + for (RGBNorm rgbNorm : matrices) { + rgbNorm.clear(); + } + } + + public void start() throws Exception {//开始聚类 + if (matrixList.size() > 1) { + Random random = new Random(); + for (int i = 0; i < speciesQuantity; i++) {//初始化均值向量 + int index = random.nextInt(matrixList.size()); + float[] rgb = matrixList.get(index); + RGBNorm rgbNorm = new RGBNorm(rgb, length); + //要进行深度克隆 + matrices.add(rgbNorm); + } + //进行两者的比较 + boolean isNext; + for (int i = 0; i < maxTimes; i++) { + //System.out.println("聚类:" + i); + averageMatrix(); + isNext = isNext(); + if (isNext && i < maxTimes - 1) { + clear(); + } else { + break; + } + } + } else { + throw new Exception("matrixList number less than 2"); + } + } +}