You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
4.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import sys
import os
from pathlib import Path
# 添加YOLOv5路径到Python路径
current_dir = Path(__file__).parent
yolov5_path = current_dir / 'yolov5'
sys.path.insert(0, str(yolov5_path))
import numpy as np
import torch
from models.common import DetectMultiBackend
from utils.torch_utils import select_device, time_sync
from utils.general import (LOGGER, check_img_size, non_max_suppression, scale_boxes)
from utils.augmentations import letterbox
import cv2
import pathlib
# 强制替换 PosixPath 为 WindowsPath 或 str
pathlib.PosixPath = pathlib.WindowsPath
class ObjectDetector:
def __init__(self, weights='yolov5s.pt', data='data/coco.yaml',
imgsz=(640, 640), conf_thres=0.25, iou_thres=0.45,
max_det=1000, device='0', half=False, dnn=False):
"""
初始化物体检测器。
:param weights: 模型权重文件路径
:param data: 数据配置文件路径
:param imgsz: 图像输入尺寸 (height, width)
:param conf_thres: 置信度阈值
:param iou_thres: IOU 阈值(用于 NMS
:param max_det: 每张图像最大检测目标数
:param device: 设备 ('cpu''0' 表示 GPU 0)
:param half: 是否使用半精度FP16
:param dnn: 是否使用 OpenCV DNN 后端
"""
self.weights = weights
self.data = data
self.imgsz = imgsz
self.conf_thres = conf_thres
self.iou_thres = iou_thres
self.max_det = max_det
self.device = select_device(device) # 自动选择设备
self.half = half
self.dnn = dnn
self.classes = None
self.agnostic_nms = False
self.augment = False
self.visualize = False
# 加载模型
self.model = DetectMultiBackend(self.weights, device=self.device, dnn=self.dnn,
data=self.data, fp16=self.half)
self.stride, self.names, self.pt = self.model.stride, self.model.names, self.model.pt
self.imgsz = check_img_size(imgsz, s=self.stride)
# 模型预热
self.model.warmup(imgsz=(1, 3, *self.imgsz))
def detect(self, img):
"""
对输入图像进行物体检测。
:param img: 输入图像 (numpy array)
:return: 检测结果列表,每个元素为 [x1, y1, x2, y2, conf]
"""
# 图像预处理
im0 = img # 原图
# Resize 图像到网络输入尺寸
# im = cv2.resize(im0, (self.imgsz[1], self.imgsz[0]), interpolation=cv2.INTER_LINEAR)
im = letterbox(im0, self.imgsz)[0]
# 转换格式HWC to CHW, BGR to RGB
im = im.transpose((2, 0, 1))[::-1]
im = np.ascontiguousarray(im)
t1 = time_sync()
im = torch.from_numpy(im).to(self.device)
im = im.half() if self.half else im.float()
im /= 255 # 归一化到 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # 扩展 batch 维度
# 模型推理
pred = self.model(im, augment=self.augment, visualize=self.visualize)
# 非极大值抑制 (NMS)
pred = non_max_suppression(pred, self.conf_thres, self.iou_thres,
self.classes, self.agnostic_nms, max_det=self.max_det)
detections = []
for i, det in enumerate(pred):
if len(det):
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
for *xyxy, conf, cls in reversed(det):
xyxy = [round(x.item()) for x in xyxy] # 将坐标四舍五入为整数
detections.append(xyxy + [float(conf)])
return detections
if __name__ == '__main__':
detector = ObjectDetector(weights='./best_3class.pt', imgsz=(1024, 1280))
img = cv2.imread('./20250703_bar_017.png')
if img is None:
print("Error: 无法加载图像")
else:
detections = detector.detect(img)
for det in detections:
x1, y1, x2, y2, conf = det # 每个 det 是 [x1, y1, x2, y2, conf]
# 画框BGR颜色粗细2
cv2.rectangle(img, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2)
# 添加置信度文本
label = f"{conf:.2f}"
cv2.putText(img, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX,
fontScale=0.5, color=(0, 255, 0), thickness=1)
print(f"检测结果: {det}")
# 保存图像
output_path = './output.jpg'
cv2.imwrite(output_path, img)
print(f"检测结果图像已保存为: {output_path}")