|
|
|
|
import sys
|
|
|
|
|
import os
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
# 添加YOLOv5路径到Python路径
|
|
|
|
|
current_dir = Path(__file__).parent
|
|
|
|
|
yolov5_path = current_dir / 'yolov5'
|
|
|
|
|
sys.path.insert(0, str(yolov5_path))
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import torch
|
|
|
|
|
from models.common import DetectMultiBackend
|
|
|
|
|
from utils.torch_utils import select_device, time_sync
|
|
|
|
|
from utils.general import (LOGGER, check_img_size, non_max_suppression, scale_boxes)
|
|
|
|
|
from utils.augmentations import letterbox
|
|
|
|
|
import cv2
|
|
|
|
|
|
|
|
|
|
import pathlib
|
|
|
|
|
# 强制替换 PosixPath 为 WindowsPath 或 str
|
|
|
|
|
pathlib.PosixPath = pathlib.WindowsPath
|
|
|
|
|
|
|
|
|
|
class ObjectDetector:
|
|
|
|
|
def __init__(self, weights='yolov5s.pt', data='data/coco.yaml',
|
|
|
|
|
imgsz=(640, 640), conf_thres=0.25, iou_thres=0.45,
|
|
|
|
|
max_det=1000, device='0', half=False, dnn=False):
|
|
|
|
|
"""
|
|
|
|
|
初始化物体检测器。
|
|
|
|
|
|
|
|
|
|
:param weights: 模型权重文件路径
|
|
|
|
|
:param data: 数据配置文件路径
|
|
|
|
|
:param imgsz: 图像输入尺寸 (height, width)
|
|
|
|
|
:param conf_thres: 置信度阈值
|
|
|
|
|
:param iou_thres: IOU 阈值(用于 NMS)
|
|
|
|
|
:param max_det: 每张图像最大检测目标数
|
|
|
|
|
:param device: 设备 ('cpu' 或 '0' 表示 GPU 0)
|
|
|
|
|
:param half: 是否使用半精度(FP16)
|
|
|
|
|
:param dnn: 是否使用 OpenCV DNN 后端
|
|
|
|
|
"""
|
|
|
|
|
self.weights = weights
|
|
|
|
|
self.data = data
|
|
|
|
|
self.imgsz = imgsz
|
|
|
|
|
self.conf_thres = conf_thres
|
|
|
|
|
self.iou_thres = iou_thres
|
|
|
|
|
self.max_det = max_det
|
|
|
|
|
self.device = select_device(device) # 自动选择设备
|
|
|
|
|
self.half = half
|
|
|
|
|
self.dnn = dnn
|
|
|
|
|
self.classes = None
|
|
|
|
|
self.agnostic_nms = False
|
|
|
|
|
self.augment = False
|
|
|
|
|
self.visualize = False
|
|
|
|
|
|
|
|
|
|
# 加载模型
|
|
|
|
|
self.model = DetectMultiBackend(self.weights, device=self.device, dnn=self.dnn,
|
|
|
|
|
data=self.data, fp16=self.half)
|
|
|
|
|
self.stride, self.names, self.pt = self.model.stride, self.model.names, self.model.pt
|
|
|
|
|
self.imgsz = check_img_size(imgsz, s=self.stride)
|
|
|
|
|
|
|
|
|
|
# 模型预热
|
|
|
|
|
self.model.warmup(imgsz=(1, 3, *self.imgsz))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect(self, img):
|
|
|
|
|
"""
|
|
|
|
|
对输入图像进行物体检测。
|
|
|
|
|
|
|
|
|
|
:param img: 输入图像 (numpy array)
|
|
|
|
|
:return: 检测结果列表,每个元素为 [x1, y1, x2, y2, conf]
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 图像预处理
|
|
|
|
|
im0 = img # 原图
|
|
|
|
|
|
|
|
|
|
# Resize 图像到网络输入尺寸
|
|
|
|
|
# im = cv2.resize(im0, (self.imgsz[1], self.imgsz[0]), interpolation=cv2.INTER_LINEAR)
|
|
|
|
|
im = letterbox(im0, self.imgsz)[0]
|
|
|
|
|
|
|
|
|
|
# 转换格式:HWC to CHW, BGR to RGB
|
|
|
|
|
im = im.transpose((2, 0, 1))[::-1]
|
|
|
|
|
im = np.ascontiguousarray(im)
|
|
|
|
|
t1 = time_sync()
|
|
|
|
|
im = torch.from_numpy(im).to(self.device)
|
|
|
|
|
im = im.half() if self.half else im.float()
|
|
|
|
|
im /= 255 # 归一化到 0.0 - 1.0
|
|
|
|
|
if len(im.shape) == 3:
|
|
|
|
|
im = im[None] # 扩展 batch 维度
|
|
|
|
|
|
|
|
|
|
# 模型推理
|
|
|
|
|
pred = self.model(im, augment=self.augment, visualize=self.visualize)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 非极大值抑制 (NMS)
|
|
|
|
|
pred = non_max_suppression(pred, self.conf_thres, self.iou_thres,
|
|
|
|
|
self.classes, self.agnostic_nms, max_det=self.max_det)
|
|
|
|
|
|
|
|
|
|
detections = []
|
|
|
|
|
|
|
|
|
|
for i, det in enumerate(pred):
|
|
|
|
|
|
|
|
|
|
if len(det):
|
|
|
|
|
|
|
|
|
|
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
|
|
|
|
|
|
|
|
|
|
for *xyxy, conf, cls in reversed(det):
|
|
|
|
|
xyxy = [round(x.item()) for x in xyxy] # 将坐标四舍五入为整数
|
|
|
|
|
detections.append(xyxy + [float(conf)])
|
|
|
|
|
|
|
|
|
|
return detections
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
detector = ObjectDetector(weights='./best_3class.pt', imgsz=(1024, 1280))
|
|
|
|
|
img = cv2.imread('./20250703_bar_017.png')
|
|
|
|
|
if img is None:
|
|
|
|
|
print("Error: 无法加载图像")
|
|
|
|
|
else:
|
|
|
|
|
detections = detector.detect(img)
|
|
|
|
|
for det in detections:
|
|
|
|
|
x1, y1, x2, y2, conf = det # 每个 det 是 [x1, y1, x2, y2, conf]
|
|
|
|
|
|
|
|
|
|
# 画框(BGR颜色,粗细2)
|
|
|
|
|
cv2.rectangle(img, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2)
|
|
|
|
|
|
|
|
|
|
# 添加置信度文本
|
|
|
|
|
label = f"{conf:.2f}"
|
|
|
|
|
cv2.putText(img, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX,
|
|
|
|
|
fontScale=0.5, color=(0, 255, 0), thickness=1)
|
|
|
|
|
|
|
|
|
|
print(f"检测结果: {det}")
|
|
|
|
|
|
|
|
|
|
# 保存图像
|
|
|
|
|
output_path = './output.jpg'
|
|
|
|
|
cv2.imwrite(output_path, img)
|
|
|
|
|
print(f"检测结果图像已保存为: {output_path}")
|