You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

84 lines
2.1 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import numpy as np
import cv2
import time
import torch
from . import detect_api as detect
from . import deep_sort as ds
# # 系统状态定义
# GLB_STATUS = {
# "GLB_STATUS_UNKOWN": 0,
# "GLB_STATUS_WAIT": 1,
# "GLB_STATUS_SEARCH": 2,
# "GLB_STATUS_TRACK": 3,
# "GLB_STATUS_SCAN": 4,
# "GLB_STATUS_LOST": 5,
# "GLB_STATUS_FSCAN": 6,
# "GLB_STATUS_LOCK": 7,
# "GLB_STATUS_LOCKFAILED": 8,
# "GLB_STATUS_MOTRACK": 9,
# "GLB_STATUS_AIM": 10,
# }
# # 锁定模式
# LockMode = {
# "LOCK_NONE": 0,
# "LOCK_AUTO": 10,
# "LOCK_POINT": 21,
# "LOCK_RECT": 22,
# "LOCK_UNLOCK": 3,
# "LOCK_ID": 4,
# }
class EOController:
def __init__(self):
# 初始化检测器
self.detector = detect.ObjectDetector(
weights='./ArithControl/model/best_3class.pt',
imgsz=(1024, 1280))
# deepsort
self.ds = ds.DeepSort(
'./ArithControl/deep_sort/deep/checkpoint/ckpt.t7',
enable_reid=True
)
def run(self,frame):
detections = self.detector.detect(frame)
if len(detections) > 0:
# 转换为 torch.Tensor
det_tensor = torch.tensor(detections, dtype=torch.float32)
bbox_xywh = torch.column_stack([
(det_tensor[:, 1:3] + det_tensor[:, 3:5])/2, # x1, y1
det_tensor[:, 3:5] - det_tensor[:, 1:3] # w, h
])
confidences = det_tensor[:, 5] # conf
clss = det_tensor[:, 0].long() # cls
# 调用一次 update 处理所有目标
# 注意deepsort输入使用中心宽高
pipe_out = self.ds.update(bbox_xywh, confidences, clss, frame)
# 返回管道
return pipe_out
# 解析指令
def parse_cmd(self, cmd):
if cmd.mode == 'LockMode':
self.ds.lock_point(cmd.point)
elif cmd.mode == 'Lock_RECT':
self.ds.lock_rect(cmd.rect)
elif cmd.mode == 'Lock_ID':
self.ds.lock_id(cmd.id)
elif cmd.mode == 'Lock_UNLOCK':
self.ds.lock_unlock()