首页/知识库/VAE:变分自编码器

VAE:变分自编码器

✍️ AI Master📅 创建 2026-04-12📖 18 min 阅读
💡

文章摘要

从概率建模到潜空间采样,理解变分自编码器的生成原理

1自编码器到变分自编码器

经典自编码器(Autoencoder)由编码器和解码器组成:编码器将高维输入 x 压缩为低维表示 z,解码器从 z 重构回 x。它本质上是做数据压缩与降维,而非真正的生成模型。问题在于,普通自编码器的潜空间是离散的、不连续的,随机采样一个 z 点往往解码出无意义的内容。变分自编码器(VAE)的关键突破在于将潜变量 z 视为概率分布而非确定值。编码器不再输出单个向量,而是输出一个分布的参数(均值 mu 和方差 sigma),解码器则从该分布采样 z 来重构 x。这使得潜空间变得连续且平滑,任意采样点都能解码出合理的样本,VAE 因此成为真正的生成模型。从信息论角度看,VAE 在压缩效率和重构精度之间引入了 KL 散度正则,迫使潜空间接近标准正态分布。

python
# 经典自编码器 vs VAE 编码器对比
import torch
import torch.nn as nn

class ClassicEncoder(nn.Module):
    def __init__(self, input_dim=784, latent_dim=32):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim)  # 输出确定向量
        )
    def forward(self, x):
        return self.fc(x)  # z 是确定的

class VAE_Encoder(nn.Module):
    def __init__(self, input_dim=784, latent_dim=32):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU()
        )
        self.mu = nn.Linear(256, latent_dim)      # 均值
        self.logvar = nn.Linear(256, latent_dim)  # 对数方差
    def forward(self, x):
        h = self.fc(x)
        return self.mu(h), self.logvar(h)  # 输出分布参数
python
# 潜空间可视化:经典 AE vs VAE
import matplotlib.pyplot as plt
import numpy as np

def visualize_latent_space(encoder, data_loader, is_vae=True):
    z_all, labels = [], []
    for x, y in data_loader:
        if is_vae:
            mu, _ = encoder(x.view(x.size(0), -1))
            z_all.append(mu.detach().numpy())
        else:
            z = encoder(x.view(x.size(0), -1))
            z_all.append(z.detach().numpy())
        labels.append(y.numpy())
    z_all = np.vstack(z_all)
    labels = np.concatenate(labels)
    plt.figure(figsize=(10, 8))
    for digit in range(10):
        mask = labels == digit
        plt.scatter(z_all[mask, 0], z_all[mask, 1],
                    s=5, alpha=0.5, label=str(digit))
    plt.legend()
    title = "VAE Latent Space" if is_vae else "Classic AE Latent Space"
    plt.title(title)
特性经典自编码器变分自编码器 VAE

编码器输出

确定向量 z

分布参数 mu, sigma

潜空间

离散、不连续

连续、平滑

能否采样生成

不能

正则化

KL 散度正则

损失函数

重构误差

重构误差 + KL

生成质量

不适用

中等(偏模糊)

对数方差 logvar 比直接学方差 sigma 更稳定,因为 logvar 的取值范围是全体实数,无需额外的约束。

经典自编码器不能直接用于生成,因为潜空间中存在大量空白区域,随机采样大概率落入无意义区域。

2概率图模型视角

理解 VAE 最优雅的方式是通过概率图模型(Probabilistic Graphical Model)。VAE 假设数据的生成过程如下:首先从先验分布 p(z)(通常为标准正态分布 N(0, I))中采样潜变量 z,然后通过条件分布 p_theta(x | z) 生成观测数据 x。这里的 theta 是解码器的可学习参数。然而,我们只有观测数据 x,不知道对应的 z。根据贝叶斯定理,后验分布 p(z | x) = p(x | z) * p(z) / p(x),其中边缘似然 p(x) = 积分 p(x | z) * p(z) dz 是不可计算的(intractable)。这就是变分推断大显身手的地方。我们引入一个近似后验分布 q_phi(z | x)(编码器,参数为 phi),用它来逼近真实的后验 p(z | x)。这样,整个 VAE 框架就可以理解为:编码器学习近似后验 q_phi(z | x),解码器学习生成分布 p_theta(x | z),两者联合优化使得近似后验尽可能接近真实后验。

python
# 概率图模型的数值理解
import torch
from torch.distributions import Normal

