首页/知识库/生成模型评估:FID, IS, CLIP Score

生成模型评估:FID, IS, CLIP Score

✍️ AI Master📅 创建 2026-04-12📖 18 min 阅读
💡

文章摘要

如何客观评估生成模型的质量,理解主流评估指标的原理与应用

1生成模型评估的挑战

生成模型评估是生成式 AI 领域最困难的课题之一。与分类任务不同,生成任务没有唯一的正确答案,评估必须同时考虑多个维度:样本质量、多样性、模式覆盖以及与条件输入的一致性。传统的像素级指标如 MSE 或 PSNR 无法反映人类对图像质量的感知,因为两张视觉上几乎相同的图片在像素层面可能有巨大差异。这催生了基于深度特征的评估方法,通过预训练网络提取高维特征空间中的统计量来衡量生成质量。另一个核心挑战是模式崩溃问题,生成器可能只学会生成少数几种高质量样本而忽略数据分布的多样性。一个好的评估指标需要同时捕捉质量和多样性,这对设计提出了极高的要求。

python
import torch
import torch.nn.functional as F
from torchvision import models

# 加载预训练 Inception 网络作为特征提取器
inception = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
inception.fc = torch.nn.Identity()  # 移除分类头
inception.eval()

def extract_features(images: torch.Tensor) -> torch.Tensor:
    """提取 2048 维 Inception 特征"""
    with torch.no_grad():
        features = inception(images)
    return features  # shape: (N, 2048)
python
import numpy as np
from scipy import linalg

def compute_statistics(features: np.ndarray):
    """计算特征的均值向量和协方差矩阵"""
    mu = np.mean(features, axis=0)
    sigma = np.cov(features, rowvar=False)
    return mu, sigma

# 使用示例
real_features = extract_features(real_images).numpy()
fake_features = extract_features(fake_images).numpy()
mu_r, sigma_r = compute_statistics(real_features)
mu_f, sigma_f = compute_statistics(fake_features)
print(f"Real mu shape: {mu_r.shape}, Sigma shape: {sigma_r.shape}")
评估维度描述典型指标

图像质量

单张图像的逼真度

IS, CLIP Score

多样性

生成样本的丰富程度

FID, Precision/Recall

模式覆盖

是否覆盖真实数据分布

Recall, Coverage

条件一致性

生成结果与文本/条件匹配

CLIP Score, R-precision

选择评估指标时,应根据具体任务场景决定,单一指标无法全面反映生成模型性能

不要仅依赖单一指标来评估模型,质量与多样性需要综合权衡

2Inception Score (IS) 原理与计算

Inception Score 是最早广泛使用的生成图像评估指标之一,由 Salimans 等人在 2016 年提出。其核心思想利用预训练 Inception 网络对生成图像进行预测,从两个维度进行评估。第一个维度是预测概率的清晰度,高质量的图像应该让分类器给出高置信度的预测,这通过计算 KL 散度来衡量。第二个维度是生成样本的多样性,如果模型只生成同一类别的图像,类别边缘分布会很集中,多样性指标就会很低。IS 的计算公式为 exp(E[KL(p(y|x) || p(y))]),其中 p(y|x) 是条件类别分布,p(y) 是边缘类别分布。IS 越高说明生成质量越好且多样性越丰富。然而 IS 存在明显缺陷,它只评估生成图像而不与真实数据比较,因此无法检测模式崩溃的严重程度,也无法识别生成了真实数据中不存在的类别。

python
import torch
import torch.nn.functional as F
import numpy as np

def compute_inception_score(preds: np.ndarray, splits: int = 10):
    """
    计算 Inception Score
    preds: (N, 1000) Inception 网络预测概率
    """
    scores = []
    N = preds.shape[0]
    split_size = N // splits

    for i in range(splits):
        part = preds[i * split_size : (i + 1) * split_size]
        # KL 散度: KL(p(y|x) || p(y))
        marginal = np.mean(part, axis=0)  # p(y)
        kl = part * (np.log(part) - np.log(marginal))  # KL per sample
        kl = np.sum(kl, axis=1)  # sum over classes
        scores.append(np.exp(np.mean(kl)))  # exp(E[KL])

    return np.mean(scores), np.std(scores)

print(f"IS = {mean_is:.2f} +/- {std_is:.2f}")
python
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

