梯度上升与朗之万动力学采样

在生成模型中,如果我们只利用梯度信息 logp(x)\nabla \log p(x) ,就像是在绝对零度( T=0T=0 )下寻找能量最低点,最终只能得到单一的“极值”;而真实的生成过程(采样)应当像常温( T>0T>0 )下的气体分子,既受势能引导,又保持热运动。

采样

当分布 p(x)p(x) 已知时,该如何从 p(x)p(x) 生成数据?

  1. 求极值点

对应的离散化算法就是梯度上升(或者说对能量的梯度下降):

xt+1=xt+τxlogp(x)x_{t+1}=x_t + \tau\nabla_x\log{p(x)}

tτ0t\to\infty,\tau\to0 时,我们每次生成的数据点会汇聚到概率密度 p(x)p(x) 的有限的 局部极大值 上,生成的样本缺乏多样性。

  1. 采样忠实于 p(x)p(x)

根据过阻尼朗之万方程,

γddtx=xU(x)+2γkBTz(t)\gamma\frac{\mathrm{d}}{\mathrm{d} t}x=-\nabla_xU(x)+\sqrt{2\gamma k_{\mathrm{B}}T}z(t) \\

其中 z(t)z(t) 为白噪声。当 tt\to\infty 时,稳态分布满足玻尔兹曼分布

p(x)eU(x)kBTp(x)\propto e^{-\frac{U(x)}{k_{\mathrm{B}T}}}

因此只要我们令 U(x)=kBTlogp(x)U(x)=-k_{\mathrm{B}}T\log{p(x)} ,生成数据的概率就会服从p(x),此时原方程:

ddtx=kBTγxlogp(x)+2kBTγz(t)\frac{\mathrm{d}}{\mathrm{d} t}x = \frac{k_{\mathrm{B}}T}{\gamma}\nabla_x\log{p(x)}+\sqrt{\frac{2 k_{\mathrm{B}}T}{\gamma}}z(t)

对应的离散化算法通常称为 朗之万动力学采样 (Langevin Dynamics Sampling):

xt+1=xt+τxlogp(x)+2τzt,ztN(0,I)x_{t+1} = x_t + \tau \nabla_x \log p(x) + \sqrt{2\tau} z_t, \quad z_t \sim \mathcal{N}(0, I)

(注:为了简化,通常在算法实现中将常数 kBTγ\frac{k_BT}{\gamma} 归一化处理)

直观比较

compare.png

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# ==========================================
# 1. 定义目标分布 p(x) (双峰高斯)
# ==========================================
def target_pdf(x):
# 混合两个高斯: 0.7 * N(-2, 1) + 0.3 * N(4, 1.5)
p1 = 0.7 * np.exp(-0.5 * ((x + 2) / 1)**2) / (np.sqrt(2 * np.pi) * 1)
p2 = 0.3 * np.exp(-0.5 * ((x - 4) / 1.5)**2) / (np.sqrt(2 * np.pi) * 1.5)
return p1 + p2

def get_score(x):
# 计算 Score Function: ∇ log p(x) = p'(x) / p(x)
# 为了数值稳定,这里显式写出导数
p = target_pdf(x)

# dp/dx calculation
grad_p1 = 0.7 * np.exp(-0.5 * ((x + 2) / 1)**2) / (np.sqrt(2 * np.pi) * 1) * (-(x + 2) / 1**2)
grad_p2 = 0.3 * np.exp(-0.5 * ((x - 4) / 1.5)**2) / (np.sqrt(2 * np.pi) * 1.5) * (-(x - 4) / 1.5**2)

grad_p = grad_p1 + grad_p2

# 避免除以0,加个极小值
return grad_p / (p + 1e-10)

# ==========================================
# 2. 模拟参数
# ==========================================
n_particles = 2000 # 粒子数量
n_steps = 500 # 迭代步数
tau = 0.1 # 步长 (学习率)
noise_scale = np.sqrt(2 * tau) # 朗之万噪声系数

# 初始化粒子 (均匀分布在区间 [-8, 10])
x_opt = np.random.uniform(-8, 10, n_particles) # 用于优化 (梯度上升)
x_sample = x_opt.copy() # 用于采样 (朗之万)

# ==========================================
# 3. 迭代更新
# ==========================================
for t in range(n_steps):
score_opt = get_score(x_opt)
score_sample = get_score(x_sample)

# --- 算法 1: 梯度上升 (找极值) ---
# x = x + τ * ∇log p(x)
x_opt = x_opt + tau * score_opt

# --- 算法 2: 朗之万动力学 (采样) ---
# x = x + τ * ∇log p(x) + √(2τ) * z
noise = np.random.normal(0, 1, n_particles)
x_sample = x_sample + tau * score_sample + noise_scale * noise

# ==========================================
# 4. 绘图结果
# ==========================================
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
x_grid = np.linspace(-8, 10, 1000)
true_pdf = target_pdf(x_grid)

# --- 左图: 梯度上升 (求极值) ---
axes[0].plot(x_grid, true_pdf, 'r--', label='True p(x)', lw=2)
axes[0].hist(x_opt, bins=50, density=True, color='blue', alpha=0.6, label='Particles')
axes[0].set_title('Case 1: Optimization (Gradient Ascent)\nTarget: Find Modes', fontsize=14)
axes[0].set_xlabel('x')
axes[0].set_ylabel('Density')
axes[0].legend()
axes[0].grid(True, alpha=0.3)
axes[0].text(-7, 0.25, "Particles collapse\nto the peaks!", fontsize=12, color='darkblue')

# --- 右图: 朗之万动力学 (采样) ---
axes[1].plot(x_grid, true_pdf, 'r--', label='True p(x)', lw=2)
axes[1].hist(x_sample, bins=50, density=True, color='green', alpha=0.6, label='Particles')
axes[1].set_title('Case 2: Sampling (Langevin Dynamics)\nTarget: Cover Distribution', fontsize=14)
axes[1].set_xlabel('x')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].text(-7, 0.25, "Particles fill the\nshape of p(x)", fontsize=12, color='darkgreen')

plt.tight_layout()
plt.show()

梯度上升与朗之万动力学采样

https://psu.monster/post/2025/2d66641f9ce7

作者

psu

发布于

2025-11-27

更新于

2025-11-27

许可协议

评论

Your browser is out-of-date!

Update your browser to view this website correctly.&npsb;Update my browser now

×