class ProbabilisticVAE:
    def __init__(self, latent_dim=32):
        self.prior = Normal(torch.zeros(latent_dim),
                           torch.ones(latent_dim))  # p(z)

    def prior_sample(self, n_samples=1):
        # 从先验 p(z) 采样
        return self.prior.sample((n_samples,))

    def log_likelihood(self, x_recon, x, var=0.1):
        # 计算对数似然 log p(x|z) 的下界
        dist = Normal(x_recon, var)
        return dist.log_prob(x).sum(dim=-1).mean()

    def kl_to_prior(self, mu, logvar):
        # KL[q(z|x) || p(z)] 的解析解
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1).mean()
python
# 边缘似然不可计算性的演示
import numpy as np
from scipy.integrate import nquad

def demonstrate_intractability():
    # 在 1 维情况下演示为什么 p(x) 难以计算
    # p(z) = N(0, 1)
    # p(x|z) = N(f(z), 0.1) 其中 f 是神经网络

    # 简单情况:f(z) = 2*z
    def p_xz(z, x=1.0):
        mean = 2 * z
        var = 0.1
        return (1/np.sqrt(2*np.pi*var)) * np.exp(-(x-mean)2/(2*var))

    def p_z(z):
        return (1/np.sqrt(2*np.pi)) * np.exp(-z2/2)

    def integrand(z, x=1.0):
        return p_xz(z, x) * p_z(z)

    # 数值积分计算 p(x)
    result, _ = nquad(integrand, [[-10, 10]])
    print(f"p(x=1.0) = {result:.6f}")
    print(f"维度灾难:32 维需要 20^32 次积分评估!")
符号含义类型参数化方式

p(z)

先验分布

已知

N(0, I),固定

p_theta(x|z)

生成分布

可学习

解码器神经网络

p(z|x)

真实后验

不可计算

需要近似

q_phi(z|x)

近似后验

可学习

编码器神经网络

p(x)

边缘似然

不可计算

VAE 优化目标

theta, phi

模型参数

可学习

联合优化

先验选择为标准正态分布不是唯一选择,也可以用混合高斯或 VampPrior 来提升生成能力。

边缘似然 p(x) 在高维空间中是不可计算的,这就是为什么 VAE 必须通过变分下界来间接优化。

3重参数化技巧

VAE 训练的核心难题在于:我们需要通过采样 z 来计算重构损失,但采样操作不可导,梯度无法反向传播。重参数化技巧(Reparameterization Trick)是 Kingma 和 Welling 在 2013 年提出的关键创新。其思想非常巧妙:与其直接从 N(mu, sigma^2) 中采样 z,不如先从标准正态分布 N(0, I) 中采样一个辅助变量 epsilon,然后通过确定性变换 z = mu + sigma * epsilon 得到 z。这样,z 的计算变成了一个完全可导的操作(mu 和 sigma 都是网络的输出,epsilon 是外部采样),梯度可以顺畅地从 z 回传到编码器。这个技巧的深远意义在于,它将随机性转移到了输入端,使得整个 VAE 可以用标准的反向传播算法端到端训练。没有重参数化,VAE 只能用 REINFORCE 等高方差梯度估计器,训练会极其困难。

python
# 重参数化技巧实现
import torch
import torch.nn as nn