# 完整的 IS 计算流程
@torch.no_grad()
def get_inception_predictions(
    generator: torch.nn.Module,
    inception: torch.nn.Module,
    num_samples: int = 50000,
    batch_size: int = 100,
    device: str = "cuda"
):
    """从生成器采样并通过 Inception 获取预测"""
    all_preds = []
    num_batches = num_samples // batch_size

    for _ in range(num_batches):
        noise = torch.randn(batch_size, 100, device=device)
        fake_images = generator(noise)
        fake_images = F.interpolate(fake_images, size=(299, 299), mode="bilinear")
        preds = torch.softmax(inception(fake_images), dim=1)
        all_preds.append(preds.cpu().numpy())

    return np.concatenate(all_preds, axis=0)

preds = get_inception_predictions(G, inception, num_samples=50000)
mean_is, std_is = compute_inception_score(preds)
print(f"Inception Score: {mean_is:.2f} +/- {std_is:.2f}")
IS 特性说明

优点

计算简单,不需要真实图像参考

缺点1

无法检测模式崩溃

缺点2

对 Inception 类别集合有偏

缺点3

无法衡量与真实数据的差距

适用场景

快速初步评估生成质量

IS 的标准做法是使用 50000 张生成图像分 10 组计算,取均值和标准差作为最终结果

IS 不包含与真实数据的比较,可能出现生成了高质量但完全偏离真实分布的图像却得到高分的情况

3Fréchet Inception Distance (FID)

Fréchet Inception Distance 由 Heusel 等人在 2017 年提出,目前是最主流的生成模型评估指标。FID 的核心思想是将真实图像和生成图像分别通过 Inception 网络映射到 2048 维特征空间,然后假设两个特征集合都服从多元高斯分布,计算两个高斯分布之间的 Fréchet 距离。FID 同时衡量了生成质量和多样性,因为它比较的是整个分布而非单个样本。FID 值越小表示两个分布越接近,理论下限为零。与 IS 相比,FID 的关键优势在于它直接比较真实数据和生成数据的分布差异,对模式崩溃更加敏感。然而 FID 也有局限性,它依赖于高斯分布假设,而实际特征分布可能并非严格高斯分布。此外 FID 对样本数量敏感,样本量不足时估计会有较大偏差,通常需要至少一万张图像才能获得稳定的估计值。计算协方差矩阵的平方根是 FID 计算的核心步骤,需要使用矩阵平方根运算。

python
import numpy as np
from scipy import linalg

def calculate_fid(
    mu1: np.ndarray, sigma1: np.ndarray,
    mu2: np.ndarray, sigma2: np.ndarray
) -> float:
    """
    计算 Fréchet Inception Distance
    mu1, sigma1: 真实图像特征的均值和协方差
    mu2, sigma2: 生成图像特征的均值和协方差
    """
    diff = mu1 - mu2
    covmean, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)

    # 处理数值不稳定产生的复数部分
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff @ diff + np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(covmean)
    return float(fid)

# 示例: mu shape=(2048,), sigma shape=(2048, 2048)
fid_value = calculate_fid(mu_real, sigma_real, mu_fake, sigma_fake)
print(f"FID: {fid_value:.2f}")  # 越低越好
python
import torch
import numpy as np
from torch.utils.data import DataLoader

@torch.no_grad()
def compute_fid_features(
    model: torch.nn.Module,
    dataloader: DataLoader,
    device: str = "cuda"
) -> tuple:
    """从数据加载器提取特征并计算统计量"""
    features_list = []
    model.eval()

    for batch in dataloader:
        images = batch[0].to(device)
        if images.shape[1] == 1:
            images = images.repeat(1, 3, 1, 1)  # 灰度转 RGB
        images = F.interpolate(images, size=(299, 299), mode="bilinear")
        features = model(images)
        features_list.append(features.cpu().numpy())

    all_features = np.concatenate(features_list, axis=0)
    mu = np.mean(all_features, axis=0)
    sigma = np.cov(all_features, rowvar=False)
    return mu, sigma

# 完整 FID 计算
mu_r, sigma_r = compute_fid_features(inception, real_loader)
mu_f, sigma_f = compute_fid_features(inception, fake_loader)
fid = calculate_fid(mu_r, sigma_r, mu_f, sigma_f)
print(f"Fréchet Inception Distance: {fid:.2f}")
FID 范围生成质量评价

