You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

84 lines
2.1 KiB

2 weeks ago
import numpy as np
import cv2
import time
import torch
from . import detect_api as detect
from . import deep_sort as ds
# # 系统状态定义
# GLB_STATUS = {
# "GLB_STATUS_UNKOWN": 0,
# "GLB_STATUS_WAIT": 1,
# "GLB_STATUS_SEARCH": 2,
# "GLB_STATUS_TRACK": 3,
# "GLB_STATUS_SCAN": 4,
# "GLB_STATUS_LOST": 5,
# "GLB_STATUS_FSCAN": 6,
# "GLB_STATUS_LOCK": 7,
# "GLB_STATUS_LOCKFAILED": 8,
# "GLB_STATUS_MOTRACK": 9,
# "GLB_STATUS_AIM": 10,
# }
# # 锁定模式
# LockMode = {
# "LOCK_NONE": 0,
# "LOCK_AUTO": 10,
# "LOCK_POINT": 21,
# "LOCK_RECT": 22,
# "LOCK_UNLOCK": 3,
# "LOCK_ID": 4,
# }
class EOController:
def __init__(self):
# 初始化检测器
self.detector = detect.ObjectDetector(
weights='./ArithControl/model/best_3class.pt',
imgsz=(1024, 1280))
# deepsort
self.ds = ds.DeepSort(
'./ArithControl/deep_sort/deep/checkpoint/ckpt.t7',
enable_reid=True
)
def run(self,frame):
detections = self.detector.detect(frame)
if len(detections) > 0:
# 转换为 torch.Tensor
det_tensor = torch.tensor(detections, dtype=torch.float32)
bbox_xywh = torch.column_stack([
(det_tensor[:, 1:3] + det_tensor[:, 3:5])/2, # x1, y1
det_tensor[:, 3:5] - det_tensor[:, 1:3] # w, h
])
confidences = det_tensor[:, 5] # conf
clss = det_tensor[:, 0].long() # cls
# 调用一次 update 处理所有目标
# 注意deepsort输入使用中心宽高
pipe_out = self.ds.update(bbox_xywh, confidences, clss, frame)
# 返回管道
return pipe_out
# 解析指令
def parse_cmd(self, cmd):
if cmd.mode == 'LockMode':
self.ds.lock_point(cmd.point)
elif cmd.mode == 'Lock_RECT':
self.ds.lock_rect(cmd.rect)
elif cmd.mode == 'Lock_ID':
self.ds.lock_id(cmd.id)
elif cmd.mode == 'Lock_UNLOCK':
self.ds.lock_unlock()