This commit is contained in:
2025-09-04 14:09:13 +08:00
parent 1ef6436575
commit e3ec978568

View File

@@ -0,0 +1,343 @@
package org.dromara.easyai.yolo;
import org.dromara.easyai.entity.Box;
import org.dromara.easyai.entity.ThreeChannelMatrix;
import org.dromara.easyai.function.ReLu;
import org.dromara.easyai.i.OutBack;
import org.dromara.easyai.nerveCenter.NerveManager;
import org.dromara.easyai.nerveEntity.SensoryNerve;
import org.dromara.easyai.tools.NMS;
import org.dromara.easyai.tools.Picture;
import java.util.*;
public class FastYolo {//yolo
private final YoloConfig yoloConfig;
private final NerveManager typeNerveManager;//识别网络
private final List<TypeBody> typeBodies = new ArrayList<>();//样本数据及集合
private final int winWidth;
private final int winHeight;
private final int widthStep;
private final int heightStep;
private final boolean proTrust;
public FastYolo(YoloConfig yoloConfig) throws Exception {
float stepReduce = yoloConfig.getCheckStepReduce();
this.yoloConfig = yoloConfig;
winHeight = yoloConfig.getWindowHeight();
winWidth = yoloConfig.getWindowWidth();
widthStep = (int) (winWidth * stepReduce);
heightStep = (int) (winHeight * stepReduce);
proTrust = yoloConfig.isProTrust();
if (stepReduce <= 1 && widthStep > 0 && heightStep > 0) {
typeNerveManager = new NerveManager(3, yoloConfig.getHiddenNerveNub(), yoloConfig.getTypeNub() + 1,
1, new ReLu(), yoloConfig.getStudyRate(), yoloConfig.getRegularModel(), yoloConfig.getRegular()
, yoloConfig.getCoreNumber(), yoloConfig.getGaMa(), yoloConfig.getGMaxTh(), yoloConfig.isAuto());
typeNerveManager.initImageNet(yoloConfig.getChannelNo(), yoloConfig.getKernelSize(), winHeight, winWidth, true,
yoloConfig.isShowLog(), yoloConfig.getStudyRate(), new ReLu(), yoloConfig.getMinFeatureValue(), yoloConfig.getOneConvStudy()
, yoloConfig.isNorm(), yoloConfig.getGRate());
} else {
throw new Exception("The stepReduce must be (0,1] and widthStep ,heightStep must Greater than 0");
}
}
private void insertYoloBody(YoloBody yoloBody) throws Exception {
boolean here = false;
for (TypeBody typeBody : typeBodies) {
if (typeBody.getTypeID() == yoloBody.getTypeID()) {
here = true;
typeBody.insertYoloBody(yoloBody);
break;
}
}
if (!here) {//不存在
TypeBody typeBody = new TypeBody(yoloConfig, winWidth, winHeight);
typeBody.setTypeID(yoloBody.getTypeID());
typeBody.setMappingID(typeBodies.size() + 1);
typeBody.insertYoloBody(yoloBody);
typeBodies.add(typeBody);
}
}
public void insertModel(YoloModel yoloModel) throws Exception {
typeNerveManager.insertConvModel(yoloModel.getTypeModel());
List<TypeModel> typeModels = yoloModel.getTypeModels();
for (TypeModel typeModel : typeModels) {
TypeBody typeBody = new TypeBody(yoloConfig, winWidth, winHeight);
typeBody.setTypeID(typeModel.getTypeID());
typeBody.setMappingID(typeModel.getMappingID());
typeBody.setMinWidth(typeModel.getMinWidth());
typeBody.setMinHeight(typeModel.getMinHeight());
typeBody.setMaxWidth(typeModel.getMaxWidth());
typeBody.setMaxHeight(typeModel.getMaxHeight());
typeBody.getPositionNerveManager().insertConvModel(typeModel.getPositionModel());
typeBodies.add(typeBody);
}
}
public YoloModel getModel() throws Exception {
YoloModel yoloModel = new YoloModel();
yoloModel.setTypeModel(typeNerveManager.getConvModel());
List<TypeModel> typeModels = new ArrayList<>();
for (TypeBody typeBody : typeBodies) {
TypeModel typeModel = new TypeModel();
typeModel.setTypeID(typeBody.getTypeID());
typeModel.setMappingID(typeBody.getMappingID());
typeModel.setMinHeight(typeBody.getMinHeight());
typeModel.setMinWidth(typeBody.getMinWidth());
typeModel.setMaxWidth(typeBody.getMaxWidth());
typeModel.setMaxHeight(typeBody.getMaxHeight());
typeModel.setPositionModel(typeBody.getPositionNerveManager().getConvModel());
typeModels.add(typeModel);
}
yoloModel.setTypeModels(typeModels);
return yoloModel;
}
private Box getBox(int i, int j, int maxX, int maxY, PositionBack positionBack, TypeBody typeBody, float out) throws Exception {
float zhou = winHeight + winWidth;
Box box = new Box();
float centerX = i - positionBack.getDistX() * zhou;
float centerY = j - positionBack.getDistY() * zhou;
int width = (int) typeBody.getRealWidth(positionBack.getWidth());
int height = (int) typeBody.getRealHeight(positionBack.getHeight());
int realX = (int) (centerX - height / 2f);
int realY = (int) (centerY - width / 2f);
if (realX < 0) {
realX = 0;
}
if (realY < 0) {
realY = 0;
}
if (realX + height > maxX) {
realX = maxX - height;
}
if (realY + width > maxY) {
realY = maxY - width;
}
box.setX(realX);
box.setY(realY);
box.setxSize(height);
box.setySize(width);
float trust;
if (proTrust) {
trust = out;
} else {
trust = positionBack.getTrust();
}
box.setConfidence(trust);
box.setTypeID(typeBody.getTypeID());
return box;
}
private List<OutBox> getOutBoxList(List<Box> boxes) {
List<OutBox> outBoxes = new ArrayList<>();
for (Box box : boxes) {
OutBox outBox = new OutBox();
outBox.setX(box.getY());
outBox.setY(box.getX());
outBox.setHeight(box.getxSize());
outBox.setWidth(box.getySize());
outBox.setTypeID(String.valueOf(box.getTypeID()));
outBox.setTrust(box.getConfidence());
outBoxes.add(outBox);
}
return outBoxes;
}
public List<OutBox> look(ThreeChannelMatrix th, long eventID) throws Exception {
int x = th.getX();
int y = th.getY();
List<Box> boxes = new ArrayList<>();
NMS nms = new NMS(yoloConfig.getIouTh());
float pth = yoloConfig.getPth();
for (int i = 0; i <= x - winHeight; i += heightStep) {
for (int j = 0; j <= y - winWidth; j += widthStep) {
YoloTypeBack yoloTypeBack = new YoloTypeBack();
PositionBack positionBack = new PositionBack();
ThreeChannelMatrix myTh = th.cutChannel(i, j, winHeight, winWidth);
study(eventID, typeNerveManager.getConvInput(), myTh, false, null, yoloTypeBack);
int mappingID = yoloTypeBack.getId();//映射id
float out = yoloTypeBack.getOut();
if (mappingID != typeBodies.size() + 1 && out > pth) {
TypeBody typeBody = getTypeBodyByMappingID(mappingID);
SensoryNerve convInput = typeBody.getPositionNerveManager().getConvInput();
study(eventID, convInput, myTh, false, null, positionBack);
boxes.add(getBox(i, j, x, y, positionBack, typeBody, out));
}
}
}
if (boxes.isEmpty()) {
return null;
}
return getOutBoxList(nms.start(boxes));
}
public void toStudy(List<YoloSample> yoloSamples) throws Exception {
for (YoloSample yoloSample : yoloSamples) {
List<YoloBody> yoloBodies = yoloSample.getYoloBodies();
for (YoloBody yoloBody : yoloBodies) {
insertYoloBody(yoloBody);
}
}
int enh = yoloConfig.getEnhance();
for (int i = 0; i < enh; i++) {
System.out.println("第===========================" + i + "" + "" + enh + "");
int index = 0;
for (YoloSample yoloSample : yoloSamples) {
index++;
System.out.println("index:" + index + ",size:" + yoloSamples.size());
study(yoloSample);
}
}
}
private Box changeBox(YoloBody yoloBody) {
Box box = new Box();
box.setX(yoloBody.getY());
box.setY(yoloBody.getX());
box.setTypeID(yoloBody.getTypeID());
box.setxSize(yoloBody.getHeight());
box.setySize(yoloBody.getWidth());
return box;
}
private YoloMessage containSample(List<Box> boxes, Box testBox, NMS nms, int i, int j) {
float containIouTh = yoloConfig.getContainIouTh();
float maxIou = 0;
Box myBox = null;
YoloMessage yoloMessage = new YoloMessage();
for (Box box : boxes) {
float iou = nms.getSRatio(testBox, box, false);
if (iou > containIouTh && iou > maxIou) {//有相交
maxIou = iou;
myBox = box;
}
}
if (myBox != null) {
int centerX = myBox.getX() + myBox.getxSize() / 2;
int centerY = myBox.getY() + myBox.getySize() / 2;
float zhou = winHeight + winWidth;
float distX = (float) (i - centerX) / zhou;
float distY = (float) (j - centerY) / zhou;
TypeBody typeBody = getTypeBodyByTypeID(myBox.getTypeID());
float height = typeBody.getOneHeight(myBox.getxSize());
float width = typeBody.getOneWidth(myBox.getySize());
float trust = 0;
if (centerX >= i && centerX <= (i + winHeight) && centerY >= j && centerY <= (j + winWidth)) {
trust = 1;
}
yoloMessage.setWidth(width);
yoloMessage.setHeight(height);
yoloMessage.setDistX(distX);
yoloMessage.setDistY(distY);
yoloMessage.setTrust(trust);
yoloMessage.setMappingID(typeBody.getMappingID());
yoloMessage.setTypeBody(typeBody);
} else {
yoloMessage.setBackGround(true);
yoloMessage.setMappingID(typeBodies.size() + 1);
}
return yoloMessage;
}
private List<Box> getBoxes(List<YoloBody> yoloBodies) {
List<Box> boxes = new ArrayList<>();
for (YoloBody yoloBody : yoloBodies) {
boxes.add(changeBox(yoloBody));
}
return boxes;
}
private List<YoloMessage> anySort(List<YoloMessage> sentences) {//做乱序
Random random = new Random();
List<YoloMessage> sent = new ArrayList<>();
int time = sentences.size();
for (int i = 0; i < time; i++) {
int size = sentences.size();
int index = random.nextInt(size);
sent.add(sentences.get(index));
sentences.remove(index);
}
return sent;
}
private void study(YoloSample yoloSample) throws Exception {//
List<YoloBody> yoloBodies = yoloSample.getYoloBodies();//集合
List<Box> boxes = getBoxes(yoloBodies);
String url = yoloSample.getLocationURL();//地址
NMS nms = new NMS(yoloConfig.getIouTh());
ThreeChannelMatrix pic = Picture.getThreeMatrix(url, false);
List<YoloMessage> yoloMessageList = new ArrayList<>();
float stepReduce = yoloConfig.getStepReduce();
int stepX = (int) (winHeight * stepReduce);
int stepY = (int) (winWidth * stepReduce);
if (stepX < 1 || stepY < 1) {
throw new Exception("训练步长收缩后步长必须大于0");
}
for (int i = 0; i <= pic.getX() - winHeight; i += stepX) {
for (int j = 0; j <= pic.getY() - winWidth; j += stepY) {
Box testBox = new Box();
testBox.setX(i);
testBox.setY(j);
testBox.setxSize(winHeight);
testBox.setySize(winWidth);
YoloMessage yoloMessage = containSample(boxes, testBox, nms, i, j);
yoloMessage.setPic(pic.cutChannel(i, j, winHeight, winWidth));
yoloMessageList.add(yoloMessage);
}
}
if (!yoloMessageList.isEmpty()) {
studyImage(anySort(yoloMessageList));
}
}
public TypeBody getTypeBodyByMappingID(int mappingID) {
TypeBody ty = null;
for (TypeBody yb : typeBodies) {
if (yb.getMappingID() == mappingID) {
ty = yb;
break;
}
}
return ty;
}
public TypeBody getTypeBodyByTypeID(int typeID) {
TypeBody ty = null;
for (TypeBody yb : typeBodies) {
if (yb.getTypeID() == typeID) {
ty = yb;
break;
}
}
return ty;
}
private void studyImage(List<YoloMessage> yoloMessageList) throws Exception {
for (YoloMessage yoloMessage : yoloMessageList) {
Map<Integer, Float> typeE = new HashMap<>();
ThreeChannelMatrix small = yoloMessage.getPic();
int mappingID = yoloMessage.getMappingID();
typeE.put(mappingID, 1f);
study(1, typeNerveManager.getConvInput(), small, true, typeE, null);
if (!yoloMessage.isBackGround()) {
Map<Integer, Float> positionE = new HashMap<>();
positionE.put(1, yoloMessage.getDistX());
positionE.put(2, yoloMessage.getDistY());
positionE.put(3, yoloMessage.getWidth());
positionE.put(4, yoloMessage.getHeight());
positionE.put(5, yoloMessage.getTrust());
NerveManager position = yoloMessage.getTypeBody().getPositionNerveManager();
study(2, position.getConvInput(), small, true, positionE, null);
}
}
}
private void study(long eventID, SensoryNerve convInput, ThreeChannelMatrix feature, boolean isStudy, Map<Integer, Float> E, OutBack back) throws Exception {
convInput.postThreeChannelMatrix(eventID, feature, isStudy, E, back, false);
}
}