|
|
|
|
移动端

一个单层的基础神经网络实现手写字识别

这篇文章是我根据个人的一些理解来写的,后续如果发现有错误,我会在新文章说出来,但这篇文章不做保留,方便后续检查思考记录的时候知道到底怎么踩坑的。

作者:kumfo来源:segmentfault|2017-12-28 14:44

年前最后一场技术盛宴 | 1月27日与京东、日志易技术大咖畅聊智能化运维发展趋势!


一个单层的基础神经网络实现手写字识别


先上代码

  1. import tensorflow 
  2.  
  3. from tensorflow.examples.tutorials.mnist import input_data 
  4.  
  5. import matplotlib.pyplot as plt 
  6.  
  7.  
  8.  
  9. # 普通的神经网络学习 
  10.  
  11. # 学习训练类 
  12.  
  13. class Normal: 
  14.  
  15.  
  16.  
  17.     weight = [] 
  18.  
  19.     biases = [] 
  20.  
  21.  
  22.  
  23.     def __init__(self): 
  24.  
  25.         self.times = 1000 
  26.  
  27.         self.mnist = [] 
  28.  
  29.         self.session = tensorflow.Session() 
  30.  
  31.         self.xs = tensorflow.placeholder(tensorflow.float32, [None, 784]) 
  32.  
  33.         self.ys = tensorflow.placeholder(tensorflow.float32, [None, 10]) 
  34.  
  35.         self.save_path = 'learn/result/normal.ckpt' 
  36.  
  37.  
  38.  
  39.     def run(self): 
  40.  
  41.         self.import_data() 
  42.  
  43.         self.train() 
  44.  
  45.         self.save() 
  46.  
  47.  
  48.  
  49.     def _setWeight(self,weight): 
  50.  
  51.         self.weight = weight 
  52.  
  53.  
  54.  
  55.     def _setBiases(self,biases): 
  56.  
  57.         self.biases = biases 
  58.  
  59.  
  60.  
  61.     def _getWeight(self): 
  62.  
  63.         return self.weight 
  64.  
  65.  
  66.  
  67.     def _getBiases(self): 
  68.  
  69.         return self.biases 
  70.  
  71.     # 训练 
  72.  
  73.     def train(self): 
  74.  
  75.  
  76.  
  77.         prediction = self.add_layer(self.xs, 784, 10, activation_function=tensorflow.nn.softmax) 
  78.  
  79.  
  80.  
  81.         cross_entropy = tensorflow.reduce_mean( 
  82.  
  83.             -tensorflow.reduce_sum( 
  84.  
  85.                 self.ys * tensorflow.log(prediction) 
  86.  
  87.                 , reduction_indices=[1]) 
  88.  
  89.         ) 
  90.  
  91.         train_step = tensorflow.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 
  92.  
  93.  
  94.  
  95.         self.session.run(tensorflow.global_variables_initializer()) 
  96.  
  97.  
  98.  
  99.         for i in range(self.times): 
  100.  
  101.             batch_xs, batch_ys = self.mnist.train.next_batch(100) 
  102.  
  103.             self.session.run(train_step, feed_dict={self.xs: batch_xs, self.ys: batch_ys}) 
  104.  
  105.             if i % 50 == 0: 
  106.  
  107.                 # images 变换为 labels,images相当于x,labels相当于y 
  108.  
  109.                 accurary = self.computer_accurary( 
  110.  
  111.                     self.mnist.test.images, 
  112.  
  113.                     self.mnist.test.labels, 
  114.  
  115.                     prediction 
  116.  
  117.                 ) 
  118.  
  119.  
  120.  
  121.     # 数据导入 
  122.  
  123.     def import_data(self): 
  124.  
  125.         self.mnist = input_data.read_data_sets('MNIST_data', one_hot=True
  126.  
  127.  
  128.  
  129.     # 数据保存 
  130.  
  131.     def save(self): 
  132.  
  133.         saver = tensorflow.train.Saver() 
  134.  
  135.         path = saver.save(self.session,self.save_path) 
  136.  
  137.  
  138.  
  139.     # 添加隐藏层 
  140.  
  141.     def add_layer(self,inputs,input_size,output_size,activation_function=None): 
  142.  
  143.  
  144.  
  145.         weight = tensorflow.Variable(tensorflow.random_normal([input_size,output_size]),dtype=tensorflow.float32,name='weight'
  146.  
  147.  
  148.  
  149.         biases = tensorflow.Variable(tensorflow.zeros([1,output_size]) + 0.1,dtype=tensorflow.float32,name='biases'
  150.  
  151.         Wx_plus_b = tensorflow.matmul(inputs,weight) + biases 
  152.  
  153.  
  154.  
  155.         self._setBiases(biases) 
  156.  
  157.         self._setWeight(weight) 
  158.  
  159.  
  160.  
  161.         if activation_function is None: 
  162.  
  163.             outputs = Wx_plus_b 
  164.  
  165.         else
  166.  
  167.             outputs = activation_function(Wx_plus_b,) 
  168.  
  169.  
  170.  
  171.         return outputs 
  172.  
  173.  
  174.  
  175.  
  176.  
  177.     # 计算结果数据与实际数据的正确率 
  178.  
  179.     def computer_accurary(self,x_data,y_data,tf_prediction): 
  180.  
  181.  
  182.  
  183.         prediction = self.session.run(tf_prediction,feed_dict={self.xs:x_data,self.ys:y_data}) 
  184.  
  185.  
  186.  
  187.         # 返回两个矩阵中最大值的索引号位置,然后进行相应位置的值大小比较并在此位置设置为True/False 
  188.  
  189.         correct_predition = tensorflow.equal(tensorflow.argmax(prediction,1),tensorflow.argmax(y_data,1)) 
  190.  
  191.  
  192.  
  193.         # 进行数据格式转换,然后进行降维求平均值 
  194.  
  195.         accurary = tensorflow.reduce_mean(tensorflow.cast(correct_predition,tensorflow.float32)) 
  196.  
  197.  
  198.  
  199.         result = self.session.run(accurary,feed_dict={self.xs:x_data,self.ys:y_data}) 
  200.  
  201.  
  202.  
  203.         return result 
  204.  
  205.  
  206.  
  207. # 识别类 
  208.  
  209. class NormalRead(Normal): 
  210.  
  211.  
  212.  
  213.     input_size = 784 
  214.  
  215.     output_size = 10 
  216.  
  217.  
  218.  
  219.     def run(self): 
  220.  
  221.         self.import_data() 
  222.  
  223.         self.getSaver() 
  224.  
  225.         origin_input = self._getInput() 
  226.  
  227.         output = self.recognize(origin_input) 
  228.  
  229.  
  230.  
  231.         self._showImage(origin_input) 
  232.  
  233.         self._showOutput(output
  234.  
  235.         pass 
  236.  
  237.  
  238.  
  239.     # 显示识别结果 
  240.  
  241.     def _showOutput(self,output): 
  242.  
  243.         number = output.index(1) 
  244.  
  245.         print('识别到的数字:',number) 
  246.  
  247.  
  248.  
  249.     # 显示被识别图片 
  250.  
  251.     def _showImage(self,origin_input): 
  252.  
  253.         data = [] 
  254.  
  255.         tmp = [] 
  256.  
  257.         i = 1 
  258.  
  259.         # 原数据转换为可显示的矩阵 
  260.  
  261.         for v in origin_input[0]: 
  262.  
  263.             if i %28 == 0: 
  264.  
  265.                 tmp.append(v) 
  266.  
  267.                 data.append(tmp) 
  268.  
  269.                 tmp = [] 
  270.  
  271.             else
  272.  
  273.                 tmp.append(v) 
  274.  
  275.             i += 1 
  276.  
  277.  
  278.  
  279.         plt.figure() 
  280.  
  281.         plt.imshow(data, cmap='binary')  # 黑白显示 
  282.  
  283.         plt.show() 
  284.  
  285.  
  286.  
  287.  
  288.  
  289.     def _setBiases(self,biases): 
  290.  
  291.         self.biases = biases 
  292.  
  293.         pass 
  294.  
  295.  
  296.  
  297.     def _setWeight(self,weight): 
  298.  
  299.         self.weight = weight 
  300.  
  301.         pass 
  302.  
  303.  
  304.  
  305.     def _getBiases(self): 
  306.  
  307.         return self.biases 
  308.  
  309.  
  310.  
  311.     def _getWeight(self): 
  312.  
  313.         return self.weight 
  314.  
  315.  
  316.  
  317.     # 获取训练模型 
  318.  
  319.     def getSaver(self): 
  320.  
  321.         weight = tensorflow.Variable(tensorflow.random_normal([self.input_size, self.output_size]), dtype=tensorflow.float32,name='weight'
  322.  
  323.  
  324.  
  325.         biases = tensorflow.Variable(tensorflow.zeros([1, self.output_size]) + 0.1, dtype=tensorflow.float32, name='biases'
  326.  
  327.  
  328.  
  329.         saver = tensorflow.train.Saver() 
  330.  
  331.         saver.restore(self.session,self.save_path) 
  332.  
  333.  
  334.  
  335.         self._setWeight(weight) 
  336.  
  337.         self._setBiases(biases) 
  338.  
  339.  
  340.  
  341.     def recognize(self,origin_input): 
  342.  
  343.         input = tensorflow.placeholder(tensorflow.float32,[None,784]) 
  344.  
  345.         weight = self._getWeight() 
  346.  
  347.         biases = self._getBiases() 
  348.  
  349.  
  350.  
  351.         result = tensorflow.matmul(input,weight) + biases 
  352.  
  353.         resultSof = tensorflow.nn.softmax(result,) # 把结果集使用softmax进行激励 
  354.  
  355.         resultSig = tensorflow.nn.sigmoid(resultSof,) # 把结果集以sigmoid函数进行激励,用于后续分类 
  356.  
  357.         output = self.session.run(resultSig,{input:origin_input}) 
  358.  
  359.  
  360.  
  361.         output = output[0] 
  362.  
  363.  
  364.  
  365.         # 对识别结果进行分类处理 
  366.  
  367.         output_tmp = [] 
  368.  
  369.         for item in output
  370.  
  371.             if item < 0.6: 
  372.  
  373.                 output_tmp.append(0) 
  374.  
  375.             else : 
  376.  
  377.                 output_tmp.append(1) 
  378.  
  379.  
  380.  
  381.         return output_tmp 
  382.  
  383.  
  384.  
  385.     def _getInput(self): 
  386.  
  387.         inputs, y = self.mnist.train.next_batch(100); 
  388.  
  389.         return [inputs[50]] 

以上是程序,整个程序基于TensorFlow来实现的,具体的TensorFlow安装我就不说了。

整个训练过程不做多说,我发现网上关于训练的教程很多,但是训练结果的教程很少。

整个程序里,通过tensorflow.train.Saver()的save进行训练结果模型进行存储,然后再用tensorflow.train.Saver()的restore进行模型恢复然后取到训练好的weight和baises。

这里要注意的一个地方是因为一次性随机取出100张手写图片进行批量训练的,我在取的时候其实也是批量随机取100张,但是我传入识别的是一张,通过以下这段程序:

  1. def _getInput(self): 
  2.  
  3.         inputs, y = self.mnist.train.next_batch(100); 
  4.  
  5.         return [inputs[50]] 

注意一下return这里的数据结构,其实是取这批量的第50张,实际上这段程序写成:

  1. def _getInput(self): 
  2.  
  3.         inputs, y = self.mnist.train.next_batch(1); 
  4.  
  5.         return [inputs[0]] 

会更好。

因为识别的时候是需要用到训练的隐藏层来进行的,所以在此我虽然识别的是一张图片,但是我必须要传入一个批量数据的这样一个结构。

然后再识别的地方,我使用了两个激励函数:

  1. resultSof = tensorflow.nn.softmax(result,) # 把结果集使用softmax进行激励 
  2.  
  3. resultSig = tensorflow.nn.sigmoid(resultSof,) # 把结果集以sigmoid函数进行激励,用于后续分类 

这里的话,第一个softmax激励后的数据我发现得到的是以e为底的指数形式,转换成普通的浮点数来看,不是很清楚到底是什么,那么我在做识别数字判断的时候就不方便,所以再通过了一次sigmoid的激励。

后续我通过一个循环判断进行一次实际上的分类,这个原因首先要说到识别结果形式:

  1. [0,0,0,0,0,0,0,0,1,0] 

像以上这个数据,表示的是8,也就是说,数组下表第几位为1就表示是几,如0的表示:

  1. [1,0,0,0,0,0,0,0,0,0] 

而sigmoid函数在这个地方其实就是对每个位置的数据进行了分类,我发现如果分类值小于0.52这样的数据其实代表的是否,也就是说此位置的值对应的是0,大于0.52应该对应的是真,也就是1;而我在程序里取的是0.6为界限做判断。

实际上,这个界限值应该是在神经网络训练的时候取的,而不是看识别结果来进行凭感觉取的(虽然训练的时候的参数也是凭感觉取的)

这篇文章是我根据个人的一些理解来写的,后续如果发现有错误,我会在新文章说出来,但这篇文章不做保留,方便后续检查思考记录的时候知道到底怎么踩坑的。

以下是我上次写的sigmoid函数的文章:

https://segmentfault.com/a/11...

关于其他激励函数,可以网上找资料进行了解,很多基础性的数学知识,放到一些比较具体的应用,会显得非常的有意思。

【编辑推荐】

  1. 一种Python全局配置规范以及其魔改
  2. 30行Python代码刷王者荣耀金币
  3. 2017年大数据年终盘点:开源工具、MySQL和Python是最大赢家!
  4. 别@微信团队了,我用Python给自己戴上了圣诞帽!
  5. 5个酷毙的Python工具
【责任编辑:庞桂玉 TEL:(010)68476606】

点赞 0
分享:
大家都在看
猜你喜欢

读 书 +更多

高质量程序设计指南:C++/C语言(第3版)

本书以轻松幽默的笔调向读者论述了高质量软件开发方法与C++/C编程规范。它是作者多年从事软件开发工作的经验总结。本书共17章,第1章到第4...

订阅51CTO邮刊

点击这里查看样刊

订阅51CTO邮刊