首页/知识库/A/B 测试与模型迭代

A/B 测试与模型迭代

✍️ AI Master📅 创建 2026-04-12📖 16 min 阅读
💡

文章摘要

从 A/B 测试到灰度发布,掌握模型迭代的最佳实践

1为什么需要 A/B 测试

在机器学习工程化中,模型上线只是开始,真正的挑战在于持续验证和优化。A/B 测试是验证模型改进效果的黄金标准,它通过将用户随机分流到不同版本,在真实生产环境中比较各版本的核心指标。与离线评估不同,A/B 测试能够捕捉到离线指标无法反映的复杂用户行为:比如一个新推荐模型可能在 NDCG 上提升了 2%,但用户停留时长却下降了,这种 tradeoff 只有通过线上实验才能发现。很多团队在模型迭代中踩过这样的坑:离线指标全线飘红,上线后核心业务指标反而下降。原因可能包括训练数据与线上数据分布不一致、离线评估指标与业务目标不匹配、或者模型在特定用户群体上存在严重偏差。A/B 测试的价值就在于它是最终的裁判,无论离线实验结果多么漂亮,线上 A/B 测试的结论才具有最终说服力。建立完善的 A/B 测试文化,是每个成熟 ML 团队的必修课。

python
from enum import Enum
from dataclasses import dataclass
from typing import Dict, Optional

class ExperimentStatus(Enum):
    DRAFT = "draft"
    RUNNING = "running"
    STOPPED = "stopped"
    CONCLUDED = "concluded"


@dataclass
class ExperimentConfig:
    """A/B 测试实验配置"""
    experiment_id: str
    name: str
    hypothesis: str
    metric_primary: str
    metrics_secondary: list[str]
    traffic_split: Dict[str, float]  # {"control": 0.5, "variant": 0.5}
    min_sample_size: int
    duration_days: int
    status: ExperimentStatus = ExperimentStatus.DRAFT

    def validate(self) -> bool:
        total = sum(self.traffic_split.values())
        return abs(total - 1.0) < 1e-6 and self.min_sample_size > 0
python
import hashlib
from typing import List


def consistent_hash(user_id: str, buckets: int = 100) -> int:
    """为同一用户始终返回相同的分桶结果"""
    return int(hashlib.md5(user_id.encode()).hexdigest(), 16) % buckets


def assign_variant(user_id: str, config: ExperimentConfig) -> Optional[str]:
    """根据流量分配配置为用户分配实验变体"""
    if config.status != ExperimentStatus.RUNNING:
        return None
    bucket = consistent_hash(user_id)
    cumulative = 0.0
    for variant, proportion in config.traffic_split.items():
        cumulative += proportion * 100
        if bucket < cumulative:
            return variant
    return None
评估方式速度成本可靠性适用阶段

离线评估

分钟级

依赖数据质量

模型开发

Shadow 模式

小时级

高(不暴露用户)

集成测试

A/B 测试

天级

最高

生产验证

多臂老虎机

天级

持续优化

灰度发布

天级

渐进上线

在启动 A/B 测试之前,先用 Shadow 模式跑几天,在不影响用户的前提下验证新模型是否正常运行。

不要在 A/B 测试中途随意修改实验配置或提前查看结果做决策,这会严重 inflate Type I error。

2实验设计基础

一个好的 A/B 测试实验始于严谨的设计。首先需要明确实验的核心假设:你相信新模型会比旧模型更好,具体好在哪里?这个假设必须是可量化的,比如 "新排序模型能将点击率提升至少 3%"。接下来要选定主指标和辅助指标。主指标是实验成败的判定标准,通常只有一个,避免多指标测试带来的多重比较问题。辅助指标用于观察新模型的副作用,比如转化率提升了但退货率是否也上升了。样本量计算是实验设计中最关键的步骤之一。样本量过小会导致统计功效不足,即使模型真的有效也检测不出来;样本量过大则会浪费资源并延长实验周期。样本量取决于基线转化率、期望检测的最小效应量、显著性水平和统计功效。随机化策略同样重要,必须确保用户被公平且一致地分配到各个实验组。常用的方法是基于用户 ID 的哈希分桶,保证同一个用户在整个实验期间始终看到同一个版本。此外,还需要考虑新奇效应,即用户因为看到新界面而产生的短期行为变化,这需要在分析时通过设置足够长的预热期来排除。