class ReparameterizedVAE(nn.Module):
    def __init__(self, input_dim=784, latent_dim=32):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 512), nn.ReLU(),
            nn.Linear(512, 256), nn.ReLU()
        )
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256), nn.ReLU(),
            nn.Linear(256, 512), nn.ReLU(),
            nn.Linear(512, input_dim), nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        # z = mu + sigma * eps
        std = torch.exp(0.5 * logvar)  # sigma = exp(logvar/2)
        eps = torch.randn_like(std)     # eps ~ N(0, I)
        return mu + std * eps           # z 的确定性计算

    def forward(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        z = self.reparameterize(mu, logvar)  # 可导!
        x_recon = self.decoder(z)
        return x_recon, mu, logvar
python
# 对比:不可导采样 vs 重参数化
import torch
from torch.autograd import gradcheck

def non_differentiable_sampling(mu, logvar):
    # 错误做法:直接采样,梯度断联
    std = torch.exp(0.5 * logvar)
    z = torch.normal(mu, std)  # 不可导!
    return z

def reparameterized_sampling(mu, logvar):
    # 正确做法:重参数化
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    z = mu + std * eps  # 完全可导!
    return z

# 验证可导性
mu = torch.randn(4, 32, requires_grad=True, dtype=torch.float64)
logvar = torch.randn(4, 32, requires_grad=True, dtype=torch.float64)
# 重参数化版本的梯度流是完整的
z = reparameterized_sampling(mu, logvar)
z.sum().backward()
print(f"mu.grad 非空: {mu.grad is not None}")  # True
方法采样公式可导性梯度方差训练效果

直接采样

z ~ N(mu, sigma^2)

不可导

N/A

无法训练

REINFORCE

z ~ N(mu, sigma^2)

可导(得分函数)

极高

训练极不稳定

重参数化

z = mu + sigma * eps

完全可导

稳定高效

Gumbel-Softmax

用于离散潜变量

可导(近似)

离散 VAE

Concrete 分布

连续松弛离散变量

可导

可学习离散结构

实现时推荐使用 logvar 而不是 sigma,可以避免数值溢出问题,sigma = exp(logvar / 2) 始终为正。

重参数化要求潜变量分布是连续且可微的,对于离散潜变量需要使用 Gumbel-Softmax 等替代方案。

4ELBO 损失推导

VAE 的训练目标是最大化对数边缘似然 log p(x),但由于 p(x) 不可计算,我们转而优化其对数证据下界(Evidence Lower Bound, ELBO)。推导过程从 KL 散度的定义出发:KL[q(z|x) || p(z|x)] = E_q[log q(z|x) - log p(z|x)]。利用贝叶斯定理展开 log p(z|x) = log p(x|z) + log p(z) - log p(x),代入后得到 KL = E_q[log q(z|x)] - E_q[log p(x|z)] - E_q[log p(z)] + log p(x)。整理可得 log p(x) - KL[q(z|x) || p(z|x)] = E_q[log p(x|z)] - KL[q(z|x) || p(z)]。由于 KL 散度非负,左侧的 E_q[log p(x|z)] - KL[q(z|x) || p(z)] 就是 log p(x) 的下界,即 ELBO。最大化 ELBO 等价于同时做两件事:最大化重构项 E_q[log p(x|z)](解码器尽可能准确地重构输入)和最小化 KL 散度项 KL[q(z|x) || p(z)](编码器输出分布接近先验)。这两项目标之间存在天然张力,正是这种张力塑造了 VAE 独特的潜空间结构。

python
# ELBO 损失的 PyTorch 实现
import torch
import torch.nn.functional as F

def compute_elbo(x_recon, x, mu, logvar, reduction="mean"):
    # ELBO = E_q[log p(x|z)] - KL[q(z|x) || p(z)]
    # 等价于: -重构误差 - KL 散度
    # 重构项: 假设 p(x|z) 是伯努利分布(二值图像)
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction="none")
    recon_loss = recon_loss.view(recon_loss.size(0), -1).sum(dim=-1)
    # KL 散度项: KL[q(z|x) || p(z)] 的解析解
    # 当 q = N(mu, diag(sigma^2)), p = N(0, I)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)
    elbo = -(recon_loss + kl_loss)  # 最大化 ELBO = 最小化负 ELBO
    if reduction == "mean":
        return elbo.mean(), recon_loss.mean(), kl_loss.mean()
    return elbo, recon_loss, kl_loss
python
# ELBO 各项的数值分析
import torch
import matplotlib.pyplot as plt

