kok202
Tensorflow - Mnist CNN 분류

2019. 1. 24. 21:48[정리] 직무별 개념 정리/딥러닝

import tensorflow as tf

from tensorflow.examples.tutorials.mnist import input_data






Mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

NUM_CLASSES = 10

TRAIN_EPOCH = 10

BATCH_SIZE  = 100

LEARNING_RATE = 0.001 # 0.1 : Overshooting






x  = tf.placeholder(tf.float32, [None, 28*28])

y_ = tf.placeholder(tf.float32, [None, NUM_CLASSES])


# Convolutional Layer

x_img = tf.reshape(x, [-1, 28, 28, 1])

LayerConv1_f = tf.Variable(tf.random_normal([3, 3, 1, 32], stddev=0.01))

LayerConv1_c = tf.nn.conv2d(x_img, LayerConv1_f, strides=[1, 1, 1, 1], padding='SAME')

LayerConv1_h = tf.nn.relu(LayerConv1_c)

LayerConv1_p = tf.nn.max_pool(LayerConv1_h, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')


LayerConv2_f = tf.Variable(tf.random_normal([3, 3, 32, 64], stddev=0.01))

LayerConv2_c = tf.nn.conv2d(LayerConv1_p, LayerConv2_f, strides=[1, 1, 1, 1], padding='SAME')

LayerConv2_h = tf.nn.relu(LayerConv2_c)

LayerConv2_p = tf.nn.max_pool(LayerConv2_h, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')


# Fully Layer

FullyX = tf.reshape(LayerConv2_p, [-1, 7*7*64])

LayerFully1_w = tf.Variable(tf.random_normal([7*7*64, 10]))

LayerFully1_b = tf.Variable(tf.random_normal([10]))

hypo = tf.matmul(FullyX, LayerFully1_w) + LayerFully1_b






cost = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(logits=hypo, labels=y_))

train = tf.train.AdamOptimizer(LEARNING_RATE).minimize(cost)

check = tf.equal(tf.arg_max(hypo, 1), tf.arg_max(y_, 1))

accrcy = tf.reduce_mean(tf.cast(check, tf.float32))






sess = tf.Session()

sess.run(tf.global_variables_initializer())

for epoch in range(TRAIN_EPOCH):

    avg_cost = 0

    total_batch = int(Mnist.train.num_examples / BATCH_SIZE)

    for i in range(total_batch):

        batch_xs, batch_ys = notMnist.train.next_batch(BATCH_SIZE)

        c, _ = sess.run([cost, train], feed_dict={x: batch_xs, y_:batch_ys})

        avg_cost += c / total_batch

    print('Epoch : ', '%04d' %(epoch+1), '\tcost : ', '{:.9f}'.format(avg_cost))

print('test accuracy : ', accrcy.eval(session=sess, feed_dict={x:notMnist.test.images, y_:notMnist.test.labels}))

sess.close()

'[정리] 직무별 개념 정리 > 딥러닝' 카테고리의 다른 글

CNN Calculater  (0) 2019.01.24
KLDivergence  (0) 2019.01.24
pyTorch - Mnist VAE  (0) 2019.01.24