python
import math
from scipy.stats import norm


def calculate_sample_size(
    baseline_rate: float,
    min_detectable_effect: float,
    alpha: float = 0.05,
    power: float = 0.80
) -> int:
    """计算每组所需的最小样本量(比例检验)"""
    z_alpha = norm.ppf(1 - alpha / 2)
    z_beta = norm.ppf(power)
    p1 = baseline_rate
    p2 = baseline_rate * (1 + min_detectable_effect)
    p_bar = (p1 + p2) / 2
    numerator = (z_alpha * math.sqrt(2 * p_bar * (1 - p_bar)) +
                 z_beta * math.sqrt(p1 * (1 - p1) + p2 * (1 - p2)))  2
    denominator = (p2 - p1)  2
    return int(math.ceil(numerator / denominator))

# 基线转化率 10%,期望检测 5% 相对提升
n = calculate_sample_size(0.10, 0.05)
print(f"Each group needs: {n} samples")
python
import pandas as pd
import numpy as np


def generate_experiment_data(
    n_control: int,
    n_variant: int,
    conversion_rate_control: float = 0.10,
    conversion_rate_variant: float = 0.108
) -> pd.DataFrame:
    """生成模拟实验数据"""
    np.random.seed(42)
    control = pd.DataFrame({
        "user_id": [f"u_{i}" for i in range(n_control)],
        "group": "control",
        "converted": np.random.binomial(1, conversion_rate_control, n_control)
    })
    variant = pd.DataFrame({
        "user_id": [f"u_{i + n_control}" for i in range(n_variant)],
        "group": "variant",
        "converted": np.random.binomial(1, conversion_rate_variant, n_variant)
    })
    return pd.concat([control, variant], ignore_index=True)

data = generate_experiment_data(50000, 50000)
print(data.groupby("group")["converted"].mean())
参数典型值说明

显著性水平 alpha

0.05

Type I error 容忍度

统计功效 power

0.80

检测到真实效应的概率

最小可检测效应

2%-10%

业务上认为有意义的变化

预热期

1-3 天

排除新奇效应

实验周期

7-14 天

覆盖完整的用户行为周期

实验周期至少覆盖一个完整的业务周期(比如一周),以消除周末效应等周期性波动。

样本量计算基于的基线率和效应量是估计值,实际运行中需要持续监控并确保数据质量。

3统计显著性与功效

统计显著性是 A/B 测试分析的核心概念,它回答了一个关键问题:观察到的差异有多大可能是真实存在的,而非随机波动造成的。p 值是最常用的显著性指标,它表示在原假设(两组无差异)为真的情况下,观察到当前差异或更极端差异的概率。当 p 值小于预设的显著性水平(通常 0.05)时,我们拒绝原假设,认为实验结果统计显著。然而,统计显著不等于业务显著。一个拥有百万级样本的实验可能检测出 0.01% 的差异并判定为显著,但这个差异在业务上可能毫无意义。因此,除了 p 值,还需要关注效应量(effect size)和置信区间。置信区间提供了效应量可能范围的估计,比单一的 p 值包含更多信息。统计功效是另一个常被忽视的重要概念,它表示当真实效应存在时,实验能够检测到的概率。功效不足是 A/B 测试失败的最常见原因之一。当功效只有 50% 时,即使新模型真的更好,你也有 50% 的概率得出 "无显著差异" 的错误结论。在实践中,建议在实验设计阶段就确保功效不低于 80%,这通常意味着需要足够的样本量和合理的实验周期。贝叶斯方法近年来在 A/B 测试中也越来越流行,它直接给出 "B 版本优于 A 版本的概率",比频率派的 p 值更直观易懂。

python
from scipy import stats
import numpy as np


