两百行的图像分类,用CIFAR10训练ResNet18
水一下这个月的稿子,网上肯定很多,但写都写了就发出来。
# train.py
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision.models as models
import torchvision.transforms as T
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import argparse
import random
import torch
import time
import sys
import os
def parse_args():
parser = argparse.ArgumentParser(description='Train a model')
parser.add_argument('--img_h', type=int, default=32, help='image height')
parser.add_argument('--img_w', type=int, default=32, help='image width')
parser.add_argument('--gpu', type=str, default='0', help='gpu id')
parser.add_argument('--model', type=str, default='ResNet18', help='model')
parser.add_argument('--optim', type=str, default='Adam', help='optimizer')
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--epochs', type=int, default=10, help='epochs')
parser.add_argument('--test_interval', type=int, default=1, help='test interval')
parser.add_argument('--log_interval', type=int, default=50, help='log interval')
parser.add_argument('--lr', type=float, default=0.00005, help='learning rate')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
parser.add_argument('--weight_decay', type=float, default=0.0005, help='weight decay')
parser.add_argument('--log_path', type=str, default='./log', help='log path')
parser.add_argument('--criterion', type=str, default='CrossEntropyLoss', help='criterion')
parser.add_argument('--seed', type=int, default=10636, help='seed')
parser.add_argument('--num_workers', type=int, default=4, help='num_workers')
args = parser.parse_args()
return args
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
class Logger():
def __init__(self, log_path, stream=sys.stdout):
self.terminal = stream
self.log_path = log_path
self.log = open(self.log_path, "w")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
pass
def close(self):
self.log.close()
def draw_loss_acc(loss, acc, args):
plt.plot(range(len(loss)), loss)
plt.axis([0, len(loss), 0, max(loss)+0.1])
plt.xlabel('epoch')
plt.ylabel('loss')
plt.savefig(os.path.join(args.log_path, "0_loss.png"))
plt.close()
plt.plot(range(len(acc)), acc)
plt.axis([0, len(acc), 0, 100])
plt.xlabel('epoch')
plt.ylabel('acc')
plt.savefig(os.path.join(args.log_path, "0_acc.png"))
plt.close()
def train(model, args, train_loader, test_loader, criterion, optimizer, scheduler):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = "cpu"
model.to(device)
best_acc = 0
best_epoch = 0
logger_loss = []
logger_acc = []
for epoch in range(args.epochs):
# 训练模型
model.train()
logger_loss.append(0)
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
logger_loss[epoch] += loss.data.cpu()
if (i + 1) % args.log_interval == 0:
print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f LR: %.6f' % (epoch + 1, args.epochs, i + 1, len(train_loader), loss.data, optimizer.param_groups[0]['lr']))
logger_loss[epoch] = logger_loss[epoch] / (i+1)
scheduler.step()
# 测试模型
if (epoch + 1) % args.test_interval == 0 or epoch == 0:
model.eval()
correct = 0
total = 0
for images, labels in test_loader:
images = Variable(images).to(device)
labels = Variable(labels).to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum()
print('Epoch [%d/%d], Accuracy of the model on the test images: %d %%' % (epoch + 1, args.epochs, 100 * correct / total))
torch.save(model.state_dict(), f'{args.log_path}/epoch_{epoch+1}.pth')
# 静态化保存模型
script_model = torch.jit.script(model)
torch.jit.save(script_model, f'{args.log_path}/epoch_{epoch+1}.pt')
logger_acc.append((100 * correct / total).cpu())
if correct > best_acc:
best_epoch = epoch
best_acc = correct
torch.save(model.state_dict(), f'{args.log_path}/best.pth')
# 静态化保存模型
script_model = torch.jit.script(model)
torch.jit.save(script_model, f'{args.log_path}/best.pt')
print('Best epoch :%d, Accuracy of the best model on the test images: %d %%' % (best_epoch + 1, 100 * best_acc / total))
draw_loss_acc(logger_loss, logger_acc, args)
def main(args):
if not os.path.isdir(args.log_path):
os.makedirs(args.log_path)
log_file = str(time.time()) + ".log"
logger = Logger(os.path.join(args.log_path, log_file))
sys.stdout = logger
os.environ["TORCH_USE_CUDA_DSA"] = "True"
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
set_seed(args.seed)
if args.model == 'ResNet18':
model = models.resnet18(pretrained=True)
num_fc_in = model.fc.in_features
model.fc = torch.nn.Linear(in_features=num_fc_in, out_features=10)
else:
raise ValueError('model not support')
train_transform = T.Compose([
T.Resize((args.img_h, args.img_w)),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
test_transform = T.Compose([
T.Resize((args.img_h, args.img_w)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
num_workers = args.num_workers
train_dataset = torchvision.datasets.CIFAR10(root='./cifar10', train=True, transform=train_transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./cifar10', train=False, transform=test_transform)
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=num_workers)
if args.optim == 'SGD':
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif args.optim == 'Adam':
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
else:
raise ValueError('optimizer not support')
if args.criterion == 'CrossEntropyLoss':
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
elif args.criterion == 'NLLLoss':
criterion = torch.nn.NLLLoss()
else:
raise ValueError('criterion not support')
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.epochs//2, gamma=0.1)
train(model, args, train_loader, test_loader, criterion, optimizer, scheduler)
if __name__ == '__main__':
args = parse_args()
main(args)
print('Done')
文章目录
关闭
共有 0 条评论