Add File
This commit is contained in:
343
src/main/java/org/dromara/easyai/yolo/FastYolo.java
Normal file
343
src/main/java/org/dromara/easyai/yolo/FastYolo.java
Normal file
@@ -0,0 +1,343 @@
|
||||
package org.dromara.easyai.yolo;
|
||||
|
||||
|
||||
import org.dromara.easyai.entity.Box;
|
||||
import org.dromara.easyai.entity.ThreeChannelMatrix;
|
||||
import org.dromara.easyai.function.ReLu;
|
||||
import org.dromara.easyai.i.OutBack;
|
||||
import org.dromara.easyai.nerveCenter.NerveManager;
|
||||
import org.dromara.easyai.nerveEntity.SensoryNerve;
|
||||
import org.dromara.easyai.tools.NMS;
|
||||
import org.dromara.easyai.tools.Picture;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class FastYolo {//yolo
|
||||
private final YoloConfig yoloConfig;
|
||||
private final NerveManager typeNerveManager;//识别网络
|
||||
private final List<TypeBody> typeBodies = new ArrayList<>();//样本数据及集合
|
||||
private final int winWidth;
|
||||
private final int winHeight;
|
||||
private final int widthStep;
|
||||
private final int heightStep;
|
||||
private final boolean proTrust;
|
||||
|
||||
public FastYolo(YoloConfig yoloConfig) throws Exception {
|
||||
float stepReduce = yoloConfig.getCheckStepReduce();
|
||||
this.yoloConfig = yoloConfig;
|
||||
winHeight = yoloConfig.getWindowHeight();
|
||||
winWidth = yoloConfig.getWindowWidth();
|
||||
widthStep = (int) (winWidth * stepReduce);
|
||||
heightStep = (int) (winHeight * stepReduce);
|
||||
proTrust = yoloConfig.isProTrust();
|
||||
if (stepReduce <= 1 && widthStep > 0 && heightStep > 0) {
|
||||
typeNerveManager = new NerveManager(3, yoloConfig.getHiddenNerveNub(), yoloConfig.getTypeNub() + 1,
|
||||
1, new ReLu(), yoloConfig.getStudyRate(), yoloConfig.getRegularModel(), yoloConfig.getRegular()
|
||||
, yoloConfig.getCoreNumber(), yoloConfig.getGaMa(), yoloConfig.getGMaxTh(), yoloConfig.isAuto());
|
||||
typeNerveManager.initImageNet(yoloConfig.getChannelNo(), yoloConfig.getKernelSize(), winHeight, winWidth, true,
|
||||
yoloConfig.isShowLog(), yoloConfig.getStudyRate(), new ReLu(), yoloConfig.getMinFeatureValue(), yoloConfig.getOneConvStudy()
|
||||
, yoloConfig.isNorm(), yoloConfig.getGRate());
|
||||
} else {
|
||||
throw new Exception("The stepReduce must be (0,1] and widthStep ,heightStep must Greater than 0");
|
||||
}
|
||||
}
|
||||
|
||||
private void insertYoloBody(YoloBody yoloBody) throws Exception {
|
||||
boolean here = false;
|
||||
for (TypeBody typeBody : typeBodies) {
|
||||
if (typeBody.getTypeID() == yoloBody.getTypeID()) {
|
||||
here = true;
|
||||
typeBody.insertYoloBody(yoloBody);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!here) {//不存在
|
||||
TypeBody typeBody = new TypeBody(yoloConfig, winWidth, winHeight);
|
||||
typeBody.setTypeID(yoloBody.getTypeID());
|
||||
typeBody.setMappingID(typeBodies.size() + 1);
|
||||
typeBody.insertYoloBody(yoloBody);
|
||||
typeBodies.add(typeBody);
|
||||
}
|
||||
}
|
||||
|
||||
public void insertModel(YoloModel yoloModel) throws Exception {
|
||||
typeNerveManager.insertConvModel(yoloModel.getTypeModel());
|
||||
List<TypeModel> typeModels = yoloModel.getTypeModels();
|
||||
for (TypeModel typeModel : typeModels) {
|
||||
TypeBody typeBody = new TypeBody(yoloConfig, winWidth, winHeight);
|
||||
typeBody.setTypeID(typeModel.getTypeID());
|
||||
typeBody.setMappingID(typeModel.getMappingID());
|
||||
typeBody.setMinWidth(typeModel.getMinWidth());
|
||||
typeBody.setMinHeight(typeModel.getMinHeight());
|
||||
typeBody.setMaxWidth(typeModel.getMaxWidth());
|
||||
typeBody.setMaxHeight(typeModel.getMaxHeight());
|
||||
typeBody.getPositionNerveManager().insertConvModel(typeModel.getPositionModel());
|
||||
typeBodies.add(typeBody);
|
||||
}
|
||||
}
|
||||
|
||||
public YoloModel getModel() throws Exception {
|
||||
YoloModel yoloModel = new YoloModel();
|
||||
yoloModel.setTypeModel(typeNerveManager.getConvModel());
|
||||
List<TypeModel> typeModels = new ArrayList<>();
|
||||
for (TypeBody typeBody : typeBodies) {
|
||||
TypeModel typeModel = new TypeModel();
|
||||
typeModel.setTypeID(typeBody.getTypeID());
|
||||
typeModel.setMappingID(typeBody.getMappingID());
|
||||
typeModel.setMinHeight(typeBody.getMinHeight());
|
||||
typeModel.setMinWidth(typeBody.getMinWidth());
|
||||
typeModel.setMaxWidth(typeBody.getMaxWidth());
|
||||
typeModel.setMaxHeight(typeBody.getMaxHeight());
|
||||
typeModel.setPositionModel(typeBody.getPositionNerveManager().getConvModel());
|
||||
typeModels.add(typeModel);
|
||||
}
|
||||
yoloModel.setTypeModels(typeModels);
|
||||
return yoloModel;
|
||||
}
|
||||
|
||||
private Box getBox(int i, int j, int maxX, int maxY, PositionBack positionBack, TypeBody typeBody, float out) throws Exception {
|
||||
float zhou = winHeight + winWidth;
|
||||
Box box = new Box();
|
||||
float centerX = i - positionBack.getDistX() * zhou;
|
||||
float centerY = j - positionBack.getDistY() * zhou;
|
||||
int width = (int) typeBody.getRealWidth(positionBack.getWidth());
|
||||
int height = (int) typeBody.getRealHeight(positionBack.getHeight());
|
||||
int realX = (int) (centerX - height / 2f);
|
||||
int realY = (int) (centerY - width / 2f);
|
||||
if (realX < 0) {
|
||||
realX = 0;
|
||||
}
|
||||
if (realY < 0) {
|
||||
realY = 0;
|
||||
}
|
||||
if (realX + height > maxX) {
|
||||
realX = maxX - height;
|
||||
}
|
||||
if (realY + width > maxY) {
|
||||
realY = maxY - width;
|
||||
}
|
||||
box.setX(realX);
|
||||
box.setY(realY);
|
||||
box.setxSize(height);
|
||||
box.setySize(width);
|
||||
float trust;
|
||||
if (proTrust) {
|
||||
trust = out;
|
||||
} else {
|
||||
trust = positionBack.getTrust();
|
||||
}
|
||||
box.setConfidence(trust);
|
||||
box.setTypeID(typeBody.getTypeID());
|
||||
return box;
|
||||
}
|
||||
|
||||
private List<OutBox> getOutBoxList(List<Box> boxes) {
|
||||
List<OutBox> outBoxes = new ArrayList<>();
|
||||
for (Box box : boxes) {
|
||||
OutBox outBox = new OutBox();
|
||||
outBox.setX(box.getY());
|
||||
outBox.setY(box.getX());
|
||||
outBox.setHeight(box.getxSize());
|
||||
outBox.setWidth(box.getySize());
|
||||
outBox.setTypeID(String.valueOf(box.getTypeID()));
|
||||
outBox.setTrust(box.getConfidence());
|
||||
outBoxes.add(outBox);
|
||||
}
|
||||
return outBoxes;
|
||||
}
|
||||
|
||||
public List<OutBox> look(ThreeChannelMatrix th, long eventID) throws Exception {
|
||||
int x = th.getX();
|
||||
int y = th.getY();
|
||||
List<Box> boxes = new ArrayList<>();
|
||||
NMS nms = new NMS(yoloConfig.getIouTh());
|
||||
float pth = yoloConfig.getPth();
|
||||
for (int i = 0; i <= x - winHeight; i += heightStep) {
|
||||
for (int j = 0; j <= y - winWidth; j += widthStep) {
|
||||
YoloTypeBack yoloTypeBack = new YoloTypeBack();
|
||||
PositionBack positionBack = new PositionBack();
|
||||
ThreeChannelMatrix myTh = th.cutChannel(i, j, winHeight, winWidth);
|
||||
study(eventID, typeNerveManager.getConvInput(), myTh, false, null, yoloTypeBack);
|
||||
int mappingID = yoloTypeBack.getId();//映射id
|
||||
float out = yoloTypeBack.getOut();
|
||||
if (mappingID != typeBodies.size() + 1 && out > pth) {
|
||||
TypeBody typeBody = getTypeBodyByMappingID(mappingID);
|
||||
SensoryNerve convInput = typeBody.getPositionNerveManager().getConvInput();
|
||||
study(eventID, convInput, myTh, false, null, positionBack);
|
||||
boxes.add(getBox(i, j, x, y, positionBack, typeBody, out));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (boxes.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
return getOutBoxList(nms.start(boxes));
|
||||
}
|
||||
|
||||
|
||||
public void toStudy(List<YoloSample> yoloSamples) throws Exception {
|
||||
for (YoloSample yoloSample : yoloSamples) {
|
||||
List<YoloBody> yoloBodies = yoloSample.getYoloBodies();
|
||||
for (YoloBody yoloBody : yoloBodies) {
|
||||
insertYoloBody(yoloBody);
|
||||
}
|
||||
}
|
||||
int enh = yoloConfig.getEnhance();
|
||||
for (int i = 0; i < enh; i++) {
|
||||
System.out.println("第===========================" + i + "次" + "共" + enh + "次");
|
||||
int index = 0;
|
||||
for (YoloSample yoloSample : yoloSamples) {
|
||||
index++;
|
||||
System.out.println("index:" + index + ",size:" + yoloSamples.size());
|
||||
study(yoloSample);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Box changeBox(YoloBody yoloBody) {
|
||||
Box box = new Box();
|
||||
box.setX(yoloBody.getY());
|
||||
box.setY(yoloBody.getX());
|
||||
box.setTypeID(yoloBody.getTypeID());
|
||||
box.setxSize(yoloBody.getHeight());
|
||||
box.setySize(yoloBody.getWidth());
|
||||
return box;
|
||||
}
|
||||
|
||||
private YoloMessage containSample(List<Box> boxes, Box testBox, NMS nms, int i, int j) {
|
||||
float containIouTh = yoloConfig.getContainIouTh();
|
||||
float maxIou = 0;
|
||||
Box myBox = null;
|
||||
YoloMessage yoloMessage = new YoloMessage();
|
||||
for (Box box : boxes) {
|
||||
float iou = nms.getSRatio(testBox, box, false);
|
||||
if (iou > containIouTh && iou > maxIou) {//有相交
|
||||
maxIou = iou;
|
||||
myBox = box;
|
||||
}
|
||||
}
|
||||
if (myBox != null) {
|
||||
int centerX = myBox.getX() + myBox.getxSize() / 2;
|
||||
int centerY = myBox.getY() + myBox.getySize() / 2;
|
||||
float zhou = winHeight + winWidth;
|
||||
float distX = (float) (i - centerX) / zhou;
|
||||
float distY = (float) (j - centerY) / zhou;
|
||||
TypeBody typeBody = getTypeBodyByTypeID(myBox.getTypeID());
|
||||
float height = typeBody.getOneHeight(myBox.getxSize());
|
||||
float width = typeBody.getOneWidth(myBox.getySize());
|
||||
float trust = 0;
|
||||
if (centerX >= i && centerX <= (i + winHeight) && centerY >= j && centerY <= (j + winWidth)) {
|
||||
trust = 1;
|
||||
}
|
||||
yoloMessage.setWidth(width);
|
||||
yoloMessage.setHeight(height);
|
||||
yoloMessage.setDistX(distX);
|
||||
yoloMessage.setDistY(distY);
|
||||
yoloMessage.setTrust(trust);
|
||||
yoloMessage.setMappingID(typeBody.getMappingID());
|
||||
yoloMessage.setTypeBody(typeBody);
|
||||
} else {
|
||||
yoloMessage.setBackGround(true);
|
||||
yoloMessage.setMappingID(typeBodies.size() + 1);
|
||||
}
|
||||
return yoloMessage;
|
||||
}
|
||||
|
||||
private List<Box> getBoxes(List<YoloBody> yoloBodies) {
|
||||
List<Box> boxes = new ArrayList<>();
|
||||
for (YoloBody yoloBody : yoloBodies) {
|
||||
boxes.add(changeBox(yoloBody));
|
||||
}
|
||||
return boxes;
|
||||
}
|
||||
|
||||
private List<YoloMessage> anySort(List<YoloMessage> sentences) {//做乱序
|
||||
Random random = new Random();
|
||||
List<YoloMessage> sent = new ArrayList<>();
|
||||
int time = sentences.size();
|
||||
for (int i = 0; i < time; i++) {
|
||||
int size = sentences.size();
|
||||
int index = random.nextInt(size);
|
||||
sent.add(sentences.get(index));
|
||||
sentences.remove(index);
|
||||
}
|
||||
return sent;
|
||||
}
|
||||
|
||||
private void study(YoloSample yoloSample) throws Exception {//
|
||||
List<YoloBody> yoloBodies = yoloSample.getYoloBodies();//集合
|
||||
List<Box> boxes = getBoxes(yoloBodies);
|
||||
String url = yoloSample.getLocationURL();//地址
|
||||
NMS nms = new NMS(yoloConfig.getIouTh());
|
||||
ThreeChannelMatrix pic = Picture.getThreeMatrix(url, false);
|
||||
List<YoloMessage> yoloMessageList = new ArrayList<>();
|
||||
float stepReduce = yoloConfig.getStepReduce();
|
||||
int stepX = (int) (winHeight * stepReduce);
|
||||
int stepY = (int) (winWidth * stepReduce);
|
||||
if (stepX < 1 || stepY < 1) {
|
||||
throw new Exception("训练步长收缩后步长必须大于0");
|
||||
}
|
||||
for (int i = 0; i <= pic.getX() - winHeight; i += stepX) {
|
||||
for (int j = 0; j <= pic.getY() - winWidth; j += stepY) {
|
||||
Box testBox = new Box();
|
||||
testBox.setX(i);
|
||||
testBox.setY(j);
|
||||
testBox.setxSize(winHeight);
|
||||
testBox.setySize(winWidth);
|
||||
YoloMessage yoloMessage = containSample(boxes, testBox, nms, i, j);
|
||||
yoloMessage.setPic(pic.cutChannel(i, j, winHeight, winWidth));
|
||||
yoloMessageList.add(yoloMessage);
|
||||
}
|
||||
}
|
||||
if (!yoloMessageList.isEmpty()) {
|
||||
studyImage(anySort(yoloMessageList));
|
||||
}
|
||||
}
|
||||
|
||||
public TypeBody getTypeBodyByMappingID(int mappingID) {
|
||||
TypeBody ty = null;
|
||||
for (TypeBody yb : typeBodies) {
|
||||
if (yb.getMappingID() == mappingID) {
|
||||
ty = yb;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return ty;
|
||||
}
|
||||
|
||||
public TypeBody getTypeBodyByTypeID(int typeID) {
|
||||
TypeBody ty = null;
|
||||
for (TypeBody yb : typeBodies) {
|
||||
if (yb.getTypeID() == typeID) {
|
||||
ty = yb;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return ty;
|
||||
}
|
||||
|
||||
private void studyImage(List<YoloMessage> yoloMessageList) throws Exception {
|
||||
for (YoloMessage yoloMessage : yoloMessageList) {
|
||||
Map<Integer, Float> typeE = new HashMap<>();
|
||||
ThreeChannelMatrix small = yoloMessage.getPic();
|
||||
int mappingID = yoloMessage.getMappingID();
|
||||
typeE.put(mappingID, 1f);
|
||||
study(1, typeNerveManager.getConvInput(), small, true, typeE, null);
|
||||
if (!yoloMessage.isBackGround()) {
|
||||
Map<Integer, Float> positionE = new HashMap<>();
|
||||
positionE.put(1, yoloMessage.getDistX());
|
||||
positionE.put(2, yoloMessage.getDistY());
|
||||
positionE.put(3, yoloMessage.getWidth());
|
||||
positionE.put(4, yoloMessage.getHeight());
|
||||
positionE.put(5, yoloMessage.getTrust());
|
||||
NerveManager position = yoloMessage.getTypeBody().getPositionNerveManager();
|
||||
study(2, position.getConvInput(), small, true, positionE, null);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private void study(long eventID, SensoryNerve convInput, ThreeChannelMatrix feature, boolean isStudy, Map<Integer, Float> E, OutBack back) throws Exception {
|
||||
convInput.postThreeChannelMatrix(eventID, feature, isStudy, E, back, false);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user