tf.GradientTape梯度求解利器的示例分析

这篇文章主要介绍tf.GradientTape梯度求解利器的示例分析,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!

我们提供的服务有:网站建设、成都网站建设、微信公众号开发、网站优化、网站认证、遵义ssl等。为1000多家企事业单位解决了网站和推广的问题。提供周到的售前咨询和贴心的售后服务,是有科学管理、有技术的遵义网站制作公司

tf.GradientTape定义在tensorflow/python/eager/backprop.py文件中,从文件路径也可以大概看出,GradientTape是eager模式下计算梯度用的,而eager模式(eager模式的具体介绍请参考文末链接)是TensorFlow 2.0的默认模式,因此tf.GradientTape是官方大力推荐的用法。下面就来具体介绍GradientTape的原理和使用。

Tape在英文中是胶带,磁带的含义,用在这里是由于eager模式带来的影响。在TensorFlow 1.x静态图时代,我们知道每个静态图都有两部分,一部分是前向图,另一部分是反向图。反向图就是用来计算梯度的,用在整个训练过程中。而TensorFlow 2.0默认是eager模式,每行代码顺序执行,没有了构建图的过程(也取消了control_dependency的用法)。但也不能每行都计算一下梯度吧?计算量太大,也没必要。因此,需要一个上下文管理器(context manager)来连接需要计算梯度的函数和变量,方便求解同时也提升效率。

举个例子:计算y=x^2在x = 3时的导数:

x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  y = x * x


# y’ = 2*x = 2*3 = 6

dy_dx = g.gradient(y, x) 

例子中的watch函数把需要计算梯度的变量x加进来了。GradientTape默认只监控由tf.Variable创建的traiable=True属性(默认)的变量。上面例子中的x是constant,因此计算梯度需要增加g.watch(x)函数。当然,也可以设置不自动监控可训练变量,完全由自己指定,设置watch_accessed_variables=False就行了(一般用不到)。

GradientTape也可以嵌套多层用来计算高阶导数,例如:

x = tf.constant(3.0)
with tf.GradientTape() as g:
  g.watch(x)
  with tf.GradientTape() as gg:
    gg.watch(x)
    y = x * x

  # y’ = 2*x = 2*3 =6   

  dy_dx = gg.gradient(y, x)

# y’’ = 2    
d2y_dx2 = g.gradient(dy_dx, x) 

另外,默认情况下GradientTape的资源在调用gradient函数后就被释放,再次调用就无法计算了。所以如果需要多次计算梯度,需要开启persistent=True属性,例如:

x = tf.constant(3.0)
with tf.GradientTape(persistent=True) as g:
  g.watch(x)
  y = x * x
  z = y * y

# z = y^2 = x^4, z’ = 4*x^3 = 4*3^3

dz_dx = g.gradient(z, x)  

# y’ = 2*x = 2*3 = 6

dy_dx = g.gradient(y, x)  

del g  # 删除这个上下文胶带

最后,一般在网络中使用时,不需要显式调用watch函数,使用默认设置,GradientTape会监控可训练变量,例如:

with tf.GradientTape() as tape:
    predictions = model(images)
    loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)

这样即可计算出所有可训练变量的梯度,然后进行下一步的更新。对于TensorFlow 2.0,推荐大家使用这种方式计算梯度,并且可以在eager模式下查看具体的梯度值。

以上是“tf.GradientTape梯度求解利器的示例分析”这篇文章的所有内容,感谢各位的阅读!希望分享的内容对大家有帮助,更多相关知识,欢迎关注创新互联行业资讯频道!

网页题目:tf.GradientTape梯度求解利器的示例分析
网站网址:https://www.cdcxhl.com/article6/gijoog.html

成都网站建设公司_创新互联,为您提供建站公司搜索引擎优化品牌网站建设自适应网站外贸网站建设网站设计公司

广告

声明:本网站发布的内容(图片、视频和文字)以用户投稿、用户转载内容为主,如果涉及侵权请尽快告知,我们将会在第一时间删除。文章观点不代表本网站立场,如需处理请联系客服。电话:028-86922220;邮箱:631063699@qq.com。内容未经允许不得转载,或转载时需注明来源: 创新互联

成都定制网站建设