def analyze_ab_test(
    conversions_control: int,
    n_control: int,
    conversions_variant: int,
    n_variant: int,
    alpha: float = 0.05
) -> dict:
    """频率派 A/B 测试分析"""
    p1 = conversions_control / n_control
    p2 = conversions_variant / n_variant
    lift = (p2 - p1) / p1 * 100

    z_stat, p_value = stats.proportions_ztest(
        [conversions_variant, conversions_control],
        [n_variant, n_control],
        alternative="two-sided"
    )

    # 计算效应量 (Cohen's h)
    h = 2 * np.arcsin(np.sqrt(p2)) - 2 * np.arcsin(np.sqrt(p1))

    # 计算 95% 置信区间
    se = np.sqrt(p1 * (1 - p1) / n_control + p2 * (1 - p2) / n_variant)
    ci_low = (p2 - p1) - 1.96 * se
    ci_high = (p2 - p1) + 1.96 * se

    return {
        "p_value": round(p_value, 6),
        "significant": p_value < alpha,
        "lift_pct": round(lift, 2),
        "effect_size_h": round(h, 4),
        "ci_95": [round(ci_low, 6), round(ci_high, 6)]
    }

result = analyze_ab_test(5000, 50000, 5400, 50000)
print(result)
python
import numpy as np
from scipy.stats import beta


def bayesian_ab_analysis(
    conversions_a: int, n_a: int,
    conversions_b: int, n_b: int,
    n_samples: int = 100000
) -> dict:
    """贝叶斯 A/B 测试分析"""
    # 使用 Beta 先验 (Jeffreys prior)
    posterior_a = beta.rvs(
        conversions_a + 0.5, n_a - conversions_a + 0.5,
        size=n_samples
    )
    posterior_b = beta.rvs(
        conversions_b + 0.5, n_b - conversions_b + 0.5,
        size=n_samples
    )

    prob_b_better = (posterior_b > posterior_a).mean()
    expected_lift = ((posterior_b - posterior_a) / posterior_a).mean() * 100
    ci_lift = np.percentile(
        (posterior_b - posterior_a) / posterior_a * 100,
        [2.5, 97.5]
    )

    return {
        "prob_b_better": round(prob_b_better, 4),
        "expected_lift_pct": round(expected_lift, 2),
        "credible_interval_95": [round(ci_lift[0], 2), round(ci_lift[1], 2)]
    }

result = bayesian_ab_analysis(5000, 50000, 5400, 50000)
print(result)
概念频率派解释贝叶斯解释实际含义

p 值

H0 为真时观察到此结果的概率

不适用

越小越支持有差异

置信区间

95% CI 覆盖真实值的频率

真实值在此区间的概率

区间不含 0 则显著

功效

真实效应存在时检测到的概率

不适用

建议不低于 80%

效应量

标准化差异大小

后验差异分布

区分统计与业务显著性

先验

实验前的信念

可用历史数据构建

始终同时报告 p 值、效应量和置信区间,三者结合才能给出完整的实验结论。

不要 peeking——在实验未达到预定样本量之前就反复查看结果并提前终止,这会严重扭曲 p 值。

4多臂老虎机算法

传统的 A/B 测试在实验期间对各个版本分配固定比例的流量,这意味着即使某个版本明显更差,它仍然会持续浪费流量。多臂老虎机(Multi-Armed Bandit, MAB)算法通过在探索(exploration)和利用(exploitation)之间动态平衡,解决了这个效率问题。想象你在赌场面对多台老虎机,每台机器的中奖概率不同但你不知道。如果你一直尝试同一台机器(纯利用),可能错过更好的选择;如果你不停地换机器尝试(纯探索),又会浪费很多筹码。MAB 算法就是帮你在这两者之间找到最优平衡。最常用的算法包括 Epsilon-Greedy、Thompson Sampling 和 Upper Confidence Bound。Epsilon-Greedy 最简单:以概率 epsilon 随机探索,以概率 1-epsilon 选择当前最优版本。Thompson Sampling 则是贝叶斯方法,它维护每个版本转化率的后验分布,每次从后验中采样,选择采样值最大的版本。MAB 特别适合需要持续优化的场景,比如推荐系统中的算法切换、广告投放策略优化等。与固定流量的 A/B 测试相比,MAB 可以在实验期间自动将更多流量分配给表现更好的版本,从而在实验过程中就获得更高的整体收益。但 MAB 也有局限性:它的统计推断不如传统 A/B 测试严谨,难以给出明确的 "哪个版本更好" 的结论;同时,如果环境发生变化(比如季节性波动),算法需要时间来适应。

python
import numpy as np
from typing import List