def analyze_elbo_balance():
    # 分析重构项和 KL 项在不同训练阶段的占比
    epochs = range(1, 101)
    recon_losses = []
    kl_losses = []
    # 模拟训练过程中的变化趋势
    for epoch in epochs:
        recon = 200 * (0.3 + 0.7 * torch.exp(-torch.tensor(epoch/30))).item()
        kl = 50 * (1 - torch.exp(-torch.tensor(epoch/20))).item()
        recon_losses.append(recon)
        kl_losses.append(kl)
    plt.figure(figsize=(10, 4))
    plt.plot(epochs, recon_losses, label="Reconstruction Loss")
    plt.plot(epochs, kl_losses, label="KL Loss")
    total = [r + k for r, k in zip(recon_losses, kl_losses)]
    plt.plot(epochs, total, "--", label="Total Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("ELBO Components During Training")
ELBO 分量数学形式作用值过大的后果值过小的后果

重构项

E_q[log p(x|z)]

保证生成质量

忽略先验约束

退化为普通 AE

KL 项

-KL[q||p]

正则化潜空间

潜空间坍缩

潜空间过于分散

总 ELBO

重构 - KL

联合优化目标

需要平衡两项

需要平衡两项

对数似然

log p(x)

理论最优目标

不可直接计算

不可直接计算

训练初期 KL 项往往占主导(posterior collapse),可以逐步增大 KL 权重(KL annealing)来缓解。

如果 KL 项直接归零,说明发生了 posterior collapse,VAE 退化为普通自编码器,完全丧失了生成能力。

5beta-VAE 与解耦表示

标准 VAE 的 ELBO 中,重构项和 KL 项的权重是固定的 1:1 比例。beta-VAE(Higgins et al., 2017)引入了一个超参数 beta,将 ELBO 修改为 ELBO_beta = E_q[log p(x|z)] - beta * KL[q(z|x) || p(z)]。当 beta > 1 时,KL 正则化更强,迫使潜变量的各个维度学习数据中独立的生成因子(disentangled factors of variation)。例如在人脸数据上,理想的解耦表示中一个维度控制头发颜色、另一个控制微笑程度、第三个控制头部姿态等。beta-VAE 的理论基础是信息瓶颈:更强的 KL 约束限制了互信息 I(x, z),迫使潜变量只编码最必要的信息。实验表明,适度增大 beta(如 beta = 4)可以学到明显解耦的表示,但过大的 beta 会导致重构质量严重下降。后续研究提出了 Annealed VAE、FactorVAE 等改进方案,在解耦度和重构质量之间寻找更好的平衡。

python
# beta-VAE 损失函数
class BetaVAELoss:
    def __init__(self, beta=4.0, use_annealing=False):
        self.beta = beta
        self.use_annealing = use_annealing
        self.current_beta = beta

    def update_beta(self, epoch, max_epoch, max_beta=4.0):
        # KL annealing: 逐步增大 beta
        if self.use_annealing:
            self.current_beta = max_beta * min(1.0, epoch / max_epoch)

    def __call__(self, x_recon, x, mu, logvar):
        recon = F.binary_cross_entropy(x_recon, x, reduction="none")
        recon = recon.view(recon.size(0), -1).sum(dim=-1).mean()
        kl = -0.5 * torch.sum(
            1 + logvar - mu.pow(2) - logvar.exp(), dim=-1).mean()
        return recon + self.current_beta * kl, recon, kl

# 解耦度评估(MIG 指标)
def compute_mig(z_samples, factors):
    # Mutual Information Gap - 解耦表示的量化指标
    from sklearn.metrics import mutual_info_score
    n_latent = z_samples.shape[1]
    n_factor = factors.shape[1]
    mi_matrix = np.zeros((n_latent, n_factor))
    for i in range(n_latent):
        for j in range(n_factor):
            mi_matrix[i, j] = mutual_info_score(
                np.digitize(z_samples[:, i], 20), factors[:, j])
    sorted_mi = np.sort(mi_matrix, axis=0)[::-1]
    gaps = sorted_mi[0] - sorted_mi[1]
    return gaps.mean()  # MIG 越高,解耦越好
python
# beta-VAE 训练循环与可视化
import torch
import matplotlib.pyplot as plt

def train_beta_vae(model, dataloader, beta=4.0, epochs=50, lr=1e-3):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    losses = []
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        for x, _ in dataloader:
            optimizer.zero_grad()
            x_recon, mu, logvar = model(x)
            recon = F.binary_cross_entropy(x_recon, x.view(x.size(0), -1),
                                            reduction="sum")
            kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            loss = recon + beta * kl
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        losses.append(epoch_loss / len(dataloader.dataset))
        if (epoch+1) % 10 == 0:
            print(f"Epoch {epoch+1}: loss={losses[-1]:.2f}")
    return losses

# 对比不同 beta 值的训练曲线
for b in [1.0, 2.0, 4.0, 10.0]:
    losses = train_beta_vae(vae, loader, beta=b, epochs=50)
    plt.plot(losses, label=f"beta={b}")
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("ELBO Loss")
plt.title("beta-VAE Training Curves")
beta 值重构质量解耦程度潜空间紧凑度适用场景

beta = 0

最优

无解耦

不紧凑

退化为普通 AE

beta = 1

弱解耦

紧凑

标准 VAE

beta = 4

中等

良好解耦

非常紧凑

推荐起点

beta = 10

较差

强解耦

极度紧凑

需要解耦表示

beta > 20

很差

过解耦

信息严重丢失

通常不推荐

实践中 beta 从 0 开始逐步 annealing 到 4 是最稳定的策略,可以避免训练初期的 posterior collapse。

beta 过大时重构质量会急剧下降,需要在解耦度和生成质量之间做取舍,没有银弹。

6VAE vs GAN vs Diffusion

在生成模型的三大支柱中,VAE、GAN 和 Diffusion 各有其独特的哲学和技术路径。VAE 基于概率推断,通过最大化 ELBO 来学习数据的隐式表示,优势在于训练稳定、天然提供似然估计下界、潜空间可用于编辑和插值。劣势是生成样本偏模糊,因为 MSE 损失倾向于输出平均值。GAN 通过生成器和判别器的零和博弈学习数据分布,生成样本极其锐利逼真,但训练不稳定、模式崩溃、缺乏显式的似然估计。Diffusion 模型结合了 VAE 的稳定性和 GAN 的生成质量,通过渐进式去噪实现高质量生成,但采样速度慢。三者的数学框架完全不同:VAE 是变分推断,GAN 是博弈论,Diffusion 是非平衡热力学。近年来也出现了混合模型如 VQ-VAE + Transformer(DALL-E)、VQ-GAN 等,试图兼取各家之长。理解这三种范式的差异,对于选择合适的生成模型至关重要。

python
# 三大模型的采样过程对比
import torch
import time

def sample_vae(decoder, n=16, latent_dim=32):
    # VAE 采样:一步完成
    z = torch.randn(n, latent_dim)
    with torch.no_grad():
        return decoder(z)

def sample_gan(generator, n=16, latent_dim=32):
    # GAN 采样:一步完成
    z = torch.randn(n, latent_dim)
    with torch.no_grad():
        return generator(z)

def sample_diffusion(model, n=16, steps=1000):
    # Diffusion 采样:多步迭代
    x = torch.randn(n, 3, 64, 64)
    for t in reversed(range(steps)):
        with torch.no_grad():
            noise_pred = model(x, t)
            x = denoise_step(x, noise_pred, t)
    return x

# 计时对比
latent_dim = 64
vae_time = time.time()
_ = sample_vae(decoder, latent_dim=latent_dim)
print(f"VAE 采样: {time.time()-vae_time:.4f}s")
python
# 生成模型综合评估
import matplotlib.pyplot as plt
import numpy as np

def compare_generative_models():
    models = ["VAE", "GAN", "Diffusion", "VQ-VAE", "Flow"]
    dims = ["质量", "多样性", "速度", "稳定性", "似然", "可编辑性"]
    scores = np.array([
        [6, 7, 9, 9, 7, 8],   # VAE
        [9, 5, 9, 3, 1, 5],   # GAN
        [10, 9, 3, 8, 8, 6],  # Diffusion
        [8, 7, 8, 9, 5, 7],   # VQ-VAE
        [7, 7, 6, 8, 9, 6],   # Flow
    ])
    fig, axes = plt.subplots(1, len(models), figsize=(20, 4))
    for i, (model, score) in enumerate(zip(models, scores)):
        angles = np.linspace(0, 2*np.pi, len(dims), endpoint=False)
        score_closed = np.append(score, score[0])
        angles_closed = np.append(angles, angles[0])
        ax = axes[i]
        ax.plot(angles_closed, score_closed, "o-")
        ax.fill(angles_closed, score_closed, alpha=0.2)
        ax.set_xticks(angles)
        ax.set_xticklabels(dims, fontsize=8)
        ax.set_ylim(0, 10)
        ax.set_title(model)
    plt.tight_layout()
维度VAEGANDiffusion

数学框架

变分推断

博弈论

非平衡热力学

训练目标

ELBO 最大化

对抗损失

噪声预测 MSE

生成质量

中等

优秀

极佳

训练稳定性

采样速度

极快(一步)

极快(一步)

慢(多步迭代)

似然估计

有下界

不可行

可行

模式崩溃

不会

不会

潜空间可编辑

天然支持

不直接支持

不直接支持

典型应用

表征学习

图像生成

高质量生成

选择模型时,如果需要可解释的潜空间和稳定训练,VAE 是最佳起点;如果追求极致视觉效果,Diffusion 是首选。

GAN 的模式崩溃问题在复杂数据集上很难完全解决,VAE 和 Diffusion 更可靠。

7PyTorch 实战:MNIST 生成与潜空间插值

本节从零构建一个完整的 VAE,在 MNIST 手写数字数据集上训练,实现数字生成和潜空间插值。MNIST 是验证 VAE 实现的黄金数据集,因为它简单(28x28 灰度图、60000 张训练样本)但足以验证管线正确性。我们的模型使用全连接网络(非卷积),潜维度设为 2,这样可以直接可视化二维潜空间。训练 20 个 epoch 即可看到不错的重构效果。关键实现细节包括:使用 sigmoid 激活确保输出在 [0, 1] 范围、KL 散度的数值稳定计算(避免 logvar 过大导致溢出)、以及采样时使用模型评估模式关闭 dropout。训练完成后,我们将展示三个经典实验:从先验分布随机采样生成新数字、在潜空间中沿直线插值实现数字渐变、以及遍历潜空间网格生成所有数字变体。这些实验直观地展示了 VAE 学到的连续且结构化的潜空间。

python
# 完整 VAE 实现:MNIST 生成
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

class VAE(nn.Module):
    def __init__(self, latent_dim=2):
        super().__init__()
        self.latent_dim = latent_dim
        # 编码器
        self.enc = nn.Sequential(
            nn.Linear(784, 512), nn.ReLU(),
            nn.Linear(512, 256), nn.ReLU()
        )
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)
        # 解码器
        self.dec = nn.Sequential(
            nn.Linear(latent_dim, 256), nn.ReLU(),
            nn.Linear(256, 512), nn.ReLU(),
            nn.Linear(512, 784), nn.Sigmoid()
        )

    def encode(self, x):
        h = self.enc(x.view(x.size(0), -1))
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        return mu + std * torch.randn_like(std)

    def decode(self, z):
        return self.dec(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    def loss(self, x_recon, x, mu, logvar):
        recon = F.binary_cross_entropy(x_recon, x.view(x.size(0), -1),
                                        reduction="sum")
        kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return recon + kl
python
# 训练 + 生成 + 潜空间插值
def train_and_generate():
    # 数据
    transform = torchvision.transforms.ToTensor()
    train_ds = torchvision.datasets.MNIST("./data", train=True,
                                          download=True, transform=transform)
    loader = DataLoader(train_ds, batch_size=128, shuffle=True)
    model = VAE(latent_dim=2)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    # 训练
    for epoch in range(20):
        model.train()
        total_loss = 0
        for x, _ in loader:
            optimizer.zero_grad()
            x_recon, mu, logvar = model(x)
            loss = model.loss(x_recon, x, mu, logvar)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}: loss={total_loss/len(train_ds):.2f}")
    # 从先验采样生成
    model.eval()
    with torch.no_grad():
        z = torch.randn(16, 2)
        gen = model.decode(z).view(-1, 1, 28, 28)
    # 潜空间插值
    z1 = torch.tensor([[2.0, -1.0]])  # 某个数字
    z2 = torch.tensor([[-2.0, 1.0]])  # 另一个数字
    alphas = torch.linspace(0, 1, 10).unsqueeze(1)
    z_interp = z1 * (1 - alphas) + z2 * alphas
    with torch.no_grad():
        interp_imgs = model.decode(z_interp).view(-1, 1, 28, 28)
    # 可视化
    fig, axes = plt.subplots(1, 10, figsize=(15, 2))
    for i, ax in enumerate(axes):
        ax.imshow(interp_imgs[i].squeeze(), cmap="gray")
        ax.axis("off")
    plt.suptitle("Latent Space Interpolation")
实验操作预期结果验证要点

随机生成

z ~ N(0,I) -> 解码

清晰可辨的数字

不同 z 生成不同数字

潜空间插值

z = alpha*z1 + (1-alpha)*z2

数字平滑渐变

无突变或跳变

潜空间网格

遍历 z1, z2 网格

按数字类别分区

同类数字聚集

重构测试

x -> encode -> decode

接近原图

保留关键特征

KL 散度检查

训练过程中监控

不趋近于 0

避免 posterior collapse

潜维度设为 2 虽然可视化方便,但会限制生成质量。实际应用中使用 32-128 维更合适。

MNIST 过于简单,VAE 实现正确后务必在更复杂的数据集(如 CIFAR-10、CelebA)上验证。

继续你的 AI 学习之旅

浏览更多 AI 知识库文章,或者探索 GitHub 上的优质 AI 项目