0-10

极高,接近真实数据分布

10-30

优秀,常见于先进 GAN

30-60

良好,多数实用模型

60-100

一般,有明显改进空间

100+

较差,模式崩溃或质量低

FID 计算时应确保真实图像和生成图像数量一致,且至少使用 5000 张图像以获得稳定结果

矩阵平方根计算在协方差矩阵接近奇异时会数值不稳定,需检查特征维度是否远小于样本数量

4CLIP Score 文本-图像对齐评估

随着文本到图像生成模型的快速发展,评估生成图像与文本提示之间的一致性成为一个关键问题。CLIP Score 利用 CLIP 模型的跨模态理解能力,计算图像特征和文本特征在共享嵌入空间中的余弦相似度。CLIP 在大规模图像-文本对上进行了对比学习训练,能够有效衡量图文语义匹配程度。CLIP Score 的计算非常直接,将图像和文本分别通过 CLIP 的图像编码器和文本编码器得到特征向量,然后计算两者的余弦相似度。该指标的优势在于不需要人工标注,可以自动大规模评估。但 CLIP Score 也有局限,它对 CLIP 训练数据中常见的概念评估效果更好,对罕见概念或新颖组合可能不够准确。此外 CLIP Score 只衡量语义对齐,不评估图像质量,因此需要与 FID 等质量指标配合使用。实践中通常报告 CLIP Score 和 FID 两个指标来全面评估文本到图像生成模型。

python
import torch
import clip

# 加载 CLIP 模型
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

@torch.no_grad()
def compute_clip_score(
    images: torch.Tensor,
    texts: list[str],
    batch_size: int = 32
) -> torch.Tensor:
    """
    计算批量图像-文本对的 CLIP Score
    images: (N, 3, 224, 224) 预处理后的图像
    texts: N 个文本提示
    """
    text_tokens = clip.tokenize(texts).to(device)
    image_features = model.encode_image(images)
    text_features = model.encode_text(text_tokens)

    # 归一化后计算余弦相似度
    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    scores = (image_features * text_features).sum(dim=-1)
    # 缩放到 [0, 100] 范围
    scores = scores * 100.0
    return scores

scores = compute_clip_score(image_batch, text_prompts)
print(f"Mean CLIP Score: {scores.mean().item():.2f}")
python
from transformers import CLIPProcessor, CLIPModel
import torch

# 使用 HuggingFace transformers 的替代方案
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")

def evaluate_t2i_model(
    generator, prompts: list[str],
    images_per_prompt: int = 5
) -> dict:
    """全面评估文本到图像生成模型"""
    all_scores = []

    for prompt in prompts:
        images = generator.generate(prompt, num_images=images_per_prompt)
        inputs = processor(
            text=[prompt] * len(images),
            images=images,
            return_tensors="pt",
            padding=True
        )
        outputs = clip_model(**inputs)
        logits_per_image = outputs.logits_per_image
        scores = torch.diag(logits_per_image).numpy()
        all_scores.extend(scores.tolist())

    return {
        "mean_clip_score": np.mean(all_scores),
        "std_clip_score": np.std(all_scores),
        "min_score": np.min(all_scores),
        "max_score": np.max(all_scores)
    }

results = evaluate_t2i_model(my_generator, test_prompts)
print(f"CLIP Score: {results['mean_clip_score']:.2f} +/- {results['std_clip_score']:.2f}")
指标评估对象数值范围方向

FID

图像质量+多样性

[0, +inf)

越低越好

IS

质量+多样性(无参考)

[1, +inf)

越高越好

CLIP Score

图文语义对齐

[0, 100]

越高越好

R-Precision

图文检索精度

[0, 1]

越高越好

建议使用多个不同的文本提示集进行评估,包括简单描述和复杂场景,以获得更全面的 CLIP Score

CLIP Score 高不代表图像质量好,可能生成模糊但语义相关的图像也得到高分,必须与 FID 配合使用

5Precision 和 Recall 用于生成模型

