You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

191 lines
5.7 KiB

#include "AI_API.h"
#include "detect.h"
#include "environment.h"
#include "Arith_obj_det.h"
using namespace AIGO;
static int WIDTH = 0;
static int HEIGTH = 0;
bool isShow() {
return _access("./AI_SHOW", 0) == 0;
}
bool comp(const objinfo &a, const objinfo &b)
{
float img_center_x = WIDTH / 2;
float img_center_y = HEIGTH / 2;
float box1_ceterx = (a.x1 + a.x2) / 2;
float box1_cetery = (a.y1 + a.y2) / 2;
float box2_ceterx = (b.x1 + b.x2) / 2;
float box2_cetery = (b.y1 + b.y2) / 2;
float dis1 = std::sqrt(std::pow(img_center_x - box1_ceterx, 2) + std::pow((img_center_y - box1_cetery), 2));
float dis2 = std::sqrt(std::pow(img_center_x - box2_ceterx, 2) + std::pow((img_center_y - box2_cetery), 2));
return dis1 < dis2;
}
API_AI_OBJ_DET::~API_AI_OBJ_DET()
{
if (m_detDemo != NULL)
{
delete ((CudaYOLODetect*)m_detDemo);
m_detDemo = NULL;
}
if (m_classDemo != NULL)
{
delete ((Detect_class*)m_classDemo);
m_classDemo = NULL;
}
}
int API_AI_OBJ_DET::Arith_inial(std::string& model_path_det, const int& det_width, const int& det_height,
const int& cls_num, const float& obj_conf, const float& cls_conf, const float& nms_conf,
std::string& model_path_class, const int& class_width, const int& class_height, bool detect_class)
{
DEBUG_LOG("loadding det model from %s", model_path_det.c_str());
m_detDemo = new CudaYOLODetect (model_path_det, det_height, det_width, cls_num, obj_conf, cls_conf, nms_conf);
DEBUG_LOG("loadding classify model from %s", model_path_class.c_str());
if (detect_class)
{
m_classDemo = new Detect_class(model_path_class, class_height, class_width, cls_num);
DEBUG_LOG("both model loaded");
}
if(!m_detDemo)
{
return -1;
}
if (!m_classDemo)
{
return -2;
}
return 0;
}
int API_AI_OBJ_DET::Arith_setParam(int idx, const float& cls_conf, const float& nms_conf)
{
DEBUG_LOG("setting det conf, idx=%d->%f->%f", idx, cls_conf, nms_conf);
if (m_detDemo != NULL)
{
((CudaYOLODetect*)m_detDemo)->setParam(idx, cls_conf, nms_conf);
return 0;
}
else
{
//DEBUG_LOG("ingting error");
return 1;
}
}
int API_AI_OBJ_DET::Arith_setClassParam(int idx, const float& cls_conf)
{
DEBUG_LOG("setting classify conf, idx:%d -> %f", idx, cls_conf);
if (m_classDemo != NULL)
{
((Detect_class*)m_classDemo)->setParam(idx, cls_conf);
return 0;
}
else
{
DEBUG_LOG("setting error");
return 1;
}
}
int API_AI_OBJ_DET::Arith_setMatchRatio(float matchratio){
DEBUG_LOG("setting match ratio");
if(this->m_detDemo == NULL){
DEBUG_LOG("model is NULL, init first needed");
return 1;
}
((CudaYOLODetect*)m_detDemo)->setMatchRatio(matchratio);
return 0;
};
static int goClassCount = 0;
int API_AI_OBJ_DET::Arith_infer(ImgMat src, std::vector<objinfo>& res_det, std::vector<objinfo>& track_det)
{
WIDTH = src.width;
HEIGTH = src.height;
if (m_detDemo != NULL)
{
//cv::Mat testImg = cv::imread("F:/Cudayolo_track/test_img_0711/00000093_classID_1.jpg");
//memset(testImg.data,0, 256 * 256 * 3 * sizeof(char));
//float* zerodata = (float*)malloc(256 * 256 * 3 * sizeof(float));
//cv::cvtColor(testImg, testImg, cv::COLOR_BGR2RGB);
//((Detect_class*)m_classDemo)->run_class(testImg.data, det_res_class);
/* printf("datainput:");
for(int i = 0; i < 20; i++)
{
printf("%d ", src.data[i]);
}
printf("\n" );*/
DEBUG_LOG("datainput: %p, %d, %d", &src.data, src.width, src.height);
auto start0 = std::chrono::high_resolution_clock::now();
// 完成YOLO检测推理
DEBUG_LOG("PimgBuff: %p, %d, %d", &src.data, src.width, src.height);
std::vector<std::shared_ptr < OBJTRACK>> track_res = ((CudaYOLODetect*) m_detDemo)->run(src, res_det);
auto end0 = std::chrono::high_resolution_clock::now();
//std::cout << std::fixed << std::setprecision(3) << "det run time: " << std::chrono::duration<double, std::milli>(end0 - start0).count() << "ms" << std::endl;
//PROFILE_LOG("detInfer cost: %f", std::chrono::duration<double, std::milli>(end0 - start0).count());
// 根据推理结果,进行分类识别
sort(res_det.begin(), res_det.end(), comp);
track_det.clear();
if (m_classDemo != NULL)
{
for (std::shared_ptr < OBJTRACK> t_info : track_res)
{
// 检测到16次目标即可启动目标分类
if (t_info->_nExistCnt > 14)
{
goClassCount++;
DEBUG_LOG("goClassCount : %d", goClassCount);
obj_res_class det_res_class;
auto start = std::chrono::high_resolution_clock::now();
det_res_class.cls_id_class = -1;
((Detect_class*)m_classDemo)->run_class(t_info->m_imgBufferCV.data, det_res_class);
auto end = std::chrono::high_resolution_clock::now();
//std::cout << std::fixed << std::setprecision(3) << "inference time: " << std::chrono::duration<double, std::milli>(end - start).count() << "ms" << std::endl;
//PROFILE_LOG("Classify Infer cost: %f", std::chrono::duration<double, std::milli>(end - start).count());
if (isShow()) {
cv::imshow("ClassifyImage", t_info->m_imgBufferCV);
cv::waitKey(10);
}
objinfo d;
t_info->_cls = det_res_class.cls_id_class;
t_info->toObjInfo(d);
DEBUG_LOG("classnum: %d-%d-%d\n", det_res_class.cls_id_class, t_info->_trackID, d.trackID);
track_det.push_back(d);
}
}
sort(track_det.begin(), track_det.end(), comp);
}
else
{
return -2;
}
}
else
{
return -1;
}
return 0;
}