You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

191 lines
5.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "AI_API.h"
#include "detect.h"
#include "environment.h"
#include "Arith_obj_det.h"
using namespace AIGO;
static int WIDTH = 0;
static int HEIGTH = 0;
bool isShow() {
return _access("./AI_SHOW", 0) == 0;
}
bool comp(const objinfo &a, const objinfo &b)
{
float img_center_x = WIDTH / 2;
float img_center_y = HEIGTH / 2;
float box1_ceterx = (a.x1 + a.x2) / 2;
float box1_cetery = (a.y1 + a.y2) / 2;
float box2_ceterx = (b.x1 + b.x2) / 2;
float box2_cetery = (b.y1 + b.y2) / 2;
float dis1 = std::sqrt(std::pow(img_center_x - box1_ceterx, 2) + std::pow((img_center_y - box1_cetery), 2));
float dis2 = std::sqrt(std::pow(img_center_x - box2_ceterx, 2) + std::pow((img_center_y - box2_cetery), 2));
return dis1 < dis2;
}
API_AI_OBJ_DET::~API_AI_OBJ_DET()
{
if (m_detDemo != NULL)
{
delete ((CudaYOLODetect*)m_detDemo);
m_detDemo = NULL;
}
if (m_classDemo != NULL)
{
delete ((Detect_class*)m_classDemo);
m_classDemo = NULL;
}
}
int API_AI_OBJ_DET::Arith_inial(std::string& model_path_det, const int& det_width, const int& det_height,
const int& cls_num, const float& obj_conf, const float& cls_conf, const float& nms_conf,
std::string& model_path_class, const int& class_width, const int& class_height, bool detect_class)
{
DEBUG_LOG("loadding det model from %s", model_path_det.c_str());
m_detDemo = new CudaYOLODetect (model_path_det, det_height, det_width, cls_num, obj_conf, cls_conf, nms_conf);
DEBUG_LOG("loadding classify model from %s", model_path_class.c_str());
if (detect_class)
{
m_classDemo = new Detect_class(model_path_class, class_height, class_width, cls_num);
DEBUG_LOG("both model loaded");
}
if(!m_detDemo)
{
return -1;
}
if (!m_classDemo)
{
return -2;
}
return 0;
}
int API_AI_OBJ_DET::Arith_setParam(int idx, const float& cls_conf, const float& nms_conf)
{
DEBUG_LOG("setting det conf, idx=%d->%f->%f", idx, cls_conf, nms_conf);
if (m_detDemo != NULL)
{
((CudaYOLODetect*)m_detDemo)->setParam(idx, cls_conf, nms_conf);
return 0;
}
else
{
//DEBUG_LOG("ingting error");
return 1;
}
}
int API_AI_OBJ_DET::Arith_setClassParam(int idx, const float& cls_conf)
{
DEBUG_LOG("setting classify conf, idx:%d -> %f", idx, cls_conf);
if (m_classDemo != NULL)
{
((Detect_class*)m_classDemo)->setParam(idx, cls_conf);
return 0;
}
else
{
DEBUG_LOG("setting error");
return 1;
}
}
int API_AI_OBJ_DET::Arith_setMatchRatio(float matchratio){
DEBUG_LOG("setting match ratio");
if(this->m_detDemo == NULL){
DEBUG_LOG("model is NULL, init first needed");
return 1;
}
((CudaYOLODetect*)m_detDemo)->setMatchRatio(matchratio);
return 0;
};
static int goClassCount = 0;
int API_AI_OBJ_DET::Arith_infer(ImgMat src, std::vector<objinfo>& res_det, std::vector<objinfo>& track_det)
{
WIDTH = src.width;
HEIGTH = src.height;
if (m_detDemo != NULL)
{
//cv::Mat testImg = cv::imread("F:/Cudayolo_track/test_img_0711/00000093_classID_1.jpg");
//memset(testImg.data,0, 256 * 256 * 3 * sizeof(char));
//float* zerodata = (float*)malloc(256 * 256 * 3 * sizeof(float));
//cv::cvtColor(testImg, testImg, cv::COLOR_BGR2RGB);
//((Detect_class*)m_classDemo)->run_class(testImg.data, det_res_class);
/* printf("datainput:");
for(int i = 0; i < 20; i++)
{
printf("%d ", src.data[i]);
}
printf("\n" );*/
DEBUG_LOG("datainput: %p, %d, %d", &src.data, src.width, src.height);
auto start0 = std::chrono::high_resolution_clock::now();
// 完成YOLO检测推理
DEBUG_LOG("PimgBuff: %p, %d, %d", &src.data, src.width, src.height);
std::vector<std::shared_ptr < OBJTRACK>> track_res = ((CudaYOLODetect*) m_detDemo)->run(src, res_det);
auto end0 = std::chrono::high_resolution_clock::now();
//std::cout << std::fixed << std::setprecision(3) << "det run time: " << std::chrono::duration<double, std::milli>(end0 - start0).count() << "ms" << std::endl;
//PROFILE_LOG("detInfer cost: %f", std::chrono::duration<double, std::milli>(end0 - start0).count());
// 根据推理结果,进行分类识别
sort(res_det.begin(), res_det.end(), comp);
track_det.clear();
if (m_classDemo != NULL)
{
for (std::shared_ptr < OBJTRACK> t_info : track_res)
{
// 检测到16次目标即可启动目标分类
if (t_info->_nExistCnt > 14)
{
goClassCount++;
DEBUG_LOG("goClassCount : %d", goClassCount);
obj_res_class det_res_class;
auto start = std::chrono::high_resolution_clock::now();
det_res_class.cls_id_class = -1;
((Detect_class*)m_classDemo)->run_class(t_info->m_imgBufferCV.data, det_res_class);
auto end = std::chrono::high_resolution_clock::now();
//std::cout << std::fixed << std::setprecision(3) << "inference time: " << std::chrono::duration<double, std::milli>(end - start).count() << "ms" << std::endl;
//PROFILE_LOG("Classify Infer cost: %f", std::chrono::duration<double, std::milli>(end - start).count());
if (isShow()) {
cv::imshow("ClassifyImage", t_info->m_imgBufferCV);
cv::waitKey(10);
}
objinfo d;
t_info->_cls = det_res_class.cls_id_class;
t_info->toObjInfo(d);
DEBUG_LOG("classnum: %d-%d-%d\n", det_res_class.cls_id_class, t_info->_trackID, d.trackID);
track_det.push_back(d);
}
}
sort(track_det.begin(), track_det.end(), comp);
}
else
{
return -2;
}
}
else
{
return -1;
}
return 0;
}