class ThompsonSampling:
    """Thompson Sampling 多臂老虎机"""

    def __init__(self, n_arms: int, alpha_prior: float = 1.0,
                 beta_prior: float = 1.0):
        self.n_arms = n_arms
        self.alpha = np.full(n_arms, alpha_prior)  #  successes + prior
        self.beta = np.full(n_arms, beta_prior)     #  failures + prior

    def select_arm(self) -> int:
        """从 Beta 后验采样,选择值最大的 arm"""
        samples = np.random.beta(self.alpha, self.beta)
        return int(np.argmax(samples))

    def update(self, arm: int, reward: float):
        """更新指定 arm 的后验分布"""
        if reward > 0:
            self.alpha[arm] += 1
        else:
            self.beta[arm] += 1

    def get_estimates(self) -> np.ndarray:
        """返回各 arm 的期望估计"""
        return self.alpha / (self.alpha + self.beta)


# 模拟: 3 个版本,真实转化率分别为 10%, 12%, 8%
ts = ThompsonSampling(3)
true_rates = [0.10, 0.12, 0.08]
for _ in range(10000):
    arm = ts.select_arm()
    reward = float(np.random.random() < true_rates[arm])
    ts.update(arm, reward)
print(f"Estimates: {ts.get_estimates()}")
python
import numpy as np
from typing import Tuple


class EpsilonGreedy:
    """Epsilon-Greedy 多臂老虎机"""

    def __init__(self, n_arms: int, epsilon: float = 0.1,
                 decay: float = 0.999):
        self.n_arms = n_arms
        self.epsilon = epsilon
        self.decay = decay
        self.counts = np.zeros(n_arms)
        self.values = np.zeros(n_arms)

    def select_arm(self) -> int:
        if np.random.random() < self.epsilon:
            return int(np.random.randint(self.n_arms))
        # 打破平局: 随机选择最优的
        max_val = np.max(self.values)
        best_arms = np.where(self.values == max_val)[0]
        return int(np.random.choice(best_arms))

    def update(self, arm: int, reward: float):
        self.counts[arm] += 1
        n = self.counts[arm]
        self.values[arm] += (reward - self.values[arm]) / n
        self.epsilon *= self.decay

    def regret(self, true_rates: np.ndarray) -> float:
        """计算累积 regret"""
        best_rate = np.max(true_rates)
        total_reward = np.sum(self.counts * self.values)
        optimal_reward = np.sum(self.counts) * best_rate
        return optimal_reward - total_reward
算法探索策略计算复杂度适用场景优势

Epsilon-Greedy

固定概率随机

O(1)

简单快速决策

实现简单,易调试

Thompson Sampling

后验采样

O(n)

转化率优化

自动平衡探索/利用

UCB

置信上界

O(n)

有明确置信区间

理论保证强

Softmax

Boltzmann 分布

O(n)

需要平滑选择

概率分配更精细

固定 A/B

无探索

O(1)

严谨统计推断

结论明确可靠

Thompson Sampling 在大多数实际场景中表现最优,特别是当你有多次实验机会时,它的 regret 增长是对数级的。

MAB 算法不适合需要明确统计结论的场景。如果你需要向管理层汇报 "A 比 B 好 X%,p<0.05",还是用传统 A/B 测试。

5灰度发布与渐进式部署

灰度发布是模型从测试环境走向全量生产的关键过渡阶段。即使 A/B 测试结果显示新模型显著优于旧模型,直接 100% 切换仍然存在风险:新模型可能在某些极端场景下表现异常,或者在特定用户群体上存在未发现的偏差。灰度发布的核心思想是 "小步快跑":先让 1% 的用户使用新模型,确认没有严重问题后逐步扩大到 5%、10%、25%、50%,最终全量。每个阶段都需要密切监控关键指标,包括业务指标和技术指标。业务指标关注转化率、用户满意度等,技术指标关注延迟、错误率、资源消耗等。渐进式部署需要强大的基础设施支持,包括流量路由、实时指标监控、快速回滚能力和自动化告警。现代云原生架构中的 Service Mesh(如 Istio)和 API Gateway 都提供了精细的流量控制能力,可以实现基于用户属性、地域、设备类型等维度的灰度策略。一个典型的灰度发布流程通常需要 1-2 周,每个阶段至少运行 24-48 小时以覆盖完整的业务周期。在灰度期间,如果发现新模型在某些场景下表现不佳,可以立即暂停扩展并回退到上一阶段,而不需要完全回滚。这种渐进式的方法既保证了发布的安全性,又不影响迭代的速度。

