inference.h 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. #ifndef INFERENCE_H
  2. #define INFERENCE_H
  3. // Cpp native
  4. #include <fstream>
  5. #include <vector>
  6. #include <string>
  7. #include <random>
  8. // OpenCV / DNN / Inference
  9. #include <opencv2/imgproc.hpp>
  10. #include <opencv2/opencv.hpp>
  11. #include <opencv2/dnn.hpp>
  12. struct Detection
  13. {
  14. int class_id{0};
  15. std::string className{};
  16. float confidence{0.0};
  17. cv::Scalar color{};
  18. cv::Rect box{};
  19. std::vector<cv::Point> kpts{};
  20. };
  21. class Inference
  22. {
  23. public:
  24. Inference(const std::string &onnxModelPath, const cv::Size2f &modelInputShape, const std::string &classesTxtFile, const bool &runWithCuda = true);
  25. std::vector<Detection> runInference(const cv::Mat &input);
  26. void setModelConfidenseThreshold(float t) { modelConfidenseThreshold = t; };
  27. void setModelScoreThreshold(float t) { modelScoreThreshold = t; };
  28. void setModelNMSThreshold(float t) { modelNMSThreshold = t; }
  29. private:
  30. void loadClassesFromFile();
  31. void loadOnnxNetwork();
  32. cv::Mat formatToSquare(const cv::Mat &source);
  33. std::string modelPath{};
  34. std::string classesPath{};
  35. bool cudaEnabled{};
  36. std::vector<std::string> classes{};
  37. cv::Size2f modelShape{};
  38. float modelConfidenseThreshold {0.25};
  39. float modelScoreThreshold {0.45f};
  40. float modelNMSThreshold {0.50};
  41. bool letterBoxForSquare = true;
  42. cv::dnn::Net net;
  43. };
  44. #endif // INFERENCE_H