tea_detect.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598
  1. #include "tea_detect.h"
  2. #include <opencv.hpp>
  3. #include <numeric>
  4. using namespace cv;
  5. using namespace std;
  6. namespace graft_cv {
  7. RetinaDrop::RetinaDrop(CGcvLogger* pLogger, float obj_th, float nms_th)
  8. :m_model_loaded(false)
  9. {
  10. BATCH_SIZE = 1;
  11. INPUT_CHANNEL = 3;
  12. IMAGE_WIDTH = 640; // default 640
  13. IMAGE_HEIGHT = 640; // default 640
  14. m_obj_threshold = obj_th;//default 0.6;
  15. m_nms_threshold = nms_th; //default0.4;
  16. m_anchor_num = 2;
  17. m_bbox_head = 4;
  18. m_variance[0] = 0.1f;
  19. m_variance[1] = 0.2f;
  20. //m_img_mean(123.0, 104.0, 117.0)
  21. m_img_mean[0] = 123.0;
  22. m_img_mean[1] = 104.0;
  23. m_img_mean[2] = 117.0;
  24. m_img_mean[3] = 0;
  25. //cv::Size size_detection(640, 640)
  26. m_size_detection.width = IMAGE_WIDTH;
  27. m_size_detection.height = IMAGE_HEIGHT;
  28. m_feature_steps = {8,16,32};
  29. m_pLogger = pLogger;
  30. for (const int step : m_feature_steps) {
  31. assert(step != 0);
  32. int feature_map = IMAGE_HEIGHT / step;
  33. m_feature_maps.push_back(feature_map);
  34. int feature_size = feature_map * feature_map;
  35. m_feature_sizes.push_back(feature_size);
  36. }
  37. m_anchor_sizes = { { 16,32 } ,{ 64,128},{ 256, 512 }};
  38. m_sum_of_feature = std::accumulate(m_feature_sizes.begin(), m_feature_sizes.end(), 0) * m_anchor_num;
  39. generate_anchors();
  40. if (m_pLogger) {
  41. m_pLogger->INFO(string("RetinaDrop object initialized"));
  42. }
  43. }
  44. RetinaDrop::~RetinaDrop() = default;
  45. bool RetinaDrop::IsModelLoaded() {
  46. return m_model_loaded;
  47. };
  48. void RetinaDrop::SetThreshold(float object_threshold, float nms_threshold)
  49. {
  50. this->m_obj_threshold = object_threshold;
  51. this->m_nms_threshold = nms_threshold;
  52. }
  53. bool RetinaDrop::LoadModel(std::string onnx_path) {
  54. if (m_pLogger) {
  55. m_pLogger->INFO(string("Loading detection model: ")+onnx_path);
  56. }
  57. else { std::cout << "Loading detection model: " << onnx_path<<std::endl; }
  58. try {
  59. m_model = cv::dnn::readNetFromONNX(onnx_path);
  60. if (m_pLogger) {m_pLogger->INFO(string("Detection model loaded"));}
  61. m_model_loaded = true;
  62. return m_model_loaded;
  63. }
  64. catch (...)
  65. {
  66. if (m_pLogger) { m_pLogger->ERRORINFO(string("loading model failed")); }
  67. }
  68. return false;
  69. }
  70. std::vector<Bbox> RetinaDrop::RunModel(cv::Mat& img, CGcvLogger* pInstanceLogger)
  71. {
  72. std::vector<Bbox> result;
  73. if (img.empty()) {
  74. if (pInstanceLogger) {
  75. pInstanceLogger->ERRORINFO(string("RunModel(), input image is empty"));
  76. }
  77. throw(string("image is empty"));
  78. }
  79. if (!m_model_loaded) {
  80. pInstanceLogger->ERRORINFO(string("model is NOT loaded"));
  81. }
  82. cv::Mat blob = cv::dnn::blobFromImage(
  83. img,
  84. 1.0,
  85. m_size_detection,
  86. m_img_mean);
  87. m_model.setInput(blob);
  88. std::vector<std::string> outNames = m_model.getUnconnectedOutLayersNames();
  89. vector<Mat>outputs;// location(1x16800x4), confidence(1x16800x2), keypoint(1x16800x2)
  90. if (pInstanceLogger) {
  91. pInstanceLogger->INFO(string("RunModel(), before forward()"));
  92. }
  93. m_model.forward(outputs, outNames);
  94. std::vector<RetinaDrop::DropRes> rects;
  95. int n = post_process(img, outputs,rects);
  96. for (const auto& rect : rects) {
  97. Bbox box;
  98. box.score = rect.confidence;
  99. box.x1 = (int)rect.drop_box.x1;
  100. box.y1 = (int)rect.drop_box.y1;
  101. box.x2 = (int)rect.drop_box.x2;
  102. box.y2 = (int)rect.drop_box.y2;
  103. box.ppoint[0] = rect.keypoints[0].x;
  104. box.ppoint[1] = rect.keypoints[0].y;
  105. box.ppoint[2] = rect.keypoints[1].x;
  106. box.ppoint[3] = rect.keypoints[1].y;
  107. box.ppoint[4] = rect.keypoints[2].x;
  108. box.ppoint[5] = rect.keypoints[2].y;
  109. box.ppoint[6] = rect.keypoints[3].x;
  110. box.ppoint[7] = rect.keypoints[3].y;
  111. box.ppoint[8] = rect.keypoints[4].x;
  112. box.ppoint[9] = rect.keypoints[4].y;
  113. box.operate_point[0] = 0.0;
  114. box.operate_point[1] = 0.0;
  115. box.operate_angle = 0.0;
  116. box.area = 0.0;
  117. box.status = 0;
  118. result.push_back(box);
  119. }
  120. if (pInstanceLogger) {
  121. stringstream buff;
  122. buff << "detected object: " << n;
  123. pInstanceLogger->INFO(buff.str());
  124. }
  125. return result;
  126. }
  127. void RetinaDrop::generate_anchors() {
  128. m_refer_matrix = cv::Mat(m_sum_of_feature, m_bbox_head, CV_32FC1);
  129. int line = 0;
  130. for (size_t feature_map = 0; feature_map < m_feature_maps.size(); feature_map++) {
  131. for (int height = 0; height < m_feature_maps[feature_map]; ++height) {
  132. for (int width = 0; width < m_feature_maps[feature_map]; ++width) {
  133. for (int anchor = 0; anchor < m_anchor_sizes[feature_map].size(); ++anchor) {
  134. auto* row = m_refer_matrix.ptr<float>(line);
  135. row[0] = (float)(width+0.5) * m_feature_steps[feature_map]/(float)IMAGE_WIDTH;
  136. row[1] = (float)(height+0.5) * m_feature_steps[feature_map]/(float)IMAGE_HEIGHT;
  137. row[2] = m_anchor_sizes[feature_map][anchor]/(float)IMAGE_WIDTH;
  138. row[3] = m_anchor_sizes[feature_map][anchor]/(float)IMAGE_HEIGHT;
  139. line++;
  140. }
  141. }
  142. }
  143. }
  144. }
  145. int RetinaDrop::post_process(
  146. cv::Mat &src_img,
  147. vector<cv::Mat> &result_matrix,
  148. std::vector<RetinaDrop::DropRes>& valid_result
  149. )
  150. {
  151. valid_result.clear();
  152. std::vector<DropRes> result;
  153. for (int item = 0; item < m_sum_of_feature; ++item) {
  154. float* cur_bbox = (float*)result_matrix[0].data + item * 4;//result_matrix[0].step;
  155. float* cur_conf = (float*)result_matrix[2].data + item * 2;//result_matrix[1].step;
  156. float* cur_keyp = (float*)result_matrix[1].data + item * 10;//result_matrix[2].step;
  157. if (cur_conf[1] > m_obj_threshold) {
  158. DropRes headbox;
  159. headbox.confidence = cur_conf[1];
  160. auto* anchor = m_refer_matrix.ptr<float>(item);
  161. auto* keyp = cur_keyp;
  162. float cx, cy, kx, ky;
  163. cx = anchor[0] + cur_bbox[0] * m_variance[0] * anchor[2];
  164. cy = anchor[1] + cur_bbox[1] * m_variance[0] * anchor[3];
  165. kx = anchor[2] * exp(cur_bbox[2] * m_variance[1]);
  166. ky = anchor[3] * exp(cur_bbox[3] * m_variance[1]);
  167. cx -= kx / 2.0f;
  168. cy -= ky / 2.0f;
  169. kx += cx;
  170. ky += cy;
  171. headbox.drop_box.x1 = cx * src_img.cols;
  172. headbox.drop_box.y1 = cy * src_img.rows;
  173. headbox.drop_box.x2 = kx * src_img.cols;
  174. headbox.drop_box.y2 = ky * src_img.rows;
  175. for (int ki = 0; ki < 5; ++ki) {
  176. float kp_x = anchor[0] + keyp[2*ki] * m_variance[0] * anchor[2];
  177. float kp_y = anchor[1] + keyp[2*ki+1] * m_variance[0] * anchor[3];
  178. kp_x *= src_img.cols;
  179. kp_y *= src_img.rows;
  180. headbox.keypoints.push_back(cv::Point2f(kp_x, kp_y));
  181. }
  182. /*float kp_x = anchor[0] + keyp[0] * m_variance[0] * anchor[2];
  183. float kp_y = anchor[1] + keyp[1] * m_variance[0] * anchor[3];
  184. kp_x *= src_img.cols;
  185. kp_y *= src_img.rows;
  186. headbox.keypoints = {
  187. cv::Point2f(kp_x,kp_y)
  188. };*/
  189. result.push_back(headbox);
  190. }
  191. }
  192. vector<int> keep;
  193. nms_detect(result,keep);
  194. for (size_t i = 0; i < keep.size(); ++i) {
  195. valid_result.push_back(result[keep[i]]);
  196. }
  197. return (int)valid_result.size();
  198. }
  199. void RetinaDrop::nms_detect(
  200. std::vector<DropRes> & detections,
  201. vector<int> & keep)
  202. {
  203. keep.clear();
  204. if (detections.size() == 1) {
  205. keep.push_back(0);
  206. return;
  207. }
  208. sort(detections.begin(), detections.end(),
  209. [=](const DropRes& left, const DropRes& right) {
  210. return left.confidence > right.confidence;
  211. });
  212. vector<int> order;
  213. for (size_t i = 0; i < detections.size(); ++i) { order.push_back((int)i); }
  214. while (order.size()) {
  215. int i = order[0];
  216. keep.push_back(i);
  217. vector<int> del_idx;
  218. for (size_t j = 1; j < order.size(); ++j) {
  219. float iou = iou_calculate(
  220. detections[i].drop_box,
  221. detections[order[j]].drop_box);
  222. if (iou > m_nms_threshold) {
  223. del_idx.push_back((int)j);
  224. }
  225. }
  226. vector<int> order_update;
  227. for (size_t j = 1; j < order.size(); ++j) {
  228. vector<int>::iterator it = find(del_idx.begin(), del_idx.end(), j);
  229. if (it == del_idx.end()) {
  230. order_update.push_back(order[j]);
  231. }
  232. }
  233. order.clear();
  234. order.assign(order_update.begin(), order_update.end());
  235. }
  236. }
  237. float RetinaDrop::iou_calculate(
  238. const RetinaDrop::DropBox & det_a,
  239. const RetinaDrop::DropBox & det_b)
  240. {
  241. float aa = (det_a.x2 - det_a.x1 + 1) * (det_a.y2 - det_a.y1 + 1);
  242. float ab = (det_b.x2 - det_b.x1 + 1) * (det_b.y2 - det_b.y1 + 1);
  243. float xx1 = max(det_a.x1, det_b.x1);
  244. float yy1 = max(det_a.y1, det_b.y1);
  245. float xx2 = min(det_a.x2, det_b.x2);
  246. float yy2 = min(det_a.y2, det_b.y2);
  247. float w = (float)max(0.0, xx2 - xx1 + 1.0);
  248. float h = (float)max(0.0, yy2 - yy1 + 1.0);
  249. float inter = w * h;
  250. float ovr = inter / (aa + ab - inter);
  251. return ovr;
  252. }
  253. float RetinaDrop::GetNmsThreshold() { return m_nms_threshold; }
  254. //////////////////////////////////////////////////////////////////////////////////
  255. //////////////////////////////////////////////////////////////////////////////////
  256. YoloDrop::YoloDrop(CGcvLogger* pLogger, float obj_th, float nms_th)
  257. :m_model_loaded(false),
  258. m_pInfer(0),
  259. m_runWithCuda(false)
  260. {
  261. BATCH_SIZE = 1;
  262. INPUT_CHANNEL = 3;
  263. IMAGE_WIDTH = 640; // default 640
  264. IMAGE_HEIGHT = 640; // default 640
  265. m_obj_threshold = obj_th;//default 0.6;
  266. m_nms_threshold = nms_th; //default0.4;
  267. m_anchor_num = 2;
  268. m_bbox_head = 4;
  269. m_variance[0] = 0.1f;
  270. m_variance[1] = 0.2f;
  271. //m_img_mean(123.0, 104.0, 117.0)
  272. m_img_mean[0] = 123.0;
  273. m_img_mean[1] = 104.0;
  274. m_img_mean[2] = 117.0;
  275. m_img_mean[3] = 0;
  276. //cv::Size size_detection(640, 640)
  277. m_size_detection.width = IMAGE_WIDTH;
  278. m_size_detection.height = IMAGE_HEIGHT;
  279. m_feature_steps = { 8,16,32 };
  280. m_pLogger = pLogger;
  281. /*for (const int step : m_feature_steps) {
  282. assert(step != 0);
  283. int feature_map = IMAGE_HEIGHT / step;
  284. m_feature_maps.push_back(feature_map);
  285. int feature_size = feature_map * feature_map;
  286. m_feature_sizes.push_back(feature_size);
  287. }
  288. m_anchor_sizes = { { 16,32 } ,{ 64,128 },{ 256, 512 } };
  289. m_sum_of_feature = std::accumulate(m_feature_sizes.begin(), m_feature_sizes.end(), 0) * m_anchor_num;
  290. generate_anchors();*/
  291. if (m_pLogger) {
  292. m_pLogger->INFO(string("YoloDrop object initialized"));
  293. }
  294. }
  295. YoloDrop::~YoloDrop() = default;
  296. bool YoloDrop::IsModelLoaded() {
  297. return m_model_loaded;
  298. };
  299. void YoloDrop::SetThreshold(float object_threshold, float nms_threshold)
  300. {
  301. this->m_obj_threshold = object_threshold;
  302. this->m_nms_threshold = nms_threshold;
  303. if (m_pInfer) {
  304. m_pInfer->setModelNMSThreshold(m_nms_threshold);
  305. m_pInfer->setModelScoreThreshold(m_obj_threshold);
  306. }
  307. }
  308. bool YoloDrop::LoadModel(std::string onnx_path) {
  309. if (m_pInfer) {
  310. delete m_pInfer;
  311. m_pInfer = 0;
  312. m_model_loaded = false;
  313. }
  314. cv::Size2f modelInputShape((float)IMAGE_WIDTH, (float)IMAGE_HEIGHT);
  315. if (m_pLogger) {
  316. m_pLogger->INFO(string("Loading detection model: ") + onnx_path);
  317. }
  318. else { std::cout << "Loading detection model: " << onnx_path << std::endl; }
  319. try {
  320. m_pInfer = new Inference(onnx_path, modelInputShape, "", m_runWithCuda);
  321. if (!m_pInfer) {
  322. throw(string("inference init error"));
  323. }
  324. m_pInfer->setModelNMSThreshold(m_nms_threshold);
  325. m_pInfer->setModelScoreThreshold(m_obj_threshold);
  326. if (m_pLogger) { m_pLogger->INFO(string("Detection model loaded")); }
  327. m_model_loaded = true;
  328. return m_model_loaded;
  329. }
  330. catch (...)
  331. {
  332. if (m_pLogger) { m_pLogger->ERRORINFO(string("loading model failed")); }
  333. }
  334. return false;
  335. }
  336. std::vector<Bbox> YoloDrop::RunModel(cv::Mat& frame, CGcvLogger* pInstanceLogger)
  337. {
  338. std::vector<Bbox> result;
  339. if (frame.empty()) {
  340. if (pInstanceLogger) {
  341. pInstanceLogger->ERRORINFO(string("RunModel(), input image is empty"));
  342. }
  343. throw(string("image is empty"));
  344. }
  345. if (!m_model_loaded) {
  346. pInstanceLogger->ERRORINFO(string("model is NOT loaded"));
  347. throw(string("model is NOT loaded"));
  348. }
  349. // Inference starts here...
  350. std::vector<Detection> output = m_pInfer->runInference(frame);
  351. int detections = output.size();
  352. std::cout << "Number of detections:" << detections << std::endl;
  353. for (int i = 0; i < detections; ++i)
  354. {
  355. Detection detection = output[i];
  356. cv::Rect box = detection.box;
  357. cv::Scalar color = detection.color;
  358. std::vector<cv::Point> pts = detection.kpts;
  359. Bbox box_out;
  360. box_out.score = detection.confidence;
  361. box_out.x1 = box.x;
  362. box_out.y1 = box.y;
  363. box_out.x2 = box.x + box.width;
  364. box_out.y2 = box.y + box.height;
  365. box_out.ppoint[0] = pts[0].x;
  366. box_out.ppoint[1] = pts[0].y;
  367. box_out.ppoint[2] = pts[1].x;
  368. box_out.ppoint[3] = pts[1].y;
  369. box_out.ppoint[4] = pts[2].x;
  370. box_out.ppoint[5] = pts[2].y;
  371. box_out.ppoint[6] = pts[3].x;
  372. box_out.ppoint[7] = pts[3].y;
  373. box_out.ppoint[8] = pts[4].x;
  374. box_out.ppoint[9] = pts[4].y;
  375. box_out.operate_point[0] = 0.0;
  376. box_out.operate_point[1] = 0.0;
  377. box_out.operate_angle = 0.0;
  378. box_out.area = 0.0;
  379. box_out.status = 0;
  380. result.push_back(box_out);
  381. //// Detection box
  382. //cv::rectangle(frame, box, color, 2);
  383. //// Detection box text
  384. //std::string classString = detection.className + ' ' + std::to_string(detection.confidence).substr(0, 4);
  385. //cv::Size textSize = cv::getTextSize(classString, cv::FONT_HERSHEY_DUPLEX, 1, 2, 0);
  386. //cv::Rect textBox(box.x, box.y - 40, textSize.width + 10, textSize.height + 20);
  387. //cv::rectangle(frame, textBox, color, cv::FILLED);
  388. //cv::putText(frame, classString, cv::Point(box.x + 5, box.y - 10), cv::FONT_HERSHEY_DUPLEX, 1, cv::Scalar(0, 0, 0), 2, 0);
  389. //for (auto& pt : pts) {
  390. // cv::circle(frame, pt, 3, cv::Scalar(0, 0, 255));
  391. //}
  392. }
  393. // Inference ends here...
  394. // This is only for preview purposes
  395. /*float scale = 0.8;
  396. cv::resize(frame, frame, cv::Size(frame.cols*scale, frame.rows*scale));
  397. cv::imshow("Inference", frame);
  398. cv::waitKey(-1);*/
  399. if (pInstanceLogger) {
  400. stringstream buff;
  401. buff << "detected object: " << detections;
  402. pInstanceLogger->INFO(buff.str());
  403. }
  404. return result;
  405. }
  406. void YoloDrop::generate_anchors() {
  407. m_refer_matrix = cv::Mat(m_sum_of_feature, m_bbox_head, CV_32FC1);
  408. int line = 0;
  409. for (size_t feature_map = 0; feature_map < m_feature_maps.size(); feature_map++) {
  410. for (int height = 0; height < m_feature_maps[feature_map]; ++height) {
  411. for (int width = 0; width < m_feature_maps[feature_map]; ++width) {
  412. for (int anchor = 0; anchor < m_anchor_sizes[feature_map].size(); ++anchor) {
  413. auto* row = m_refer_matrix.ptr<float>(line);
  414. row[0] = (float)(width + 0.5) * m_feature_steps[feature_map] / (float)IMAGE_WIDTH;
  415. row[1] = (float)(height + 0.5) * m_feature_steps[feature_map] / (float)IMAGE_HEIGHT;
  416. row[2] = m_anchor_sizes[feature_map][anchor] / (float)IMAGE_WIDTH;
  417. row[3] = m_anchor_sizes[feature_map][anchor] / (float)IMAGE_HEIGHT;
  418. line++;
  419. }
  420. }
  421. }
  422. }
  423. }
  424. int YoloDrop::post_process(
  425. cv::Mat &src_img,
  426. vector<cv::Mat> &result_matrix,
  427. std::vector<YoloDrop::DropRes>& valid_result
  428. )
  429. {
  430. valid_result.clear();
  431. std::vector<DropRes> result;
  432. for (int item = 0; item < m_sum_of_feature; ++item) {
  433. float* cur_bbox = (float*)result_matrix[0].data + item * 4;//result_matrix[0].step;
  434. float* cur_conf = (float*)result_matrix[2].data + item * 2;//result_matrix[1].step;
  435. float* cur_keyp = (float*)result_matrix[1].data + item * 10;//result_matrix[2].step;
  436. if (cur_conf[1] > m_obj_threshold) {
  437. DropRes headbox;
  438. headbox.confidence = cur_conf[1];
  439. auto* anchor = m_refer_matrix.ptr<float>(item);
  440. auto* keyp = cur_keyp;
  441. float cx, cy, kx, ky;
  442. cx = anchor[0] + cur_bbox[0] * m_variance[0] * anchor[2];
  443. cy = anchor[1] + cur_bbox[1] * m_variance[0] * anchor[3];
  444. kx = anchor[2] * exp(cur_bbox[2] * m_variance[1]);
  445. ky = anchor[3] * exp(cur_bbox[3] * m_variance[1]);
  446. cx -= kx / 2.0f;
  447. cy -= ky / 2.0f;
  448. kx += cx;
  449. ky += cy;
  450. headbox.drop_box.x1 = cx * src_img.cols;
  451. headbox.drop_box.y1 = cy * src_img.rows;
  452. headbox.drop_box.x2 = kx * src_img.cols;
  453. headbox.drop_box.y2 = ky * src_img.rows;
  454. for (int ki = 0; ki < 5; ++ki) {
  455. float kp_x = anchor[0] + keyp[2 * ki] * m_variance[0] * anchor[2];
  456. float kp_y = anchor[1] + keyp[2 * ki + 1] * m_variance[0] * anchor[3];
  457. kp_x *= src_img.cols;
  458. kp_y *= src_img.rows;
  459. headbox.keypoints.push_back(cv::Point2f(kp_x, kp_y));
  460. }
  461. /*float kp_x = anchor[0] + keyp[0] * m_variance[0] * anchor[2];
  462. float kp_y = anchor[1] + keyp[1] * m_variance[0] * anchor[3];
  463. kp_x *= src_img.cols;
  464. kp_y *= src_img.rows;
  465. headbox.keypoints = {
  466. cv::Point2f(kp_x,kp_y)
  467. };*/
  468. result.push_back(headbox);
  469. }
  470. }
  471. vector<int> keep;
  472. nms_detect(result, keep);
  473. for (size_t i = 0; i < keep.size(); ++i) {
  474. valid_result.push_back(result[keep[i]]);
  475. }
  476. return (int)valid_result.size();
  477. }
  478. void YoloDrop::nms_detect(
  479. std::vector<DropRes> & detections,
  480. vector<int> & keep)
  481. {
  482. keep.clear();
  483. if (detections.size() == 1) {
  484. keep.push_back(0);
  485. return;
  486. }
  487. sort(detections.begin(), detections.end(),
  488. [=](const DropRes& left, const DropRes& right) {
  489. return left.confidence > right.confidence;
  490. });
  491. vector<int> order;
  492. for (size_t i = 0; i < detections.size(); ++i) { order.push_back((int)i); }
  493. while (order.size()) {
  494. int i = order[0];
  495. keep.push_back(i);
  496. vector<int> del_idx;
  497. for (size_t j = 1; j < order.size(); ++j) {
  498. float iou = iou_calculate(
  499. detections[i].drop_box,
  500. detections[order[j]].drop_box);
  501. if (iou > m_nms_threshold) {
  502. del_idx.push_back((int)j);
  503. }
  504. }
  505. vector<int> order_update;
  506. for (size_t j = 1; j < order.size(); ++j) {
  507. vector<int>::iterator it = find(del_idx.begin(), del_idx.end(), j);
  508. if (it == del_idx.end()) {
  509. order_update.push_back(order[j]);
  510. }
  511. }
  512. order.clear();
  513. order.assign(order_update.begin(), order_update.end());
  514. }
  515. }
  516. float YoloDrop::iou_calculate(
  517. const YoloDrop::DropBox & det_a,
  518. const YoloDrop::DropBox & det_b)
  519. {
  520. float aa = (det_a.x2 - det_a.x1 + 1) * (det_a.y2 - det_a.y1 + 1);
  521. float ab = (det_b.x2 - det_b.x1 + 1) * (det_b.y2 - det_b.y1 + 1);
  522. float xx1 = max(det_a.x1, det_b.x1);
  523. float yy1 = max(det_a.y1, det_b.y1);
  524. float xx2 = min(det_a.x2, det_b.x2);
  525. float yy2 = min(det_a.y2, det_b.y2);
  526. float w = (float)max(0.0, xx2 - xx1 + 1.0);
  527. float h = (float)max(0.0, yy2 - yy1 + 1.0);
  528. float inter = w * h;
  529. float ovr = inter / (aa + ab - inter);
  530. return ovr;
  531. }
  532. float YoloDrop::GetNmsThreshold() { return m_nms_threshold; }
  533. }