inference.cpp 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. #include "inference.h"
  2. Inference::Inference(const std::string &onnxModelPath, const cv::Size2f &modelInputShape, const std::string &classesTxtFile, const bool &runWithCuda)
  3. {
  4. modelPath = onnxModelPath;
  5. modelShape = modelInputShape;
  6. classesPath = classesTxtFile;
  7. cudaEnabled = runWithCuda;
  8. loadOnnxNetwork();
  9. loadClassesFromFile();
  10. }
  11. std::vector<Detection> Inference::runInference(const cv::Mat &input)
  12. {
  13. cv::Mat modelInput = input;
  14. if (letterBoxForSquare && modelShape.width == modelShape.height)
  15. modelInput = formatToSquare(modelInput);
  16. cv::Mat blob;
  17. cv::dnn::blobFromImage(modelInput, blob, 1.0/255.0, modelShape, cv::Scalar(), true, false);
  18. net.setInput(blob);
  19. std::vector<cv::Mat> outputs;
  20. net.forward(outputs, net.getUnconnectedOutLayersNames());
  21. int rows = outputs[0].size[1];
  22. int dimensions = outputs[0].size[2];
  23. bool yolov8 = false;
  24. // yolov5 has an output of shape (batchSize, 25200, 85) (Num classes + box[x,y,w,h] + confidence[c])
  25. // yolov8 has an output of shape (batchSize, 84, 8400) (Num classes + box[x,y,w,h])
  26. if (dimensions > rows) // Check if the shape[2] is more than shape[1] (yolov8)
  27. {
  28. yolov8 = true;
  29. rows = outputs[0].size[2];
  30. dimensions = outputs[0].size[1];
  31. outputs[0] = outputs[0].reshape(1, dimensions);
  32. cv::transpose(outputs[0], outputs[0]);
  33. }
  34. float *data = (float *)outputs[0].data;
  35. float x_factor = modelInput.cols / modelShape.width;
  36. float y_factor = modelInput.rows / modelShape.height;
  37. std::vector<int> class_ids;
  38. std::vector<float> confidences;
  39. std::vector<cv::Rect> boxes;
  40. std::vector<std::vector<cv::Point>> kpts;
  41. for (int i = 0; i < rows; ++i)
  42. {
  43. if (yolov8)
  44. {
  45. float *classes_scores = data+4;
  46. cv::Mat scores(1, classes.size(), CV_32FC1, classes_scores);
  47. cv::Point class_id;
  48. double maxClassScore;
  49. minMaxLoc(scores, 0, &maxClassScore, 0, &class_id);
  50. if (maxClassScore > modelScoreThreshold)
  51. {
  52. confidences.push_back(maxClassScore);
  53. class_ids.push_back(class_id.x);
  54. float x = data[0];
  55. float y = data[1];
  56. float w = data[2];
  57. float h = data[3];
  58. int left = int((x - 0.5 * w) * x_factor);
  59. int top = int((y - 0.5 * h) * y_factor);
  60. int width = int(w * x_factor);
  61. int height = int(h * y_factor);
  62. int step = 3;
  63. std::vector<cv::Point> kps;
  64. for (int kpi = 0; kpi < 5; ++kpi) {
  65. float kp_x = data[5 + kpi * step];
  66. float kp_y = data[5 + kpi * step + 1];
  67. cv::Point kp(int(kp_x * x_factor), int(kp_y * y_factor));
  68. kps.push_back(kp);
  69. }
  70. kpts.push_back(kps);
  71. boxes.push_back(cv::Rect(left, top, width, height));
  72. }
  73. }
  74. else // yolov5
  75. {
  76. float confidence = data[4];
  77. if (confidence >= modelConfidenseThreshold)
  78. {
  79. float *classes_scores = data+5;
  80. cv::Mat scores(1, classes.size(), CV_32FC1, classes_scores);
  81. cv::Point class_id;
  82. double max_class_score;
  83. minMaxLoc(scores, 0, &max_class_score, 0, &class_id);
  84. if (max_class_score > modelScoreThreshold)
  85. {
  86. confidences.push_back(confidence);
  87. class_ids.push_back(class_id.x);
  88. float x = data[0];
  89. float y = data[1];
  90. float w = data[2];
  91. float h = data[3];
  92. int left = int((x - 0.5 * w) * x_factor);
  93. int top = int((y - 0.5 * h) * y_factor);
  94. int width = int(w * x_factor);
  95. int height = int(h * y_factor);
  96. boxes.push_back(cv::Rect(left, top, width, height));
  97. }
  98. }
  99. }
  100. data += dimensions;
  101. }
  102. std::vector<int> nms_result;
  103. cv::dnn::NMSBoxes(boxes, confidences, modelScoreThreshold, modelNMSThreshold, nms_result);
  104. std::vector<Detection> detections{};
  105. for (unsigned long i = 0; i < nms_result.size(); ++i)
  106. {
  107. int idx = nms_result[i];
  108. Detection result;
  109. result.class_id = class_ids[idx];
  110. result.confidence = confidences[idx];
  111. std::random_device rd;
  112. std::mt19937 gen(rd());
  113. std::uniform_int_distribution<int> dis(100, 255);
  114. result.color = cv::Scalar(dis(gen),
  115. dis(gen),
  116. dis(gen));
  117. result.className = classes[result.class_id];
  118. result.box = boxes[idx];
  119. result.kpts = kpts[idx];
  120. detections.push_back(result);
  121. }
  122. return detections;
  123. }
  124. void Inference::loadClassesFromFile()
  125. {
  126. std::ifstream inputFile(classesPath);
  127. if (inputFile.is_open())
  128. {
  129. std::string classLine;
  130. while (std::getline(inputFile, classLine))
  131. classes.push_back(classLine);
  132. inputFile.close();
  133. }
  134. classes.push_back(std::string("tea"));
  135. }
  136. void Inference::loadOnnxNetwork()
  137. {
  138. net = cv::dnn::readNetFromONNX(modelPath);
  139. if (cudaEnabled)
  140. {
  141. std::cout << "\nRunning on CUDA" << std::endl;
  142. net.setPreferableBackend(cv::dnn::DNN_BACKEND_CUDA);
  143. net.setPreferableTarget(cv::dnn::DNN_TARGET_CUDA);
  144. }
  145. else
  146. {
  147. std::cout << "\nRunning on CPU" << std::endl;
  148. net.setPreferableBackend(cv::dnn::DNN_BACKEND_OPENCV);
  149. net.setPreferableTarget(cv::dnn::DNN_TARGET_CPU);
  150. }
  151. }
  152. cv::Mat Inference::formatToSquare(const cv::Mat &source)
  153. {
  154. int col = source.cols;
  155. int row = source.rows;
  156. int _max = MAX(col, row);
  157. cv::Mat result = cv::Mat::zeros(_max, _max, CV_8UC3);
  158. source.copyTo(result(cv::Rect(0, 0, col, row)));
  159. return result;
  160. }