import tensorflow as tf #导入tensorflow库
from tensorflow.examples.tutorials.mnist import input_data
import pylab
# 1 加载数据
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
print("训练数据总数:" , mnist.train.num_examples)
print(mnist.test.nu)
# 2 函数用于清除默认图形堆栈并重置全局默认图形
tf.reset_default_graph()
# 3 tf Graph Input 输入输出占位
x = tf.placeholder(tf.float32, [None, 784]) # mnist data维度 28*28=784
y = tf.placeholder(tf.float32, [None, 10]) # 0-9 数字=> 10 classes
# 4 Set model weights 权重占位,
W = tf.Variable(tf.random_normal([784, 10]))
b = tf.Variable(tf.zeros([10]))
# 5 前向输出
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax分类 # 激活函数,输出为总输出的概率, 总和为1,
# 6 生成的pred与样本标签y进行一次交叉熵运算, 然后在取平均值, 注意这里的交叉熵运行在上一节已经提到过
# 将这个结果作为一次正向传播的误差, 通过梯度下降的优化方法找到能够使这个误差最小化的b,w偏移量
# 更细b,w参数, 使其调整为合适的参数
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))
# 7 使用梯度下降优化器
learning_rate = 0.01 # 学习率, 对于梯度下降法的输入参数
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# 8 训练参数设置
training_epochs = 25
batch_size = 100
display_step = 1
saver = tf.train.Saver()
model_path = "log/521model.ckpt"
# 9 启动session
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())# Initializing OP
# 启动循环开始训练
for epoch in range(training_epochs): # 循环25次 25*550次循环,每次循环喂100
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size) # total_batch=55000/100=550 一次训练
# 注意下面的for是训练一遍数据的
# 遍历全部数据集
for _1 in range(total_batch): # 1 - 550
batch_xs, batch_ys = mnist.train.next_batch(batch_size) # 拿出来100个喂给训练兽
# Run optimization op (backprop) and cost op (to get loss value)
_2, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,
y: batch_ys})
# Compute average loss 计算平均loss值,
avg_cost += (c / total_batch)
# 显示训练中的详细信息
if (epoch+1) % display_step == 0:
print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))
print( " Finished!")
# 测试 model
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
# 计算准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print ("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
# Save model weights to disk
save_path = saver.save(sess, model_path)
print("Model saved in file: %s" % save_path)