基于YOLO的模型训练及模型在JAVA中的应用

内容分享5小时前发布
0 0 0

一 . 实现逻辑概述

YOLO提供了基于视觉算法的图像解析识别能力,可以根据给定的数据集做专项训练,并得到训练结果,即训练模型。有封装好的应用训练模型的maven依赖,集成依赖后可在java项目中加载训练好的模型,应用模型的能力。

二.  获取数据集

数据集的目录结构

datasets

       └ 自定义名称

                └ images –训练图片集,jpgjepgpng

                        └ train

                        └ val

                └ labels — 图片描述文件,txt

                        └ train

                        └ val

目录内容说明

执行训练模型的文件时,基于配置文件,会自动找到datasets文件夹下对应名称的数据集,数据集包括images和labels。images和labels下分别有train(训练)和val(验证),images存放图片,labels存放图片的说明文件,图片和对应说明文件的文件名需保持一致。目录结构也可改为train和val下分别包含images和labels。

  3.  labels下的说明文件内容及获取方式

3.1 说明文件示例

(第一列:物体类别,第二列:X轴坐标,第三列:Y轴坐标,第四列:X轴长度,第五列:Y轴长度)

基于YOLO的模型训练及模型在JAVA中的应用

    3.2 说明文件获取

基于本地python环境,安装 labelImg(pip install labelImg)工具,安装完成后运行labelImg.py(python ./labelImg.py)会启动可标注的图形界面(如下图)。

基于YOLO的模型训练及模型在JAVA中的应用

本地文件位置:C:Usersyanggs.pyenvpyenv-winversions3.9.12Libsite-packageslabelImg,python版本:3.9.12

打开图片,点击Create RectBox,在图片中框选要识别的物体并给定名称,即可完成一次标注。某一图片全部标注完成后,点击save,保存说明文件到指定位置,说明文件的文件名默认与图片的文件名一致。

三.  训练模型

1 在与datasets同级的test目录下新建python文件(test_new.py),内容如下

基于YOLO的模型训练及模型在JAVA中的应用

训练配置文件内容如下

基于YOLO的模型训练及模型在JAVA中的应用

执行文件训练模型

执行命令 python test_new.py, 执行结束后可在控制台看见训练中产生的文件位置和训练成果的文件位置,可同步测试训练成果是否符合要求。

注 :python版本为3.10.5,第一次执行时会提示缺失onnx,onnxruntime等module,使用pip install {module}[=={version}] 进行安装,onnx版本建议为1.13.0,如果onnx版本过高,则后续在Java中会出现问题。

四.  在java中运用模型

下载源码(感谢代码作者开源)

     https://gitcode.com/changzengli/yolo-onnx-java,

下载后仔细查看项目介绍,该项目集成了对监控、视频、图像的识别能力。

   2. 使用已训练好的模型

       2.1 把模型放在resources下的model中

       2.2  创建测试类,示例如下



public class QianziDetection {
 
    static {
        // 加载opencv动态库,
        //System.load(ClassLoader.getSystemResource("lib/opencv_java470-无用.dll").getPath());
        nu.pattern.OpenCV.loadLocally();
    }
 
