1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
| def run_training(): with tf.Graph().as_default(): images_placeholder, labels_placeholder = placeholder_inputs( FLAGS.batch_size)
logits = mnist.inference(images_placeholder, FLAGS.hidden1, FLAGS.hidden2)
loss = mnist.loss(logits, labels_placeholder)
train_op = mnist.training(loss, FLAGS.learning_rate)
eval_correct = mnist.evaluation(logits, labels_placeholder)
summary = tf.summary.merge_all()
init = tf.global_variables_initializer()
saver = tf.train.Saver()
sess = tf.Session()
summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
sess.run(init)
for step in xrange(FLAGS.max_steps): start_time = time.time()
feed_dict = fill_feed_dict(data_sets.train, images_placeholder, labels_placeholder) _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
duration = time.time() - start_time
if step % 100 == 0: print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) summary_str = sess.run(summary, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) summary_writer.flush()
if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt') saver.save(sess, checkpoint_file, global_step=step) print('Training Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.train) print('Validation Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.validation) print('Test Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_sets.test)
|