本次实战使用的数据集是mnist。tensorflow提供了一个库,可以直接用在下载MNIST,见下面代码。
1 | from tensorflow.examples.tutorials.mnist import input_data |
运行上面的代码,会自动下载数据集并将文件解压到当前代码所在的同级目录下。one_hot=True表示将样本标签转化为one-hot编码
1 | #返回各子集样本数 |
1 | #返回标签,第1张图片的one-hot编码 |
1 | import tensorflow as tf |
上面的代码训练效果实在惨不忍睹,训练了半天还是连20%都没过。所以,又去抄了大佬的代码学习,他分成三个文件:mnist_forward、mnist_backward、mnist_test。
mnist_forward
1 | import tensorflow as tf |
mnist_backward
1 | import tensorflow as tf |
mnist_test
1 | #coding:utf-8 |