diff --git a/.gitignore b/.gitignore index 55be276..e6945a1 100644 --- a/.gitignore +++ b/.gitignore @@ -152,3 +152,4 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +/*.pth diff --git a/code/minist/Minist.py b/code/minist/Minist.py index 61a6347..b8bff3c 100644 --- a/code/minist/Minist.py +++ b/code/minist/Minist.py @@ -16,7 +16,7 @@ random_seed = 1 torch.manual_seed(random_seed) - +#下载数据,预处理 train_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('./data/', train=True, download=True, transform=torchvision.transforms.Compose([ @@ -34,12 +34,11 @@ test_loader = torch.utils.data.DataLoader( ])), batch_size=batch_size_test, shuffle=True) +#查看数据 examples = enumerate(test_loader) batch_idx, (example_data, example_targets) = next(examples) print(example_targets) print(example_data.shape) - - fig = plt.figure() for i in range(6): plt.subplot(2,3,i+1) @@ -50,9 +49,12 @@ for i in range(6): plt.yticks([]) plt.show() +train_losses = [] +train_counter = [] +test_losses = [] +test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)] - - +#定义网络 class Net(nn.Module): def __init__(self): super(Net, self).__init__() @@ -70,6 +72,10 @@ class Net(nn.Module): x = self.fc2(x) return F.log_softmax(x) +network = Net() +optimizer = optim.SGD(network.parameters(), lr=learning_rate, + momentum=momentum) +# 定义训练 def train(epoch): network.train() for batch_idx, (data, target) in enumerate(train_loader): @@ -90,6 +96,7 @@ def train(epoch): #train(1) +#测试 def test(): network.eval() test_loss = 0 @@ -105,15 +112,10 @@ def test(): print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))) +test() + -network = Net() -optimizer = optim.SGD(network.parameters(), lr=learning_rate, - momentum=momentum) -train_losses = [] -train_counter = [] -test_losses = [] -test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)] diff --git a/code/minist/data/MNIST/raw/train-images-idx3-ubyte.gz b/code/minist/data/MNIST/raw/train-images-idx3-ubyte.gz deleted file mode 100644 index d7fc67e..0000000 Binary files a/code/minist/data/MNIST/raw/train-images-idx3-ubyte.gz and /dev/null differ diff --git a/code/minist/test.py b/code/minist/test.py new file mode 100644 index 0000000..a835fe9 --- /dev/null +++ b/code/minist/test.py @@ -0,0 +1,57 @@ +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import matplotlib.pyplot as plt +import torch +import cv2 +import torchvision.transforms as transforms +#定义网络 +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x) + +network = Net() + + + + +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +net = Net() +net.load_state_dict(torch.load('model.pth')) +net.to(device) +net.eval() + + +img = cv2.imread('./data/8.png') +img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) +img = img/255 +# cv2.namedWindow('img',0) +# cv2.imshow('img',img) +# cv2.waitKey(0) +transf = transforms.ToTensor() +imgTensor = transf(img).unsqueeze(0) +imgTensor = imgTensor.type(torch.FloatTensor).to(device) +out = net(imgTensor) + + +print(out.data.max(1, keepdim=True)[1][0].item()) + +dummy_input = torch.randn(1, 1,28,28).to(device)#输入大小 #data type nchw +torch.onnx.export(net, dummy_input, "ministNet.onnx", verbose=True, input_names=['input_111'], output_names=['output_111']) + + + + diff --git a/code/minist/testOnnx.py b/code/minist/testOnnx.py new file mode 100644 index 0000000..3733180 --- /dev/null +++ b/code/minist/testOnnx.py @@ -0,0 +1,27 @@ +import cv2 +import torchvision.transforms as transforms +import numpy as np +net = cv2.dnn.readNetFromONNX("ministNet.onnx") # 加载训练好的识别模型 +img = cv2.imread('./data/8.png') +img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) +img = img/255 + +transf = transforms.ToTensor() +imgTensor = transf(img).unsqueeze(0) + + + +im = img[np.newaxis, np.newaxis,:, :] +im = im.astype(np.float32) +outNames = net.getUnconnectedOutLayersNames() +net.setInput(im) +out = net.forward(outNames) + + +ind = np.where(out[0][0]==np.max(out[0][0])) +print(ind[0]) + + + + + diff --git a/data/8.png b/data/8.png new file mode 100644 index 0000000..8099d4f Binary files /dev/null and b/data/8.png differ diff --git a/model.pth b/model.pth index 3d7a757..8834831 100644 Binary files a/model.pth and b/model.pth differ diff --git a/optimizer.pth b/optimizer.pth index 1f74a1e..ad49ab8 100644 Binary files a/optimizer.pth and b/optimizer.pth differ