python
from dataclasses import dataclass
from typing import Dict, List
import time


@dataclass
class CanaryStage:
    """灰度发布的一个阶段"""
    traffic_pct: float
    min_duration_hours: int
    success_criteria: Dict[str, float]


class CanaryDeployment:
    """管理灰度发布流程"""

    def __init__(self, model_name: str):
        self.model_name = model_name
        self.stages: List[CanaryStage] = [
            CanaryStage(1, 24, {"error_rate": 0.01, "latency_p99": 500}),
            CanaryStage(5, 24, {"error_rate": 0.01, "latency_p99": 500}),
            CanaryStage(25, 48, {"error_rate": 0.005, "latency_p99": 450}),
            CanaryStage(50, 48, {"error_rate": 0.005, "latency_p99": 450}),
            CanaryStage(100, 24, {"error_rate": 0.005, "latency_p99": 400}),
        ]
        self.current_stage = 0

    def check_stage_health(self, metrics: Dict[str, float]) -> bool:
        """检查当前阶段的健康指标"""
        criteria = self.stages[self.current_stage].success_criteria
        return all(
            metrics.get(k, float("inf")) <= v
            for k, v in criteria.items()
        )

    def promote(self) -> bool:
        """推进到下一个阶段"""
        if self.current_stage >= len(self.stages) - 1:
            return False
        self.current_stage += 1
        return True
yaml
# Istio VirtualService 灰度流量配置
apiVersion: networking.istio.io/v1beta3
kind: VirtualService
metadata:
  name: model-serving-route
spec:
  hosts:
    - model-api.example.com
  http:
    - route:
        - destination:
            host: model-api
            subset: stable
          weight: 95
        - destination:
            host: model-api
            subset: canary
          weight: 5
      timeout: 2s
      retries:
        attempts: 3
        perTryTimeout: 500ms
---
apiVersion: networking.istio.io/v1beta3
kind: DestinationRule
metadata:
  name: model-api-versions
spec:
  host: model-api.example.com
  subsets:
    - name: stable
      labels:
        version: v1.2.0
    - name: canary
      labels:
        version: v1.3.0
阶段流量比例持续时间关注重点回滚策略

Canary 1%

1%

24h

致命错误

立即回滚到 0%

Canary 5%

5%

24h

性能与错误率

回滚到 1%

Canary 25%

25%

48h

业务指标趋势

回滚到 5%

Canary 50%

50%

48h

用户反馈

回滚到 25%

Full Rollout

100%

24h

全量稳定性

回滚到 50%

灰度发布的每个阶段都应该有明确的退出标准(success criteria)和回滚触发条件,提前写好 Runbook。

不要在非工作时间推进灰度阶段,万一出问题需要团队快速响应。选择工作日的上午推进新阶段。

6模型回滚策略

即使有了完善的 A/B 测试和灰度发布流程,模型上线后仍然可能出现意外问题:数据分布的突然变化、依赖服务的故障、或者某个边界场景下的推理错误。因此,建立快速可靠的模型回滚机制是 ML 工程化的最后一道安全防线。回滚策略需要在问题发生之前就设计好,而不是临时应对。首先需要定义清晰的回滚触发条件:哪些指标异常到什么程度需要回滚?这通常包括技术指标(错误率飙升、延迟超标)和业务指标(转化率下降、用户投诉激增)。其次,回滚操作本身必须足够快,理想情况下应该在分钟级别完成。这就要求模型服务架构支持蓝绿部署或多版本共存,新的请求可以瞬间切换到旧版本模型。回滚不仅仅是切换模型版本,还需要考虑数据一致性、用户状态和缓存清理等问题。比如,如果新模型修改了用户的推荐队列,回滚时是否需要恢复旧队列?如果新模型写入了特定的用户画像标签,这些标签是否需要清理?最后,每次回滚都应该被视为一次学习机会。进行根因分析,找到模型失败的原因,改进训练流程或数据质量,然后重新迭代。一个成熟的 ML 团队不会因为回滚而感到挫败,他们会因为快速发现问题并安全回滚而感到自豪。

