rbm
RBM
RBM 也就是 Restricted Boltzmann Machine ,这个概念由来已久了。我是在《人工神经网络》课上对这个有了一些了解。这个问题之前我知道,是来源于 DBN 的 pretrain 和 DBMs 。在深度学习中, RBM 可以说是具有很强的推动作用,记得上《模式识别》课时,老师在讲到深度神经网络使用 BP 算法的时候出现 Gradient vanish ,导致深度的网络是无法训练的, RBM 的可以说是作为训练 DBN 的救世主的形象出现的,通过 RBM 进行 pretrain 可以有效的解决 Gradient vanish 的问题。所以说这个 RBM 是个很重要的概念,那我们来看看这个东西吧。
Structure
RBM structure
玻尔兹曼机是一个双层的,层间是一个全连的结构,层内也是全连的结构,但是这个是一个很复杂的结构。秉着简化模型的理念,就出现了限制玻尔兹曼机,限制就是去掉层内部的链接,简化了之后的模型看起来就很想我们常见的神经网络的模型。如上图所示,我们可以看出这个结构和我们所接触的前馈的神经网络还是有所不同的,因为层间的链接是一个无向边,也就是说这是一个对传的网络结构。
我认为去掉层内之间的链接,就使得,这个模型不在 model 一个观测里面各个特征之间的关系,而是 model 观测和观测之间的关系,所以这个模型可以看作是无监督的特征的提取。
但是我们还有走的稍微深入一点,来追根溯源一下,这个模型本来是来自什么地方呢。
Dynamic system
这个问题可以追溯到动态系统那个地方去,为什么这么说呢?其实我也不知道,《人工神经网络》课上老师告诉我的。所以我们就用动态系统,或者暂时把自己当做是控制系的学生来看看这个问题吧。
第一个重要的理论是 lyapunov function ,这个函数可以说是动态系统的一个最重要的指标了。为什么这么说呢?因为这个函数定义了一个动态系统稳定的条件。下面我们就来看看到底是怎么定义的吧。
对于一个动态系统,我们都可以定义一个能量函数$\mathcal{U}$,如果一个系统:
$$\frac{\partial \mathcal{U}}{ \partial t } \leq 0$$
我们就说这个系统是稳定的,也就是说在经过一段时间后这个是同会收敛到稳定状态。这里需要注意的是这里的是能量函数对时间的导数,这里描述的是一个动态的变化的量之间的关系。
但是我们要用我们机器学习的观点来看这个问题,一般我们搞出一个模型,是必要伴随着参数学习的过程。也就是说,我们的模型变化的是参数,是参数在随着时间在变化,那我们怎么把上面的问题迁移到我们的问题上呢?数学上对这个已经有很好的办法了那就是链式求导法:
$$\frac{\partial \mathcal{U}}{ \partial t } =\frac{\partial \mathcal{U}}{ \partial \omega } \frac{\partial \omega}{ \partial t } \leq 0 $$。
这样我们就很好的对应到我们学习的过程中了,$\frac{\partial \mathcal{U}}{ \partial \omega }$是我们设计的能量函数对参数的固有导数,$\frac{\partial \omega}{ \partial t }$是我们学习过程中参数的变化过程。
知道了这么多了,我们来 specialized 这个模型到 RBM 。
Prove the stability of RBM
第一步,我们来定义 RBM 的能量函数:
$$\mathcal{U}=-\sum_{j \in \mathbb{V}}a_j v_j - \sum_{i \in \mathbb{H}} b_i h_i -\sum_{i \in \mathbb{H},j \in \mathbb{V}}v_j w_{ji} h_i$$
下面我们就来 证明这个系统的稳定性。 PS :这里用到的参数更新的公式来至于 Hinton 老爷子关于 RBM 的文章。
(1)让我们来看第一个参数$b$:
\begin{equation}\label{st}
\begin{split}
\frac{\partial \mathcal{U}}{ \partial a_j } & =-v_j^k \
\frac{\partial a_j}{ \partial t } & =v_j^k-v_j^{k+1}\
\frac{\partial \mathcal{U}}{ \partial t } & = \frac{\partial \mathcal{U}}{ \partial a_j } \frac{\partial a_j}{ \partial t }
=-v_j^k(v_j^k-v_j^{k+1})
\end{split}
\end{equation}
那我们就讨论一下这个东西到底是不是小于等于0的。首先需要知道的是这个网络是二值的网络,也就是说 v,h 的取值是在 { 1,0} 的。所以:
1.$v_j^k=0$,这个我们可以很轻松的得到:
$$\frac{\partial \mathcal{U}}{ \partial t } =\frac{\partial \mathcal{U}}{ \partial a_j } \frac{\partial a_j}{ \partial t }=0(0-v_j^{k+1})=0 \leq 0$$
2.$v_j^k=1$,然后,无论$v_j^{k+1}$取什么值,$v_j^k-v_j^{k+1} \geq 0$,同样我们也就知道了:
$$\frac{\partial \mathcal{U}}{ \partial t } =\frac{\partial \mathcal{U}}{ \partial a_j } \frac{\partial a_j}{ \partial t }=-1(v_j^k-v_j^{k+1}) \leq 0$$
同理,我们可以证明$\frac{\partial \mathcal{U}}{ \partial b_j } \frac{\partial b_j}{ \partial t }$和$\frac{\partial \mathcal{U}}{ \partial w_{ji} } \frac{\partial w_{ji}}{ \partial t }$都是满足之前我们定义的稳定性的条件的。
既然我们知道了这个系统的稳定性了,那我们就转换一下身份吧,不要再用动态系统的观点来看这个问题了。下面我们将从概率的角度来看这个问题。
Why it call Boltzmann?
这个我记得,我上物理课的时候,听说了波尔兹曼分布这个东西。感觉就是一个和一个系统的能量有关的概率问题。那我们就定义一下在这里的概率分布:
\begin{equation}\label{blz}
P(v,h)=\frac{1}{\mathcal{Z}} e^{-\mathcal{U}}
\end{equation}
由于这里涉及的是概率的问题,所以毫无疑问是要归一化的$\mathcal{Z}$就是归一化因子。
\begin{equation}\label{zz}
\mathcal{Z}=\sum_{v,h}e^{-\mathcal{U}}
\end{equation}
用概率的观点来看问题,一切就变得奇怪起来了。以前,我们知道了$v$就可以根据非线性变换就可以得到$h$了,但是现在不是这样了,我们根据概率$P(h=1 \mid v)$来得到$h$。所以我就说,一切变得奇怪起来了。
既然是要按照一定的概率来更新,那我们就得知道这个概率到底是什么呢?好吗,那我们来推导一下:
\begin{equation}\label{w3}
\begin{split}
P(h_i=1 \mid v)
&= P(h_i=1 \mid h_{-i},v) \\
& =\frac{P(h_i=1 , h_{-i},v)}{P(h_{-i},v)}\\
&=\frac{P(h_i=1 , h_{-i},v)}{P(h_i=1 , h_{-i},v) + P(h_i=0 , h_{-i},v)}\\
&=\frac{\frac{1}{\mathcal{Z}}e^{\mathcal{H}_{-i} +(b_i+\sum_j v_j w_{ji})}}
{\frac{1}{\mathcal{Z}}e^{ \mathcal{H}_{-i}+(b_i+\sum_j v_j w_{ji})}+\frac{1}{\mathcal{Z}}e^{ \mathcal{H}_{-i}+0}}\\
&=\frac{1}{1+e^{-(b_i+\sum_j v_j w_{ji})}}\\
&=sigmoid(b_i+\sum_j v_j w_{ji})
\end{split}
\end{equation}
其中:
\begin{equation}\label{w5}
\mathcal{H}_{-i}=\sum_{j \in \mathbb{V}}a_j v_j + \sum_{i \in \mathbb{H}-h_i} b_i h_i +\sum_{i \in \mathbb{H}-h_i,j \in \mathbb{V}}v_j w_{ji} h_i
\end{equation}
同理可证:
\begin{equation}\label{w4}
P(v_j=1 \mid h)=sigmoid(a_j + \sum_i w_{ji}h_i)
\end{equation}
写了这么多,感觉跟饶了一大圈又回到了原点一样。多眼熟的 sigmoid 函数,但是在这里是不一样的,这里是不一样的,这里的输出不是 sigmoid 的输出,这里的 sigmoid 代表的是输出为1的概率,输出到底是什么呢?这样牵扯概率的更新,一下子就与众不同了。
Loss function
但是我们这里并不是关注的礼盒概率分布$P(v,h)$,我们关注的是$P(v)$这个边缘概率分布,换句话说,这个$P(v)$才是我们的 Loss function 。我们在这就不证明了,直接给我结果:
\begin{equation}\label{7}
P(v)=\sum_h P(v,h)=\frac{1}{\mathcal{Z}} \prod_{j \in \mathbb{V}}e^{a_j v_j} \prod_{i \in \mathbb{H}}(1+e^{b_i + \sum_{j \in \mathbb{V}}v_j w_{ji}})
\end{equation}
但是我们并不喜欢这样连乘的损失函数,我们希望是求和式,这里就对$P(v)$取对数:
\begin{equation}\label{8}
\mathcal{L}(\omega)=\ln P(v) =\sum_{j \in \mathbb{V}} e^{a_j v_j} + \sum_{i \in \mathbb{H}} \ln(1 + e^{b_i + \sum_{j \in \mathbb{V}} v_j w_{ji}})-\ln (\mathcal{Z})
\end{equation}
根据梯度下降法:
$$\omega^k =\omega^{k-1}+ \eta\frac{\partial \mathcal{L}(\omega)}{\partial \omega}$$
我们就可以得到我们想要的各个参数的梯度变化值,这里就不一一证明了:
\begin{equation}\label{9}
\frac{\partial \mathcal{L}(\omega)}{\partial w_{ji}}=P(h_i=1 \mid v)v_j-\sum_{j \in \mathbb{V}} P(v)P(h_i=1 \mid v) v_j
\end{equation}
\begin{equation}\label{l9}
\frac{\partial \mathcal{L}(\omega)}{\partial a_j}=v_j-\sum_{j \in \mathbb{V}} P(v) v_j
\end{equation}
\begin{equation}\label{l99}
\frac{\partial \mathcal{L}(\omega)}{\partial b_i}=P(h_i=1 \mid v)-\sum_{j \in \mathbb{V}} P(v)P(h_i=1 \mid v)
\end{equation}
看到这个梯度我们就可以使用梯度下降法来进行学习了,但是麻烦的问题来了这个梯度里面有一个$\sum\limits_{j \in \mathbb{V}}$这一项,显然对于我们来说,这个是不好计算的,因为我们不知道$v$的分布。那怎么办呢?运用很直观的办法,那就是$MCMC$采样来进行估计,但是这里收敛速度是十分重要的,因为$MCMC$可能在很多步的转移之后才会$burn in$。
怎么才能得到更加有效的算法呢?我们使用 RBM 去拟合训练数据的分布,那我们是不是可以以训练数据为起点,这样可以更好的达到 RBM 的分布。
基于这个想法,咱们还是来看 Hinton 老爷子的 contrast divergence 算法吧!
Contrast Divergence
这就是大名鼎鼎的 CD 算法了!这里基于之前的假设,以样本$v^0$为初始点,然后在使用 Gibbs 采样$k$ 次得到$v^k$,最后之前的公式就变成了:
\begin{equation}\label{10}
\frac{\partial \mathcal{L}(\omega)}{\partial w_{ji}} \approx P(h_i=1 \mid v^0)v_j^0- P(h_i=1 \mid v^k) v_j^k
\end{equation}
\begin{equation}\label{11}
\frac{\partial \mathcal{L}(\omega)}{\partial a_j} \approx v_j^0- v_j^k
\end{equation}
\begin{equation}\label{12}
\frac{\partial \mathcal{L}(\omega)}{\partial b_i}\approx P(h_i=1 \mid v^0)-P(h_i=1 \mid v^k)
\end{equation}
然后一切就迎刃而解了:
1 | \begin{algorithm} |