import torch import torch.nn as nn import torch.optim as optim import torch.backends.cudnn as cudnn import torchvision import torchvision.transforms as transforms import os import argparse from wideresnet import Wide_ResNet import math from tensorboardX import SummaryWriter import numpy as np from collections import OrderedDict parser = argparse.ArgumentParser() parser.add_argument('--lr', default=0.05, type=float, help='learning rate') parser.add_argument('--total_epoch', default=200, type=int, help='total epoch') parser.add_argument('--train_batchsize', default=128, type=int, help='training batchsize') parser.add_argument('--test_batchsize', default=100, type=int, help='testing batchsize') parser.add_argument('--resume', action='store_true', help='resume from checkpoint') parser.add_argument('--test', action='store_true', help='testing mode') parser.add_argument('--ckpt', default="'./checkpoint/ckpt.pth'", type=str, help='dir of the checkpoint need to be loaded') args = parser.parse_args() class SGDR(object): def __init__(self, args): """ 1. Initialize the model. 2. Prepare data for training and testing. """ # set random seed torch.manual_seed(2019) # parsing hyper-parameters self.args = args args = None # GPU or CPU mode self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.start_epoch = 0 # start from epoch 0 or last checkpoint epoch print("==> Device:", self.device) # Preparing Data print('==> Preparing data..') # Training: image augmentation and normalization transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) # Testing: image normalization only transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) # load training and testing data in dataloader trainset = torchvision.datasets.CIFAR10(root='cifar', train=True, download=False, transform=transform_train) self.trainloader = torch.utils.data.DataLoader(trainset, batch_size=self.args.train_batchsize, shuffle=True, num_workers=4) testset = torchvision.datasets.CIFAR10(root='cifar', train=False, download=False, transform=transform_test) self.testloader = torch.utils.data.DataLoader(testset, batch_size=self.args.test_batchsize, shuffle=False, num_workers=4) # classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # Build the model print('==> Building model..') self.net = Wide_ResNet(depth=28, widen_factor=20, dropout_rate=0.0, num_classes=10) self.net = self.net.to(self.device) if self.device == 'cuda': self.net = torch.nn.DataParallel(self.net) cudnn.benchmark = True def load_SGDR(self): """ Load checkpoint for resume traiing or testing. """ print('==> Resuming from checkpoint ' + self.args.ckpt) assert os.path.isdir('checkpoint'), 'Error: no checkpoint directory found!' ckpt = self.args.ckpt if self.device == 'cpu': checkpoint = torch.load(ckpt, map_location='cpu') else: checkpoint = torch.load(ckpt) if self.device == 'cpu': new_state_dict = OrderedDict() for k, v in checkpoint['net'].items(): name = k[7:] # remove `module.` in the state_dict which is saved with the "nn.DataParallel()" new_state_dict[name] = v self.net.load_state_dict(new_state_dict) else: self.net.load_state_dict(checkpoint['net']) self.best_acc = checkpoint['acc'] self.start_epoch = checkpoint['epoch'] print("======== Best acc %.3f ========" % self.best_acc) def test_SGDR(self): """ Testing the model (Can be used during training as validation). """ test_total = len(self.testloader) * self.args.test_batchsize predicted_list = torch.zeros(test_total).to(self.device) target_list = torch.zeros(test_total).to(self.device) self.net.eval() with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(self.testloader): # read data from testing datalaoder inputs, targets = inputs.to(self.device), targets.to(self.device) # forward pass outputs = self.net(inputs) # save prediction result _, predicted_list[batch_idx * self.args.test_batchsize: (batch_idx + 1) * self.args.test_batchsize] = outputs.max(1) target_list[batch_idx * self.args.test_batchsize: (batch_idx + 1) * self.args.test_batchsize] = targets print('| Epoch: %03d/%03d | Batch: %03d/%03d |' % (self.start_epoch, self.args.total_epoch, batch_idx, len(self.testloader))) # calculate prediction accuracy test_acc = predicted_list.eq(target_list).sum().item() / float(test_total) print("======== Acc: %.3f ========" % (100 * test_acc)) # save prediction result to text file np.savetxt(os.path.join("test_results", str(self.start_epoch) + ".txt"), np.column_stack([np.arange(0, test_total), predicted_list.data.cpu().numpy(), target_list.data.cpu().numpy()]), fmt="%s", delimiter=",") return test_acc def train_SGDR(self): """ Train the model. """ # loss function: Cross entropy loss criterion = nn.CrossEntropyLoss() # optimizer: Stochastic Gradient Descent (SGD) optimizer = optim.SGD(self.net.parameters(), lr=self.args.lr, momentum=0.9, weight_decay=5e-4) # keep track of important parameters such as loss, accuracy, etc. writer = SummaryWriter(log_dir="./tensorboardX") # SGDR params T_cur = 0.0 # current epoch T_mult = 1.0 # multiplier of next restart duration T_restart = 200 # next restart duration T_cur_previous = T_cur T_next_restart = T_restart loss_list = [] train_acc_list = [] lr_list = [] best_test_acc = 0.0 test_acc_list = [] for epoch in range(self.start_epoch, self.start_epoch + self.args.total_epoch): # ======= Training phase ======= self.net.train() train_correct = 0 train_total = 0 for batch_idx, (inputs, targets) in enumerate(self.trainloader): # read data from dataloader inputs, targets = inputs.to(self.device), targets.to(self.device) # clear previous gradient optimizer.zero_grad() # forward pass outputs = self.net(inputs) # calculate loss train_loss = criterion(outputs, targets) # back propagation train_loss.backward() # update weights optimizer.step() # calculate prediction accuracy _, predicted = outputs.max(1) train_total += targets.size(0) train_correct += predicted.eq(targets).sum().item() train_acc = 100.0 * train_correct / train_total # SGDR: Learning rate decrease T_cur = T_cur_previous + batch_idx / (len(self.trainloader) - 1) # T_cur: T_cur_previous + [0, 1] lr_new = 0.5 * self.args.lr * (1.0 + math.cos(T_cur * math.pi / T_restart)) # SGDR: Change lr for param_group in optimizer.param_groups: param_group['lr'] = lr_new # write log writer.add_scalar('lr', lr_new, epoch * len(self.trainloader) + batch_idx) writer.add_scalar('loss', train_loss, epoch * len(self.trainloader) + batch_idx) writer.add_scalar('acc', train_acc, epoch * len(self.trainloader) + batch_idx) lr_list.append(lr_new) train_acc_list.append(train_acc) loss_list.append(train_loss) print('| lr: %.3f/%d %.6f | Epoch: %03d/%03d | Batch: %03d/%03d | Loss: %.6f | Acc: %.3f (%d/%d) |' % (T_cur, T_restart, lr_new, epoch, self.args.total_epoch, batch_idx, len(self.trainloader), train_loss, train_acc, train_correct, train_total)) T_cur_previous = T_cur # SGDR: Warm restart if int(epoch + 1) == int(T_next_restart): T_cur_previous = 0.0 T_restart = T_restart * T_mult T_next_restart = T_next_restart + T_restart # ======= Testing phase ======= test_acc = self.test_SGDR() test_acc_list.append(test_acc) writer.add_scalar('test_acc', test_acc, (epoch + 1) * len(self.trainloader)) # Save checkpoint when current accuracy is the highest if test_acc > best_test_acc: print('Saving the best checkpoit ...') # parameters need to be saved state = { 'net': self.net.state_dict(), 'acc': test_acc, 'epoch': epoch, 'lr_list': lr_list, 'train_acc_list': train_acc_list, 'test_acc_list': test_acc_list, "loss_list": loss_list, 'T_mult': T_mult, 'T_restart': T_restart, } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') torch.save(state, './checkpoint/ckpt.pth') best_test_acc = test_acc writer.close() if __name__ == "__main__": model = SGDR(args) if model.args.resume or model.args.test: model.load_SGDR() if model.args.test: model.test_SGDR() else: model.train_SGDR() else: model.train_SGDR()