python
import time
from enum import Enum
from typing import Optional
import threading


class ModelState(Enum):
    ACTIVE = "active"
    ROLLED_BACK = "rolled_back"
    ROLLING_BACK = "rolling_back"


class ModelRollbackManager:
    """模型回滚管理器"""

    def __init__(self, model_name: str):
        self.model_name = model_name
        self.state = ModelState.ACTIVE
        self.current_version: Optional[str] = None
        self.previous_version: Optional[str] = None
        self.rollback_thresholds = {
            "error_rate": 0.05,
            "latency_p99_ms": 1000,
            "conversion_drop_pct": -10,
        }
        self._lock = threading.Lock()

    def check_and_trigger_rollback(self, metrics: dict) -> bool:
        """检查指标,必要时触发回滚"""
        with self._lock:
            if self.state == ModelState.ROLLING_BACK:
                return False
            triggered = False
            if metrics.get("error_rate", 0) > self.rollback_thresholds["error_rate"]:
                triggered = True
            if metrics.get("latency_p99_ms", 0) > self.rollback_thresholds["latency_p99_ms"]:
                triggered = True
            if metrics.get("conversion_drop_pct", 0) < self.rollback_thresholds["conversion_drop_pct"]:
                triggered = True
            if triggered:
                return self._execute_rollback()
            return False

    def _execute_rollback(self) -> bool:
        if not self.previous_version:
            return False
        self.state = ModelState.ROLLING_BACK
        # 切换流量到旧版本
        time.sleep(1)  # 模拟流量切换
        self.current_version, self.previous_version = (
            self.previous_version, self.current_version
        )
        self.state = ModelState.ROLLED_BACK
        return True
python
from dataclasses import dataclass
from typing import List
import json


@dataclass
class RollbackEvent:
    """回滚事件记录"""
    model_name: str
    from_version: str
    to_version: str
    trigger_reason: str
    metrics_snapshot: dict
    timestamp: str
    operator: str


class RollbackHistory:
    """维护回滚历史并生成分析报告"""

    def __init__(self):
        self.events: List[RollbackEvent] = []

    def record(self, event: RollbackEvent):
        self.events.append(event)

    def get_rollback_rate(self, model_name: str) -> float:
        total = sum(1 for e in self.events if e.model_name == model_name)
        return total / max(1, 10)  # 假设 10 次发布

    def common_reasons(self) -> List[str]:
        reasons = [e.trigger_reason for e in self.events]
        from collections import Counter
        return [r for r, _ in Counter(reasons).most_common(5)]

    def generate_report(self) -> dict:
        return {
            "total_rollbacks": len(self.events),
            "common_reasons": self.common_reasons(),
            "recent_events": [
                json.dumps({
                    "model": e.model_name,
                    "from": e.from_version,
                    "to": e.to_version,
                    "reason": e.trigger_reason
                })
                for e in self.events[-5:]
            ]
        }
回滚触发条件阈值建议响应时间要求自动化程度

错误率飙升

超过 5%

1 分钟内

全自动

延迟超标

P99 超过 1s

1 分钟内

全自动

转化率下降

相对下降 10%

15 分钟内

半自动

数据漂移

PSI 超过 0.25

1 小时内

告警+人工

用户投诉激增

工单量翻倍

30 分钟内

告警+人工

定期进行回滚演练(Chaos Engineering),验证回滚流程的可靠性和速度,就像消防演习一样重要。

回滚不等于放弃。每次回滚后必须做根因分析,否则同样的问题会在下次发布时重现。

7实战:MLflow + A/B 测试框架

