高斯分布的采样

如果说最常用的分布是哪个,第一个当属均匀分布,第二个就非高斯分布莫属了,在机器学习中常见的采样方法中介绍了均匀分布的采样,本文来介绍一下高斯分布的采样.

首先,假设随机变量 z 服从标准正态分布 N(0,1) ,令:

x = \sigma \cdot z + \mu

x 服从均值为 \mu ,方差为 \sigma^2 的高斯分布 N(\mu,\sigma^2) .因此,任意高斯分布都可以由标准正态分布通过拉伸和平移得到,所以这里只考虑标准正态分布的采样.

由于高斯分布的累积分布函数并不是一个初等函数,所以使用逆变换法计算比较麻烦,Box-Muller算法给出了简单的解决方案.

Box-Muller算法

假设 x,y 是两个服从标准正态分布的独立随机变量,他们的联合概率密度为

p ( x , y ) = \frac { 1 } { 2 \pi } e ^ { - \frac { x ^ { 2 } + y ^ { 2 } } { 2 } }

考虑 (x,y) 在圆盘 \{ ( x , y ) | x ^ { 2 } + y ^ { 2 } \leqslant R ^ { 2 } \} 上的概率

F ( R ) = \int _ { x ^ { 2 } + y ^ { 2 } \leqslant R ^ { 2 } } \frac { 1 } { 2 \pi } e ^ { - \frac { x ^ { 2 } + y ^ { 2 } } { 2 } } d x d y

通过极坐标变换将 (x,y) 转化为 ( r , \theta ) ,可以很容易求得二重积分:

F ( R ) = 1 - \mathrm { e } ^ { - \frac { R ^ { 2 } } { 2 } } , R \geqslant 0

这里 F(R) 可以看成是极坐标中 r 的累积分布函数.由于 F(R) 的计算公式比较简单,逆函数也容易求得,所以可以利用逆变换法来对 r 进行采样;对于 \theta ,在 [0,2\pi] 上进行均匀采样即可.这样就得到了 (r,\theta) ,经过坐标变换即可得到符合标准正态分布的 (x,y) .具体采样过程如下:

  1. 产生 [0,1] 上的两个独立的均匀分布随机数 u_1,u_2
  2. \left\{ \begin{array} { l } { x = \sqrt { - 2 \ln \left( u _ { 1 } \right) } \cos 2 \pi u _ { 2 } } \\ { y = \sqrt { - 2 \ln \left( u _ { 1 } \right) } \sin 2 \pi u _ { 2 } } \end{array} \right. ,则 x,y 服从标准正态分布,并且是相互独立的.

Box-muller算法由于需要计算三角函数,相对来说还是比较耗时,而Marsaglia polar method则避开了三角函数的计算,因而更快,其具体采样操作如下:

  1. 在单位圆盘 \{ ( x , y ) | x ^ { 2 } + y ^ { 2 } \leq 1 \} 上产生均匀分布随机数对 (x,y) (在矩阵 \{ ( x , y ) | - 1 \leqslant x , y \leqslant 1 \} 上利用拒绝采用法即可得到).
  2. s = x ^ { 2 } + y ^ { 2 } ,则 x \sqrt { \frac { - 2 \ln s } { s } } , y \sqrt { \frac { - 2 \ln s } { s } } 是两个服从标准正态分布的样本,其中 \frac { x } { \sqrt { s } } , \frac { y } { \sqrt { s } } 用来代替Box-muller算法中的cos和sin操作.

编程实现

import matplotlib.pyplot as plt
import numpy as np

def getNormal(SampleSize):
    iid = np.random.uniform(0,1,SampleSize)
    normal1 = np.cos(2*np.pi*iid[0:int(SampleSize/2-1)])*np.sqrt(-2*np.log(iid[int(SampleSize/2):SampleSize-1]))
    normal2 = np.sin(2*np.pi*iid[0:int(SampleSize/2-1)])*np.sqrt(-2*np.log(iid[int(SampleSize/2):SampleSize-1]))
    return np.hstack((normal1,normal2))

# 生成10000000个数,观察它们的分布情况
SampleSize = 10000000
normal = getNormal(SampleSize)
plt.hist(normal,np.linspace(-4,4,81),facecolor="green")
plt.show()

image

posted @ 2019/01/22 15:22:46