tea_detect.h 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #pragma once
  2. #include <opencv.hpp>
  3. #include "logger.h"
  4. #include "data_def.h"
  5. namespace graft_cv {
  6. class RetinaDrop {
  7. struct DropBox {
  8. float x1;
  9. float y1;
  10. float x2;
  11. float y2;
  12. };
  13. struct DropRes {
  14. float confidence;
  15. DropBox drop_box;
  16. std::vector<cv::Point2f> keypoints;
  17. };
  18. public:
  19. explicit RetinaDrop(CGcvLogger* pLogger=0, float obj_th=0.6, float nms_th=0.4);
  20. ~RetinaDrop();
  21. bool LoadModel(std::string onnx_path);
  22. std::vector<Bbox> RunModel(cv::Mat& img, CGcvLogger* pInstanceLogger=0);
  23. bool IsModelLoaded();
  24. float GetNmsThreshold();
  25. void SetThreshold(float object_threshold, float nms_threshold);
  26. private:
  27. void generate_anchors();
  28. int post_process(
  29. cv::Mat &vec_Mat,
  30. std::vector<cv::Mat> &result_matrix,
  31. std::vector<RetinaDrop::DropRes>& valid_result);
  32. void nms_detect(
  33. std::vector<DropRes>& detections,
  34. std::vector<int>& keep);
  35. static float iou_calculate(
  36. const DropBox& det_a,
  37. const DropBox& det_b);
  38. int BATCH_SIZE; //default 1
  39. int INPUT_CHANNEL; //default 3
  40. int IMAGE_WIDTH; //default 640
  41. int IMAGE_HEIGHT; //default 640
  42. float m_obj_threshold; // default 0.5
  43. float m_nms_threshold; // default 0.45
  44. cv::Mat m_refer_matrix;
  45. int m_anchor_num;
  46. int m_bbox_head;
  47. std::vector<int> m_feature_sizes;
  48. std::vector<int> m_feature_steps;
  49. std::vector<int> m_feature_maps;
  50. std::vector<std::vector<int>>m_anchor_sizes;
  51. int m_sum_of_feature;
  52. cv::dnn::Net m_model;
  53. float m_variance[2];
  54. cv::Scalar m_img_mean;
  55. cv::Size m_size_detection;
  56. bool m_model_loaded;
  57. CGcvLogger* m_pLogger;
  58. };
  59. }