手写识别包含五个内容,分别为配置文件,数据处理,神经网络,测试,以及执行脚本。
config
import torch
device = torch.device('cuda:0'if torch.cuda.is_available() else 'cpu')
batch_size = 128
data_path = r'你自己的路径'
train_epochs = 50
lr = 0.001
model_save_path = r'你自己的路径/lenet5.pth'
print(device)
dataset
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
from config import batch_size,data_path
def get_dataset():
Mnist_train = datasets.MNIST(root=data_path, download=False, train=True, transform=transforms.ToTensor())
Mnist_train = DataLoader(Mnist_train, batch_size=batch_size, shuffle=True)
Mnist_test = datasets.MNIST(root=data_path, download=False, train=False, transform=transforms.ToTensor())
Mnist_test = DataLoader(Mnist_test, batch_size=batch_size, shuffle=True)
return Mnist_train, Mnist_test
def main():
train, test = get_dataset()
train = iter(train).__next__()
test = iter(test).__next__()
print('train',train)
print('test',test)
if __name__ == '__main__':
main()
lenet5
import torch
import torch.nn as nn
class Lenet5(nn.Module):
def __init__(self):
super(Lenet5, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=1,
out_channels=6,
kernel_size=(5,5),
stride=(1,1),
padding=2),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(6,16,kernel_size=(5,5),stride=(1,1),padding=0),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2)
)
self.fc = nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
def forward(self,x):
out = self.conv1(x)
out = self.conv2(out)
out = out.view(-1,16*5*5)
logitst = self.fc(out)
return logitst
def main():
temp = torch.randn(2,1,28,28)
model = Lenet5()
out = model(temp)
print(out.shape)
if __name__ == '__main__':
main()
main
from train import train
if __name__ == '__main__':
train()
train
from lenet5 import Lenet5
from dataset import get_dataset
from config import device, train_epochs, lr, model_save_path
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
import torch
def train():
mnist_train, mnist_test = get_dataset() # 加上括号来调用函数
net = Lenet5().to(device)
optimizer = Adam(net.parameters(), lr=lr)
criterion = CrossEntropyLoss().to(device) # typo: cirterion -> criterion
net.train()
for epoch in range(train_epochs):
for idx, (x, labels) in enumerate(mnist_train):
x = x.to(device)
labels = labels.to(device)
predict = net(x)
optimizer.zero_grad()
loss = criterion(predict, labels) # 使用修正后的 criterion
loss.backward()
optimizer.step()
if idx % 100 == 0 or idx + 1 == len(mnist_train):
print('epoch:{}, idx:{}, loss:{:.3f}'.format(epoch, idx, loss.item())) # 修正了括号
torch.save(net.state_dict(), model_save_path)
net.eval()
correct_num = 0
total_num = 0
for x, labels in mnist_test:
x = x.to(device)
labels = labels.to(device)
out = net(x)
predict = out.argmax(-1)
correct_num += torch.eq(predict, labels).float().sum().item()
total_num += len(x)
print('epoch:{}, test correct rate:{:.1f}%'.format(epoch, correct_num / total_num * 100))
if __name__ == '__main__':
train()