0%

深入剖析 fully_connected_feed.py

简介

fully_connected_feed.py 是 tensorflow 中训练 MINIST 数据集的简单示例,通过对该代码的学习,可以帮助我们更好得理解 python 的语法以及 tensorflow 的工作流程。

代码详细注释

由于Python是由社区推动的开源并且免费的开发语言,不受商业公司控制,因此,Python的改进往往比较激进,不兼容的情况时有发生。Python为了确保你能顺利过渡到新版本,特别提供了future模块,让你在旧的版本中试验新版本的一些特性。

1
2
3
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
1
2
3
4
5
# pylint: disable=missing-docstring
import argparse
import os
import sys
import time
1
2
3
4
5
6
7
8
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.examples.tutorials.mnist import mnist

# Basic model parameters as external flags.
FLAGS = None
1
2
3
4
5
def placeholder_inputs(batch_size):
images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,
mnist.IMAGE_PIXELS))
labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
return images_placeholder, labels_placeholder
1
2
3
4
5
6
7
8
def fill_feed_dict(data_set, images_pl, labels_pl):
images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,
FLAGS.fake_data)
feed_dict = {
images_pl: images_feed,
labels_pl: labels_feed,
}
return feed_dict
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_set):
true_count = 0 # Counts the number of correct predictions.
steps_per_epoch = data_set.num_examples // FLAGS.batch_size
num_examples = steps_per_epoch * FLAGS.batch_size
for step in xrange(steps_per_epoch):
feed_dict = fill_feed_dict(data_set,
images_placeholder,
labels_placeholder)
true_count += sess.run(eval_correct, feed_dict=feed_dict)
precision = float(true_count) / num_examples
print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' %
(num_examples, true_count, precision))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def run_training():
with tf.Graph().as_default():
# Generate placeholders for the images and labels.
images_placeholder, labels_placeholder = placeholder_inputs(
FLAGS.batch_size)

# Build a Graph that computes predictions from the inference model.
logits = mnist.inference(images_placeholder, FLAGS.hidden1, FLAGS.hidden2)

# Add to the Graph the Ops for loss calculation.
loss = mnist.loss(logits, labels_placeholder)

# Add to the Graph the Ops that calculate and apply gradients.
train_op = mnist.training(loss, FLAGS.learning_rate)

# Add the Op to compare the logits to the labels during evaluation.
eval_correct = mnist.evaluation(logits, labels_placeholder)

# Build the summary Tensor based on the TF collection of Summaries.
summary = tf.summary.merge_all()

# Add the variable initializer Op.
init = tf.global_variables_initializer()

# Create a saver for writing training checkpoints.
saver = tf.train.Saver()

# Create a session for running Ops on the Graph.
sess = tf.Session()

# Instantiate a SummaryWriter to output summaries and the Graph.
summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)

# And then after everything is built:

# Run the Op to initialize the variables.
sess.run(init)

# Start the training loop.
for step in xrange(FLAGS.max_steps):
start_time = time.time()

# Fill a feed dictionary with the actual set of images and labels
# for this particular training step.
feed_dict = fill_feed_dict(data_sets.train, images_placeholder, labels_placeholder)

_, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

duration = time.time() - start_time

if step % 100 == 0:
# Print status to stdout.
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
# Update the events file.
summary_str = sess.run(summary, feed_dict=feed_dict)
summary_writer.add_summary(summary_str, step)
summary_writer.flush()

if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')
saver.save(sess, checkpoint_file, global_step=step)
# Evaluate against the training set.
print('Training Data Eval:')
do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.train)
# Evaluate against the validation set.
print('Validation Data Eval:')
do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.validation)
# Evaluate against the test set.
print('Test Data Eval:')
do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test)
1
2
3
4
5
6
7
8
def main(_):
# 把之前的 log 文件删除
if tf.gfile.Exists(FLAGS.log_dir):
tf.gfile.DeleteRecursively(FLAGS.log_dir)
# 重新创建 log 文件
tf.gfile.MakeDirs(FLAGS.log_dir)
# 正式开始训练
run_training()

Python 中的 argparse 模块主要负责命令行解析。其主要作用是在 python 调用 .py 脚本文件是可以传入参数。比如对于以下程序,我们可以在命令行输入 python fully_connected_feed.py --learning_rate 0.1 --max_steps 200。其中 ArgumentParser() 函数的作用是定义一个 parser 实例。parser.add_argument() 函数的作用是添加输入命令参数,它的第一个参数为命令参数名称,第二个参数为命令参数类型,第三个参数为命令参数默认值,如果没有输入,那么取默认值,第四个参数为调用 help 时显示的内容。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# if 的作用是保证此文件被 import 时,不会执行 main 里面的内容
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--learning_rate',
type=float,
default=0.01,
help='Initial learning rate.'
)
parser.add_argument(
'--max_steps',
type=int,
default=2000,
help='Number of steps to run trainer.'
)
parser.add_argument(
'--hidden1',
type=int,
default=128,
help='Number of units in hidden layer 1.'
)
parser.add_argument(
'--hidden2',
type=int,
default=32,
help='Number of units in hidden layer 2.'
)
parser.add_argument(
'--batch_size',
type=int,
default=100,
help='Batch size. Must divide evenly into the dataset sizes.'
)
parser.add_argument(
'--input_data_dir',
type=str,
default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
'tensorflow/mnist/input_data'),
help='Directory to put the input data.'
)
parser.add_argument(
'--log_dir',
type=str,
default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
'tensorflow/mnist/logs/fully_connected_feed'),
help='Directory to put the log data.'
)
parser.add_argument(
'--fake_data',
default=False,
help='If true, uses fake data for unit testing.',
action='store_true'
)

# parse_known_args() 函数的作用是把传入的参数读出来,并放到 FLAGS 里面,unparsed 里面存的是默认值。
FLAGS, unparsed = parser.parse_known_args()
# 正式激活 main() 函数。
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

本文标题:深入剖析 fully_connected_feed.py

文章作者:Zhikun Zhang

发布时间:2017年09月21日 - 10:07:41

最后更新:2020年05月16日 - 01:49:13

原始链接:http://zhangzhk.com/2017/09/21/understanding-fully-connected-feed/

许可协议: 署名-非商业性使用-禁止演绎 4.0 国际 转载请保留原文链接及作者。