核心要点

  • 能列常用 dtype:训练默认 float32,混合精度用 float16/bfloat16,标签用 int64,图像原始像素 uint8,掩码用 bool

  • 能说明算子要求 dtype 一致,混用会报错或隐式提升,用 tf.cast 显式转换

  • 能指出标签需与损失匹配:sparse_categorical_crossentropy 期望 int 类别索引,categorical 期望 one-hot float

  • 能讲混合精度用 tf.keras.mixed_precision:float16 算、float32 累加,配 loss scaling 防梯度下溢

简要回答

TensorFlow 张量支持丰富 dtype

类型 用途
float32 默认训练精度
float16 / bfloat16 混合精度加速
float64 高精度数值(较少)
int32 / int64 索引、标签(int64 常用)
uint8 图像原始像素
bool 掩码、条件
string 文本 token(tf.data)
complex64/128 信号处理

注意

  • 算子对 dtype 有要求,混用会隐式提升或报错
  • tf.cast(x, tf.float32) 显式转换
  • Keras 默认 float32;混合精度 tf.keras.mixed_precision 用 float16 算、float32 累加
  • 标签 sparse_categorical_crossentropy 期望 int64 类别索引

与 NumPy dtype 大致对应,互转时注意 tf.constant(np_arr) 保留 dtype

标准回答

TensorFlow 张量支持丰富 dtype

类型 用途
float32 默认训练精度
float16 / bfloat16 混合精度加速
float64 高精度数值(较少)
int32 / int64 索引、标签(int64 常用)
uint8 图像原始像素
bool 掩码、条件
string 文本 token(tf.data)
complex64/128 信号处理

注意

  • 算子对 dtype 有要求,混用会隐式提升或报错
  • tf.cast(x, tf.float32) 显式转换
  • Keras 默认 float32;混合精度 tf.keras.mixed_precision 用 float16 算、float32 累加
  • 标签 sparse_categorical_crossentropy 期望 int64 类别索引

与 NumPy dtype 大致对应,互转时注意 tf.constant(np_arr) 保留 dtype。详见 TensorFlow 文档。

常见误区

⚠️ 常见踩坑

图像保持 uint8(0~255)未 cast 成 float32 并归一化就喂进网络,数值过大导致激活饱和、loss 爆炸;另一个是纯 float16 全程训练不开 loss scaling,小梯度下溢为 0、参数停更或出现 NaN——正确做法是用 mixed_precision 策略让其自动管理 loss scale。

追问

追问 1混合精度如何配置?

调 tf.keras.mixed_precision.set_global_policy("mixed_float16"),之后各层用 float16 计算、float32 保存变量;用 model.fit 时 Keras 自动加 loss scaling 防下溢。手写训练循环则需 LossScaleOptimizer 包裹优化器。注意最后一层输出建议显式 cast 回 float32 保证 softmax 数值稳定。

追问 2int32 和 int64 标签用哪个?

分类标签两者都可,TF/Keras 的 sparse 交叉熵都接受。int64 是很多 op 的默认整型、兼容性最好;类别数远小于 21 亿时 int32 省一半内存、对大 batch 索引更省显存。GPU 上部分整型 op 仅支持 int32,遇报错可 cast。一般直接用默认 int64 即可。

追问 3string 张量能参与数学运算吗?

不能直接做加减乘等数值运算。tf.string 用于存原始文本、字节,配 tf.strings 模块做切分、正则、编解码,或在 tf.data 管道里读文件路径。要进网络须先经 tokenization / 查表(StringLookup、TextVectorization)映射成数值张量再计算。

延伸学习

与本题相关的知识库文章、术语、工具与行业资讯。

🛠️ AI 工具

  • Pytorch

    Meta 开源的深度学习框架,100K+ stars。以动态计算图和 Pythonic 风格著称,在学术界和工业界都有广泛应用,支持分布式训练、移动端部署和 ONNX 导出

  • Tensorflow

    全球最流行的机器学习框架之一,195K+ stars。Google 开源的端到端 ML 平台,支持 TensorFlow、Keras 等多种 API,覆盖深度学习、强化学习、移动端部署等全场景,是 AI 工程师的必备工具

  • Keras

    深度学习框架,64,020+ stars。高级神经网络 API,支持 TensorFlow、JAX、PyTorch 多后端。以用户友好著称,让深度学习从实验到生产的转化变得简单高效