You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
4.7 KiB

2 weeks ago
import sys
import os
# 获取当前文件所在目录然后添加yolov5路径yolov5在ArithControl目录内
current_dir = os.path.dirname(os.path.abspath(__file__))
yolov5_path = os.path.join(current_dir, 'yolov5')
if yolov5_path not in sys.path:
sys.path.insert(0, yolov5_path)
import numpy as np
import torch
# 直接导入具体模块
from models.common import DetectMultiBackend
from utils.torch_utils import select_device, time_sync
from utils.general import (LOGGER, check_img_size, non_max_suppression, scale_boxes)
from utils.augmentations import letterbox
import cv2
import pathlib
# 强制替换 PosixPath 为 WindowsPath 或 str
pathlib.PosixPath = pathlib.WindowsPath
class ObjectDetector:
def __init__(self, weights='yolov5s.pt', data='data/coco.yaml',
imgsz=(640, 640), conf_thres=0.25, iou_thres=0.45,
max_det=1000, device='0', half=False, dnn=False):
"""
初始化物体检测器
:param weights: 模型权重文件路径
:param data: 数据配置文件路径
:param imgsz: 图像输入尺寸 (height, width)
:param conf_thres: 置信度阈值
:param iou_thres: IOU 阈值用于 NMS
:param max_det: 每张图像最大检测目标数
:param device: 设备 ('cpu' '0' 表示 GPU 0)
:param half: 是否使用半精度FP16
:param dnn: 是否使用 OpenCV DNN 后端
"""
self.weights = weights
self.data = data
self.imgsz = imgsz
self.conf_thres = conf_thres
self.iou_thres = iou_thres
self.max_det = max_det
self.device = select_device(device) # 自动选择设备
self.half = half
self.dnn = dnn
self.classes = None
self.agnostic_nms = False
self.augment = False
self.visualize = False
# 加载模型
self.model = DetectMultiBackend(self.weights, device=self.device, dnn=self.dnn,
data=self.data, fp16=self.half)
self.stride, self.names, self.pt = self.model.stride, self.model.names, self.model.pt
self.imgsz = check_img_size(imgsz, s=self.stride)
# 模型预热
self.model.warmup(imgsz=(1, 3, *self.imgsz))
def detect(self, img):
"""
对输入图像进行物体检测
:param img: 输入图像 (numpy array)
:return: 检测结果列表每个元素为 [x1, y1, x2, y2, conf]
"""
# 图像预处理
im0 = img # 原图
# Resize 图像到网络输入尺寸
# im = cv2.resize(im0, (self.imgsz[1], self.imgsz[0]), interpolation=cv2.INTER_LINEAR)
im = letterbox(im0, self.imgsz)[0]
# 转换格式HWC to CHW, BGR to RGB
im = im.transpose((2, 0, 1))[::-1]
im = np.ascontiguousarray(im)
t1 = time_sync()
im = torch.from_numpy(im).to(self.device)
im = im.half() if self.half else im.float()
im /= 255 # 归一化到 0.0 - 1.0
if len(im.shape) == 3:
im = im[None] # 扩展 batch 维度
# 模型推理
pred = self.model(im, augment=self.augment, visualize=self.visualize)
# 非极大值抑制 (NMS)
pred = non_max_suppression(pred, self.conf_thres, self.iou_thres,
self.classes, self.agnostic_nms, max_det=self.max_det)
detections = []
for i, det in enumerate(pred):
if len(det):
det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round()
for *xyxy, conf, cls in reversed(det):
xyxy = [round(x.item()) for x in xyxy] # 将坐标四舍五入为整数
detections.append([int(cls)] + xyxy + [float(conf)])
return detections
if __name__ == '__main__':
detector = ObjectDetector(weights='./best_3class.pt', imgsz=(1024, 1280))
img = cv2.imread('./20250703_bar_017.png')
if img is None:
print("Error: 无法加载图像")
else:
detections = detector.detect(img)
for det in detections:
x1, y1, x2, y2, conf = det # 每个 det 是 [x1, y1, x2, y2, conf]
# 画框BGR颜色粗细2
cv2.rectangle(img, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2)
# 添加置信度文本
label = f"{conf:.2f}"
cv2.putText(img, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX,
fontScale=0.5, color=(0, 255, 0), thickness=1)
print(f"检测结果: {det}")
# 保存图像
output_path = './output.jpg'
cv2.imwrite(output_path, img)
print(f"检测结果图像已保存为: {output_path}")