手写数字识别算法的研究意义(104.人工智能手写汉字数字识别)

在前一章,对手写汉字数字的数据进行了预处理,本文准备使用ResNet18卷积神经网络来实现手写汉字数字的识别。关于ResNet网络模型和代码可以参看:101.人工智能——构建残差网络ResNet18网络模型。

自定数据集

#自定义数据集 import paddle import paddle.vision.transforms as T class MyDataset(paddle.io.Dataset): def __init__(self,data,mode="train"): self.data=data self.mode=mode def transform(self,mode): if mode=="train": return T.Compose([ T.ToTensor(), T.Normalize() ]) else: return T.Compose([ T.ToTensor(), T.Normalize() ]) def __getitem__(self,idx): img=mpimg.imread(os.path.join(datadir,self.data[idx][0])) img=self.transform(self.mode)(img) label=self.data[idx][1] label=np.array(label).astype("int64") label=np.reshape(label,(1)) return img,label def __len__(self): return len(self.data) #返回数据集 train_dataset=MyDataset(traindata,"train") val_dataset=MyDataset(valdata,"val") test_dataset=MyDataset(testdata,"test") train_loader=paddle.io.DataLoader(train_dataset,batch_size=16,shuffle=True) val_loader=paddle.io.DataLoader(val_dataset,batch_size=16,shuffle=False) test_loader=paddle.io.DataLoader(test_dataset,batch_size=16,shuffle=False) #查看数据集形状 print(train_dataset.data.shape,val_dataset.data.shape,test_dataset.data.shape) #查看批次数据 for i,data in enumerate(train_loader()): img,label=data print(img.shape,label.shape) break

#运行结果 (10500, 2) (3000, 2) (1500, 2) [16, 3, 64, 64] [16, 1]

构建ResNet18网络模型

详细代码本文这里省略。可以参看:101.人工智能——构建残差网络ResNet18网络模型。

模型训练

#查看模型结构 model = Model_ResNet18(in_channels=3, num_classes=15, use_residual=True) params_info = paddle.summary(model, (1, 3, 64, 64)) print(params_info)

#开始训练 def train(model,train_loader,epochs=10): #定义优化器 opt=paddle.optimizer.Adam(learning_rate=0.001,parameters=model.parameters()) bestacc=0 ############################ # # 读取参数文件(恢复训练,1:加载最后的训练模型文件和优化参数文件) # params_dict = paddle.load("models/finally.pdparams") # opt_dict = paddle.load("models/finally.pdopt") # # 加载参数到模型 # model.set_state_dict(params_dict) # # 加载参数到优化器 # opt.set_state_dict(opt_dict) ############################ print('start training ... ') model.train() for epoch in range(epochs): for i,data in enumerate(train_loader()): img,label=data #转换数据类型 img=paddle.to_tensor(img) label=paddle.to_tensor(label) #前向计算,获取损失函数值 out=model(img) loss=F.cross_entropy(out,label) avg_loss=paddle.mean(loss) #反向传播,更新参数,清空梯度 avg_loss.backward() opt.step() opt.clear_gradients() print(f"epoch:{epoch} loss:{avg_loss.numpy()}") #验证模型 model.eval() accs=[] losses=[] for i,data in enumerate(val_loader()): img,label=data img=paddle.to_tensor(img) label=paddle.to_tensor(label) out=model(img) loss=F.cross_entropy(out,label) avg_loss=paddle.mean(loss) acc=paddle.metric.accuracy(out,label) accs.append(acc.numpy()) losses.append(avg_loss.numpy()) print(f"epoch:{epoch},loss:{np.mean(losses)},acc:{np.mean(accs)}") #保存最佳模型 if np.mean(accs)>bestacc: bestacc=np.mean(accs) print(f"save best model,acc:{bestacc},epoch:{epoch}") paddle.save(model.state_dict(),"models/rawnum_resnet18_best.pdparams") paddle.save(opt.state_dict(), 'models/rawnum_resnet18_best.pdopt') model.train() #恢复训练模式 train(model,train_loader,epochs=10)

#训练过程,在CPU环境下训练时间很长……但从训练结果来看,准确率还在达到90%以上。 epoch:0 loss:[0.10153195] epoch:0,loss:0.23171958327293396,acc:0.9301861524581909 save best model,acc:0.9301861524581909,epoch:0 epoch:1 loss:[0.02970559] epoch:1,loss:0.28862079977989197,acc:0.9135638475418091 epoch:2 loss:[0.8410681] epoch:2,loss:0.3803146481513977,acc:0.904920220375061 ………………

模型预测

#加载模型、预测模型 model_dict=paddle.load("models/rawnum_resnet18_best.pdparams") model.load_dict(model_dict) model.eval() #随机取一条测试数据 idx=np.random.randint(len(test_dataset)) img,label=test_dataset[idx] img=np.reshape(img,(1,3,64,64)) label=np.reshape(label,(1,1)) #print(img.shape,label.shape,label.item()) results=model(paddle.to_tensor(img)) predictlabel=np.argmax(results.numpy()) #最大值的索引,用argmax print(f"predict:{predictlabel},label:{label.item()}")

#预测结果:随机运行5次 predict:5,label:5 predict:10,label:10 predict:6,label:5 predict:1,label:1 predict:5,label:5

从预测结果来看,和准确率90%以上是相符的,没有达到100%的识别效果。

手写数字识别算法的研究意义(104.人工智能手写汉字数字识别)(1)

,

免责声明:本文仅代表文章作者的个人观点,与本站无关。其原创性、真实性以及文中陈述文字和内容未经本站证实,对本文以及其中全部或者部分内容文字的真实性、完整性和原创性本站不作任何保证或承诺,请读者仅作参考,并自行核实相关内容。文章投诉邮箱:anhduc.ph@yahoo.com

    分享
    投诉
    首页