commit 0c1a2f8eee14179ef2b8ec207c3f8199817149b9 Author: wangchongwu <759291707@qq.com> Date: Fri Aug 15 18:30:56 2025 +0800 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..615abcb --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +__pycache__ +*.prc + +*.MP4 \ No newline at end of file diff --git a/3rdparty/ultralytics-YOLO-DeepSort-ByteTrack-PyQt-GUI b/3rdparty/ultralytics-YOLO-DeepSort-ByteTrack-PyQt-GUI new file mode 160000 index 0000000..03b11dd --- /dev/null +++ b/3rdparty/ultralytics-YOLO-DeepSort-ByteTrack-PyQt-GUI @@ -0,0 +1 @@ +Subproject commit 03b11dd6b5d13a8d4ab069ffeb2a28d4d939575d diff --git a/ArithControl/Arith_EOController.py b/ArithControl/Arith_EOController.py new file mode 100644 index 0000000..e5385ef --- /dev/null +++ b/ArithControl/Arith_EOController.py @@ -0,0 +1,83 @@ +import numpy as np +import cv2 +import time +import torch +from . import detect_api as detect +from . import deep_sort as ds + + +# # 系统状态定义 +# GLB_STATUS = { +# "GLB_STATUS_UNKOWN": 0, +# "GLB_STATUS_WAIT": 1, +# "GLB_STATUS_SEARCH": 2, +# "GLB_STATUS_TRACK": 3, +# "GLB_STATUS_SCAN": 4, +# "GLB_STATUS_LOST": 5, +# "GLB_STATUS_FSCAN": 6, +# "GLB_STATUS_LOCK": 7, +# "GLB_STATUS_LOCKFAILED": 8, +# "GLB_STATUS_MOTRACK": 9, +# "GLB_STATUS_AIM": 10, +# } + +# # 锁定模式 +# LockMode = { +# "LOCK_NONE": 0, +# "LOCK_AUTO": 10, +# "LOCK_POINT": 21, +# "LOCK_RECT": 22, +# "LOCK_UNLOCK": 3, +# "LOCK_ID": 4, +# } + + + +class EOController: + def __init__(self): + # 初始化检测器 + self.detector = detect.ObjectDetector( + weights='./ArithControl/model/best_3class.pt', + imgsz=(1024, 1280)) + # deepsort + self.ds = ds.DeepSort( + './ArithControl/deep_sort/deep/checkpoint/ckpt.t7', + enable_reid=True + ) + + def run(self,frame): + detections = self.detector.detect(frame) + + if len(detections) > 0: + # 转换为 torch.Tensor + det_tensor = torch.tensor(detections, dtype=torch.float32) + bbox_xywh = torch.column_stack([ + (det_tensor[:, 1:3] + det_tensor[:, 3:5])/2, # x1, y1 + det_tensor[:, 3:5] - det_tensor[:, 1:3] # w, h + ]) + confidences = det_tensor[:, 5] # conf + clss = det_tensor[:, 0].long() # cls + + # 调用一次 update 处理所有目标 + # 注意:deepsort输入使用中心宽高 + pipe_out = self.ds.update(bbox_xywh, confidences, clss, frame) + + # 返回管道 + return pipe_out + + + # 解析指令 + def parse_cmd(self, cmd): + if cmd.mode == 'LockMode': + self.ds.lock_point(cmd.point) + elif cmd.mode == 'Lock_RECT': + self.ds.lock_rect(cmd.rect) + elif cmd.mode == 'Lock_ID': + self.ds.lock_id(cmd.id) + elif cmd.mode == 'Lock_UNLOCK': + self.ds.lock_unlock() + + + + + diff --git a/ArithControl/deep_sort/README.md b/ArithControl/deep_sort/README.md new file mode 100644 index 0000000..e89c9b3 --- /dev/null +++ b/ArithControl/deep_sort/README.md @@ -0,0 +1,3 @@ +# Deep Sort + +This is the implemention of deep sort with pytorch. \ No newline at end of file diff --git a/ArithControl/deep_sort/__init__.py b/ArithControl/deep_sort/__init__.py new file mode 100644 index 0000000..5fe5d0f --- /dev/null +++ b/ArithControl/deep_sort/__init__.py @@ -0,0 +1,21 @@ +from .deep_sort import DeepSort + + +__all__ = ['DeepSort', 'build_tracker'] + + +def build_tracker(cfg, use_cuda): + return DeepSort(cfg.DEEPSORT.REID_CKPT, + max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE, + nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE, + max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET, use_cuda=use_cuda) + + + + + + + + + + diff --git a/ArithControl/deep_sort/deep/__init__.py b/ArithControl/deep_sort/deep/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ArithControl/deep_sort/deep/checkpoint/.gitkeep b/ArithControl/deep_sort/deep/checkpoint/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/ArithControl/deep_sort/deep/checkpoint/ckpt.t7 b/ArithControl/deep_sort/deep/checkpoint/ckpt.t7 new file mode 100644 index 0000000..d253aae Binary files /dev/null and b/ArithControl/deep_sort/deep/checkpoint/ckpt.t7 differ diff --git a/ArithControl/deep_sort/deep/evaluate.py b/ArithControl/deep_sort/deep/evaluate.py new file mode 100644 index 0000000..31c40a4 --- /dev/null +++ b/ArithControl/deep_sort/deep/evaluate.py @@ -0,0 +1,15 @@ +import torch + +features = torch.load("features.pth") +qf = features["qf"] +ql = features["ql"] +gf = features["gf"] +gl = features["gl"] + +scores = qf.mm(gf.t()) +res = scores.topk(5, dim=1)[1][:,0] +top1correct = gl[res].eq(ql).sum().item() + +print("Acc top1:{:.3f}".format(top1correct/ql.size(0))) + + diff --git a/ArithControl/deep_sort/deep/feature_extractor.py b/ArithControl/deep_sort/deep/feature_extractor.py new file mode 100644 index 0000000..0443e37 --- /dev/null +++ b/ArithControl/deep_sort/deep/feature_extractor.py @@ -0,0 +1,55 @@ +import torch +import torchvision.transforms as transforms +import numpy as np +import cv2 +import logging + +from .model import Net + +class Extractor(object): + def __init__(self, model_path, use_cuda=True): + self.net = Net(reid=True) + self.device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu" + state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)['net_dict'] + self.net.load_state_dict(state_dict) + logger = logging.getLogger("root.tracker") + logger.info("Loading weights from {}... Done!".format(model_path)) + self.net.to(self.device) + self.size = (64, 128) + self.norm = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ]) + + + + def _preprocess(self, im_crops): + """ + TODO: + 1. to float with scale from 0 to 1 + 2. resize to (64, 128) as Market1501 dataset did + 3. concatenate to a numpy array + 3. to torch Tensor + 4. normalize + """ + def _resize(im, size): + return cv2.resize(im.astype(np.float32)/255., size) + + im_batch = torch.cat([self.norm(_resize(im, self.size)).unsqueeze(0) for im in im_crops], dim=0).float() + return im_batch + + + def __call__(self, im_crops): + im_batch = self._preprocess(im_crops) + with torch.no_grad(): + im_batch = im_batch.to(self.device) + features = self.net(im_batch) + return features.cpu().numpy() + + +if __name__ == '__main__': + img = cv2.imread("demo.jpg")[:,:,(2,1,0)] + extr = Extractor("checkpoint/ckpt.t7") + feature = extr(img) + print(feature.shape) + diff --git a/ArithControl/deep_sort/deep/model.py b/ArithControl/deep_sort/deep/model.py new file mode 100644 index 0000000..97e8754 --- /dev/null +++ b/ArithControl/deep_sort/deep/model.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class BasicBlock(nn.Module): + def __init__(self, c_in, c_out,is_downsample=False): + super(BasicBlock,self).__init__() + self.is_downsample = is_downsample + if is_downsample: + self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False) + else: + self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(c_out) + self.relu = nn.ReLU(True) + self.conv2 = nn.Conv2d(c_out,c_out,3,stride=1,padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(c_out) + if is_downsample: + self.downsample = nn.Sequential( + nn.Conv2d(c_in, c_out, 1, stride=2, bias=False), + nn.BatchNorm2d(c_out) + ) + elif c_in != c_out: + self.downsample = nn.Sequential( + nn.Conv2d(c_in, c_out, 1, stride=1, bias=False), + nn.BatchNorm2d(c_out) + ) + self.is_downsample = True + + def forward(self,x): + y = self.conv1(x) + y = self.bn1(y) + y = self.relu(y) + y = self.conv2(y) + y = self.bn2(y) + if self.is_downsample: + x = self.downsample(x) + return F.relu(x.add(y),True) + +def make_layers(c_in,c_out,repeat_times, is_downsample=False): + blocks = [] + for i in range(repeat_times): + if i ==0: + blocks += [BasicBlock(c_in,c_out, is_downsample=is_downsample),] + else: + blocks += [BasicBlock(c_out,c_out),] + return nn.Sequential(*blocks) + +class Net(nn.Module): + def __init__(self, num_classes=751 ,reid=False): + super(Net,self).__init__() + # 3 128 64 + self.conv = nn.Sequential( + nn.Conv2d(3,64,3,stride=1,padding=1), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + # nn.Conv2d(32,32,3,stride=1,padding=1), + # nn.BatchNorm2d(32), + # nn.ReLU(inplace=True), + nn.MaxPool2d(3,2,padding=1), + ) + # 32 64 32 + self.layer1 = make_layers(64,64,2,False) + # 32 64 32 + self.layer2 = make_layers(64,128,2,True) + # 64 32 16 + self.layer3 = make_layers(128,256,2,True) + # 128 16 8 + self.layer4 = make_layers(256,512,2,True) + # 256 8 4 + self.avgpool = nn.AvgPool2d((8,4),1) + # 256 1 1 + self.reid = reid + self.classifier = nn.Sequential( + nn.Linear(512, 256), + nn.BatchNorm1d(256), + nn.ReLU(inplace=True), + nn.Dropout(), + nn.Linear(256, num_classes), + ) + + def forward(self, x): + x = self.conv(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.avgpool(x) + x = x.view(x.size(0),-1) + # B x 128 + if self.reid: + x = x.div(x.norm(p=2,dim=1,keepdim=True)) + return x + # classifier + x = self.classifier(x) + return x + + +if __name__ == '__main__': + net = Net() + x = torch.randn(4,3,128,64) + y = net(x) + import ipdb; ipdb.set_trace() + + diff --git a/ArithControl/deep_sort/deep/original_model.py b/ArithControl/deep_sort/deep/original_model.py new file mode 100644 index 0000000..72453a6 --- /dev/null +++ b/ArithControl/deep_sort/deep/original_model.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class BasicBlock(nn.Module): + def __init__(self, c_in, c_out,is_downsample=False): + super(BasicBlock,self).__init__() + self.is_downsample = is_downsample + if is_downsample: + self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=2, padding=1, bias=False) + else: + self.conv1 = nn.Conv2d(c_in, c_out, 3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(c_out) + self.relu = nn.ReLU(True) + self.conv2 = nn.Conv2d(c_out,c_out,3,stride=1,padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(c_out) + if is_downsample: + self.downsample = nn.Sequential( + nn.Conv2d(c_in, c_out, 1, stride=2, bias=False), + nn.BatchNorm2d(c_out) + ) + elif c_in != c_out: + self.downsample = nn.Sequential( + nn.Conv2d(c_in, c_out, 1, stride=1, bias=False), + nn.BatchNorm2d(c_out) + ) + self.is_downsample = True + + def forward(self,x): + y = self.conv1(x) + y = self.bn1(y) + y = self.relu(y) + y = self.conv2(y) + y = self.bn2(y) + if self.is_downsample: + x = self.downsample(x) + return F.relu(x.add(y),True) + +def make_layers(c_in,c_out,repeat_times, is_downsample=False): + blocks = [] + for i in range(repeat_times): + if i ==0: + blocks += [BasicBlock(c_in,c_out, is_downsample=is_downsample),] + else: + blocks += [BasicBlock(c_out,c_out),] + return nn.Sequential(*blocks) + +class Net(nn.Module): + def __init__(self, num_classes=625 ,reid=False): + super(Net,self).__init__() + # 3 128 64 + self.conv = nn.Sequential( + nn.Conv2d(3,32,3,stride=1,padding=1), + nn.BatchNorm2d(32), + nn.ELU(inplace=True), + nn.Conv2d(32,32,3,stride=1,padding=1), + nn.BatchNorm2d(32), + nn.ELU(inplace=True), + nn.MaxPool2d(3,2,padding=1), + ) + # 32 64 32 + self.layer1 = make_layers(32,32,2,False) + # 32 64 32 + self.layer2 = make_layers(32,64,2,True) + # 64 32 16 + self.layer3 = make_layers(64,128,2,True) + # 128 16 8 + self.dense = nn.Sequential( + nn.Dropout(p=0.6), + nn.Linear(128*16*8, 128), + nn.BatchNorm1d(128), + nn.ELU(inplace=True) + ) + # 256 1 1 + self.reid = reid + self.batch_norm = nn.BatchNorm1d(128) + self.classifier = nn.Sequential( + nn.Linear(128, num_classes), + ) + + def forward(self, x): + x = self.conv(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = x.view(x.size(0),-1) + if self.reid: + x = self.dense[0](x) + x = self.dense[1](x) + x = x.div(x.norm(p=2,dim=1,keepdim=True)) + return x + x = self.dense(x) + # B x 128 + # classifier + x = self.classifier(x) + return x + + +if __name__ == '__main__': + net = Net(reid=True) + x = torch.randn(4,3,128,64) + y = net(x) + import ipdb; ipdb.set_trace() + + diff --git a/ArithControl/deep_sort/deep/test.py b/ArithControl/deep_sort/deep/test.py new file mode 100644 index 0000000..ebd5903 --- /dev/null +++ b/ArithControl/deep_sort/deep/test.py @@ -0,0 +1,77 @@ +import torch +import torch.backends.cudnn as cudnn +import torchvision + +import argparse +import os + +from model import Net + +parser = argparse.ArgumentParser(description="Train on market1501") +parser.add_argument("--data-dir",default='data',type=str) +parser.add_argument("--no-cuda",action="store_true") +parser.add_argument("--gpu-id",default=0,type=int) +args = parser.parse_args() + +# device +device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu" +if torch.cuda.is_available() and not args.no_cuda: + cudnn.benchmark = True + +# data loader +root = args.data_dir +query_dir = os.path.join(root,"query") +gallery_dir = os.path.join(root,"gallery") +transform = torchvision.transforms.Compose([ + torchvision.transforms.Resize((128,64)), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) +]) +queryloader = torch.utils.data.DataLoader( + torchvision.datasets.ImageFolder(query_dir, transform=transform), + batch_size=64, shuffle=False +) +galleryloader = torch.utils.data.DataLoader( + torchvision.datasets.ImageFolder(gallery_dir, transform=transform), + batch_size=64, shuffle=False +) + +# net definition +net = Net(reid=True) +assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!" +print('Loading from checkpoint/ckpt.t7') +checkpoint = torch.load("./checkpoint/ckpt.t7") +net_dict = checkpoint['net_dict'] +net.load_state_dict(net_dict, strict=False) +net.eval() +net.to(device) + +# compute features +query_features = torch.tensor([]).float() +query_labels = torch.tensor([]).long() +gallery_features = torch.tensor([]).float() +gallery_labels = torch.tensor([]).long() + +with torch.no_grad(): + for idx,(inputs,labels) in enumerate(queryloader): + inputs = inputs.to(device) + features = net(inputs).cpu() + query_features = torch.cat((query_features, features), dim=0) + query_labels = torch.cat((query_labels, labels)) + + for idx,(inputs,labels) in enumerate(galleryloader): + inputs = inputs.to(device) + features = net(inputs).cpu() + gallery_features = torch.cat((gallery_features, features), dim=0) + gallery_labels = torch.cat((gallery_labels, labels)) + +gallery_labels -= 2 + +# save features +features = { + "qf": query_features, + "ql": query_labels, + "gf": gallery_features, + "gl": gallery_labels +} +torch.save(features,"features.pth") \ No newline at end of file diff --git a/ArithControl/deep_sort/deep/train.jpg b/ArithControl/deep_sort/deep/train.jpg new file mode 100644 index 0000000..3635a61 Binary files /dev/null and b/ArithControl/deep_sort/deep/train.jpg differ diff --git a/ArithControl/deep_sort/deep/train.py b/ArithControl/deep_sort/deep/train.py new file mode 100644 index 0000000..a931763 --- /dev/null +++ b/ArithControl/deep_sort/deep/train.py @@ -0,0 +1,189 @@ +import argparse +import os +import time + +import numpy as np +import matplotlib.pyplot as plt +import torch +import torch.backends.cudnn as cudnn +import torchvision + +from model import Net + +parser = argparse.ArgumentParser(description="Train on market1501") +parser.add_argument("--data-dir",default='data',type=str) +parser.add_argument("--no-cuda",action="store_true") +parser.add_argument("--gpu-id",default=0,type=int) +parser.add_argument("--lr",default=0.1, type=float) +parser.add_argument("--interval",'-i',default=20,type=int) +parser.add_argument('--resume', '-r',action='store_true') +args = parser.parse_args() + +# device +device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu" +if torch.cuda.is_available() and not args.no_cuda: + cudnn.benchmark = True + +# data loading +root = args.data_dir +train_dir = os.path.join(root,"train") +test_dir = os.path.join(root,"test") +transform_train = torchvision.transforms.Compose([ + torchvision.transforms.RandomCrop((128,64),padding=4), + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) +]) +transform_test = torchvision.transforms.Compose([ + torchvision.transforms.Resize((128,64)), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) +]) +trainloader = torch.utils.data.DataLoader( + torchvision.datasets.ImageFolder(train_dir, transform=transform_train), + batch_size=64,shuffle=True +) +testloader = torch.utils.data.DataLoader( + torchvision.datasets.ImageFolder(test_dir, transform=transform_test), + batch_size=64,shuffle=True +) +num_classes = max(len(trainloader.dataset.classes), len(testloader.dataset.classes)) + +# net definition +start_epoch = 0 +net = Net(num_classes=num_classes) +if args.resume: + assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!" + print('Loading from checkpoint/ckpt.t7') + checkpoint = torch.load("./checkpoint/ckpt.t7") + # import ipdb; ipdb.set_trace() + net_dict = checkpoint['net_dict'] + net.load_state_dict(net_dict) + best_acc = checkpoint['acc'] + start_epoch = checkpoint['epoch'] +net.to(device) + +# loss and optimizer +criterion = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.SGD(net.parameters(), args.lr, momentum=0.9, weight_decay=5e-4) +best_acc = 0. + +# train function for each epoch +def train(epoch): + print("\nEpoch : %d"%(epoch+1)) + net.train() + training_loss = 0. + train_loss = 0. + correct = 0 + total = 0 + interval = args.interval + start = time.time() + for idx, (inputs, labels) in enumerate(trainloader): + # forward + inputs,labels = inputs.to(device),labels.to(device) + outputs = net(inputs) + loss = criterion(outputs, labels) + + # backward + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # accumurating + training_loss += loss.item() + train_loss += loss.item() + correct += outputs.max(dim=1)[1].eq(labels).sum().item() + total += labels.size(0) + + # print + if (idx+1)%interval == 0: + end = time.time() + print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format( + 100.*(idx+1)/len(trainloader), end-start, training_loss/interval, correct, total, 100.*correct/total + )) + training_loss = 0. + start = time.time() + + return train_loss/len(trainloader), 1.- correct/total + +def test(epoch): + global best_acc + net.eval() + test_loss = 0. + correct = 0 + total = 0 + start = time.time() + with torch.no_grad(): + for idx, (inputs, labels) in enumerate(testloader): + inputs, labels = inputs.to(device), labels.to(device) + outputs = net(inputs) + loss = criterion(outputs, labels) + + test_loss += loss.item() + correct += outputs.max(dim=1)[1].eq(labels).sum().item() + total += labels.size(0) + + print("Testing ...") + end = time.time() + print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format( + 100.*(idx+1)/len(testloader), end-start, test_loss/len(testloader), correct, total, 100.*correct/total + )) + + # saving checkpoint + acc = 100.*correct/total + if acc > best_acc: + best_acc = acc + print("Saving parameters to checkpoint/ckpt.t7") + checkpoint = { + 'net_dict':net.state_dict(), + 'acc':acc, + 'epoch':epoch, + } + if not os.path.isdir('checkpoint'): + os.mkdir('checkpoint') + torch.save(checkpoint, './checkpoint/ckpt.t7') + + return test_loss/len(testloader), 1.- correct/total + +# plot figure +x_epoch = [] +record = {'train_loss':[], 'train_err':[], 'test_loss':[], 'test_err':[]} +fig = plt.figure() +ax0 = fig.add_subplot(121, title="loss") +ax1 = fig.add_subplot(122, title="top1err") +def draw_curve(epoch, train_loss, train_err, test_loss, test_err): + global record + record['train_loss'].append(train_loss) + record['train_err'].append(train_err) + record['test_loss'].append(test_loss) + record['test_err'].append(test_err) + + x_epoch.append(epoch) + ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train') + ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val') + ax1.plot(x_epoch, record['train_err'], 'bo-', label='train') + ax1.plot(x_epoch, record['test_err'], 'ro-', label='val') + if epoch == 0: + ax0.legend() + ax1.legend() + fig.savefig("train.jpg") + +# lr decay +def lr_decay(): + global optimizer + for params in optimizer.param_groups: + params['lr'] *= 0.1 + lr = params['lr'] + print("Learning rate adjusted to {}".format(lr)) + +def main(): + for epoch in range(start_epoch, start_epoch+40): + train_loss, train_err = train(epoch) + test_loss, test_err = test(epoch) + draw_curve(epoch, train_loss, train_err, test_loss, test_err) + if (epoch+1)%20==0: + lr_decay() + + +if __name__ == '__main__': + main() diff --git a/ArithControl/deep_sort/deep_sort.py b/ArithControl/deep_sort/deep_sort.py new file mode 100644 index 0000000..8092f42 --- /dev/null +++ b/ArithControl/deep_sort/deep_sort.py @@ -0,0 +1,120 @@ +import numpy as np +import torch + +from .deep.feature_extractor import Extractor +from .sort.nn_matching import NearestNeighborDistanceMetric +from .sort.preprocessing import non_max_suppression +from .sort.detection import Detection +from .sort.tracker import Tracker + + +__all__ = ['DeepSort'] + + +class DeepSort(object): + def __init__(self, model_path, max_dist=0.2, min_confidence=0.3, nms_max_overlap=1.0, max_iou_distance=0.7, + max_age=70, n_init=3, nn_budget=100, use_cuda=True,enable_reid=True): + self.min_confidence = min_confidence + self.nms_max_overlap = nms_max_overlap + self.enable_reid = enable_reid + if self.enable_reid: + self.extractor = Extractor(model_path, use_cuda=use_cuda) + else: + self.extractor = None + + max_cosine_distance = max_dist + nn_budget = 100 + metric = NearestNeighborDistanceMetric( + "cosine", max_cosine_distance, nn_budget) + self.tracker = Tracker( + metric, max_iou_distance=max_iou_distance, max_age=max_age, n_init=n_init) + + def update(self, bbox_xywh, confidences, clss, ori_img): + self.height, self.width = ori_img.shape[:2] + # generate detections + #features = self._get_features(bbox_xywh, ori_img) + # 根据开关决定是否提取特征 + if self.enable_reid: + features = self._get_features(bbox_xywh, ori_img) + else: + # 创建空特征或随机特征(用于占位) + features = [np.zeros(512) for _ in range(len(bbox_xywh))] # 假设特征维度是128 + + bbox_tlwh = self._xywh_to_tlwh(bbox_xywh) + detections = [Detection(bbox_tlwh[i], clss[i], conf, features[i]) for i, conf in enumerate( + confidences) if conf > self.min_confidence] + # update tracker + self.tracker.predict() + self.tracker.update(detections) + + # output bbox identities + outputs = [] + for track in self.tracker.tracks: + if not track.is_confirmed() or track.time_since_update > 1: + continue + box = track.to_tlwh() + x1, y1, x2, y2 = self._tlwh_to_xyxy(box) + outputs.append((x1, y1, x2, y2, track.cls_, track.track_id)) + return outputs + + @staticmethod + def _xywh_to_tlwh(bbox_xywh): + if isinstance(bbox_xywh, np.ndarray): + bbox_tlwh = bbox_xywh.copy() + elif isinstance(bbox_xywh, torch.Tensor): + bbox_tlwh = bbox_xywh.clone() + if bbox_tlwh.size(0) > 0: + bbox_tlwh[:, 0] = bbox_xywh[:, 0] - bbox_xywh[:, 2]/2. + bbox_tlwh[:, 1] = bbox_xywh[:, 1] - bbox_xywh[:, 3]/2. + return bbox_tlwh + + def _xywh_to_xyxy(self, bbox_xywh): + x, y, w, h = bbox_xywh + x1 = max(int(x-w/2), 0) + x2 = min(int(x+w/2), self.width-1) + y1 = max(int(y-h/2), 0) + y2 = min(int(y+h/2), self.height-1) + return x1, y1, x2, y2 + + def _tlwh_to_xyxy(self, bbox_tlwh): + """ + TODO: + Convert bbox from xtl_ytl_w_h to xc_yc_w_h + Thanks JieChen91@github.com for reporting this bug! + """ + x, y, w, h = bbox_tlwh + x1 = max(int(x), 0) + x2 = min(int(x+w), self.width-1) + y1 = max(int(y), 0) + y2 = min(int(y+h), self.height-1) + return x1, y1, x2, y2 + + def _xyxy_to_tlwh(self, bbox_xyxy): + x1, y1, x2, y2 = bbox_xyxy + + t = x1 + l = y1 + w = int(x2-x1) + h = int(y2-y1) + return t, l, w, h + + def _get_features(self, bbox_xywh, ori_img): + im_crops = [] + for box in bbox_xywh: + x1, y1, x2, y2 = self._xywh_to_xyxy(box) + im = ori_img[y1:y2, x1:x2] + im_crops.append(im) + if im_crops: + features = self.extractor(im_crops) + else: + features = np.array([]) + return features + + def lock_point(self,point): + pass + + def lock_rect(self,rect): + pass + + def lock_id(self,id): + pass \ No newline at end of file diff --git a/ArithControl/deep_sort/sort/__init__.py b/ArithControl/deep_sort/sort/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ArithControl/deep_sort/sort/detection.py b/ArithControl/deep_sort/sort/detection.py new file mode 100644 index 0000000..ec306db --- /dev/null +++ b/ArithControl/deep_sort/sort/detection.py @@ -0,0 +1,28 @@ +# vim: expandtab:ts=4:sw=4 +import numpy as np + + +class Detection(object): + + def __init__(self, tlwh, cls_, confidence, feature): + self.tlwh = np.asarray(tlwh, dtype=np.float32) + self.cls_ = cls_ + self.confidence = float(confidence) + self.feature = np.asarray(feature, dtype=np.float32) + + def to_tlbr(self): + """Convert bounding box to format `(min x, min y, max x, max y)`, i.e., + `(top left, bottom right)`. + """ + ret = self.tlwh.copy() + ret[2:] += ret[:2] + return ret + + def to_xyah(self): + """Convert bounding box to format `(center x, center y, aspect ratio, + height)`, where the aspect ratio is `width / height`. + """ + ret = self.tlwh.copy() + ret[:2] += ret[2:] / 2 + ret[2] /= ret[3] + return ret diff --git a/ArithControl/deep_sort/sort/iou_matching.py b/ArithControl/deep_sort/sort/iou_matching.py new file mode 100644 index 0000000..b06418a --- /dev/null +++ b/ArithControl/deep_sort/sort/iou_matching.py @@ -0,0 +1,82 @@ +# vim: expandtab:ts=4:sw=4 +from __future__ import absolute_import +import numpy as np +from . import linear_assignment + + +def iou(bbox, candidates): + """Computer intersection over union. + + Parameters + ---------- + bbox : ndarray + A bounding box in format `(top left x, top left y, width, height)`. + candidates : ndarray + A matrix of candidate bounding boxes (one per row) in the same format + as `bbox`. + + Returns + ------- + ndarray + The intersection over union in [0, 1] between the `bbox` and each + candidate. A higher score means a larger fraction of the `bbox` is + occluded by the candidate. + + """ + bbox_tl, bbox_br = bbox[:2], bbox[:2] + bbox[2:] + candidates_tl = candidates[:, :2] + candidates_br = candidates[:, :2] + candidates[:, 2:] + + tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis], + np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]] + br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis], + np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]] + wh = np.maximum(0., br - tl) + + area_intersection = wh.prod(axis=1) + area_bbox = bbox[2:].prod() + area_candidates = candidates[:, 2:].prod(axis=1) + return area_intersection / (area_bbox + area_candidates - area_intersection) + + +def iou_cost(tracks, detections, track_indices=None, + detection_indices=None): + """An intersection over union distance metric. + + Parameters + ---------- + tracks : List[deep_sort.track.Track] + A list of tracks. + detections : List[deep_sort.detection.Detection] + A list of detections. + track_indices : Optional[List[int]] + A list of indices to tracks that should be matched. Defaults to + all `tracks`. + detection_indices : Optional[List[int]] + A list of indices to detections that should be matched. Defaults + to all `detections`. + + Returns + ------- + ndarray + Returns a cost matrix of shape + len(track_indices), len(detection_indices) where entry (i, j) is + `1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`. + + """ + if track_indices is None: + track_indices = np.arange(len(tracks)) + if detection_indices is None: + detection_indices = np.arange(len(detections)) + + cost_matrix = np.zeros((len(track_indices), len(detection_indices))) + for row, track_idx in enumerate(track_indices): + if tracks[track_idx].time_since_update > 1: + cost_matrix[row, :] = linear_assignment.INFTY_COST + continue + + bbox = tracks[track_idx].to_tlwh() + candidates = np.asarray([detections[i].tlwh for i in detection_indices]) + cost_matrix[row, :] = 1. - iou(bbox, candidates) + + return cost_matrix diff --git a/ArithControl/deep_sort/sort/kalman_filter.py b/ArithControl/deep_sort/sort/kalman_filter.py new file mode 100644 index 0000000..787a76e --- /dev/null +++ b/ArithControl/deep_sort/sort/kalman_filter.py @@ -0,0 +1,229 @@ +# vim: expandtab:ts=4:sw=4 +import numpy as np +import scipy.linalg + + +""" +Table for the 0.95 quantile of the chi-square distribution with N degrees of +freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv +function and used as Mahalanobis gating threshold. +""" +chi2inv95 = { + 1: 3.8415, + 2: 5.9915, + 3: 7.8147, + 4: 9.4877, + 5: 11.070, + 6: 12.592, + 7: 14.067, + 8: 15.507, + 9: 16.919} + + +class KalmanFilter(object): + """ + A simple Kalman filter for tracking bounding boxes in image space. + + The 8-dimensional state space + + x, y, a, h, vx, vy, va, vh + + contains the bounding box center position (x, y), aspect ratio a, height h, + and their respective velocities. + + Object motion follows a constant velocity model. The bounding box location + (x, y, a, h) is taken as direct observation of the state space (linear + observation model). + + """ + + def __init__(self): + ndim, dt = 4, 1. + + # Create Kalman filter model matrices. + self._motion_mat = np.eye(2 * ndim, 2 * ndim) + for i in range(ndim): + self._motion_mat[i, ndim + i] = dt + self._update_mat = np.eye(ndim, 2 * ndim) + + # Motion and observation uncertainty are chosen relative to the current + # state estimate. These weights control the amount of uncertainty in + # the model. This is a bit hacky. + self._std_weight_position = 1. / 20 + self._std_weight_velocity = 1. / 160 + + def initiate(self, measurement): + """Create track from unassociated measurement. + + Parameters + ---------- + measurement : ndarray + Bounding box coordinates (x, y, a, h) with center position (x, y), + aspect ratio a, and height h. + + Returns + ------- + (ndarray, ndarray) + Returns the mean vector (8 dimensional) and covariance matrix (8x8 + dimensional) of the new track. Unobserved velocities are initialized + to 0 mean. + + """ + mean_pos = measurement + mean_vel = np.zeros_like(mean_pos) + mean = np.r_[mean_pos, mean_vel] + + std = [ + 2 * self._std_weight_position * measurement[3], + 2 * self._std_weight_position * measurement[3], + 1e-2, + 2 * self._std_weight_position * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 10 * self._std_weight_velocity * measurement[3], + 1e-5, + 10 * self._std_weight_velocity * measurement[3]] + covariance = np.diag(np.square(std)) + return mean, covariance + + def predict(self, mean, covariance): + """Run Kalman filter prediction step. + + Parameters + ---------- + mean : ndarray + The 8 dimensional mean vector of the object state at the previous + time step. + covariance : ndarray + The 8x8 dimensional covariance matrix of the object state at the + previous time step. + + Returns + ------- + (ndarray, ndarray) + Returns the mean vector and covariance matrix of the predicted + state. Unobserved velocities are initialized to 0 mean. + + """ + std_pos = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-2, + self._std_weight_position * mean[3]] + std_vel = [ + self._std_weight_velocity * mean[3], + self._std_weight_velocity * mean[3], + 1e-5, + self._std_weight_velocity * mean[3]] + motion_cov = np.diag(np.square(np.r_[std_pos, std_vel])) + + mean = np.dot(self._motion_mat, mean) + covariance = np.linalg.multi_dot(( + self._motion_mat, covariance, self._motion_mat.T)) + motion_cov + + return mean, covariance + + def project(self, mean, covariance): + """Project state distribution to measurement space. + + Parameters + ---------- + mean : ndarray + The state's mean vector (8 dimensional array). + covariance : ndarray + The state's covariance matrix (8x8 dimensional). + + Returns + ------- + (ndarray, ndarray) + Returns the projected mean and covariance matrix of the given state + estimate. + + """ + std = [ + self._std_weight_position * mean[3], + self._std_weight_position * mean[3], + 1e-1, + self._std_weight_position * mean[3]] + innovation_cov = np.diag(np.square(std)) + + mean = np.dot(self._update_mat, mean) + covariance = np.linalg.multi_dot(( + self._update_mat, covariance, self._update_mat.T)) + return mean, covariance + innovation_cov + + def update(self, mean, covariance, measurement): + """Run Kalman filter correction step. + + Parameters + ---------- + mean : ndarray + The predicted state's mean vector (8 dimensional). + covariance : ndarray + The state's covariance matrix (8x8 dimensional). + measurement : ndarray + The 4 dimensional measurement vector (x, y, a, h), where (x, y) + is the center position, a the aspect ratio, and h the height of the + bounding box. + + Returns + ------- + (ndarray, ndarray) + Returns the measurement-corrected state distribution. + + """ + projected_mean, projected_cov = self.project(mean, covariance) + + chol_factor, lower = scipy.linalg.cho_factor( + projected_cov, lower=True, check_finite=False) + kalman_gain = scipy.linalg.cho_solve( + (chol_factor, lower), np.dot(covariance, self._update_mat.T).T, + check_finite=False).T + innovation = measurement - projected_mean + + new_mean = mean + np.dot(innovation, kalman_gain.T) + new_covariance = covariance - np.linalg.multi_dot(( + kalman_gain, projected_cov, kalman_gain.T)) + return new_mean, new_covariance + + def gating_distance(self, mean, covariance, measurements, + only_position=False): + """Compute gating distance between state distribution and measurements. + + A suitable distance threshold can be obtained from `chi2inv95`. If + `only_position` is False, the chi-square distribution has 4 degrees of + freedom, otherwise 2. + + Parameters + ---------- + mean : ndarray + Mean vector over the state distribution (8 dimensional). + covariance : ndarray + Covariance of the state distribution (8x8 dimensional). + measurements : ndarray + An Nx4 dimensional matrix of N measurements, each in + format (x, y, a, h) where (x, y) is the bounding box center + position, a the aspect ratio, and h the height. + only_position : Optional[bool] + If True, distance computation is done with respect to the bounding + box center position only. + + Returns + ------- + ndarray + Returns an array of length N, where the i-th element contains the + squared Mahalanobis distance between (mean, covariance) and + `measurements[i]`. + + """ + mean, covariance = self.project(mean, covariance) + if only_position: + mean, covariance = mean[:2], covariance[:2, :2] + measurements = measurements[:, :2] + + cholesky_factor = np.linalg.cholesky(covariance) + d = measurements - mean + z = scipy.linalg.solve_triangular( + cholesky_factor, d.T, lower=True, check_finite=False, + overwrite_b=True) + squared_maha = np.sum(z * z, axis=0) + return squared_maha diff --git a/ArithControl/deep_sort/sort/linear_assignment.py b/ArithControl/deep_sort/sort/linear_assignment.py new file mode 100644 index 0000000..2006230 --- /dev/null +++ b/ArithControl/deep_sort/sort/linear_assignment.py @@ -0,0 +1,159 @@ +# vim: expandtab:ts=4:sw=4 +from __future__ import absolute_import +import numpy as np +# from sklearn.utils.linear_assignment_ import linear_assignment +from scipy.optimize import linear_sum_assignment as linear_assignment +from . import kalman_filter + + +INFTY_COST = 1e+5 + + +def min_cost_matching( + distance_metric, max_distance, tracks, detections, track_indices=None, + detection_indices=None): + if track_indices is None: + track_indices = np.arange(len(tracks)) + if detection_indices is None: + detection_indices = np.arange(len(detections)) + + if len(detection_indices) == 0 or len(track_indices) == 0: + return [], track_indices, detection_indices # Nothing to match. + + cost_matrix = distance_metric( + tracks, detections, track_indices, detection_indices) + cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5 + + row_indices, col_indices = linear_assignment(cost_matrix) + + matches, unmatched_tracks, unmatched_detections = [], [], [] + for col, detection_idx in enumerate(detection_indices): + if col not in col_indices: + unmatched_detections.append(detection_idx) + for row, track_idx in enumerate(track_indices): + if row not in row_indices: + unmatched_tracks.append(track_idx) + for row, col in zip(row_indices, col_indices): + track_idx = track_indices[row] + detection_idx = detection_indices[col] + if cost_matrix[row, col] > max_distance: + unmatched_tracks.append(track_idx) + unmatched_detections.append(detection_idx) + else: + matches.append((track_idx, detection_idx)) + return matches, unmatched_tracks, unmatched_detections + + +def matching_cascade( + distance_metric, max_distance, cascade_depth, tracks, detections, + track_indices=None, detection_indices=None): + """Run matching cascade. + + Parameters + ---------- + distance_metric : Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray + The distance metric is given a list of tracks and detections as well as + a list of N track indices and M detection indices. The metric should + return the NxM dimensional cost matrix, where element (i, j) is the + association cost between the i-th track in the given track indices and + the j-th detection in the given detection indices. + max_distance : float + Gating threshold. Associations with cost larger than this value are + disregarded. + cascade_depth: int + The cascade depth, should be se to the maximum track age. + tracks : List[track.Track] + A list of predicted tracks at the current time step. + detections : List[detection.Detection] + A list of detections at the current time step. + track_indices : Optional[List[int]] + List of track indices that maps rows in `cost_matrix` to tracks in + `tracks` (see description above). Defaults to all tracks. + detection_indices : Optional[List[int]] + List of detection indices that maps columns in `cost_matrix` to + detections in `detections` (see description above). Defaults to all + detections. + + Returns + ------- + (List[(int, int)], List[int], List[int]) + Returns a tuple with the following three entries: + * A list of matched track and detection indices. + * A list of unmatched track indices. + * A list of unmatched detection indices. + + """ + if track_indices is None: + track_indices = list(range(len(tracks))) + if detection_indices is None: + detection_indices = list(range(len(detections))) + + unmatched_detections = detection_indices + matches = [] + for level in range(cascade_depth): + if len(unmatched_detections) == 0: # No detections left + break + + track_indices_l = [ + k for k in track_indices + if tracks[k].time_since_update == 1 + level + ] + if len(track_indices_l) == 0: # Nothing to match at this level + continue + + matches_l, _, unmatched_detections = \ + min_cost_matching( + distance_metric, max_distance, tracks, detections, + track_indices_l, unmatched_detections) + matches += matches_l + unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches)) + return matches, unmatched_tracks, unmatched_detections + + +def gate_cost_matrix( + kf, cost_matrix, tracks, detections, track_indices, detection_indices, + gated_cost=INFTY_COST, only_position=False): + """Invalidate infeasible entries in cost matrix based on the state + distributions obtained by Kalman filtering. + + Parameters + ---------- + kf : The Kalman filter. + cost_matrix : ndarray + The NxM dimensional cost matrix, where N is the number of track indices + and M is the number of detection indices, such that entry (i, j) is the + association cost between `tracks[track_indices[i]]` and + `detections[detection_indices[j]]`. + tracks : List[track.Track] + A list of predicted tracks at the current time step. + detections : List[detection.Detection] + A list of detections at the current time step. + track_indices : List[int] + List of track indices that maps rows in `cost_matrix` to tracks in + `tracks` (see description above). + detection_indices : List[int] + List of detection indices that maps columns in `cost_matrix` to + detections in `detections` (see description above). + gated_cost : Optional[float] + Entries in the cost matrix corresponding to infeasible associations are + set this value. Defaults to a very large value. + only_position : Optional[bool] + If True, only the x, y position of the state distribution is considered + during gating. Defaults to False. + + Returns + ------- + ndarray + Returns the modified cost matrix. + + """ + gating_dim = 2 if only_position else 4 + gating_threshold = kalman_filter.chi2inv95[gating_dim] + measurements = np.asarray( + [detections[i].to_xyah() for i in detection_indices]) + for row, track_idx in enumerate(track_indices): + track = tracks[track_idx] + gating_distance = kf.gating_distance( + track.mean, track.covariance, measurements, only_position) + cost_matrix[row, gating_distance > gating_threshold] = gated_cost + return cost_matrix diff --git a/ArithControl/deep_sort/sort/nn_matching.py b/ArithControl/deep_sort/sort/nn_matching.py new file mode 100644 index 0000000..2e7bfea --- /dev/null +++ b/ArithControl/deep_sort/sort/nn_matching.py @@ -0,0 +1,177 @@ +# vim: expandtab:ts=4:sw=4 +import numpy as np + + +def _pdist(a, b): + """Compute pair-wise squared distance between points in `a` and `b`. + + Parameters + ---------- + a : array_like + An NxM matrix of N samples of dimensionality M. + b : array_like + An LxM matrix of L samples of dimensionality M. + + Returns + ------- + ndarray + Returns a matrix of size len(a), len(b) such that eleement (i, j) + contains the squared distance between `a[i]` and `b[j]`. + + """ + a, b = np.asarray(a), np.asarray(b) + if len(a) == 0 or len(b) == 0: + return np.zeros((len(a), len(b))) + a2, b2 = np.square(a).sum(axis=1), np.square(b).sum(axis=1) + r2 = -2. * np.dot(a, b.T) + a2[:, None] + b2[None, :] + r2 = np.clip(r2, 0., float(np.inf)) + return r2 + + +def _cosine_distance(a, b, data_is_normalized=False): + """Compute pair-wise cosine distance between points in `a` and `b`. + + Parameters + ---------- + a : array_like + An NxM matrix of N samples of dimensionality M. + b : array_like + An LxM matrix of L samples of dimensionality M. + data_is_normalized : Optional[bool] + If True, assumes rows in a and b are unit length vectors. + Otherwise, a and b are explicitly normalized to lenght 1. + + Returns + ------- + ndarray + Returns a matrix of size len(a), len(b) such that eleement (i, j) + contains the squared distance between `a[i]` and `b[j]`. + + """ + if not data_is_normalized: + a = np.asarray(a) / np.linalg.norm(a, axis=1, keepdims=True) + b = np.asarray(b) / np.linalg.norm(b, axis=1, keepdims=True) + return 1. - np.dot(a, b.T) + + +def _nn_euclidean_distance(x, y): + """ Helper function for nearest neighbor distance metric (Euclidean). + + Parameters + ---------- + x : ndarray + A matrix of N row-vectors (sample points). + y : ndarray + A matrix of M row-vectors (query points). + + Returns + ------- + ndarray + A vector of length M that contains for each entry in `y` the + smallest Euclidean distance to a sample in `x`. + + """ + distances = _pdist(x, y) + return np.maximum(0.0, distances.min(axis=0)) + + +def _nn_cosine_distance(x, y): + """ Helper function for nearest neighbor distance metric (cosine). + + Parameters + ---------- + x : ndarray + A matrix of N row-vectors (sample points). + y : ndarray + A matrix of M row-vectors (query points). + + Returns + ------- + ndarray + A vector of length M that contains for each entry in `y` the + smallest cosine distance to a sample in `x`. + + """ + distances = _cosine_distance(x, y) + return distances.min(axis=0) + + +class NearestNeighborDistanceMetric(object): + """ + A nearest neighbor distance metric that, for each target, returns + the closest distance to any sample that has been observed so far. + + Parameters + ---------- + metric : str + Either "euclidean" or "cosine". + matching_threshold: float + The matching threshold. Samples with larger distance are considered an + invalid match. + budget : Optional[int] + If not None, fix samples per class to at most this number. Removes + the oldest samples when the budget is reached. + + Attributes + ---------- + samples : Dict[int -> List[ndarray]] + A dictionary that maps from target identities to the list of samples + that have been observed so far. + + """ + + def __init__(self, metric, matching_threshold, budget=None): + + + if metric == "euclidean": + self._metric = _nn_euclidean_distance + elif metric == "cosine": + self._metric = _nn_cosine_distance + else: + raise ValueError( + "Invalid metric; must be either 'euclidean' or 'cosine'") + self.matching_threshold = matching_threshold + self.budget = budget + self.samples = {} + + def partial_fit(self, features, targets, active_targets): + """Update the distance metric with new data. + + Parameters + ---------- + features : ndarray + An NxM matrix of N features of dimensionality M. + targets : ndarray + An integer array of associated target identities. + active_targets : List[int] + A list of targets that are currently present in the scene. + + """ + for feature, target in zip(features, targets): + self.samples.setdefault(target, []).append(feature) + if self.budget is not None: + self.samples[target] = self.samples[target][-self.budget:] + self.samples = {k: self.samples[k] for k in active_targets} + + def distance(self, features, targets): + """Compute distance between features and targets. + + Parameters + ---------- + features : ndarray + An NxM matrix of N features of dimensionality M. + targets : List[int] + A list of targets to match the given `features` against. + + Returns + ------- + ndarray + Returns a cost matrix of shape len(targets), len(features), where + element (i, j) contains the closest squared distance between + `targets[i]` and `features[j]`. + + """ + cost_matrix = np.zeros((len(targets), len(features))) + for i, target in enumerate(targets): + cost_matrix[i, :] = self._metric(self.samples[target], features) + return cost_matrix diff --git a/ArithControl/deep_sort/sort/preprocessing.py b/ArithControl/deep_sort/sort/preprocessing.py new file mode 100644 index 0000000..5493b12 --- /dev/null +++ b/ArithControl/deep_sort/sort/preprocessing.py @@ -0,0 +1,73 @@ +# vim: expandtab:ts=4:sw=4 +import numpy as np +import cv2 + + +def non_max_suppression(boxes, max_bbox_overlap, scores=None): + """Suppress overlapping detections. + + Original code from [1]_ has been adapted to include confidence score. + + .. [1] http://www.pyimagesearch.com/2015/02/16/ + faster-non-maximum-suppression-python/ + + Examples + -------- + + >>> boxes = [d.roi for d in detections] + >>> scores = [d.confidence for d in detections] + >>> indices = non_max_suppression(boxes, max_bbox_overlap, scores) + >>> detections = [detections[i] for i in indices] + + Parameters + ---------- + boxes : ndarray + Array of ROIs (x, y, width, height). + max_bbox_overlap : float + ROIs that overlap more than this values are suppressed. + scores : Optional[array_like] + Detector confidence score. + + Returns + ------- + List[int] + Returns indices of detections that have survived non-maxima suppression. + + """ + if len(boxes) == 0: + return [] + + boxes = boxes.astype(np.float) + pick = [] + + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + boxes[:, 0] + y2 = boxes[:, 3] + boxes[:, 1] + + area = (x2 - x1 + 1) * (y2 - y1 + 1) + if scores is not None: + idxs = np.argsort(scores) + else: + idxs = np.argsort(y2) + + while len(idxs) > 0: + last = len(idxs) - 1 + i = idxs[last] + pick.append(i) + + xx1 = np.maximum(x1[i], x1[idxs[:last]]) + yy1 = np.maximum(y1[i], y1[idxs[:last]]) + xx2 = np.minimum(x2[i], x2[idxs[:last]]) + yy2 = np.minimum(y2[i], y2[idxs[:last]]) + + w = np.maximum(0, xx2 - xx1 + 1) + h = np.maximum(0, yy2 - yy1 + 1) + + overlap = (w * h) / area[idxs[:last]] + + idxs = np.delete( + idxs, np.concatenate( + ([last], np.where(overlap > max_bbox_overlap)[0]))) + + return pick diff --git a/ArithControl/deep_sort/sort/track.py b/ArithControl/deep_sort/sort/track.py new file mode 100644 index 0000000..1624caa --- /dev/null +++ b/ArithControl/deep_sort/sort/track.py @@ -0,0 +1,173 @@ +# vim: expandtab:ts=4:sw=4 + + +class TrackState: + """ + Enumeration type for the single target track state. Newly created tracks are + classified as `tentative` until enough evidence has been collected. Then, + the track state is changed to `confirmed`. Tracks that are no longer alive + are classified as `deleted` to mark them for removal from the set of active + tracks. + + """ + + Tentative = 1 + Confirmed = 2 + Deleted = 3 + + +class Track: + """ + A single target track with state space `(x, y, a, h)` and associated + velocities, where `(x, y)` is the center of the bounding box, `a` is the + aspect ratio and `h` is the height. + + Parameters + ---------- + mean : ndarray + Mean vector of the initial state distribution. + covariance : ndarray + Covariance matrix of the initial state distribution. + track_id : int + A unique track identifier. + n_init : int + Number of consecutive detections before the track is confirmed. The + track state is set to `Deleted` if a miss occurs within the first + `n_init` frames. + max_age : int + The maximum number of consecutive misses before the track state is + set to `Deleted`. + feature : Optional[ndarray] + Feature vector of the detection this track originates from. If not None, + this feature is added to the `features` cache. + + Attributes + ---------- + mean : ndarray + Mean vector of the initial state distribution. + covariance : ndarray + Covariance matrix of the initial state distribution. + track_id : int + A unique track identifier. + hits : int + Total number of measurement updates. + age : int + Total number of frames since first occurance. + time_since_update : int + Total number of frames since last measurement update. + state : TrackState + The current track state. + features : List[ndarray] + A cache of features. On each measurement update, the associated feature + vector is added to this list. + + """ + + def __init__(self, mean, cls_, covariance, track_id, n_init, max_age, + feature=None): + self.mean = mean + self.cls_ = cls_ + self.covariance = covariance + self.track_id = track_id + self.hits = 1 + self.age = 1 + self.time_since_update = 0 + + self.state = TrackState.Tentative + self.features = [] + if feature is not None: + self.features.append(feature) + + self._n_init = n_init + self._max_age = max_age + + def to_tlwh(self): + """Get current position in bounding box format `(top left x, top left y, + width, height)`. + + Returns + ------- + ndarray + The bounding box. + + """ + ret = self.mean[:4].copy() + ret[2] *= ret[3] + ret[:2] -= ret[2:] / 2 + return ret + + def to_tlbr(self): + """Get current position in bounding box format `(min x, miny, max x, + max y)`. + + Returns + ------- + ndarray + The bounding box. + + """ + ret = self.to_tlwh() + ret[2:] = ret[:2] + ret[2:] + return ret + + def predict(self, kf): + """Propagate the state distribution to the current time step using a + Kalman filter prediction step. + + Parameters + ---------- + kf : kalman_filter.KalmanFilter + The Kalman filter. + + """ + self.mean, self.covariance = kf.predict(self.mean, self.covariance) + self.age += 1 + self.time_since_update += 1 + + def update(self, kf, detection): + """Perform Kalman filter measurement update step and update the feature + cache. + + Parameters + ---------- + kf : kalman_filter.KalmanFilter + The Kalman filter. + detection : Detection + The associated detection. + + """ + self.mean, self.covariance = kf.update( + self.mean, self.covariance, detection.to_xyah()) + self.features.append(detection.feature) + self.cls_ = detection.cls_ + + self.hits += 1 + self.time_since_update = 0 + if self.state == TrackState.Tentative and self.hits >= self._n_init: + self.state = TrackState.Confirmed + + def mark_missed(self): + """Mark this track as missed (no association at the current time step). + """ + if self.state == TrackState.Tentative: + self.state = TrackState.Deleted + elif self.time_since_update > self._max_age: + self.state = TrackState.Deleted + + def is_tentative(self): + """Returns True if this track is tentative (unconfirmed). + """ + return self.state == TrackState.Tentative + + def is_confirmed(self): + """Returns True if this track is confirmed.""" + return self.state == TrackState.Confirmed + + def is_deleted(self): + """Returns True if this track is dead and should be deleted.""" + return self.state == TrackState.Deleted + + # 人工建航,强制确认航迹 + def set_confirmed(self): + """Set this track to confirmed.""" + self.state = TrackState.Confirmed diff --git a/ArithControl/deep_sort/sort/tracker.py b/ArithControl/deep_sort/sort/tracker.py new file mode 100644 index 0000000..6bfaaab --- /dev/null +++ b/ArithControl/deep_sort/sort/tracker.py @@ -0,0 +1,116 @@ +# vim: expandtab:ts=4:sw=4 +from __future__ import absolute_import +import numpy as np +from . import kalman_filter +from . import linear_assignment +from . import iou_matching +from .track import Track + + +class Tracker: + + def __init__(self, metric, max_iou_distance=0.7, max_age=70, n_init=3): + self.metric = metric + self.max_iou_distance = max_iou_distance + self.max_age = max_age + self.n_init = n_init + + self.kf = kalman_filter.KalmanFilter() + self.tracks = [] + self._next_id = 1 + + def predict(self): + """Propagate track state distributions one time step forward. + + This function should be called once every time step, before `update`. + """ + for track in self.tracks: + track.predict(self.kf) + + def update(self, detections): + """Perform measurement update and track management. + + Parameters + ---------- + detections : List[deep_sort.detection.Detection] + A list of detections at the current time step. + + """ + # Run matching cascade. + matches, unmatched_tracks, unmatched_detections = \ + self._match(detections) + + # Update track set. + for track_idx, detection_idx in matches: + self.tracks[track_idx].update( + self.kf, detections[detection_idx]) + for track_idx in unmatched_tracks: + self.tracks[track_idx].mark_missed() + for detection_idx in unmatched_detections: + self._initiate_track(detections[detection_idx]) + self.tracks = [t for t in self.tracks if not t.is_deleted()] + + # Update distance metric. + active_targets = [t.track_id for t in self.tracks if t.is_confirmed()] + features, targets = [], [] + for track in self.tracks: + if not track.is_confirmed(): + continue + features += track.features + targets += [track.track_id for _ in track.features] + track.features = [] + self.metric.partial_fit( + np.asarray(features), np.asarray(targets), active_targets) + + def _match(self, detections): + + def gated_metric(tracks, dets, track_indices, detection_indices): + features = np.array([dets[i].feature for i in detection_indices]) + targets = np.array([tracks[i].track_id for i in track_indices]) + cost_matrix = self.metric.distance(features, targets) + cost_matrix = linear_assignment.gate_cost_matrix( + self.kf, cost_matrix, tracks, dets, track_indices, + detection_indices) + + return cost_matrix + + # Split track set into confirmed and unconfirmed tracks. + confirmed_tracks = [ + i for i, t in enumerate(self.tracks) if t.is_confirmed()] + unconfirmed_tracks = [ + i for i, t in enumerate(self.tracks) if not t.is_confirmed()] + + # Associate confirmed tracks using appearance features. + matches_a, unmatched_tracks_a, unmatched_detections = \ + linear_assignment.matching_cascade( + gated_metric, self.metric.matching_threshold, self.max_age, + self.tracks, detections, confirmed_tracks) + + # Associate remaining tracks together with unconfirmed tracks using IOU. + iou_track_candidates = unconfirmed_tracks + [ + k for k in unmatched_tracks_a if + self.tracks[k].time_since_update == 1] + unmatched_tracks_a = [ + k for k in unmatched_tracks_a if + self.tracks[k].time_since_update != 1] + matches_b, unmatched_tracks_b, unmatched_detections = \ + linear_assignment.min_cost_matching( + iou_matching.iou_cost, self.max_iou_distance, self.tracks, + detections, iou_track_candidates, unmatched_detections) + matches = matches_a + matches_b + unmatched_tracks = list(set(unmatched_tracks_a + unmatched_tracks_b)) + return matches, unmatched_tracks, unmatched_detections + + def _initiate_track(self, detection): + mean, covariance = self.kf.initiate(detection.to_xyah()) + self.tracks.append(Track( + mean, detection.cls_, covariance, self._next_id, self.n_init, self.max_age, + detection.feature)) + self._next_id += 1 + + # 添加一个确认track,人工建航 + def add_track(self, detection): + self._initiate_track(detection) # 添加一个track + self.tracks[-1].set_confirmed() # 强制确认航迹 + + diff --git a/ArithControl/detect_api.py b/ArithControl/detect_api.py new file mode 100644 index 0000000..5f1a58c --- /dev/null +++ b/ArithControl/detect_api.py @@ -0,0 +1,136 @@ +import sys +import os +# 获取当前文件所在目录,然后添加yolov5路径(yolov5在ArithControl目录内) +current_dir = os.path.dirname(os.path.abspath(__file__)) +yolov5_path = os.path.join(current_dir, 'yolov5') +if yolov5_path not in sys.path: + sys.path.insert(0, yolov5_path) + +import numpy as np +import torch +# 直接导入具体模块 +from models.common import DetectMultiBackend +from utils.torch_utils import select_device, time_sync +from utils.general import (LOGGER, check_img_size, non_max_suppression, scale_boxes) +from utils.augmentations import letterbox +import cv2 + +import pathlib +# 强制替换 PosixPath 为 WindowsPath 或 str +pathlib.PosixPath = pathlib.WindowsPath + +class ObjectDetector: + def __init__(self, weights='yolov5s.pt', data='data/coco.yaml', + imgsz=(640, 640), conf_thres=0.25, iou_thres=0.45, + max_det=1000, device='0', half=False, dnn=False): + """ + 初始化物体检测器。 + + :param weights: 模型权重文件路径 + :param data: 数据配置文件路径 + :param imgsz: 图像输入尺寸 (height, width) + :param conf_thres: 置信度阈值 + :param iou_thres: IOU 阈值(用于 NMS) + :param max_det: 每张图像最大检测目标数 + :param device: 设备 ('cpu' 或 '0' 表示 GPU 0) + :param half: 是否使用半精度(FP16) + :param dnn: 是否使用 OpenCV DNN 后端 + """ + self.weights = weights + self.data = data + self.imgsz = imgsz + self.conf_thres = conf_thres + self.iou_thres = iou_thres + self.max_det = max_det + self.device = select_device(device) # 自动选择设备 + self.half = half + self.dnn = dnn + self.classes = None + self.agnostic_nms = False + self.augment = False + self.visualize = False + + # 加载模型 + self.model = DetectMultiBackend(self.weights, device=self.device, dnn=self.dnn, + data=self.data, fp16=self.half) + self.stride, self.names, self.pt = self.model.stride, self.model.names, self.model.pt + self.imgsz = check_img_size(imgsz, s=self.stride) + + # 模型预热 + self.model.warmup(imgsz=(1, 3, *self.imgsz)) + + + def detect(self, img): + """ + 对输入图像进行物体检测。 + + :param img: 输入图像 (numpy array) + :return: 检测结果列表,每个元素为 [x1, y1, x2, y2, conf] + """ + + + # 图像预处理 + im0 = img # 原图 + + # Resize 图像到网络输入尺寸 + # im = cv2.resize(im0, (self.imgsz[1], self.imgsz[0]), interpolation=cv2.INTER_LINEAR) + im = letterbox(im0, self.imgsz)[0] + + # 转换格式:HWC to CHW, BGR to RGB + im = im.transpose((2, 0, 1))[::-1] + im = np.ascontiguousarray(im) + t1 = time_sync() + im = torch.from_numpy(im).to(self.device) + im = im.half() if self.half else im.float() + im /= 255 # 归一化到 0.0 - 1.0 + if len(im.shape) == 3: + im = im[None] # 扩展 batch 维度 + + # 模型推理 + pred = self.model(im, augment=self.augment, visualize=self.visualize) + + + # 非极大值抑制 (NMS) + pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, + self.classes, self.agnostic_nms, max_det=self.max_det) + + detections = [] + + for i, det in enumerate(pred): + + if len(det): + + det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() + + for *xyxy, conf, cls in reversed(det): + xyxy = [round(x.item()) for x in xyxy] # 将坐标四舍五入为整数 + detections.append([int(cls)] + xyxy + [float(conf)]) + + return detections + + + +if __name__ == '__main__': + detector = ObjectDetector(weights='./best_3class.pt', imgsz=(1024, 1280)) + img = cv2.imread('./20250703_bar_017.png') + if img is None: + print("Error: 无法加载图像") + else: + detections = detector.detect(img) + for det in detections: + x1, y1, x2, y2, conf = det # 每个 det 是 [x1, y1, x2, y2, conf] + + # 画框(BGR颜色,粗细2) + cv2.rectangle(img, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2) + + # 添加置信度文本 + label = f"{conf:.2f}" + cv2.putText(img, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, + fontScale=0.5, color=(0, 255, 0), thickness=1) + + print(f"检测结果: {det}") + + # 保存图像 + output_path = './output.jpg' + cv2.imwrite(output_path, img) + print(f"检测结果图像已保存为: {output_path}") \ No newline at end of file diff --git a/ArithControl/model/best_3class.pt b/ArithControl/model/best_3class.pt new file mode 100644 index 0000000..e22bdf8 Binary files /dev/null and b/ArithControl/model/best_3class.pt differ diff --git a/ArithControl/yolov5 b/ArithControl/yolov5 new file mode 160000 index 0000000..567c664 --- /dev/null +++ b/ArithControl/yolov5 @@ -0,0 +1 @@ +Subproject commit 567c66463e943b731e08c9a9476660c13408f088 diff --git a/H30T_Reader.py b/H30T_Reader.py new file mode 100644 index 0000000..c5f1eba --- /dev/null +++ b/H30T_Reader.py @@ -0,0 +1,193 @@ +import av +import cv2 +import re +from typing import Optional, Tuple, Dict, Any +import numpy as np + +# 读取H30T视频及字幕信息 +class H30T_Reader: + """ + 视频字幕读取器库 + 支持同时读取视频帧和字幕信息 + """ + + def __init__(self): + self.container = None + self.video_streams = [] + self.subtitle_streams = [] + self.is_open = False + + def open(self, video_path: str) -> bool: + """ + 打开视频文件 + + Args: + video_path (str): 视频文件路径 + + Returns: + bool: 是否成功打开 + """ + try: + self.container = av.open(video_path) + self.container.seek(0) + + # 获取所有流 + streams = self.container.streams + self.subtitle_streams = [s for s in streams if s.type == 'subtitle'] + self.video_streams = [s for s in streams if s.type == 'video'] + + self.is_open = True + return True + + except Exception as e: + print(f"打开视频文件失败: {e}") + self.is_open = False + return False + + def read(self) -> Tuple[Optional[np.ndarray], Optional[Dict[str, Any]]]: + """ + 读取下一帧数据 + + Returns: + Tuple[Optional[np.ndarray], Optional[Dict[str, Any]]]: + (视频帧数组, 字幕信息字典) + 如果没有更多数据,返回 (None, None) + """ + if not self.is_open or self.container is None: + return None, None + + try: + # 按读取顺序处理所有数据包 + for packet in self.container.demux(): + if packet.stream.type == 'subtitle': + # 处理字幕包 + for frame in packet.decode(): + dialogue_text = frame.dialogue.decode('utf-8', errors='ignore') + + # 解析元数据的正则表达式 + metadata_pattern = r'FrameCnt: (\d+) ([\d-]+ [\d:.]+)\n\[focal_len: ([\d.]+)\] \[dzoom_ratio: ([\d.]+)\], \[latitude: ([\d.-]+)\] \[longitude: ([\d.-]+)\] \[rel_alt: ([\d.]+) abs_alt: ([\d.]+)\] \[gb_yaw: ([\d.-]+) gb_pitch: ([\d.-]+) gb_roll: ([\d.-]+)\]' + + match = re.match(metadata_pattern, dialogue_text) + if match: + frame_cnt, timestamp, focal_len, dzoom_ratio, lat, lon, rel_alt, abs_alt, yaw, pitch, roll = match.groups() + + subtitle_info = { + 'frame_cnt': int(frame_cnt), + 'timestamp': timestamp, + 'focal_len': float(focal_len), + 'dzoom_ratio': float(dzoom_ratio), + 'latitude': float(lat), + 'longitude': float(lon), + 'rel_alt': float(rel_alt), + 'abs_alt': float(abs_alt), + 'gb_yaw': float(yaw), + 'gb_pitch': float(pitch), + 'gb_roll': float(roll), + 'raw_text': dialogue_text + } + + + # 如果没有找到对应的视频帧,返回字幕和None + return subtitle_info + else: + continue + + except Exception as e: + print(f"读取数据失败: {e}") + return None, None + + return None, None + + def close(self): + """ + 关闭视频文件 + """ + if self.container is not None: + self.container.close() + self.container = None + self.is_open = False + self.video_streams = [] + self.subtitle_streams = [] + + def __enter__(self): + """支持上下文管理器""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """退出时自动关闭""" + self.close() + + def get_stream_info(self) -> Dict[str, Any]: + """ + 获取流信息 + + Returns: + Dict[str, Any]: 包含视频和字幕流信息的字典 + """ + if not self.is_open: + return {} + + info = { + 'video_streams': [], + 'subtitle_streams': [] + } + + for stream in self.video_streams: + info['video_streams'].append({ + 'index': stream.index, + 'width': stream.width, + 'height': stream.height, + 'fps': float(stream.average_rate) if stream.average_rate else None, + 'duration': float(stream.duration * stream.time_base) if stream.duration else None + }) + + for stream in self.subtitle_streams: + info['subtitle_streams'].append({ + 'index': stream.index, + 'language': getattr(stream, 'language', 'unknown') + }) + + return info + + +# 使用示例 +if __name__ == "__main__": + # 创建读取器实例 + reader = H30T_Reader() + + # 打开视频文件 + if reader.open("DJI_20250418150210_0006_S.MP4"): + print("视频文件打开成功") + + # 获取流信息 + stream_info = reader.get_stream_info() + print("流信息:", stream_info) + + # 读取数据 + frame_count = 0 + subtitle_count = 0 + + while True: + frame, subtitle = reader.read() + + if frame is None and subtitle is None: + break + + if frame is not None: + frame_count += 1 + print(f"读取视频帧 {frame_count} - 尺寸: {frame.shape}") + + # 显示视频帧(可选) + cv2.imshow('Video Frame', frame) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + + if subtitle is not None: + subtitle_count += 1 + print(f"读取字幕 {subtitle_count}: {subtitle}") + + # 关闭视频文件 + reader.close() + cv2.destroyAllWindows() + + print(f"总共读取了 {frame_count} 帧视频和 {subtitle_count} 条字幕") diff --git a/test/20250703_bar_017.png b/test/20250703_bar_017.png new file mode 100644 index 0000000..1704d3b Binary files /dev/null and b/test/20250703_bar_017.png differ diff --git a/test/README.md b/test/README.md new file mode 100644 index 0000000..ca3b2cf --- /dev/null +++ b/test/README.md @@ -0,0 +1,5 @@ +### 1、下载yolov5仓库 +git clone https://github.com/ultralytics/yolov5.git + +### 2、demo测试 +python3 detect_api.py \ No newline at end of file diff --git a/test/best_3class.pt b/test/best_3class.pt new file mode 100644 index 0000000..e22bdf8 Binary files /dev/null and b/test/best_3class.pt differ diff --git a/test/detect_api.py b/test/detect_api.py new file mode 100644 index 0000000..eb18dbb --- /dev/null +++ b/test/detect_api.py @@ -0,0 +1,136 @@ +import sys +import os +from pathlib import Path + +# 添加YOLOv5路径到Python路径 +current_dir = Path(__file__).parent +yolov5_path = current_dir / 'yolov5' +sys.path.insert(0, str(yolov5_path)) + +import numpy as np +import torch +from models.common import DetectMultiBackend +from utils.torch_utils import select_device, time_sync +from utils.general import (LOGGER, check_img_size, non_max_suppression, scale_boxes) +from utils.augmentations import letterbox +import cv2 + +import pathlib +# 强制替换 PosixPath 为 WindowsPath 或 str +pathlib.PosixPath = pathlib.WindowsPath + +class ObjectDetector: + def __init__(self, weights='yolov5s.pt', data='data/coco.yaml', + imgsz=(640, 640), conf_thres=0.25, iou_thres=0.45, + max_det=1000, device='0', half=False, dnn=False): + """ + 初始化物体检测器。 + + :param weights: 模型权重文件路径 + :param data: 数据配置文件路径 + :param imgsz: 图像输入尺寸 (height, width) + :param conf_thres: 置信度阈值 + :param iou_thres: IOU 阈值(用于 NMS) + :param max_det: 每张图像最大检测目标数 + :param device: 设备 ('cpu' 或 '0' 表示 GPU 0) + :param half: 是否使用半精度(FP16) + :param dnn: 是否使用 OpenCV DNN 后端 + """ + self.weights = weights + self.data = data + self.imgsz = imgsz + self.conf_thres = conf_thres + self.iou_thres = iou_thres + self.max_det = max_det + self.device = select_device(device) # 自动选择设备 + self.half = half + self.dnn = dnn + self.classes = None + self.agnostic_nms = False + self.augment = False + self.visualize = False + + # 加载模型 + self.model = DetectMultiBackend(self.weights, device=self.device, dnn=self.dnn, + data=self.data, fp16=self.half) + self.stride, self.names, self.pt = self.model.stride, self.model.names, self.model.pt + self.imgsz = check_img_size(imgsz, s=self.stride) + + # 模型预热 + self.model.warmup(imgsz=(1, 3, *self.imgsz)) + + + def detect(self, img): + """ + 对输入图像进行物体检测。 + + :param img: 输入图像 (numpy array) + :return: 检测结果列表,每个元素为 [x1, y1, x2, y2, conf] + """ + + + # 图像预处理 + im0 = img # 原图 + + # Resize 图像到网络输入尺寸 + # im = cv2.resize(im0, (self.imgsz[1], self.imgsz[0]), interpolation=cv2.INTER_LINEAR) + im = letterbox(im0, self.imgsz)[0] + + # 转换格式:HWC to CHW, BGR to RGB + im = im.transpose((2, 0, 1))[::-1] + im = np.ascontiguousarray(im) + t1 = time_sync() + im = torch.from_numpy(im).to(self.device) + im = im.half() if self.half else im.float() + im /= 255 # 归一化到 0.0 - 1.0 + if len(im.shape) == 3: + im = im[None] # 扩展 batch 维度 + + # 模型推理 + pred = self.model(im, augment=self.augment, visualize=self.visualize) + + + # 非极大值抑制 (NMS) + pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, + self.classes, self.agnostic_nms, max_det=self.max_det) + + detections = [] + + for i, det in enumerate(pred): + + if len(det): + + det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() + + for *xyxy, conf, cls in reversed(det): + xyxy = [round(x.item()) for x in xyxy] # 将坐标四舍五入为整数 + detections.append(xyxy + [float(conf)]) + + return detections + + + +if __name__ == '__main__': + detector = ObjectDetector(weights='./best_3class.pt', imgsz=(1024, 1280)) + img = cv2.imread('./20250703_bar_017.png') + if img is None: + print("Error: 无法加载图像") + else: + detections = detector.detect(img) + for det in detections: + x1, y1, x2, y2, conf = det # 每个 det 是 [x1, y1, x2, y2, conf] + + # 画框(BGR颜色,粗细2) + cv2.rectangle(img, (x1, y1), (x2, y2), color=(0, 255, 0), thickness=2) + + # 添加置信度文本 + label = f"{conf:.2f}" + cv2.putText(img, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, + fontScale=0.5, color=(0, 255, 0), thickness=1) + + print(f"检测结果: {det}") + + # 保存图像 + output_path = './output.jpg' + cv2.imwrite(output_path, img) + print(f"检测结果图像已保存为: {output_path}") \ No newline at end of file diff --git a/test/output.jpg b/test/output.jpg new file mode 100644 index 0000000..1019a87 Binary files /dev/null and b/test/output.jpg differ diff --git a/test/yolov5 b/test/yolov5 new file mode 160000 index 0000000..567c664 --- /dev/null +++ b/test/yolov5 @@ -0,0 +1 @@ +Subproject commit 567c66463e943b731e08c9a9476660c13408f088 diff --git a/video_player.py b/video_player.py new file mode 100644 index 0000000..8b2dd50 --- /dev/null +++ b/video_player.py @@ -0,0 +1,345 @@ +import cv2 +import numpy as np +import os +import sys + +from ArithControl import Arith_EOController as Arith + +import os +os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' + +from H30T_Reader import H30T_Reader + +class VideoPlayer: + def __init__(self, video_path): + self.video_path = video_path + self.cap = cv2.VideoCapture(video_path) + self.h30t_reader = H30T_Reader() + self.h30t_reader.open(video_path) + if not self.cap.isOpened(): + print(f"Error: Could not open video file {video_path}") + sys.exit(1) + + # 获取视频属性 + self.total_frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT)) + self.fps = self.cap.get(cv2.CAP_PROP_FPS) + self.frame_width = int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + self.frame_height = int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + + # 播放状态 + self.current_frame = 0 + self.is_playing = True + self.is_seeking = False + self.should_exit = False + + # 鼠标绘制状态 + self.is_drawing = False + self.start_point = None + self.end_point = None + self.tracking_boxes = [] # 存储所有绘制的矩形框 + + + + + # 窗口设置 + self.window_name = "Video Player - Press SPACE to pause/play, LEFT/RIGHT for frame control" + cv2.namedWindow(self.window_name, cv2.WINDOW_NORMAL) + cv2.resizeWindow(self.window_name, 800, 600) + + # 创建highgui进度条 + cv2.createTrackbar('Progress', self.window_name, 0, self.total_frames - 1, self.on_trackbar) + + # 设置键盘回调 + self.setup_keyboard_controls() + + # 设置鼠标回调 + cv2.setMouseCallback(self.window_name, self.mouse_callback) + + + self.ArithModule = Arith.EOController() + + self.arithout = [] + + def on_trackbar(self, pos): + """highgui进度条回调函数""" + if not self.is_seeking: # 避免循环更新 + self.is_seeking = True + self.current_frame = pos + self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame) + self.is_seeking = False + + def mouse_callback(self, event, x, y, flags, param): + """鼠标回调函数 - 处理点击和拖动事件""" + # 检查鼠标是否在视频区域内(排除信息栏) + if y >= self.frame_height: # 鼠标在信息栏区域,忽略 + return + + # 边界检查 + x = max(0, min(x, self.frame_width - 1)) + y = max(0, min(y, self.frame_height - 1)) + + if event == cv2.EVENT_LBUTTONDOWN: + # 鼠标按下 + self.is_drawing = True + self.start_point = (x, y) + self.end_point = (x, y) + print(f"鼠标按下: ({x}, {y})") + + elif event == cv2.EVENT_MOUSEMOVE: + if self.is_drawing: + # 计算拖动距离 + if self.start_point: + dx = x - self.start_point[0] + dy = y - self.start_point[1] + self.drag_distance = (dx * dx + dy * dy) ** 0.5 + + # 更新矩形终点 + self.end_point = (x, y) + + + elif event == cv2.EVENT_LBUTTONUP: + if self.is_drawing: + self.end_point = (x, y) + # 根据框尺寸判断点还是框 + w = self.end_point[0] - self.start_point[0] + h = self.end_point[1] - self.start_point[1] + if w > 5 and h > 5: + self.on_mouse_box([self.start_point[0], self.start_point[1], self.end_point[0] - self.start_point[0], self.end_point[1] - self.start_point[1]]) + else: + self.on_mouse_click(self.start_point) + # 清除状态 + self.is_drawing = False + self.start_point = None + self.end_point = None + + + + def on_mouse_click(self, point): + """鼠标点击事件处理""" + print(f"鼠标点击: ({point})") + + + + + + def on_mouse_box(self, box): + """鼠标框选事件处理""" + print(f"框选: ({box})") + # 调用自定义处理器 + + + + def setup_keyboard_controls(self): + """设置键盘控制映射""" + # 默认键盘映射 + self.key_handlers = { + 27: self.handle_escape, # ESC + 32: self.handle_space, # SPACE + 81: self.handle_left_arrow, # LEFT ARROW + 83: self.handle_right_arrow, # RIGHT ARROW + } + + # 支持自定义键盘映射 + self.custom_key_handlers = {} + + # 支持自定义鼠标事件处理器 + self.mouse_click_handler = None + self.mouse_drag_start_handler = None + self.mouse_drag_end_handler = None + + def add_key_handler(self, key_code, handler_func): + """添加自定义键盘处理器""" + self.custom_key_handlers[key_code] = handler_func + + def remove_key_handler(self, key_code): + """移除键盘处理器""" + if key_code in self.key_handlers: + del self.key_handlers[key_code] + if key_code in self.custom_key_handlers: + del self.custom_key_handlers[key_code] + + def set_mouse_click_handler(self, handler_func): + """设置鼠标点击事件处理器""" + self.mouse_click_handler = handler_func + + def set_mouse_drag_start_handler(self, handler_func): + """设置鼠标拖动开始事件处理器""" + self.mouse_drag_start_handler = handler_func + + def set_mouse_drag_end_handler(self, handler_func): + """设置鼠标拖动结束事件处理器""" + self.mouse_drag_end_handler = handler_func + + def handle_keyboard_input(self, key): + """处理键盘输入的回调函数""" + # 先检查自定义处理器 + if key in self.custom_key_handlers: + self.custom_key_handlers[key]() + # 再检查默认处理器 + elif key in self.key_handlers: + self.key_handlers[key]() + + def handle_escape(self): + """处理ESC键 - 退出程序""" + self.should_exit = True + print("Exiting...") + + def handle_space(self): + """处理空格键 - 播放/暂停切换""" + self.is_playing = not self.is_playing + status = "PLAYING" if self.is_playing else "PAUSED" + print(f"Play/Pause toggled: {status}") + + def handle_left_arrow(self): + """处理左方向键 - 上一帧""" + self.current_frame = max(self.current_frame - 1, 0) + print(f"Previous frame: {self.current_frame}") + + def handle_right_arrow(self): + """处理右方向键 - 下一帧""" + self.current_frame = min(self.current_frame + 1, self.total_frames - 1) + print(f"Next frame: {self.current_frame}") + + + + def draw_info_overlay(self, frame): + """在视频帧上绘制信息覆盖层""" + # 创建信息显示区域 + info_height = 60 + info_bar = np.zeros((info_height, self.frame_width, 3), dtype=np.uint8) + info_bar[:] = (0, 0, 0) # 黑色背景,半透明效果 + + # 计算时间信息 + current_time = self.current_frame / self.fps if self.fps > 0 else 0 + total_time = self.total_frames / self.fps if self.fps > 0 else 0 + + # 显示信息文本 + time_text = f"Time: {current_time:.1f}s / {total_time:.1f}s" + frame_text = f"Frame: {self.current_frame} / {self.total_frames}" + status_text = "PLAYING" if self.is_playing else "PAUSED" + tracking_text = f"Tracking Boxes: {len(self.tracking_boxes)}" + + # 添加文本到信息栏 + cv2.putText(info_bar, time_text, (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) + cv2.putText(info_bar, frame_text, (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) + cv2.putText(info_bar, status_text, (self.frame_width - 150, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.7, + (0, 255, 0) if self.is_playing else (0, 0, 255), 2) + cv2.putText(info_bar, tracking_text, (self.frame_width - 200, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.7, + (255, 255, 0), 2) # 黄色显示跟踪框数量 + + # 将信息栏添加到帧的底部 + combined_frame = np.vstack([frame, info_bar]) + return combined_frame + + def draw_tracking_boxes(self, frame): + """在帧上绘制跟踪框(类似selectROI的显示效果)""" + # 绘制当前帧的跟踪框 + for roi in self.tracking_boxes: + if roi['frame'] == self.current_frame: + x1, y1, x2, y2 = roi['bbox_corners'] + # 绘制矩形框 + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) + # 绘制ID标签 + cv2.putText(frame, f"ROI {roi['id']}", (x1, y1-10), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) + # 绘制中心点 + center = roi['center'] + cv2.circle(frame, center, 3, (0, 255, 0), -1) + # 绘制尺寸信息 + width, height = roi['bbox'][2], roi['bbox'][3] + cv2.putText(frame, f"{width}x{height}", (x1, y2+20), + cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) + + # 绘制正在绘制的矩形(实时预览) + if self.is_drawing and self.start_point and self.end_point: + x1, y1 = min(self.start_point[0], self.end_point[0]), min(self.start_point[1], self.end_point[1]) + x2, y2 = max(self.start_point[0], self.end_point[0]), max(self.start_point[1], self.end_point[1]) + + # 绘制半透明矩形 + overlay = frame.copy() + cv2.rectangle(overlay, (x1, y1), (x2, y2), (255, 0, 0), -1) + cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame) + + # 绘制边框 + cv2.rectangle(frame, (x1, y1), (x2, y2), (255, 0, 0), 2) + + # 显示尺寸信息 + width, height = x2 - x1, y2 - y1 + cv2.putText(frame, f"ROI: {width} x {height}", (10, 30), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2) + + return frame + + def play(self): + while not self.should_exit: + if self.is_playing: + ret, frame = self.cap.read() + subtitle = self.h30t_reader.read() + if not ret: + # 视频结束,重新开始 + self.current_frame = 0 + self.cap.set(cv2.CAP_PROP_POS_FRAMES, 0) + continue + + self.current_frame = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES)) + + # 获取当前帧 + # self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.current_frame) + ret, frame = self.cap.read() + + if not ret: + break + + # 运行算法 + self.arithout = self.ArithModule.run(frame) + + # 绘制算法结果 + frame = self.draw_arith_result(frame) + + # 绘制信息覆盖层 + display_frame = self.draw_info_overlay(frame) + + # 显示帧 + cv2.imshow(self.window_name, display_frame) + + # 更新进度条位置(避免循环更新) + if not self.is_seeking: + cv2.setTrackbarPos('Progress', self.window_name, self.current_frame) + + # 处理键盘输入 + key = cv2.waitKey(int(1000/self.fps) if self.is_playing else 0) & 0xFF + self.handle_keyboard_input(key) + + self.cleanup() + + def draw_arith_result(self,frame): + for box in self.arithout: + x1, y1, x2, y2, cls, track_id = box + cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) + cv2.putText(frame, f"ID: {track_id}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2) + return frame + + + def cleanup(self): + self.cap.release() + cv2.destroyAllWindows() + + + +def main(): + + video_path = 'DJI_20250418150210_0006_S.MP4' + + # 检查文件是否存在 + if not os.path.exists(video_path): + print(f"Error: Video file {video_path} does not exist") + sys.exit(1) + + # 创建并运行视频播放器 + player = VideoPlayer(video_path) + player.play() + + # + +if __name__ == "__main__": + main()