You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
190 lines
6.1 KiB
190 lines
6.1 KiB
2 weeks ago
|
import argparse
|
||
|
import os
|
||
|
import time
|
||
|
|
||
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
import torch
|
||
|
import torch.backends.cudnn as cudnn
|
||
|
import torchvision
|
||
|
|
||
|
from model import Net
|
||
|
|
||
|
parser = argparse.ArgumentParser(description="Train on market1501")
|
||
|
parser.add_argument("--data-dir",default='data',type=str)
|
||
|
parser.add_argument("--no-cuda",action="store_true")
|
||
|
parser.add_argument("--gpu-id",default=0,type=int)
|
||
|
parser.add_argument("--lr",default=0.1, type=float)
|
||
|
parser.add_argument("--interval",'-i',default=20,type=int)
|
||
|
parser.add_argument('--resume', '-r',action='store_true')
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
# device
|
||
|
device = "cuda:{}".format(args.gpu_id) if torch.cuda.is_available() and not args.no_cuda else "cpu"
|
||
|
if torch.cuda.is_available() and not args.no_cuda:
|
||
|
cudnn.benchmark = True
|
||
|
|
||
|
# data loading
|
||
|
root = args.data_dir
|
||
|
train_dir = os.path.join(root,"train")
|
||
|
test_dir = os.path.join(root,"test")
|
||
|
transform_train = torchvision.transforms.Compose([
|
||
|
torchvision.transforms.RandomCrop((128,64),padding=4),
|
||
|
torchvision.transforms.RandomHorizontalFlip(),
|
||
|
torchvision.transforms.ToTensor(),
|
||
|
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||
|
])
|
||
|
transform_test = torchvision.transforms.Compose([
|
||
|
torchvision.transforms.Resize((128,64)),
|
||
|
torchvision.transforms.ToTensor(),
|
||
|
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||
|
])
|
||
|
trainloader = torch.utils.data.DataLoader(
|
||
|
torchvision.datasets.ImageFolder(train_dir, transform=transform_train),
|
||
|
batch_size=64,shuffle=True
|
||
|
)
|
||
|
testloader = torch.utils.data.DataLoader(
|
||
|
torchvision.datasets.ImageFolder(test_dir, transform=transform_test),
|
||
|
batch_size=64,shuffle=True
|
||
|
)
|
||
|
num_classes = max(len(trainloader.dataset.classes), len(testloader.dataset.classes))
|
||
|
|
||
|
# net definition
|
||
|
start_epoch = 0
|
||
|
net = Net(num_classes=num_classes)
|
||
|
if args.resume:
|
||
|
assert os.path.isfile("./checkpoint/ckpt.t7"), "Error: no checkpoint file found!"
|
||
|
print('Loading from checkpoint/ckpt.t7')
|
||
|
checkpoint = torch.load("./checkpoint/ckpt.t7")
|
||
|
# import ipdb; ipdb.set_trace()
|
||
|
net_dict = checkpoint['net_dict']
|
||
|
net.load_state_dict(net_dict)
|
||
|
best_acc = checkpoint['acc']
|
||
|
start_epoch = checkpoint['epoch']
|
||
|
net.to(device)
|
||
|
|
||
|
# loss and optimizer
|
||
|
criterion = torch.nn.CrossEntropyLoss()
|
||
|
optimizer = torch.optim.SGD(net.parameters(), args.lr, momentum=0.9, weight_decay=5e-4)
|
||
|
best_acc = 0.
|
||
|
|
||
|
# train function for each epoch
|
||
|
def train(epoch):
|
||
|
print("\nEpoch : %d"%(epoch+1))
|
||
|
net.train()
|
||
|
training_loss = 0.
|
||
|
train_loss = 0.
|
||
|
correct = 0
|
||
|
total = 0
|
||
|
interval = args.interval
|
||
|
start = time.time()
|
||
|
for idx, (inputs, labels) in enumerate(trainloader):
|
||
|
# forward
|
||
|
inputs,labels = inputs.to(device),labels.to(device)
|
||
|
outputs = net(inputs)
|
||
|
loss = criterion(outputs, labels)
|
||
|
|
||
|
# backward
|
||
|
optimizer.zero_grad()
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
|
||
|
# accumurating
|
||
|
training_loss += loss.item()
|
||
|
train_loss += loss.item()
|
||
|
correct += outputs.max(dim=1)[1].eq(labels).sum().item()
|
||
|
total += labels.size(0)
|
||
|
|
||
|
# print
|
||
|
if (idx+1)%interval == 0:
|
||
|
end = time.time()
|
||
|
print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(
|
||
|
100.*(idx+1)/len(trainloader), end-start, training_loss/interval, correct, total, 100.*correct/total
|
||
|
))
|
||
|
training_loss = 0.
|
||
|
start = time.time()
|
||
|
|
||
|
return train_loss/len(trainloader), 1.- correct/total
|
||
|
|
||
|
def test(epoch):
|
||
|
global best_acc
|
||
|
net.eval()
|
||
|
test_loss = 0.
|
||
|
correct = 0
|
||
|
total = 0
|
||
|
start = time.time()
|
||
|
with torch.no_grad():
|
||
|
for idx, (inputs, labels) in enumerate(testloader):
|
||
|
inputs, labels = inputs.to(device), labels.to(device)
|
||
|
outputs = net(inputs)
|
||
|
loss = criterion(outputs, labels)
|
||
|
|
||
|
test_loss += loss.item()
|
||
|
correct += outputs.max(dim=1)[1].eq(labels).sum().item()
|
||
|
total += labels.size(0)
|
||
|
|
||
|
print("Testing ...")
|
||
|
end = time.time()
|
||
|
print("[progress:{:.1f}%]time:{:.2f}s Loss:{:.5f} Correct:{}/{} Acc:{:.3f}%".format(
|
||
|
100.*(idx+1)/len(testloader), end-start, test_loss/len(testloader), correct, total, 100.*correct/total
|
||
|
))
|
||
|
|
||
|
# saving checkpoint
|
||
|
acc = 100.*correct/total
|
||
|
if acc > best_acc:
|
||
|
best_acc = acc
|
||
|
print("Saving parameters to checkpoint/ckpt.t7")
|
||
|
checkpoint = {
|
||
|
'net_dict':net.state_dict(),
|
||
|
'acc':acc,
|
||
|
'epoch':epoch,
|
||
|
}
|
||
|
if not os.path.isdir('checkpoint'):
|
||
|
os.mkdir('checkpoint')
|
||
|
torch.save(checkpoint, './checkpoint/ckpt.t7')
|
||
|
|
||
|
return test_loss/len(testloader), 1.- correct/total
|
||
|
|
||
|
# plot figure
|
||
|
x_epoch = []
|
||
|
record = {'train_loss':[], 'train_err':[], 'test_loss':[], 'test_err':[]}
|
||
|
fig = plt.figure()
|
||
|
ax0 = fig.add_subplot(121, title="loss")
|
||
|
ax1 = fig.add_subplot(122, title="top1err")
|
||
|
def draw_curve(epoch, train_loss, train_err, test_loss, test_err):
|
||
|
global record
|
||
|
record['train_loss'].append(train_loss)
|
||
|
record['train_err'].append(train_err)
|
||
|
record['test_loss'].append(test_loss)
|
||
|
record['test_err'].append(test_err)
|
||
|
|
||
|
x_epoch.append(epoch)
|
||
|
ax0.plot(x_epoch, record['train_loss'], 'bo-', label='train')
|
||
|
ax0.plot(x_epoch, record['test_loss'], 'ro-', label='val')
|
||
|
ax1.plot(x_epoch, record['train_err'], 'bo-', label='train')
|
||
|
ax1.plot(x_epoch, record['test_err'], 'ro-', label='val')
|
||
|
if epoch == 0:
|
||
|
ax0.legend()
|
||
|
ax1.legend()
|
||
|
fig.savefig("train.jpg")
|
||
|
|
||
|
# lr decay
|
||
|
def lr_decay():
|
||
|
global optimizer
|
||
|
for params in optimizer.param_groups:
|
||
|
params['lr'] *= 0.1
|
||
|
lr = params['lr']
|
||
|
print("Learning rate adjusted to {}".format(lr))
|
||
|
|
||
|
def main():
|
||
|
for epoch in range(start_epoch, start_epoch+40):
|
||
|
train_loss, train_err = train(epoch)
|
||
|
test_loss, test_err = test(epoch)
|
||
|
draw_curve(epoch, train_loss, train_err, test_loss, test_err)
|
||
|
if (epoch+1)%20==0:
|
||
|
lr_decay()
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
main()
|