两百行的图像分类,用CIFAR10训练ResNet18

水一下这个月的稿子,网上肯定很多,但写都写了就发出来。

# train.py
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision.models as models
import torchvision.transforms as T
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import argparse
import random
import torch
import time
import sys
import os

def parse_args():
    parser = argparse.ArgumentParser(description='Train a model')
    parser.add_argument('--img_h', type=int, default=32, help='image height')
    parser.add_argument('--img_w', type=int, default=32, help='image width')
    parser.add_argument('--gpu', type=str, default='0', help='gpu id')
    parser.add_argument('--model', type=str, default='ResNet18', help='model')
    parser.add_argument('--optim', type=str, default='Adam', help='optimizer')
    parser.add_argument('--batch_size', type=int, default=128, help='batch size')
    parser.add_argument('--epochs', type=int, default=10, help='epochs')
    parser.add_argument('--test_interval', type=int, default=1, help='test interval')
    parser.add_argument('--log_interval', type=int, default=50, help='log interval')
    parser.add_argument('--lr', type=float, default=0.00005, help='learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
    parser.add_argument('--weight_decay', type=float, default=0.0005, help='weight decay')
    parser.add_argument('--log_path', type=str, default='./log', help='log path')
    parser.add_argument('--criterion', type=str, default='CrossEntropyLoss', help='criterion')
    parser.add_argument('--seed', type=int, default=10636, help='seed')
    parser.add_argument('--num_workers', type=int, default=4, help='num_workers')
    args = parser.parse_args()
    return args

def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

class Logger():
    def __init__(self, log_path, stream=sys.stdout):
        self.terminal = stream
        self.log_path = log_path
        self.log = open(self.log_path, "w")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass

    def close(self):
        self.log.close()

def draw_loss_acc(loss, acc, args):    
    plt.plot(range(len(loss)), loss)
    plt.axis([0, len(loss), 0, max(loss)+0.1])
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.savefig(os.path.join(args.log_path, "0_loss.png"))
    plt.close()

    plt.plot(range(len(acc)), acc)
    plt.axis([0, len(acc), 0, 100])
    plt.xlabel('epoch')
    plt.ylabel('acc')
    plt.savefig(os.path.join(args.log_path, "0_acc.png"))
    plt.close()

def train(model, args, train_loader, test_loader, criterion, optimizer, scheduler):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # device = "cpu"
    model.to(device)
    best_acc = 0
    best_epoch = 0
    logger_loss = []
    logger_acc = []
    for epoch in range(args.epochs):
        # 训练模型
        model.train()
        logger_loss.append(0)
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            logger_loss[epoch] += loss.data.cpu()

            if (i + 1) % args.log_interval == 0:
                print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f LR: %.6f' % (epoch + 1, args.epochs, i + 1, len(train_loader), loss.data, optimizer.param_groups[0]['lr']))
        logger_loss[epoch] = logger_loss[epoch] / (i+1)
        scheduler.step()

        # 测试模型
        if (epoch + 1) % args.test_interval == 0 or epoch == 0:
            model.eval()
            correct = 0
            total = 0
            for images, labels in test_loader:
                images = Variable(images).to(device)
                labels = Variable(labels).to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum()
            print('Epoch [%d/%d], Accuracy of the model on the test images: %d %%' % (epoch + 1, args.epochs, 100 * correct / total))
            torch.save(model.state_dict(), f'{args.log_path}/epoch_{epoch+1}.pth')
            # 静态化保存模型
            script_model = torch.jit.script(model)
            torch.jit.save(script_model, f'{args.log_path}/epoch_{epoch+1}.pt')
            logger_acc.append((100 * correct / total).cpu())
            if correct > best_acc:
                best_epoch = epoch
                best_acc = correct
                torch.save(model.state_dict(), f'{args.log_path}/best.pth')
                # 静态化保存模型
                script_model = torch.jit.script(model)
                torch.jit.save(script_model, f'{args.log_path}/best.pt')
    print('Best epoch :%d, Accuracy of the best model on the test images: %d %%' % (best_epoch + 1, 100 * best_acc / total))
    draw_loss_acc(logger_loss, logger_acc, args)

def main(args):
    if not os.path.isdir(args.log_path):
        os.makedirs(args.log_path)

    log_file = str(time.time()) + ".log"
    logger = Logger(os.path.join(args.log_path, log_file))
    sys.stdout = logger
    os.environ["TORCH_USE_CUDA_DSA"] = "True"
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    set_seed(args.seed)

    if args.model == 'ResNet18':
        model = models.resnet18(pretrained=True)
        num_fc_in = model.fc.in_features
        model.fc = torch.nn.Linear(in_features=num_fc_in, out_features=10)
    else:
        raise ValueError('model not support')

    train_transform = T.Compose([
        T.Resize((args.img_h, args.img_w)),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    test_transform = T.Compose([
        T.Resize((args.img_h, args.img_w)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    num_workers = args.num_workers
    train_dataset = torchvision.datasets.CIFAR10(root='./cifar10', train=True, transform=train_transform, download=True)
    test_dataset = torchvision.datasets.CIFAR10(root='./cifar10', train=False, transform=test_transform)

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=num_workers)

    if args.optim == 'SGD':
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.optim == 'Adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    else:
        raise ValueError('optimizer not support')

    if args.criterion == 'CrossEntropyLoss':
        criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
    elif args.criterion == 'NLLLoss':
        criterion = torch.nn.NLLLoss()
    else:
        raise ValueError('criterion not support')

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.epochs//2, gamma=0.1)

    train(model, args, train_loader, test_loader, criterion, optimizer, scheduler)

if __name__ == '__main__':
    args = parse_args()
    main(args)
    print('Done')

版权声明:
作者:MWHLS
链接:http://panwj.top/4912.html
来源:无镣之涯
文章版权归作者所有,未经允许请勿转载。

THE END
分享
二维码
打赏
< <上一篇
下一篇>>
文章目录
关闭
目 录