1为什么需要 A/B 测试
在机器学习工程化中,模型上线只是开始,真正的挑战在于持续验证和优化。A/B 测试是验证模型改进效果的黄金标准,它通过将用户随机分流到不同版本,在真实生产环境中比较各版本的核心指标。与离线评估不同,A/B 测试能够捕捉到离线指标无法反映的复杂用户行为:比如一个新推荐模型可能在 NDCG 上提升了 2%,但用户停留时长却下降了,这种 tradeoff 只有通过线上实验才能发现。很多团队在模型迭代中踩过这样的坑:离线指标全线飘红,上线后核心业务指标反而下降。原因可能包括训练数据与线上数据分布不一致、离线评估指标与业务目标不匹配、或者模型在特定用户群体上存在严重偏差。A/B 测试的价值就在于它是最终的裁判,无论离线实验结果多么漂亮,线上 A/B 测试的结论才具有最终说服力。建立完善的 A/B 测试文化,是每个成熟 ML 团队的必修课。
from enum import Enum
from dataclasses import dataclass
from typing import Dict, Optional
class ExperimentStatus(Enum):
DRAFT = "draft"
RUNNING = "running"
STOPPED = "stopped"
CONCLUDED = "concluded"
@dataclass
class ExperimentConfig:
"""A/B 测试实验配置"""
experiment_id: str
name: str
hypothesis: str
metric_primary: str
metrics_secondary: list[str]
traffic_split: Dict[str, float] # {"control": 0.5, "variant": 0.5}
min_sample_size: int
duration_days: int
status: ExperimentStatus = ExperimentStatus.DRAFT
def validate(self) -> bool:
total = sum(self.traffic_split.values())
return abs(total - 1.0) < 1e-6 and self.min_sample_size > 0import hashlib
from typing import List
def consistent_hash(user_id: str, buckets: int = 100) -> int:
"""为同一用户始终返回相同的分桶结果"""
return int(hashlib.md5(user_id.encode()).hexdigest(), 16) % buckets
def assign_variant(user_id: str, config: ExperimentConfig) -> Optional[str]:
"""根据流量分配配置为用户分配实验变体"""
if config.status != ExperimentStatus.RUNNING:
return None
bucket = consistent_hash(user_id)
cumulative = 0.0
for variant, proportion in config.traffic_split.items():
cumulative += proportion * 100
if bucket < cumulative:
return variant
return None| 评估方式 | 速度 | 成本 | 可靠性 | 适用阶段 |
|---|---|---|---|---|
离线评估 | 分钟级 | 低 | 依赖数据质量 | 模型开发 |
Shadow 模式 | 小时级 | 中 | 高(不暴露用户) | 集成测试 |
A/B 测试 | 天级 | 高 | 最高 | 生产验证 |
多臂老虎机 | 天级 | 中 | 高 | 持续优化 |
灰度发布 | 天级 | 中 | 高 | 渐进上线 |
在启动 A/B 测试之前,先用 Shadow 模式跑几天,在不影响用户的前提下验证新模型是否正常运行。
不要在 A/B 测试中途随意修改实验配置或提前查看结果做决策,这会严重 inflate Type I error。
2实验设计基础
一个好的 A/B 测试实验始于严谨的设计。首先需要明确实验的核心假设:你相信新模型会比旧模型更好,具体好在哪里?这个假设必须是可量化的,比如 "新排序模型能将点击率提升至少 3%"。接下来要选定主指标和辅助指标。主指标是实验成败的判定标准,通常只有一个,避免多指标测试带来的多重比较问题。辅助指标用于观察新模型的副作用,比如转化率提升了但退货率是否也上升了。样本量计算是实验设计中最关键的步骤之一。样本量过小会导致统计功效不足,即使模型真的有效也检测不出来;样本量过大则会浪费资源并延长实验周期。样本量取决于基线转化率、期望检测的最小效应量、显著性水平和统计功效。随机化策略同样重要,必须确保用户被公平且一致地分配到各个实验组。常用的方法是基于用户 ID 的哈希分桶,保证同一个用户在整个实验期间始终看到同一个版本。此外,还需要考虑新奇效应,即用户因为看到新界面而产生的短期行为变化,这需要在分析时通过设置足够长的预热期来排除。
import math
from scipy.stats import norm
def calculate_sample_size(
baseline_rate: float,
min_detectable_effect: float,
alpha: float = 0.05,
power: float = 0.80
) -> int:
"""计算每组所需的最小样本量(比例检验)"""
z_alpha = norm.ppf(1 - alpha / 2)
z_beta = norm.ppf(power)
p1 = baseline_rate
p2 = baseline_rate * (1 + min_detectable_effect)
p_bar = (p1 + p2) / 2
numerator = (z_alpha * math.sqrt(2 * p_bar * (1 - p_bar)) +
z_beta * math.sqrt(p1 * (1 - p1) + p2 * (1 - p2))) 2
denominator = (p2 - p1) 2
return int(math.ceil(numerator / denominator))
# 基线转化率 10%,期望检测 5% 相对提升
n = calculate_sample_size(0.10, 0.05)
print(f"Each group needs: {n} samples")import pandas as pd
import numpy as np
def generate_experiment_data(
n_control: int,
n_variant: int,
conversion_rate_control: float = 0.10,
conversion_rate_variant: float = 0.108
) -> pd.DataFrame:
"""生成模拟实验数据"""
np.random.seed(42)
control = pd.DataFrame({
"user_id": [f"u_{i}" for i in range(n_control)],
"group": "control",
"converted": np.random.binomial(1, conversion_rate_control, n_control)
})
variant = pd.DataFrame({
"user_id": [f"u_{i + n_control}" for i in range(n_variant)],
"group": "variant",
"converted": np.random.binomial(1, conversion_rate_variant, n_variant)
})
return pd.concat([control, variant], ignore_index=True)
data = generate_experiment_data(50000, 50000)
print(data.groupby("group")["converted"].mean())| 参数 | 典型值 | 说明 |
|---|---|---|
显著性水平 alpha | 0.05 | Type I error 容忍度 |
统计功效 power | 0.80 | 检测到真实效应的概率 |
最小可检测效应 | 2%-10% | 业务上认为有意义的变化 |
预热期 | 1-3 天 | 排除新奇效应 |
实验周期 | 7-14 天 | 覆盖完整的用户行为周期 |
实验周期至少覆盖一个完整的业务周期(比如一周),以消除周末效应等周期性波动。
样本量计算基于的基线率和效应量是估计值,实际运行中需要持续监控并确保数据质量。
3统计显著性与功效
统计显著性是 A/B 测试分析的核心概念,它回答了一个关键问题:观察到的差异有多大可能是真实存在的,而非随机波动造成的。p 值是最常用的显著性指标,它表示在原假设(两组无差异)为真的情况下,观察到当前差异或更极端差异的概率。当 p 值小于预设的显著性水平(通常 0.05)时,我们拒绝原假设,认为实验结果统计显著。然而,统计显著不等于业务显著。一个拥有百万级样本的实验可能检测出 0.01% 的差异并判定为显著,但这个差异在业务上可能毫无意义。因此,除了 p 值,还需要关注效应量(effect size)和置信区间。置信区间提供了效应量可能范围的估计,比单一的 p 值包含更多信息。统计功效是另一个常被忽视的重要概念,它表示当真实效应存在时,实验能够检测到的概率。功效不足是 A/B 测试失败的最常见原因之一。当功效只有 50% 时,即使新模型真的更好,你也有 50% 的概率得出 "无显著差异" 的错误结论。在实践中,建议在实验设计阶段就确保功效不低于 80%,这通常意味着需要足够的样本量和合理的实验周期。贝叶斯方法近年来在 A/B 测试中也越来越流行,它直接给出 "B 版本优于 A 版本的概率",比频率派的 p 值更直观易懂。
from scipy import stats
import numpy as np
def analyze_ab_test(
conversions_control: int,
n_control: int,
conversions_variant: int,
n_variant: int,
alpha: float = 0.05
) -> dict:
"""频率派 A/B 测试分析"""
p1 = conversions_control / n_control
p2 = conversions_variant / n_variant
lift = (p2 - p1) / p1 * 100
z_stat, p_value = stats.proportions_ztest(
[conversions_variant, conversions_control],
[n_variant, n_control],
alternative="two-sided"
)
# 计算效应量 (Cohen's h)
h = 2 * np.arcsin(np.sqrt(p2)) - 2 * np.arcsin(np.sqrt(p1))
# 计算 95% 置信区间
se = np.sqrt(p1 * (1 - p1) / n_control + p2 * (1 - p2) / n_variant)
ci_low = (p2 - p1) - 1.96 * se
ci_high = (p2 - p1) + 1.96 * se
return {
"p_value": round(p_value, 6),
"significant": p_value < alpha,
"lift_pct": round(lift, 2),
"effect_size_h": round(h, 4),
"ci_95": [round(ci_low, 6), round(ci_high, 6)]
}
result = analyze_ab_test(5000, 50000, 5400, 50000)
print(result)import numpy as np
from scipy.stats import beta
def bayesian_ab_analysis(
conversions_a: int, n_a: int,
conversions_b: int, n_b: int,
n_samples: int = 100000
) -> dict:
"""贝叶斯 A/B 测试分析"""
# 使用 Beta 先验 (Jeffreys prior)
posterior_a = beta.rvs(
conversions_a + 0.5, n_a - conversions_a + 0.5,
size=n_samples
)
posterior_b = beta.rvs(
conversions_b + 0.5, n_b - conversions_b + 0.5,
size=n_samples
)
prob_b_better = (posterior_b > posterior_a).mean()
expected_lift = ((posterior_b - posterior_a) / posterior_a).mean() * 100
ci_lift = np.percentile(
(posterior_b - posterior_a) / posterior_a * 100,
[2.5, 97.5]
)
return {
"prob_b_better": round(prob_b_better, 4),
"expected_lift_pct": round(expected_lift, 2),
"credible_interval_95": [round(ci_lift[0], 2), round(ci_lift[1], 2)]
}
result = bayesian_ab_analysis(5000, 50000, 5400, 50000)
print(result)| 概念 | 频率派解释 | 贝叶斯解释 | 实际含义 |
|---|---|---|---|
p 值 | H0 为真时观察到此结果的概率 | 不适用 | 越小越支持有差异 |
置信区间 | 95% CI 覆盖真实值的频率 | 真实值在此区间的概率 | 区间不含 0 则显著 |
功效 | 真实效应存在时检测到的概率 | 不适用 | 建议不低于 80% |
效应量 | 标准化差异大小 | 后验差异分布 | 区分统计与业务显著性 |
先验 | 无 | 实验前的信念 | 可用历史数据构建 |
始终同时报告 p 值、效应量和置信区间,三者结合才能给出完整的实验结论。
不要 peeking——在实验未达到预定样本量之前就反复查看结果并提前终止,这会严重扭曲 p 值。
4多臂老虎机算法
传统的 A/B 测试在实验期间对各个版本分配固定比例的流量,这意味着即使某个版本明显更差,它仍然会持续浪费流量。多臂老虎机(Multi-Armed Bandit, MAB)算法通过在探索(exploration)和利用(exploitation)之间动态平衡,解决了这个效率问题。想象你在赌场面对多台老虎机,每台机器的中奖概率不同但你不知道。如果你一直尝试同一台机器(纯利用),可能错过更好的选择;如果你不停地换机器尝试(纯探索),又会浪费很多筹码。MAB 算法就是帮你在这两者之间找到最优平衡。最常用的算法包括 Epsilon-Greedy、Thompson Sampling 和 Upper Confidence Bound。Epsilon-Greedy 最简单:以概率 epsilon 随机探索,以概率 1-epsilon 选择当前最优版本。Thompson Sampling 则是贝叶斯方法,它维护每个版本转化率的后验分布,每次从后验中采样,选择采样值最大的版本。MAB 特别适合需要持续优化的场景,比如推荐系统中的算法切换、广告投放策略优化等。与固定流量的 A/B 测试相比,MAB 可以在实验期间自动将更多流量分配给表现更好的版本,从而在实验过程中就获得更高的整体收益。但 MAB 也有局限性:它的统计推断不如传统 A/B 测试严谨,难以给出明确的 "哪个版本更好" 的结论;同时,如果环境发生变化(比如季节性波动),算法需要时间来适应。
import numpy as np
from typing import List
class ThompsonSampling:
"""Thompson Sampling 多臂老虎机"""
def __init__(self, n_arms: int, alpha_prior: float = 1.0,
beta_prior: float = 1.0):
self.n_arms = n_arms
self.alpha = np.full(n_arms, alpha_prior) # successes + prior
self.beta = np.full(n_arms, beta_prior) # failures + prior
def select_arm(self) -> int:
"""从 Beta 后验采样,选择值最大的 arm"""
samples = np.random.beta(self.alpha, self.beta)
return int(np.argmax(samples))
def update(self, arm: int, reward: float):
"""更新指定 arm 的后验分布"""
if reward > 0:
self.alpha[arm] += 1
else:
self.beta[arm] += 1
def get_estimates(self) -> np.ndarray:
"""返回各 arm 的期望估计"""
return self.alpha / (self.alpha + self.beta)
# 模拟: 3 个版本,真实转化率分别为 10%, 12%, 8%
ts = ThompsonSampling(3)
true_rates = [0.10, 0.12, 0.08]
for _ in range(10000):
arm = ts.select_arm()
reward = float(np.random.random() < true_rates[arm])
ts.update(arm, reward)
print(f"Estimates: {ts.get_estimates()}")import numpy as np
from typing import Tuple
class EpsilonGreedy:
"""Epsilon-Greedy 多臂老虎机"""
def __init__(self, n_arms: int, epsilon: float = 0.1,
decay: float = 0.999):
self.n_arms = n_arms
self.epsilon = epsilon
self.decay = decay
self.counts = np.zeros(n_arms)
self.values = np.zeros(n_arms)
def select_arm(self) -> int:
if np.random.random() < self.epsilon:
return int(np.random.randint(self.n_arms))
# 打破平局: 随机选择最优的
max_val = np.max(self.values)
best_arms = np.where(self.values == max_val)[0]
return int(np.random.choice(best_arms))
def update(self, arm: int, reward: float):
self.counts[arm] += 1
n = self.counts[arm]
self.values[arm] += (reward - self.values[arm]) / n
self.epsilon *= self.decay
def regret(self, true_rates: np.ndarray) -> float:
"""计算累积 regret"""
best_rate = np.max(true_rates)
total_reward = np.sum(self.counts * self.values)
optimal_reward = np.sum(self.counts) * best_rate
return optimal_reward - total_reward| 算法 | 探索策略 | 计算复杂度 | 适用场景 | 优势 |
|---|---|---|---|---|
Epsilon-Greedy | 固定概率随机 | O(1) | 简单快速决策 | 实现简单,易调试 |
Thompson Sampling | 后验采样 | O(n) | 转化率优化 | 自动平衡探索/利用 |
UCB | 置信上界 | O(n) | 有明确置信区间 | 理论保证强 |
Softmax | Boltzmann 分布 | O(n) | 需要平滑选择 | 概率分配更精细 |
固定 A/B | 无探索 | O(1) | 严谨统计推断 | 结论明确可靠 |
Thompson Sampling 在大多数实际场景中表现最优,特别是当你有多次实验机会时,它的 regret 增长是对数级的。
MAB 算法不适合需要明确统计结论的场景。如果你需要向管理层汇报 "A 比 B 好 X%,p<0.05",还是用传统 A/B 测试。
5灰度发布与渐进式部署
灰度发布是模型从测试环境走向全量生产的关键过渡阶段。即使 A/B 测试结果显示新模型显著优于旧模型,直接 100% 切换仍然存在风险:新模型可能在某些极端场景下表现异常,或者在特定用户群体上存在未发现的偏差。灰度发布的核心思想是 "小步快跑":先让 1% 的用户使用新模型,确认没有严重问题后逐步扩大到 5%、10%、25%、50%,最终全量。每个阶段都需要密切监控关键指标,包括业务指标和技术指标。业务指标关注转化率、用户满意度等,技术指标关注延迟、错误率、资源消耗等。渐进式部署需要强大的基础设施支持,包括流量路由、实时指标监控、快速回滚能力和自动化告警。现代云原生架构中的 Service Mesh(如 Istio)和 API Gateway 都提供了精细的流量控制能力,可以实现基于用户属性、地域、设备类型等维度的灰度策略。一个典型的灰度发布流程通常需要 1-2 周,每个阶段至少运行 24-48 小时以覆盖完整的业务周期。在灰度期间,如果发现新模型在某些场景下表现不佳,可以立即暂停扩展并回退到上一阶段,而不需要完全回滚。这种渐进式的方法既保证了发布的安全性,又不影响迭代的速度。
from dataclasses import dataclass
from typing import Dict, List
import time
@dataclass
class CanaryStage:
"""灰度发布的一个阶段"""
traffic_pct: float
min_duration_hours: int
success_criteria: Dict[str, float]
class CanaryDeployment:
"""管理灰度发布流程"""
def __init__(self, model_name: str):
self.model_name = model_name
self.stages: List[CanaryStage] = [
CanaryStage(1, 24, {"error_rate": 0.01, "latency_p99": 500}),
CanaryStage(5, 24, {"error_rate": 0.01, "latency_p99": 500}),
CanaryStage(25, 48, {"error_rate": 0.005, "latency_p99": 450}),
CanaryStage(50, 48, {"error_rate": 0.005, "latency_p99": 450}),
CanaryStage(100, 24, {"error_rate": 0.005, "latency_p99": 400}),
]
self.current_stage = 0
def check_stage_health(self, metrics: Dict[str, float]) -> bool:
"""检查当前阶段的健康指标"""
criteria = self.stages[self.current_stage].success_criteria
return all(
metrics.get(k, float("inf")) <= v
for k, v in criteria.items()
)
def promote(self) -> bool:
"""推进到下一个阶段"""
if self.current_stage >= len(self.stages) - 1:
return False
self.current_stage += 1
return True# Istio VirtualService 灰度流量配置
apiVersion: networking.istio.io/v1beta3
kind: VirtualService
metadata:
name: model-serving-route
spec:
hosts:
- model-api.example.com
http:
- route:
- destination:
host: model-api
subset: stable
weight: 95
- destination:
host: model-api
subset: canary
weight: 5
timeout: 2s
retries:
attempts: 3
perTryTimeout: 500ms
---
apiVersion: networking.istio.io/v1beta3
kind: DestinationRule
metadata:
name: model-api-versions
spec:
host: model-api.example.com
subsets:
- name: stable
labels:
version: v1.2.0
- name: canary
labels:
version: v1.3.0| 阶段 | 流量比例 | 持续时间 | 关注重点 | 回滚策略 |
|---|---|---|---|---|
Canary 1% | 1% | 24h | 致命错误 | 立即回滚到 0% |
Canary 5% | 5% | 24h | 性能与错误率 | 回滚到 1% |
Canary 25% | 25% | 48h | 业务指标趋势 | 回滚到 5% |
Canary 50% | 50% | 48h | 用户反馈 | 回滚到 25% |
Full Rollout | 100% | 24h | 全量稳定性 | 回滚到 50% |
灰度发布的每个阶段都应该有明确的退出标准(success criteria)和回滚触发条件,提前写好 Runbook。
不要在非工作时间推进灰度阶段,万一出问题需要团队快速响应。选择工作日的上午推进新阶段。
6模型回滚策略
即使有了完善的 A/B 测试和灰度发布流程,模型上线后仍然可能出现意外问题:数据分布的突然变化、依赖服务的故障、或者某个边界场景下的推理错误。因此,建立快速可靠的模型回滚机制是 ML 工程化的最后一道安全防线。回滚策略需要在问题发生之前就设计好,而不是临时应对。首先需要定义清晰的回滚触发条件:哪些指标异常到什么程度需要回滚?这通常包括技术指标(错误率飙升、延迟超标)和业务指标(转化率下降、用户投诉激增)。其次,回滚操作本身必须足够快,理想情况下应该在分钟级别完成。这就要求模型服务架构支持蓝绿部署或多版本共存,新的请求可以瞬间切换到旧版本模型。回滚不仅仅是切换模型版本,还需要考虑数据一致性、用户状态和缓存清理等问题。比如,如果新模型修改了用户的推荐队列,回滚时是否需要恢复旧队列?如果新模型写入了特定的用户画像标签,这些标签是否需要清理?最后,每次回滚都应该被视为一次学习机会。进行根因分析,找到模型失败的原因,改进训练流程或数据质量,然后重新迭代。一个成熟的 ML 团队不会因为回滚而感到挫败,他们会因为快速发现问题并安全回滚而感到自豪。
import time
from enum import Enum
from typing import Optional
import threading
class ModelState(Enum):
ACTIVE = "active"
ROLLED_BACK = "rolled_back"
ROLLING_BACK = "rolling_back"
class ModelRollbackManager:
"""模型回滚管理器"""
def __init__(self, model_name: str):
self.model_name = model_name
self.state = ModelState.ACTIVE
self.current_version: Optional[str] = None
self.previous_version: Optional[str] = None
self.rollback_thresholds = {
"error_rate": 0.05,
"latency_p99_ms": 1000,
"conversion_drop_pct": -10,
}
self._lock = threading.Lock()
def check_and_trigger_rollback(self, metrics: dict) -> bool:
"""检查指标,必要时触发回滚"""
with self._lock:
if self.state == ModelState.ROLLING_BACK:
return False
triggered = False
if metrics.get("error_rate", 0) > self.rollback_thresholds["error_rate"]:
triggered = True
if metrics.get("latency_p99_ms", 0) > self.rollback_thresholds["latency_p99_ms"]:
triggered = True
if metrics.get("conversion_drop_pct", 0) < self.rollback_thresholds["conversion_drop_pct"]:
triggered = True
if triggered:
return self._execute_rollback()
return False
def _execute_rollback(self) -> bool:
if not self.previous_version:
return False
self.state = ModelState.ROLLING_BACK
# 切换流量到旧版本
time.sleep(1) # 模拟流量切换
self.current_version, self.previous_version = (
self.previous_version, self.current_version
)
self.state = ModelState.ROLLED_BACK
return Truefrom dataclasses import dataclass
from typing import List
import json
@dataclass
class RollbackEvent:
"""回滚事件记录"""
model_name: str
from_version: str
to_version: str
trigger_reason: str
metrics_snapshot: dict
timestamp: str
operator: str
class RollbackHistory:
"""维护回滚历史并生成分析报告"""
def __init__(self):
self.events: List[RollbackEvent] = []
def record(self, event: RollbackEvent):
self.events.append(event)
def get_rollback_rate(self, model_name: str) -> float:
total = sum(1 for e in self.events if e.model_name == model_name)
return total / max(1, 10) # 假设 10 次发布
def common_reasons(self) -> List[str]:
reasons = [e.trigger_reason for e in self.events]
from collections import Counter
return [r for r, _ in Counter(reasons).most_common(5)]
def generate_report(self) -> dict:
return {
"total_rollbacks": len(self.events),
"common_reasons": self.common_reasons(),
"recent_events": [
json.dumps({
"model": e.model_name,
"from": e.from_version,
"to": e.to_version,
"reason": e.trigger_reason
})
for e in self.events[-5:]
]
}| 回滚触发条件 | 阈值建议 | 响应时间要求 | 自动化程度 |
|---|---|---|---|
错误率飙升 | 超过 5% | 1 分钟内 | 全自动 |
延迟超标 | P99 超过 1s | 1 分钟内 | 全自动 |
转化率下降 | 相对下降 10% | 15 分钟内 | 半自动 |
数据漂移 | PSI 超过 0.25 | 1 小时内 | 告警+人工 |
用户投诉激增 | 工单量翻倍 | 30 分钟内 | 告警+人工 |
定期进行回滚演练(Chaos Engineering),验证回滚流程的可靠性和速度,就像消防演习一样重要。
回滚不等于放弃。每次回滚后必须做根因分析,否则同样的问题会在下次发布时重现。
7实战:MLflow + A/B 测试框架
理论终究要落实到代码。本节通过一个完整的实战案例,演示如何将 MLflow 与 A/B 测试框架结合,构建端到端的模型迭代流水线。MLflow 是 Databricks 开源的 ML 生命周期管理平台,提供了模型注册、版本管理、实验追踪和模型服务部署等功能。在我们的实战方案中,MLflow 负责模型版本的注册和追踪,A/B 测试框架负责流量分配和效果评估,两者通过 API 集成形成完整的迭代闭环。首先需要搭建 MLflow Model Registry,将训练好的模型注册为版本化的 "Model" 对象。每个模型版本都有明确的生命周期状态:None(开发中)、Staging(测试中)、Production(生产中)、Archived(归档)。当新模型通过离线评估后,注册为 Staging 状态,然后通过 A/B 测试验证其线上效果。验证通过后,自动将模型版本标记为 Production,触发灰度发布流程。这个流程可以通过 MLflow 的 Webhook 与 CI/CD 流水线集成,实现从模型注册到线上部署的完全自动化。A/B 测试的流量分配可以通过自定义中间件实现,该中间件查询实验配置,根据用户 ID 决定路由到哪个模型版本。评估指标通过日志系统收集,定期运行统计分析脚本,自动生成实验报告。整个闭环的关键在于自动化:从模型注册、实验创建、流量分配到结果分析,每一步都应该是可编程的、可追溯的、可复现的。
import mlflow
from mlflow.tracking import MlflowClient
class MLflowModelRegistry:
"""基于 MLflow 的模型注册与部署管理"""
def __init__(self, tracking_uri: str = "http://localhost:5000"):
self.client = MlflowClient(tracking_uri=tracking_uri)
def register_model(
self,
model_name: str,
run_id: str,
description: str = ""
) -> str:
"""注册模型并返回版本号"""
result = self.client.create_model_version(
name=model_name,
source=f"runs:/{run_id}/model",
run_id=run_id,
description=description
)
return result.version
def promote_to_staging(self, model_name: str, version: str):
"""将模型版本推进到 Staging"""
self.client.transition_model_version_stage(
name=model_name, version=version, stage="Staging"
)
def promote_to_production(self, model_name: str, version: str):
"""将模型版本推进到 Production"""
self.client.transition_model_version_stage(
name=model_name, version=version, stage="Production"
)
def get_production_version(self, model_name: str) -> str:
versions = self.client.get_latest_versions(
model_name, stages=["Production"]
)
return versions[0].version if versions else Noneimport mlflow
import requests
from typing import Optional
class ABTestModelRouter:
"""集成 MLflow 的 A/B 测试模型路由器"""
def __init__(self, model_name: str, experiment_id: str):
self.model_name = model_name
self.experiment_id = experiment_id
self.client = mlflow.tracking.MlflowClient()
self._model_cache = {}
def get_model_uri(self, stage: str = "Production") -> Optional[str]:
"""获取指定阶段模型的 Serving URI"""
versions = self.client.get_latest_versions(
self.model_name, stages=[stage]
)
if not versions:
return None
return versions[0].source
def predict(self, user_id: str, features: dict) -> dict:
"""根据实验配置路由到对应模型并返回预测结果"""
variant = self._get_assigned_variant(user_id)
if variant == "control":
uri = self.get_model_uri("Production")
else:
uri = self.get_model_uri("Staging")
if not uri:
raise RuntimeError(f"No model found for variant: {variant}")
response = requests.post(
f"{uri}/invocations",
json={"inputs": [features]},
timeout=5
)
prediction = response.json()
self._log_prediction(user_id, variant, prediction)
return prediction
def _log_prediction(self, user_id, variant, prediction):
pass # 记录到实验分析系统
def _get_assigned_variant(self, user_id: str) -> str:
import hashlib
bucket = int(hashlib.md5(user_id.encode()).hexdigest(), 16) % 100
return "variant" if bucket < 10 else "control" # 10% 流量| 组件 | 职责 | 工具选择 | 关键指标 |
|---|---|---|---|
模型注册 | 版本管理与状态追踪 | MLflow Registry | 注册成功率 |
流量分配 | 用户分流与一致性 | 自定义中间件 | 分流均匀性 |
指标收集 | 业务与技术指标日志 | ELK / DataDog | 数据完整性 |
统计分析 | 显著性检验与报告 | 自定义脚本 | 分析准确率 |
部署编排 | 版本切换与回滚 | K8s + Istio | 部署成功率 |
将 MLflow Webhook 与 Slack/飞书集成,模型版本状态变更时自动通知团队,保持信息透明。
MLflow Model Registry 的 stage 转换是异步操作,在高并发场景下需要加锁防止并发冲突。