理论终究要落实到代码。本节通过一个完整的实战案例,演示如何将 MLflow 与 A/B 测试框架结合,构建端到端的模型迭代流水线。MLflow 是 Databricks 开源的 ML 生命周期管理平台,提供了模型注册、版本管理、实验追踪和模型服务部署等功能。在我们的实战方案中,MLflow 负责模型版本的注册和追踪,A/B 测试框架负责流量分配和效果评估,两者通过 API 集成形成完整的迭代闭环。首先需要搭建 MLflow Model Registry,将训练好的模型注册为版本化的 "Model" 对象。每个模型版本都有明确的生命周期状态:None(开发中)、Staging(测试中)、Production(生产中)、Archived(归档)。当新模型通过离线评估后,注册为 Staging 状态,然后通过 A/B 测试验证其线上效果。验证通过后,自动将模型版本标记为 Production,触发灰度发布流程。这个流程可以通过 MLflow 的 Webhook 与 CI/CD 流水线集成,实现从模型注册到线上部署的完全自动化。A/B 测试的流量分配可以通过自定义中间件实现,该中间件查询实验配置,根据用户 ID 决定路由到哪个模型版本。评估指标通过日志系统收集,定期运行统计分析脚本,自动生成实验报告。整个闭环的关键在于自动化:从模型注册、实验创建、流量分配到结果分析,每一步都应该是可编程的、可追溯的、可复现的。

python
import mlflow
from mlflow.tracking import MlflowClient


class MLflowModelRegistry:
    """基于 MLflow 的模型注册与部署管理"""

    def __init__(self, tracking_uri: str = "http://localhost:5000"):
        self.client = MlflowClient(tracking_uri=tracking_uri)

    def register_model(
        self,
        model_name: str,
        run_id: str,
        description: str = ""
    ) -> str:
        """注册模型并返回版本号"""
        result = self.client.create_model_version(
            name=model_name,
            source=f"runs:/{run_id}/model",
            run_id=run_id,
            description=description
        )
        return result.version

    def promote_to_staging(self, model_name: str, version: str):
        """将模型版本推进到 Staging"""
        self.client.transition_model_version_stage(
            name=model_name, version=version, stage="Staging"
        )

    def promote_to_production(self, model_name: str, version: str):
        """将模型版本推进到 Production"""
        self.client.transition_model_version_stage(
            name=model_name, version=version, stage="Production"
        )

    def get_production_version(self, model_name: str) -> str:
        versions = self.client.get_latest_versions(
            model_name, stages=["Production"]
        )
        return versions[0].version if versions else None
python
import mlflow
import requests
from typing import Optional


class ABTestModelRouter:
    """集成 MLflow 的 A/B 测试模型路由器"""

    def __init__(self, model_name: str, experiment_id: str):
        self.model_name = model_name
        self.experiment_id = experiment_id
        self.client = mlflow.tracking.MlflowClient()
        self._model_cache = {}

    def get_model_uri(self, stage: str = "Production") -> Optional[str]:
        """获取指定阶段模型的 Serving URI"""
        versions = self.client.get_latest_versions(
            self.model_name, stages=[stage]
        )
        if not versions:
            return None
        return versions[0].source

    def predict(self, user_id: str, features: dict) -> dict:
        """根据实验配置路由到对应模型并返回预测结果"""
        variant = self._get_assigned_variant(user_id)
        if variant == "control":
            uri = self.get_model_uri("Production")
        else:
            uri = self.get_model_uri("Staging")

        if not uri:
            raise RuntimeError(f"No model found for variant: {variant}")

        response = requests.post(
            f"{uri}/invocations",
            json={"inputs": [features]},
            timeout=5
        )
        prediction = response.json()
        self._log_prediction(user_id, variant, prediction)
        return prediction

    def _log_prediction(self, user_id, variant, prediction):
        pass  # 记录到实验分析系统

    def _get_assigned_variant(self, user_id: str) -> str:
        import hashlib
        bucket = int(hashlib.md5(user_id.encode()).hexdigest(), 16) % 100
        return "variant" if bucket < 10 else "control"  # 10% 流量
组件职责工具选择关键指标

模型注册

版本管理与状态追踪

MLflow Registry

注册成功率

流量分配

用户分流与一致性

自定义中间件

分流均匀性

指标收集

业务与技术指标日志

ELK / DataDog

数据完整性

统计分析

显著性检验与报告

自定义脚本

分析准确率

部署编排

版本切换与回滚

K8s + Istio

部署成功率

将 MLflow Webhook 与 Slack/飞书集成,模型版本状态变更时自动通知团队,保持信息透明。

MLflow Model Registry 的 stage 转换是异步操作,在高并发场景下需要加锁防止并发冲突。

继续你的 AI 学习之旅

浏览更多 AI 知识库文章,或者探索 GitHub 上的优质 AI 项目