用tensorflow搭建全连接神经网络实现mnist数据集的识别-创新互联

I 前向传播网络搭建

创新互联公司坚持“要么做到,要么别承诺”的工作理念,服务领域包括:成都网站制作、成都做网站、企业官网、英文网站、手机端网站、网站推广等服务,满足客户于互联网时代的沧县网站设计、移动媒体设计的需求,帮助企业找到有效的互联网解决方案。努力成为您成熟可靠的网络建设合作伙伴!

在mnist_forward.py中搭建两层全连接网络,这里面就是定义层数,节点数,激活函数这些。

输入节点数目就是mnist数据集的图片28*28大小,用784行的向量作为输入。

第一层y1=relu(x*w1+b1 )其中y1为500行的向量。那么w1里面就有784*500个变量啦~~b1是500个变量。然后经过一个relu激活函数。

第二层就是从500节点变换到10个节点的输出,输出为标签,表示0-9手写数字出现的概率。y=y1*w2+b2。w2就是500*10的矩阵。b2是10行的向量。没有激活函数。

这里面w1 b1 w2 b2就是要训练的参数

采用了正则化

正则化就是在损失函数中给每个参数w加上权重,引入模型复杂度指标,从而抑制模型噪声,减少过拟合。这里使用的是L2正则化,即w的L2范数也是loss的一部分,也就是说在求解最优w的过程中,要使得w的值尽量在0附近。

用tensorflow搭建全连接神经网络实现mnist数据集的识别

import tensorflow as tf

INPUT_NODE = 784

OUTPUT_NODE = 10

LAYER1_NODE = 500

def get_weight(shape,regularizer):

w = tf.Variable(tf.truncated_normal(shape,stddev=0.1))

# 截断正态分布

if regularizer != None: tf.add_to_collection('losses',tf.contrib.layers.l2_regularizer(regularizer)(w))

# 使用正则化 L2范数 将每个参数的正则化损失加到总损失中

return w

def get_bias(shape):

b = tf.Variable(tf.zeros(shape))

return b

def forward(x,regularizer):

w1 = get_weight([INPUT_NODE,LAYER1_NODE],regularizer)

b1 = get_bias([LAYER1_NODE])

y1 = tf.nn.relu(tf.matmul(x,w1) + b1)

w2 = get_weight([LAYER1_NODE,OUTPUT_NODE],regularizer)

b2 = get_bias([OUTPUT_NODE])

y = tf.matmul(y1,w2) + b2

return y

II误差反向传播

在mnist_backward.py中读入mnist数据集,计算误差,进行误差反向传播,实现模型的训练,得到网络参数并保存在模型中

2.1 loss

loss的计算先用softmax把输出的10行向量变成概率分布,再与真实的输出标签进行对比,求交叉熵。cross entropy 可以看作是两个概率分布函数之间的距离。距离越小,说明预测越准确,loss越小。

2.2 学习率

学习率是每次沿着梯度下降方向进行参数更新的步长,步长过大会导致在最优点震荡,步长过小会导致学习速度太慢。这里采用了指数衰减的步长。在训练初始阶段,步长较大,较快收敛,在最优点附近,步长较小,能够得到较精确的最优解。

2.3 滑动平均

记录一段时间内模型中所有参数w和b的各自的平均值。用于增强模型的泛化能力。

import tensorflow as tf

import mnist_forward

import os无锡妇科医院 http://www.bhnnk120.com/

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

from tensorflow.examples.tutorials.mnist import input_data

BATCH_SIZE = 200 #每次输入的图片数

LEARNING_RATE_BASE = 0.1 #初始学习率

LEARNING_RATE_DECAY = 0.99 #学习率衰减率

REGULARIZER = 0.0001 #正则化系数

STEPS = 10000 #训练轮数

MOVING_AVERAGE_DECAY = 0.99

MODEL_SAVE_PATH="./model/"

MODEL_NAME = "mnist_model"

def backward(mnist):

x = tf.placeholder(tf.float32,[None,mnist_forward.INPUT_NODE])

y_ = tf.placeholder(tf.float32,[None,mnist_forward.OUTPUT_NODE])

y = mnist_forward.forward(x,REGULARIZER)

global_step = tf.Variable(0,trainable = False)

# step计数 不可训练的参数

ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits = y, labels = tf.argmax(y_,1))

cem = tf.reduce_mean(ce)

loss = cem + tf.add_n(tf.get_collection('losses'))

learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples/BATCH_SIZE,LEARNING_RATE_DECAY,staircase=True)

train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step = global_step)

ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step)

ema_op = ema.apply(tf.trainable_variables())

# ema.apply()对括号内参数求滑动平均

# tf.trainable_variables() 将所有可以被训练的参数汇总为list 也就是[w1 b1 w2 b2]

with tf.control_dependencies([train_step, ema_op]):

train_op = tf.no_op(name='train')

# 该函数实现将滑动平均和训练过程同步运行。

saver = tf.train.Saver()

with tf.Session() as sess:

init_op = tf.global_variables_initializer()

sess.run(init_op)

for i in range(STEPS):

xs,ys = mnist.train.next_batch(BATCH_SIZE)

_, loss_value, step = sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys})

if i%1000 == 0:

print("After %d training steps, loss on training batch is %g." %(step,loss_value))

saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step)

if __name__ == '__main__':

mnist = input_data.read_data_sets('./data/',one_hot=True)

backward(mnist)

III 运行代码

在Terminal里面激活tensorflow,运行python mnist_backward.py

就可以输出训练过程的loss,每1000步打印一次loss。从下图可以看出,loss逐渐减小。

另外有需要云服务器可以了解下创新互联cdcxhl.cn,海内外云服务器15元起步,三天无理由+7*72小时售后在线,公司持有idc许可证,提供“云服务器、裸金属服务器、高防服务器、香港服务器、美国服务器、虚拟主机、免备案服务器”等云主机租用服务以及企业上云的综合解决方案,具有“安全稳定、简单易用、服务可用性高、性价比高”等特点与优势,专为企业上云打造定制,能够满足用户丰富、多元化的应用场景需求。

分享名称:用tensorflow搭建全连接神经网络实现mnist数据集的识别-创新互联
文章出自:https://www.cdcxhl.com/article30/cshcpo.html

成都网站建设公司_创新互联,为您提供动态网站网站导航网站营销关键词优化网站设计公司全网营销推广

广告

声明:本网站发布的内容(图片、视频和文字)以用户投稿、用户转载内容为主,如果涉及侵权请尽快告知,我们将会在第一时间删除。文章观点不代表本网站立场,如需处理请联系客服。电话:028-86922220;邮箱:631063699@qq.com。内容未经允许不得转载,或转载时需注明来源: 创新互联

成都做网站