1工业 AI 概述:从自动化到智能化
工业 4.0 浪潮下,AI 正在重塑制造业的每一个环节。传统的自动化产线依赖预设规则和固定阈值,而 AI 驱动的异常检测系统能够从海量数据中自主学习正常模式,实时识别偏离行为。异常检测是工业 AI 最核心的应用场景之一,涵盖视觉缺陷检测、时序传感器分析、多模态数据融合等多个方向。当前主流方法包括基于统计的孤立森林、基于深度学习的自编码器、基于对比学习的 PaDiM 和 PatchCore 等。这些技术正在从实验室走向产线,为制造业带来降本增效的显著收益。
# 工业异常检测架构概览
import torch
from torch import nn
class IndustrialAnomalyDetector(nn.Module):
"""工业异常检测基础架构"""
def __init__(self, backbone: str = "resnet18", latent_dim: int = 128):
super().__init__()
self.encoder = self._build_backbone(backbone)
self.bottleneck = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(512, latent_dim),
nn.ReLU()
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 512),
nn.Unflatten(1, (512, 1, 1)),
nn.ConvTranspose2d(512, 3, kernel_size=4, stride=2, padding=1)
)
def _build_backbone(self, name):
from torchvision.models import resnet18, ResNet18_Weights
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
return nn.Sequential(*list(model.children())[:-2])
def forward(self, x):
features = self.encoder(x)
latent = self.bottleneck(features)
reconstruction = self.decoder(latent)
return reconstruction, latent# 异常分数计算模块
import numpy as np
from sklearn.neighbors import NearestNeighbors
class AnomalyScorer:
"""计算异常分数并生成热力图"""
def __init__(self, n_neighbors: int = 5):
self.nn = NearestNeighbors(n_neighbors=n_neighbors, metric="euclidean")
self.memory_bank = None
def fit(self, features: np.ndarray):
self.nn.fit(features)
self.memory_bank = features
print(f"Memory bank built with {len(features)} normal samples")
def score(self, query: np.ndarray) -> np.ndarray:
distances, _ = self.nn.kneighbors(query)
anomaly_scores = distances.mean(axis=1)
normalized = (anomaly_scores - anomaly_scores.min()) / (anomaly_scores.max() - anomaly_scores.min() + 1e-8)
return normalized
def generate_heatmap(self, scores: np.ndarray, shape: tuple) -> np.ndarray:
heatmap = scores.reshape(shape)
from scipy.ndimage import gaussian_filter
return gaussian_filter(heatmap, sigma=3)| 技术方向 | 典型算法 | 适用场景 | 检测精度 |
|---|---|---|---|
视觉缺陷检测 | PaDiM, PatchCore | 表面缺陷、划痕 | 95-99% |
时序异常检测 | LSTM-AE, OmniAnomaly | 振动、温度传感器 | 90-96% |
多模态融合 | Cross-modal AE | 复合工业场景 | 93-98% |
边缘部署 | MobileNet-AD | 产线实时检测 | 88-94% |
工业 AI 项目的第一步是明确定义正常与异常的边界,这决定了后续所有技术选型。
不要直接用 ImageNet 预训练模型处理工业图像,工业场景的纹理分布与自然图像差异巨大。
2视觉缺陷检测:从像素到缺陷
视觉缺陷检测是工业异常检测中最成熟的应用方向。传统方法依赖人工设计的特征如 HOG、LBP 和 Gabor 滤波器,但这些方法对光照变化和复杂纹理极其敏感。深度学习方法通过端到端学习图像的深层语义表示,能够捕捉人眼难以察觉的微小缺陷。主流架构包括基于重构的自编码器、基于特征嵌入的 PatchCore 和基于对比学习的 CFlow。MVTec AD 数据集作为工业视觉异常检测的基准,包含 15 类工业对象和超过 5000 张图像,涵盖划痕、污渍、形变等多种缺陷类型。在真实产线中,还需要处理光照不均匀、相机抖动、产品位姿变化等工程挑战。
# PatchCore 核心实现
import torch
from torch.nn import functional as F
class PatchCoreMemoryBank:
"""PatchCore 记忆库构建与推理"""
def __init__(self, coreset_ratio: float = 0.1):
self.coreset_ratio = coreset_ratio
self.memory = None
def build_memory(self, features: torch.Tensor):
B, C, H, W = features.shape
patches = features.permute(0, 2, 3, 1).reshape(-1, C)
# 随机子采样构建核心集
n_samples = int(len(patches) * self.coreset_ratio)
indices = torch.randperm(len(patches))[:n_samples]
self.memory = patches[indices]
print(f"Memory bank: {self.memory.shape}")
def compute_anomaly_map(self, query_features: torch.Tensor) -> torch.Tensor:
B, C, H, W = query_features.shape
patches = query_features.permute(0, 2, 3, 1).reshape(-1, C)
# 归一化
patches = F.normalize(patches, dim=1)
memory = F.normalize(self.memory, dim=1)
# 计算最近邻距离
dists = torch.cdist(patches, memory)
min_dists = dists.min(dim=1).values
anomaly_map = min_dists.reshape(B, H, W)
return anomaly_map# MVTec AD 数据加载与预处理
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
class MVTecDataset(Dataset):
def __init__(self, root: str, class_name: str, is_train: bool = True):
self.root = root
self.class_name = class_name
base_dir = os.path.join(root, class_name, "train" if is_train else "test")
self.image_paths = []
self.mask_paths = []
for subdir in os.listdir(base_dir):
for fname in os.listdir(os.path.join(base_dir, subdir)):
self.image_paths.append(os.path.join(base_dir, subdir, fname))
mask_path = os.path.join(root, class_name, "ground_truth", subdir, fname.replace(".png", "_mask.png"))
self.mask_paths.append(mask_path if os.path.exists(mask_path) else None)
self.transform = T.Compose([
T.Resize((256, 256)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img = Image.open(self.image_paths[idx]).convert("RGB")
img = self.transform(img)
return img| 数据集 | 类别数 | 图像数 | 缺陷类型 | 标注级别 |
|---|---|---|---|---|
MVTec AD | 15 | 5354 | 10种缺陷 | 像素级 |
VisA | 12 | 10821 | 多种工业缺陷 | 像素级 |
BTAD | 3 | 2830 | 工业部件缺陷 | 像素级 |
KolektorSDD | 1 | 500 | 电子元件缺陷 | 像素级 |
使用数据增强时,只对正常样本进行增强。异常样本极少,增强可能引入虚假模式。
PatchCore 的记忆库大小直接影响显存占用,产线部署时建议将 coreset_ratio 控制在 0.01 到 0.05 之间。
3时序异常检测:传感器数据的秘密
工业设备运行过程中会产生大量时序传感器数据,包括振动、温度、压力、电流等信号。这些时序数据蕴含着设备健康状态的关键信息。时序异常检测的目标是在设备发生故障之前,识别出传感器信号中的异常模式。与静态图像不同,时序数据具有时间依赖性和多变量耦合性。常用的方法包括 LSTM 自编码器、TCN 时序卷积网络、以及基于注意力机制的 Transformer 架构。SWaT 和 WADI 数据集是工业控制系统异常检测的常用基准,包含数十个传感器变量和数百种攻击场景。在时序检测中,关键点检测比单纯的整体分类更具实用价值。
# LSTM 自编码器时序异常检测
import torch
import torch.nn as nn
class LSTMAutoEncoder(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int = 64, num_layers: int = 2):
super().__init__()
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.encoder = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
self.decoder = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
self.output_layer = nn.Linear(hidden_dim, input_dim)
def forward(self, x):
batch_size = x.size(0)
h0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=x.device)
c0 = torch.zeros(self.num_layers, batch_size, self.hidden_dim, device=x.device)
_, (h_n, c_n) = self.encoder(x, (h0, c0))
decoded, _ = self.decoder(h_n[-1].unsqueeze(1).repeat(1, x.size(1), 1).contiguous(), (h_n, c_n))
reconstructed = self.output_layer(decoded)
return reconstructed
def anomaly_score(self, x: torch.Tensor) -> torch.Tensor:
recon = self.forward(x)
mse = ((x - recon) ** 2).mean(dim=-1)
return mse# 多变量时序异常检测 - OmniAnomaly 简化版
import torch
import torch.nn as nn
import torch.nn.functional as F
class OmniAnomaly(nn.Module):
def __init__(self, input_dim: int, hidden_dim: int = 100, latent_dim: int = 20):
super().__init__()
self.encoder_rnn = nn.GRU(input_dim, hidden_dim, batch_first=True)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_var = nn.Linear(hidden_dim, latent_dim)
self.decoder_rnn = nn.GRU(latent_dim, hidden_dim, batch_first=True)
self.fc_out = nn.Linear(hidden_dim, input_dim)
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
enc_out, _ = self.encoder_rnn(x)
mu = self.fc_mu(enc_out)
logvar = self.fc_var(enc_out)
z = self.reparameterize(mu, logvar)
dec_out, _ = self.decoder_rnn(z)
recon = self.fc_out(dec_out)
# VAE loss
recon_loss = F.mse_loss(recon, x, reduction="none").sum(dim=-1)
kl_loss = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(dim=-1)
return recon, recon_loss.mean(), kl_loss.mean()| 方法 | 时间建模 | 变量关系 | 计算复杂度 | 适用场景 |
|---|---|---|---|---|
LSTM-AE | 序列依赖 | 隐式学习 | 中等 | 单变量/少变量 |
TCN-AD | 因果卷积 | 并行计算 | 低 | 长时序 |
OmniAnomaly | 概率生成 | 变量相关性 | 高 | 多变量耦合 |
TranAD | 自注意力 | 全局依赖 | 高 | 长程依赖 |
时序异常检测中,滑动窗口大小的选择至关重要。窗口太小会丢失长期依赖,太大会稀释局部异常信号。
传感器数据中的周期性变化不应被误判为异常,务必在训练数据中包含完整的生产周期。
4预测性维护:从异常到预见
预测性维护是工业异常检测的最终目标。与传统的事后维修和定期维护不同,预测性维护通过分析设备退化趋势,在故障发生前安排维护窗口。核心流程包括:剩余使用寿命预测、退化阶段划分、维护策略优化。PHM 数据集和 NASA C-MAPSS 发动机退化数据集是该领域的标准基准。深度学习在 RUL 预测中展现出强大能力,尤其是 CNN-LSTM 混合架构和 Transformer 模型。一个完整的预测性维护系统不仅要预测故障时间,还需要输出置信区间和维护建议,帮助工厂制定最优的生产计划。
# RUL 预测模型 - CNN-LSTM 混合架构
import torch
import torch.nn as nn
class RULPredictor(nn.Module):
def __init__(self, n_sensors: int, seq_len: int = 30, hidden: int = 64):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv1d(n_sensors, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv1d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool1d(2)
)
self.lstm = nn.LSTM(64, hidden, num_layers=2, batch_first=True, dropout=0.2)
self.fc = nn.Sequential(
nn.Linear(hidden, 32),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(32, 1)
)
def forward(self, x):
cnn_out = self.cnn(x.permute(0, 2, 1))
lstm_in = cnn_out.permute(0, 2, 1)
lstm_out, _ = self.lstm(lstm_in)
last_hidden = lstm_out[:, -1, :]
rul = self.fc(last_hidden)
return rul.squeeze(-1)
def predict_with_uncertainty(self, x, n_samples: int = 50):
predictions = torch.stack([self(x) for _ in range(n_samples)])
mean = predictions.mean(dim=0)
std = predictions.std(dim=0)
return mean, std# 退化阶段划分 - 隐马尔可夫模型
import numpy as np
from hmmlearn import hmm
class DegradationHMM:
def __init__(self, n_states: int = 4):
self.n_states = n_states
self.model = hmm.GaussianHMM(
n_components=n_states,
covariance_type="full",
n_iter=100
)
self.state_labels = ["健康", "轻度退化", "中度退化", "严重退化"]
def fit(self, sequences: np.ndarray, lengths: np.ndarray):
self.model.fit(sequences, lengths=lengths)
print(f"HMM trained with {self.n_states} degradation states")
def decode_state(self, observation: np.ndarray) -> tuple:
states = self.model.predict(observation)
probs = self.model.predict_proba(observation)
current_state = states[-1]
confidence = probs[-1, current_state]
return self.state_labels[current_state], confidence
def transition_probability(self) -> np.ndarray:
return self.model.transmat_| 维护策略 | 触发条件 | 成本 | 停机时间 | 适用范围 |
|---|---|---|---|---|
事后维修 | 设备故障后 | 低 | 长 | 非关键设备 |
定期维护 | 固定周期 | 中 | 中 | 一般设备 |
状态维护 | 阈值触发 | 中高 | 短 | 关键设备 |
预测性维护 | 退化预测 | 高 | 最短 | 高价值设备 |
RUL 预测的标签设计很关键。分段线性标签比线性标签更符合真实的设备退化规律。
预测性维护模型需要在不同工况下验证泛化能力,单一工况的模型在产线切换时可能完全失效。
5多模态工业数据融合
现代工业场景中,单一数据源往往无法全面反映设备状态。多模态数据融合将视觉图像、振动信号、温度数据、声学信号等多种传感器信息整合,通过跨模态的互补信息提升异常检测的准确性和鲁棒性。多模态融合的关键挑战在于不同模态的采样频率、数据维度和语义层次差异巨大。常用的融合策略包括早期融合、晚期融合和混合融合。深度学习中的跨模态注意力机制和对比学习方法(如 CLIP 的工业适配版本)正在为这一领域带来突破。多模态方法在半导体制造、风力发电机监测、航空航天等领域展现出显著优势。
# 跨模态注意力融合模块
import torch
import torch.nn as nn
import torch.nn.functional as F
class CrossModalAttention(nn.Module):
def __init__(self, visual_dim: int, temporal_dim: int, fusion_dim: int = 128):
super().__init__()
self.visual_proj = nn.Linear(visual_dim, fusion_dim)
self.temporal_proj = nn.Linear(temporal_dim, fusion_dim)
self.cross_attn = nn.MultiheadAttention(fusion_dim, num_heads=4, batch_first=True)
self.fusion = nn.Sequential(
nn.Linear(fusion_dim * 2, fusion_dim),
nn.LayerNorm(fusion_dim),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(fusion_dim, 64)
)
def forward(self, visual_feat: torch.Tensor, temporal_feat: torch.Tensor) -> torch.Tensor:
v = self.visual_proj(visual_feat).unsqueeze(1)
t = self.temporal_proj(temporal_feat).unsqueeze(1)
# 交叉注意力
v_attended, _ = self.cross_attn(v, t, t)
t_attended, _ = self.cross_attn(t, v, v)
fused = torch.cat([v_attended.squeeze(1), t_attended.squeeze(1)], dim=-1)
return self.fusion(fused)# 多模态工业数据集构建
import torch
from torch.utils.data import Dataset
import numpy as np
class MultiModalIndustrialDataset(Dataset):
def __init__(self, image_dir: str, sensor_file: str, acoustic_dir: str):
import pandas as pd
self.sensor_data = pd.read_csv(sensor_file).values.astype(np.float32)
self.image_paths = [f"{image_dir}/img_{i}.png" for i in range(len(self.sensor_data))]
self.acoustic_paths = [f"{acoustic_dir}/audio_{i}.npy" for i in range(len(self.sensor_data))]
self.labels = np.zeros(len(self.sensor_data))
def __len__(self):
return len(self.sensor_data)
def __getitem__(self, idx):
from PIL import Image
import torchvision.transforms as T
img = Image.open(self.image_paths[idx]).convert("RGB")
img = T.Compose([T.Resize(224), T.ToTensor()])(img)
sensor = torch.tensor(self.sensor_data[idx])
acoustic = torch.tensor(np.load(self.acoustic_paths[idx]))
return {"image": img, "sensor": sensor, "acoustic": acoustic}| 模态组合 | 采样率 | 特征维度 | 融合策略 | 典型应用 |
|---|---|---|---|---|
视觉+振动 | 图像 30fps + 振动 10kHz | 512 + 128 | 交叉注意力 | 轴承监测 |
视觉+温度 | 图像 30fps + 温度 1Hz | 512 + 32 | 晚期融合 | 焊接质量 |
声学+振动 | 音频 44kHz + 振动 10kHz | 64 + 128 | 早期融合 | 齿轮箱 |
三模态融合 | 多种传感器混合 | 512 + 128 + 64 | 混合融合 | 半导体制造 |
多模态融合中,先用单模态模型建立基线,再逐步增加模态。这样可以量化每个模态的贡献度。
不同模态的时间对齐是多模态工业数据的最大陷阱。务必在数据预处理阶段完成精确的时钟同步。
6边缘部署挑战:从云端到产线
工业异常检测模型最终需要部署到产线边缘设备上,实现毫秒级实时推理。边缘部署面临三大核心挑战:模型体积受限、推理延迟要求严格、环境条件恶劣。模型压缩技术包括知识蒸馏、量化感知训练、剪枝和神经架构搜索。TensorRT、ONNX Runtime 和 OpenVINO 是工业边缘推理的主流框架。此外,边缘设备还需要处理持续学习和模型更新的问题。联邦学习可以在不集中数据的前提下,利用多个工厂的数据联合训练模型,既保护了数据隐私,又提升了模型泛化能力。
# 模型量化与 TensorRT 部署
import torch
import torch.nn as nn
def quantize_model(model: nn.Module, calib_loader: torch.utils.data.DataLoader):
model.eval()
model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
torch.quantization.prepare(model, inplace=True)
with torch.no_grad():
for batch in calib_loader:
model(batch)
torch.quantization.convert(model, inplace=True)
original_size = sum(p.numel() * p.element_size() for p in model.parameters())
print(f"Original model size: {original_size / 1024:.1f} KB")
return model
def export_to_onnx(model: nn.Module, dummy_input: torch.Tensor, path: str):
torch.onnx.export(model, dummy_input, path,
export_params=True,
opset_version=14,
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch"}})
print(f"Model exported to {path}")# 边缘设备推理服务 - FastAPI
from fastapi import FastAPI
import torch
import numpy as np
from pydantic import BaseModel
import base64
from io import BytesIO
from PIL import Image
import torchvision.transforms as T
app = FastAPI(title="工业异常检测 API")
model = torch.jit.load("anomaly_detector.pt")
model.eval()
class InferenceRequest(BaseModel):
image_b64: str
threshold: float = 0.5
@app.post("/detect")
def detect_anomaly(req: InferenceRequest):
img_bytes = base64.b64decode(req.image_b64)
img = Image.open(BytesIO(img_bytes)).convert("RGB")
transform = T.Compose([
T.Resize((256, 256)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
tensor = transform(img).unsqueeze(0)
with torch.no_grad():
score = model(tensor).item()
is_anomaly = score > req.threshold
return {"anomaly_score": round(score, 4), "is_anomaly": is_anomaly}| 部署平台 | 推理延迟 | 模型大小限制 | 精度损失 | 成本 |
|---|---|---|---|---|
NVIDIA Jetson | <10ms | <500MB | <1% | 高 |
Intel OpenVINO | <15ms | <200MB | <2% | 中 |
Raspberry Pi | <50ms | <100MB | 3-5% | 低 |
FPGA 加速 | <5ms | 可定制 | <0.5% | 极高 |
边缘部署前,务必在目标设备上实测推理延迟。服务器上的测试结果不能直接迁移到边缘设备。
INT8 量化在某些视觉异常检测任务中会导致精度显著下降,建议先用 INT8 量化评估再决定是否采用。
7实战:工业缺陷检测完整项目
本章将整合前面所学的技术,从零构建一个完整的工业缺陷检测系统。我们选择 MVTec AD 数据集中的 bottle 类别作为示例,使用 PatchCore 算法实现异常检测和定位。整个项目包含数据准备、模型构建、记忆库建立、推理评估和部署五个阶段。在真实场景中,还需要加入数据版本管理、模型监控告警和自动化回滚机制。一个成熟的工业 AI 系统,模型本身只占 20% 的工作量,剩余 80% 是数据管道、基础设施和运维体系的构建。这也是为什么 MLOps 在工业 AI 项目中如此重要。
# 完整的 PatchCore 训练与评估流程
import torch
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
import numpy as np
class CompletePatchCore:
def __init__(self, backbone: str = "resnet18", coreset_ratio: float = 0.01):
self.feature_extractor = self._build_extractor()
self.memory = None
self.coreset_ratio = coreset_ratio
def _build_extractor(self):
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
layers = list(model.children())
return torch.nn.Sequential(*layers[:-2]).eval()
def extract_features(self, images: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
return self.feature_extractor(images)
def train(self, normal_images: torch.Tensor, batch_size: int = 32):
all_features = []
for i in range(0, len(normal_images), batch_size):
batch = normal_images[i:i+batch_size]
feat = self.extract_features(batch)
all_features.append(feat.cpu())
all_features = torch.cat(all_features, dim=0)
# 构建记忆库
n_patches = all_features.shape[0] * all_features.shape[2] * all_features.shape[3]
flat = all_features.permute(0, 2, 3, 1).reshape(-1, all_features.shape[1])
n_select = int(n_patches * self.coreset_ratio)
indices = torch.randperm(len(flat))[:n_select]
self.memory = flat[indices]
print(f"Memory bank: {self.memory.shape}")
def evaluate(self, test_images: torch.Tensor, labels: torch.Tensor) -> dict:
scores = []
for img in test_images:
feat = self.extract_features(img.unsqueeze(0))
patches = feat.permute(0, 2, 3, 1).reshape(-1, feat.shape[1])
patches = F.normalize(patches, dim=1)
mem = F.normalize(self.memory, dim=1)
dists = torch.cdist(patches, mem)
scores.append(dists.min(dim=1).values.max().item())
auc = roc_auc_score(labels.numpy(), scores)
return {"image_auc": auc, "n_test": len(test_images)}# MLOps 流水线 - 模型版本管理与监控
import mlflow
import json
from datetime import datetime
class ModelRegistry:
def __init__(self, experiment_name: str = "industrial_anomaly_detection"):
mlflow.set_experiment(experiment_name)
self.experiment = mlflow.get_experiment_by_name(experiment_name)
def log_model(self, model: object, metrics: dict, tags: dict):
with mlflow.start_run() as run:
mlflow.log_params(tags)
mlflow.log_metrics(metrics)
mlflow.log_artifact("config.json")
model_uri = mlflow.pytorch.log_model(model, "model")
run_id = run.info.run_id
print(f"Model logged: run_id={run_id}")
return run_id
def promote_to_production(self, run_id: str):
client = mlflow.tracking.MlflowClient()
client.transition_model_version_stage(
name="PatchCoreDetector",
version=1,
stage="Production"
)
print(f"Model {run_id} promoted to Production")
def monitor_drift(self, current_data: np.ndarray, reference_data: np.ndarray) -> dict:
from scipy.stats import ks_2statistic
drift_scores = {}
for i in range(current_data.shape[1]):
stat, p_value = ks_2statistic(reference_data[:, i], current_data[:, i])
drift_scores[f"feature_{i}"] = {"ks_stat": stat, "p_value": p_value}
return drift_scores| 项目阶段 | 关键产出 | 时间占比 | 常见风险 |
|---|---|---|---|
数据采集 | 标注数据集 | 15% | 样本不均衡、标注不一致 |
模型开发 | 训练好的模型 | 20% | 过拟合、泛化差 |
系统构建 | 推理服务 | 25% | 延迟超标、并发瓶颈 |
测试验证 | 评估报告 | 20% | 测试集偏差 |
部署运维 | 上线系统 | 20% | 数据漂移、性能衰减 |
工业项目中,数据质量比模型复杂度重要 10 倍。先确保数据干净,再考虑换模型。
不要把训练集和测试集的数据来源混在一起。真实产线的数据分布可能和实验室完全不同。