1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
|
""" 文件说明: """ import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torchvision import datasets, transforms import wandb from tqdm import * from model import CNN_Net
def train(model, device, train_loader, optimizer, epoch): model.train()
n_ex = len(train_loader)
for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=n_ex): data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step()
def test(model, device, test_loader): model.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data) test_loss += F.nll_loss(output, target, reduction='sum').item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
tqdm.write('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset)))
wandb.log({'test_loss': test_loss, 'accuracy': correct / len(test_loader.dataset)})
def main(): config = { 'BATCH_SIZE': 64, 'TEST_BATCH_SIZE': 1000, 'EPOCHS': 10, 'LR': 0.01, 'MOMENTUM': 0, 'SEED': 17,
}
wandb.init(project='explore-gradients', reinit=True, config=config)
use_cuda = torch.cuda.is_available()
torch.manual_seed(config['SEED'])
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=config['BATCH_SIZE'], shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=config['TEST_BATCH_SIZE'], shuffle=True, **kwargs)
model = CNN_Net().to(device)
wandb.watch(model, log='all')
optimizer = optim.SGD(model.parameters(), lr=config['LR'], momentum=config['MOMENTUM'])
for epoch in range(1, config['EPOCHS'] + 1): train(model, device, train_loader, optimizer, epoch) test(model, device, test_loader)
if __name__ == '__main__':
main()
|