梯度下降原理讲解(超基础)

需要的背景知识

-高中导数的知识 -高数中偏导数概念

话不多说,直接开始

梯度下降在机器学习和深度学习中都是比较重要的,主要用在训练网络参数更新过程中,说白了就是损失函数找最小值的过程。 直接举个栗子吧

正文开始,别眨眼。。。。。。

假设我们需要拟合一组这样的数据(2.8,3.2) (4.1,5.8)… 目标任务是得到一个关系y=wx+b,能够满足这组数据,类比到分类任务上就是如图,输入一个猫,得到标签“cat”,我们这里时输入2,8得到3.2,这个就很好理解。 ok,对于训练过程,网络在每一次输入一个x后,都会有一个y输出,同时为了拟合真实的关系,也会得到一个误差e,如下图。 我们将所有的e进行求和,得到全局误差总量 由于误差有正有负,这里对误差进行平方,得到这个 其中的x和y时我们丢进去训练的参数,就是图片和标签,这是已知的,这时候就可以将损失函数进一步改写: 这个函数Z就成为了我们常说的损失函数的表达式,只不过这是一个低维的,真实训练过程中的损失函数是多维的,那么对于我们这样一个函数,利用MATLAB可视化后是这样。 我们的目标是,求得所有误差e最小甚至等于0的时候,对应的参数,那么这个参数就是我们的w和b,进而目标就变成了求这个函数的最小值。 以 f(x)=x²+2 为例,我任意的输入两个数 x= 2.8, x=3.2,计算一下可以知道x=2.8的时候,值比x=3.2小,那么我下一次更新就是在2.8的基础上往下减, 不过计算机并不知道怎么减小合适,于是有了这么一个公式: 这里的η是学习率,你带入几个数简单计算一下就会发现,这个公式会让x始终朝着最小值的方向移动更新,这个公式就是梯度下降。 我带入算一下,xn=2.8,η设为0.1,f(x)的倒数未2x,那么xn+1 = 2.8 - 0.1×(2×2.8)=2.24. 对于高维的函数,比如二元二次方程 z = f(x,y) = x²+8y²,他就需要在x和y两个方向上去更新,他的偏导数如下: 以上就是梯度下降的基本原理,欢迎交流~

经验分享 程序员 微信小程序 职场和发展