inference.h 1.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. #ifndef INFERENCE_H
  2. #define INFERENCE_H
  3. // Cpp native
  4. #include <fstream>
  5. #include <vector>
  6. #include <string>
  7. #include <random>
  8. // OpenCV / DNN / Inference
  9. #include <opencv2/imgproc.hpp>
  10. #include <opencv2/opencv.hpp>
  11. #include <opencv2/dnn.hpp>
  12. struct Detection
  13. {
  14. int class_id{0};
  15. std::string className{};
  16. float confidence{0.0};
  17. cv::Scalar color{};
  18. cv::Rect box{};
  19. };
  20. class Inference
  21. {
  22. public:
  23. Inference(const std::string &onnxModelPath, const cv::Size2f &modelInputShape, const std::string &classesTxtFile, const bool &runWithCuda = true);
  24. std::vector<Detection> runInference(const cv::Mat &input);
  25. private:
  26. void loadClassesFromFile();
  27. void loadOnnxNetwork();
  28. cv::Mat formatToSquare(const cv::Mat &source);
  29. std::string modelPath{};
  30. std::string classesPath{};
  31. bool cudaEnabled{};
  32. std::vector<std::string> classes{};
  33. cv::Size2f modelShape{};
  34. float modelConfidenseThreshold {0.25};
  35. float modelScoreThreshold {0.45f};
  36. float modelNMSThreshold {0.50};
  37. bool letterBoxForSquare = true;
  38. cv::dnn::Net net;
  39. };
  40. #endif // INFERENCE_H