    public static void main(String[] args) throws OrtException {
 
        String model_path = "src\main\resources\model\gangjin.onnx";
 
        List<double[]> colors = new ArrayList<>();
 
        float confThreshold = 0.35F;
 
        float nmsThreshold = 0.55F;
 
        String[] labels = {"gangjin"};
 
        // 加载ONNX模型
        OrtEnvironment environment = OrtEnvironment.getEnvironment();
        OrtSession.SessionOptions sessionOptions = new OrtSession.SessionOptions();
 
        // 使用gpu,需要本机按钻过cuda,并修改pom.xml,不安装也能运行本程序
        // sessionOptions.addCUDA(0);
 
        OrtSession session = environment.createSession(model_path, sessionOptions);
        String meteStr = session.getMetadata().getCustomMetadata().get("names");
 
 
        labels = new String[meteStr.split(",").length];
 
 
        Pattern pattern = Pattern.compile("'([^']*)'");
        Matcher matcher = pattern.matcher(meteStr);
 
        int h = 0;
        while (matcher.find()) {
            labels[h] = matcher.group(1);
            Random random = new Random();
            double[] color = {random.nextDouble()*256, random.nextDouble()*256, random.nextDouble()*256};
            colors.add(color);
            h++;
        }
        // 输出基本信息
        session.getInputInfo().keySet().forEach(x-> {
            try {
                System.out.println("input name = " + x);
                System.out.println(session.getInputInfo().get(x).getInfo().toString());
            } catch (OrtException e) {
                throw new RuntimeException(e);
            }
        });
 
        // 要检测的图片所在目录
        String imagePath = "images/newImage";
        Map<String, String> map = getImagePathMap(imagePath);
        for(String fileName : map.keySet()){
            String imageFilePath = map.get(fileName);
            System.out.println(imageFilePath);
            // 读取 image
            Mat img = Imgcodecs.imread(imageFilePath);
            Mat image = img.clone();
            //Imgproc.cvtColor(image, image, Imgproc.COLOR_BGR2RGB);
 
 
            // 在这里先定义下框的粗细、字的大小、字的类型、字的颜色(按比例设置大小粗细比较好一些)
            int minDwDh = Math.min(img.width(), img.height());
            int thickness = minDwDh/ODConfig.lineThicknessRatio;
            long start_time = System.currentTimeMillis();
            // 更改 image 尺寸
            Letterbox letterbox = new Letterbox();
            image = letterbox.letterbox(image);
 
            double ratio  = letterbox.getRatio();
            double dw = letterbox.getDw();
            double dh = letterbox.getDh();
            int rows  = letterbox.getHeight();
            int cols  = letterbox.getWidth();
            int channels = image.channels();
 
            // 将Mat对象的像素值赋值给Float[]对象
            float[] pixels = new float[channels * rows * cols];
            for (int i = 0; i < rows; i++) {
                for (int j = 0; j < cols; j++) {
                    double[] pixel = image.get(j,i);
                    for (int k = 0; k < channels; k++) {
                        // 这样设置相当于同时做了image.transpose((2, 0, 1))操作
                        pixels[rows*cols*k+j*cols+i] = (float) pixel[k]/255.0f;
                    }
                }
            }
 
            // 创建OnnxTensor对象
            long[] shape = { 1L, (long)channels, (long)rows, (long)cols };
            OnnxTensor tensor = OnnxTensor.createTensor(environment, FloatBuffer.wrap(pixels), shape);
            HashMap<String, OnnxTensor> stringOnnxTensorHashMap = new HashMap<>();
            stringOnnxTensorHashMap.put(session.getInputInfo().keySet().iterator().next(), tensor);
 
            // 运行推理
            OrtSession.Result output = session.run(stringOnnxTensorHashMap);
            float[][] outputData = ((float[][][])output.get(0).getValue())[0];
 
            outputData = transposeMatrix(outputData);
            Map<Integer, List<float[]>> class2Bbox = new HashMap<>();
 
            for (float[] bbox : outputData) {
 
 
                float[] conditionalProbabilities = Arrays.copyOfRange(bbox, 4, bbox.length);
                int label = argmax(conditionalProbabilities);
                float conf = conditionalProbabilities[label];
                if (conf < confThreshold) continue;
 
                bbox[4] = conf;
 
                // xywh to (x1, y1, x2, y2)
                xywh2xyxy(bbox);
 
                // skip invalid predictions
                if (bbox[0] >= bbox[2] || bbox[1] >= bbox[3]) continue;
 
 
                class2Bbox.putIfAbsent(label, new ArrayList<>());
                class2Bbox.get(label).add(bbox);
            }
 
            List<Detection> detections = new ArrayList<>();
            for (Map.Entry<Integer, List<float[]>> entry : class2Bbox.entrySet()) {
                int label = entry.getKey();
                List<float[]> bboxes = entry.getValue();
                bboxes = nonMaxSuppression(bboxes, nmsThreshold);
                for (float[] bbox : bboxes) {
                    String labelString = labels[label];
                    detections.add(new Detection(labelString,entry.getKey(), Arrays.copyOfRange(bbox, 0, 4), bbox[4]));
                }
            }
 
 
            for (Detection detection : detections) {
                float[] bbox = detection.getBbox();
                System.out.println("======="+detection.toString());
                // 画框
                Point topLeft = new Point((bbox[0]-dw)/ratio, (bbox[1]-dh)/ratio);
                Point bottomRight = new Point((bbox[2]-dw)/ratio, (bbox[3]-dh)/ratio);
                Scalar color = new Scalar(colors.get(detection.getClsId()));
                Imgproc.rectangle(img, topLeft, bottomRight, color, thickness);
                // 框上写文字
                Point boxNameLoc = new Point((bbox[0]-dw)/ratio, (bbox[1]-dh)/ratio-3);
 
                Imgproc.putText(img, detection.getLabel(), boxNameLoc, Imgproc.FONT_HERSHEY_SIMPLEX, 0.7, color, thickness);
            }
            System.out.printf("time:%d ms.", (System.currentTimeMillis() - start_time));
 
            System.out.println();
            //服务器部署:由于服务器没有桌面,所以无法弹出画面预览,主要注释一下代码
 
            // 保存图像到同级目录
            // Imgcodecs.imwrite(ODConfig.savePicPath, img);
            // 弹窗展示图像
            HighGui.imshow("Display Image", img);
            // 按任意按键关闭弹窗画面,结束程序
            HighGui.waitKey();
        }
        HighGui.destroyAllWindows();
        System.exit(0);
 
    }
 
