diff --git a/code/minist/test.py b/code/minist/test.py index a835fe9..f2a37c9 100644 --- a/code/minist/test.py +++ b/code/minist/test.py @@ -5,6 +5,13 @@ import matplotlib.pyplot as plt import torch import cv2 import torchvision.transforms as transforms + +import numpy as np + +import tkinter +import tkinter.messagebox #弹窗库 + +from PyQt5.QtWidgets import * #定义网络 class Net(nn.Module): def __init__(self): @@ -35,6 +42,10 @@ net.to(device) net.eval() +for name in net.state_dict(): + print(name) +print(net.state_dict()['conv1.weight']) + img = cv2.imread('./data/8.png') img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) img = img/255 @@ -49,9 +60,35 @@ out = net(imgTensor) print(out.data.max(1, keepdim=True)[1][0].item()) +tkinter.messagebox.showinfo('pt推理结果', out.data.max(1, keepdim=True)[1][0].item()) + + + + +# onnx model dummy_input = torch.randn(1, 1,28,28).to(device)#输入大小 #data type nchw torch.onnx.export(net, dummy_input, "ministNet.onnx", verbose=True, input_names=['input_111'], output_names=['output_111']) +net = cv2.dnn.readNetFromONNX("ministNet.onnx") # 加载训练好的识别模型 +img = cv2.imread('./data/8.png') +img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) +img = img/255 + +transf = transforms.ToTensor() +imgTensor = transf(img).unsqueeze(0) + + + +im = img[np.newaxis, np.newaxis,:, :] +im = im.astype(np.float32) +outNames = net.getUnconnectedOutLayersNames() +net.setInput(im) +out = net.forward(outNames) + + +ind = np.where(out[0][0]==np.max(out[0][0])) +print(ind[0]) +tkinter.messagebox.showinfo('onnx推理结果',ind[0]) diff --git a/code/minist/testOnnx.py b/code/minist/testOnnx.py deleted file mode 100644 index 3733180..0000000 --- a/code/minist/testOnnx.py +++ /dev/null @@ -1,27 +0,0 @@ -import cv2 -import torchvision.transforms as transforms -import numpy as np -net = cv2.dnn.readNetFromONNX("ministNet.onnx") # 加载训练好的识别模型 -img = cv2.imread('./data/8.png') -img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) -img = img/255 - -transf = transforms.ToTensor() -imgTensor = transf(img).unsqueeze(0) - - - -im = img[np.newaxis, np.newaxis,:, :] -im = im.astype(np.float32) -outNames = net.getUnconnectedOutLayersNames() -net.setInput(im) -out = net.forward(outNames) - - -ind = np.where(out[0][0]==np.max(out[0][0])) -print(ind[0]) - - - - - diff --git a/ministNet.onnx b/ministNet.onnx new file mode 100644 index 0000000..081ee0c Binary files /dev/null and b/ministNet.onnx differ diff --git a/model.pth b/model.pth index 8834831..4ec8aba 100644 Binary files a/model.pth and b/model.pth differ diff --git a/optimizer.pth b/optimizer.pth index ad49ab8..40ca06a 100644 Binary files a/optimizer.pth and b/optimizer.pth differ