传统的 Precision 和 Recall 概念被引入到生成模型评估中,用于分别衡量生成样本的质量和多样性。这一方法由 Kynkaanniemi 等人在 2019 年提出,核心思想是在特征空间中为真实数据构建流形,然后判断生成样本是否落在该流形内。Precision 定义为落在真实数据流形内的生成样本比例,反映生成质量。Recall 定义为被生成数据流形覆盖的真实样本比例,反映多样性。这种方法的优势在于能够独立分析质量和多样性两个维度,帮助研究者更精确地定位模型的问题。例如高 Precision 低 Recall 意味着模型生成了高质量但单一类型的样本,即存在模式崩溃。相反低 Precision 高 Recall 意味着模型覆盖了广泛的模式但单个样本质量不高。计算过程需要在特征空间中使用 K 近邻方法估计流形覆盖范围,对计算资源有一定要求。

python
import torch

def compute_precision_recall(
    real_features: torch.Tensor,
    fake_features: torch.Tensor,
    k: int = 5
) -> tuple:
    """
    计算生成模型的 Precision 和 Recall
    基于 K 近邻的流形估计方法
    """
    # 计算真实特征到自身第 k 近邻的距离
    real_dists = torch.cdist(real_features, real_features)
    real_dists = torch.sort(real_dists, dim=1)[0]
    real_radii = real_dists[:, k]  # 第 k 近邻距离

    # 计算生成特征到真实特征的距离
    fake_to_real = torch.cdist(fake_features, real_features)
    # Precision: 生成样本有多少在真实流形内
    precision = (fake_to_real < real_radii.unsqueeze(0)).any(dim=1).float().mean()

    # 计算生成特征到自身第 k 近邻的距离
    fake_dists = torch.cdist(fake_features, fake_features)
    fake_dists = torch.sort(fake_dists, dim=1)[0]
    fake_radii = fake_dists[:, k]
    # Recall: 真实样本有多少被生成分布覆盖
    real_to_fake = torch.cdist(real_features, fake_features)
    recall = (real_to_fake < fake_radii.unsqueeze(0)).any(dim=1).float().mean()

    return precision.item(), recall.item()

p, r = compute_precision_recall(real_feats, fake_feats, k=5)
print(f"Precision: {p:.4f}, Recall: {r:.4f}")
python
import numpy as np
import torch

class ImprovedPrecisionRecall:
    """改进的 Precision/Recall 计算(Mani 等人 2020)"""

    def __init__(self, k: int = 5, num_center: int = 20000):
        self.k = k
        self.num_center = num_center

    def _compute_nearest_distances(self, X, Y):
        """计算 X 中每个点到 Y 的第 k 近邻距离"""
        dists = torch.cdist(X, Y)
        knn_dists = torch.topk(dists, k=self.k + 1, largest=False)[0]
        return knn_dists[:, -1]  # 第 k 近邻距离

    def fit(self, real_features: torch.Tensor, fake_features: torch.Tensor):
        """拟合真实和生成分布的流形"""
        r_idx = torch.randperm(len(real_features))[:self.num_center]
        f_idx = torch.randperm(len(fake_features))[:self.num_center]

        self.real_centers = real_features[r_idx]
        self.fake_centers = fake_features[f_idx]

        self.real_radii = self._compute_nearest_distances(
            self.real_centers, real_features
        )
        self.fake_radii = self._compute_nearest_distances(
            self.fake_centers, fake_features
        )

    def compute(self, real_features, fake_features):
        """计算 Precision 和 Recall"""
        precision = self._compute_fraction_inside(
            fake_features, self.real_centers, self.real_radii
        )
        recall = self._compute_fraction_inside(
            real_features, self.fake_centers, self.fake_radii
        )
        return precision, recall

pr = ImprovedPrecisionRecall(k=5)
pr.fit(real_feats, fake_feats)
precision, recall = pr.compute(real_feats, fake_feats)
f_score = 2 * precision * recall / (precision + recall + 1e-8)
print(f"P={precision:.3f}, R={recall:.3f}, F={f_score:.3f}")
场景PrecisionRecall问题诊断

模式崩溃

生成质量高但种类少

过度泛化

覆盖广但质量差

理想状态

质量与多样性兼备

完全失败

模型未收敛或训练不当

K 值的选择会影响结果,较小的 K 对异常值更敏感,较大的 K 估计更平滑但计算成本更高,建议 K=3 到 5

Precision 和 Recall 计算在高维特征空间中对距离度量敏感,确保特征已归一化且维度适当

6人工评估方法