    public static void scaleCoords(float[] bbox, float orgW, float orgH, float padW, float padH, float gain) {
        // xmin, ymin, xmax, ymax -> (xmin_org, ymin_org, xmax_org, ymax_org)
        bbox[0] = Math.max(0, Math.min(orgW - 1, (bbox[0] - padW) / gain));
        bbox[1] = Math.max(0, Math.min(orgH - 1, (bbox[1] - padH) / gain));
        bbox[2] = Math.max(0, Math.min(orgW - 1, (bbox[2] - padW) / gain));
        bbox[3] = Math.max(0, Math.min(orgH - 1, (bbox[3] - padH) / gain));
    }
    public static void xywh2xyxy(float[] bbox) {
        float x = bbox[0];
        float y = bbox[1];
        float w = bbox[2];
        float h = bbox[3];
 
        bbox[0] = x - w * 0.5f;
        bbox[1] = y - h * 0.5f;
        bbox[2] = x + w * 0.5f;
        bbox[3] = y + h * 0.5f;
    }
 
    public static float[][] transposeMatrix(float [][] m){
        float[][] temp = new float[m[0].length][m.length];
        for (int i = 0; i < m.length; i++)
            for (int j = 0; j < m[0].length; j++)
                temp[j][i] = m[i][j];
        return temp;
    }
 
    public static List<float[]> nonMaxSuppression(List<float[]> bboxes, float iouThreshold) {
 
        List<float[]> bestBboxes = new ArrayList<>();
 
        bboxes.sort(Comparator.comparing(a -> a[4]));
 
        while (!bboxes.isEmpty()) {
            float[] bestBbox = bboxes.remove(bboxes.size() - 1);
            bestBboxes.add(bestBbox);
            bboxes = bboxes.stream().filter(a -> computeIOU(a, bestBbox) < iouThreshold).collect(Collectors.toList());
        }
 
        return bestBboxes;
    }
 
    public static float computeIOU(float[] box1, float[] box2) {
 
        float area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]);
        float area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]);
 
        float left = Math.max(box1[0], box2[0]);
        float top = Math.max(box1[1], box2[1]);
        float right = Math.min(box1[2], box2[2]);
        float bottom = Math.min(box1[3], box2[3]);
 
        float interArea = Math.max(right - left, 0) * Math.max(bottom - top, 0);
        float unionArea = area1 + area2 - interArea;
        return Math.max(interArea / unionArea, 1e-8f);
 
    }
 
    //返回最大值的索引
    public static int argmax(float[] a) {
        float re = -Float.MAX_VALUE;
        int arg = -1;
        for (int i = 0; i < a.length; i++) {
            if (a[i] >= re) {
                re = a[i];
                arg = i;
            }
        }
        return arg;
    }
 
    public static Map<String, String> getImagePathMap(String imagePath){
        Map<String, String> map = new TreeMap<>();
        File file = new File(imagePath);
        if(file.isFile()){
            map.put(file.getName(), file.getAbsolutePath());
        }else if(file.isDirectory()){
            for(File tmpFile : Objects.requireNonNull(file.listFiles())){
                map.putAll(getImagePathMap(tmpFile.getPath()));
            }
        }
        return map;
    }

YOLO官网:YOLO官网

YOLOgit源码:YOLO-GIT源码

python版本管理工具:https://github.com/pyenv-win/pyenv-win

© 版权声明

相关文章

暂无评论

none
暂无评论...