import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import matplotlib.pyplot as plt import torch import cv2 import torchvision.transforms as transforms import numpy as np import tkinter import tkinter.messagebox #弹窗库 from PyQt5.QtWidgets import * #定义网络 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x) network = Net() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") net = Net() net.load_state_dict(torch.load('model.pth')) net.to(device) net.eval() for name in net.state_dict(): print(name) print(net.state_dict()['conv1.weight']) img = cv2.imread('./data/8.png') img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) img = img/255 # cv2.namedWindow('img',0) # cv2.imshow('img',img) # cv2.waitKey(0) transf = transforms.ToTensor() imgTensor = transf(img).unsqueeze(0) imgTensor = imgTensor.type(torch.FloatTensor).to(device) out = net(imgTensor) print(out.data.max(1, keepdim=True)[1][0].item()) tkinter.messagebox.showinfo('pt推理结果', out.data.max(1, keepdim=True)[1][0].item()) # onnx model dummy_input = torch.randn(1, 1,28,28).to(device)#输入大小 #data type nchw torch.onnx.export(net, dummy_input, "ministNet.onnx", verbose=True, input_names=['input_111'], output_names=['output_111']) net = cv2.dnn.readNetFromONNX("ministNet.onnx") # 加载训练好的识别模型 img = cv2.imread('./data/8.png') img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) img = img/255 transf = transforms.ToTensor() imgTensor = transf(img).unsqueeze(0) im = img[np.newaxis, np.newaxis,:, :] im = im.astype(np.float32) outNames = net.getUnconnectedOutLayersNames() net.setInput(im) out = net.forward(outNames) ind = np.where(out[0][0]==np.max(out[0][0])) print(ind[0]) tkinter.messagebox.showinfo('onnx推理结果',ind[0])