尽管自动化指标如 FID 和 CLIP Score 已被广泛采用,人工评估仍然是生成模型评估的金标准。自动化指标可能与人主观感知存在偏差,特别是对于创意性任务,机器指标难以完全捕捉人类的审美标准。常用的人工评估方法包括平均意见评分、两两比较、真实与生成图像辨别和条件一致性评分。平均意见评分让评估者对生成图像从一到五分进行打分,结果直观但成本高且受主观因素影响。两两比较让评估者在两张图像中选择更偏好的一张,可以减少评分偏差但需要更多比较次数。真实与生成图像辨别任务测试人类能否区分真实图像和生成图像,如果人类无法区分则说明生成质量极高。条件一致性评分专门针对文本到图像生成,评估者判断生成图像是否符合文本描述。近年来研究者尝试将人工评估与自动化指标结合,通过人类反馈来校准自动指标,提高其与人感知的一致性。

python
import csv
from dataclasses import dataclass
from typing import List

@dataclass
class HumanEvalSample:
    image_id: str
    prompt: str
    mos_score: float      # 1-5 平均意见评分
    realism_score: float  # 1-5 真实度评分
    alignment_score: float  # 1-5 图文一致性评分
    is_real_guess: bool   # 人类猜是否真实

def analyze_human_eval(samples: List[HumanEvalSample]) -> dict:
    """分析人工评估结果"""
    n = len(samples)
    mos = sum(s.mos_score for s in samples) / n
    realism = sum(s.realism_score for s in samples) / n
    alignment = sum(s.alignment_score for s in samples) / n
    # 人类辨别准确率(理想情况应接近 50%)
    accuracy = sum(1 for s in samples if s.is_real_guess) / n

    return {
        "num_evaluations": n,
        "mean_opinion_score": round(mos, 3),
        "mean_realism_score": round(realism, 3),
        "mean_alignment_score": round(alignment, 3),
        "human_accuracy": round(accuracy, 3)
    }

# 从 CSV 读取评估结果
results = []
with open("human_eval_results.csv") as f:
    reader = csv.DictReader(f)
    for row in reader:
        results.append(HumanEvalSample(**row))

stats = analyze_human_eval(results)
print(f"MOS: {stats['mean_opinion_score']}, Human Acc: {stats['human_accuracy']}")
python
import numpy as np
from scipy import stats

def pairwise_preference_analysis(
    preferences: np.ndarray,
    model_names: list[str]
) -> dict:
    """
    分析两两比较实验结果
    preferences[i,j] = model_i 胜过 model_j 的次数
    """
    n_models = len(model_names)
    total_comparisons = preferences + preferences.T

    # Elo 风格的胜率计算
    win_rates = np.zeros(n_models)
    for i in range(n_models):
        for j in range(n_models):
            if i != j and total_comparisons[i, j] > 0:
                win_rates[i] += preferences[i, j] / total_comparisons[i, j]
    win_rates /= (n_models - 1)

    # 统计显著性检验
    significance = {}
    for i in range(n_models):
        for j in range(i + 1, n_models):
            n = total_comparisons[i, j]
            if n > 0:
                p_value = stats.binom_test(
                    preferences[i, j], n, p=0.5
                )
                significance[f"{model_names[i]} vs {model_names[j]}"] = {
                    "preference_ratio": preferences[i, j] / n,
                    "p_value": round(p_value, 4),
                    "significant": p_value < 0.05
                }

    return {"win_rates": dict(zip(model_names, win_rates)), "significance": significance}

# 示例: preferences[i,j] 表示模型 i 胜过模型 j 的次数
prefs = np.array([[0, 60, 45], [40, 0, 30], [55, 70, 0]])
result = pairwise_preference_analysis(prefs, ["ModelA", "ModelB", "ModelC"])
print(result)
评估方法成本可靠性适用场景

平均意见评分

通用质量评估

两两比较

中高

模型排序对比

真实与生成辨别

评估逼真度

条件一致性评分

文本到图像评估

FID/IS 自动指标

快速迭代优化

进行人工评估时应确保评估者数量充足,建议至少 20 名评估者,并使用 Cohen Kappa 系数评估评分者间一致性

人工评估结果受评估者文化背景和经验影响,跨文化比较时需要特别谨慎,建议在报告评估结果时注明评估者构成

7torchmetrics 实战计算

在实际项目中,推荐使用 torchmetrics 库来简化生成模型评估指标的计算。torchmetrics 提供了 FID、KID、Inception Score 等多种指标的官方实现,支持增量计算和分布式训练。使用 torchmetrics 的主要优势在于 API 统一,不同指标的使用方式一致,且底层优化良好,支持 GPU 加速和大规模数据集处理。FID 指标在 torchmetrics 中的实现基于干净的 InceptionV3 特征提取,自动处理图像预处理和统计计算。KID 即 Kernel Inception Distance 是 FID 的无偏估计版本,使用多项式核的 MMD 统计量。对于文本到图像生成任务,可以结合 torchmetrics 的 CLIP 相关指标和自定义评估流程。在实际部署时,建议建立一个评估流水线,在训练过程中定期计算关键指标并记录到 TensorBoard 或 WandB,以便跟踪模型改进趋势。

python
import torch
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.kid import KernelInceptionDistance
from torchmetrics.image.inception import InceptionScore

# FID 计算
fid = FrechetInceptionDistance(feature=2048)
fid.update(real_images, real=True)
fid.update(fake_images, real=False)
fid_score = fid.compute()
print(f"FID: {fid_score:.2f}")

# KID (Kernel Inception Distance) - 无偏估计
kid = KernelInceptionDistance(subset_size=1000)
kid.update(real_images, real=True)
kid.update(fake_images, real=False)
kid_mean, kid_std = kid.compute()
print(f"KID: {kid_mean:.4f} +/- {kid_std:.4f}")

# Inception Score
inception = InceptionScore()
inception.update(fake_images)
is_mean, is_std = inception.compute()
print(f"IS: {is_mean:.2f} +/- {is_std:.2f}")
python
import torch
from torchmetrics.image.fid import FrechetInceptionDistance
from torch.utils.tensorboard import SummaryWriter
import os

class GenerationEvaluator:
    """集成化的生成模型评估器"""

    def __init__(self, device: str = "cuda", log_dir: str = "logs"):
        self.device = device
        self.fid = FrechetInceptionDistance(feature=2048).to(device)
        self.writer = SummaryWriter(log_dir)
        os.makedirs("checkpoints", exist_ok=True)

    def evaluate_epoch(
        self,
        generator: torch.nn.Module,
        real_loader: torch.utils.data.DataLoader,
        epoch: int,
        num_samples: int = 5000
    ) -> dict:
        """每个 epoch 评估生成模型"""
        generator.eval()

        # 更新真实数据特征
        for batch in real_loader:
            self.fid.update(batch[0].to(self.device), real=True)

        # 生成图像并更新
        with torch.no_grad():
            noise = torch.randn(num_samples, 100, device=self.device)
            fake_images = generator(noise)
            self.fid.update(fake_images, real=False)

        fid_score = self.fid.compute().item()

        # 记录到 TensorBoard
        self.writer.add_scalar("Metrics/FID", fid_score, epoch)
        self.writer.add_images("Generated", fake_images[:8], epoch)

        # 保存最佳模型
        if fid_score < getattr(self, "best_fid", float("inf")):
            self.best_fid = fid_score
            torch.save(generator.state_dict(), "checkpoints/best_generator.pt")
            print(f"New best FID: {fid_score:.2f} at epoch {epoch}")

        self.fid.reset()
        return {"fid": fid_score, "best_fid": self.best_fid}

evaluator = GenerationEvaluator(device="cuda")
for epoch in range(num_epochs):
    train_one_epoch(generator, dataloader, epoch)
    metrics = evaluator.evaluate_epoch(generator, real_loader, epoch)
    print(f"Epoch {epoch}: FID={metrics['fid']:.2f}")
torchmetrics 指标类名关键参数计算开销

FID

FrechetInceptionDistance

feature: 64/192/768/2048

KID

KernelInceptionDistance

subset_size: 子集大小

中高

IS

InceptionScore

splits: 分组数

LPIPS

LearnedPerceptualImagePatchSimilarity

net_type: alex/vgg

SSIM

StructuralSimilarityIndexMeasure

data_range: 数据范围

极低

使用 torchmetrics 时调用 reset() 方法很重要,否则指标会在多个 epoch 之间累积导致结果错误

FID 计算需要较大内存存储协方差矩阵,2048 维特征约 32MB,在显存有限的情况下可以分批更新或使用较小的 feature 维度

继续你的 AI 学习之旅

浏览更多 AI 知识库文章,或者探索 GitHub 上的优质 AI 项目