Feature-wise Linear Modulation (FiLM)

FiLM: 特征层面线性调制 (Feature-wise Linear Modulation)

在深度学习领域,我们经常需要模型根据某些外部条件来调整自身的行为。例如,在图像生成任务中,我们可能希望根据一段文字描述(“一只正在草地上奔跑的金色小狗”)来生成相应的图片;在视觉问答(VQA)中,模型需要根据提出的问题来分析图像并给出答案。

如何将这些“条件信息”高效地注入到神经网络中,一直是一个核心问题。一种朴素的方法是在输入层将条件向量与主输入数据拼接(Concatenate)在一起,但这常常导致条件信息在网络深层传递时被稀释或遗忘。

为了解决这个问题,FiLM (Feature-wise Linear Modulation) 应运而生。它是一种思想巧妙且高效的技术,其核心在于:不直接将条件作为数据输入,而是用它来“调制”或“操纵”网络中间层的特征,这种调整通过一个非常简单的仿射变换(affine transformation)来实现,也就是对特征进行缩放 (scale) 和偏移 (shift)。

工作原理

FiLM 层的基本数学公式是:

其中,$x$ 是输入特征,$\gamma$ 和 $\beta$ 是可学习的参数, 由外部的条件信息(比如 文本嵌入、类别标签、时间步等)通过一个小网络预测得到。FiLM 的操作不包含归一化步骤。它直接在原始特征上进行缩放和偏移,目的是让条件信息来“操纵”或“调整”这些特征。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn

class FiLMLayer(nn.Module):
def __init__(self, channels, condition_dim):
super().__init__()
# 条件网络,将条件向量映射到 gamma 和 beta
self.cond_net = nn.Linear(condition_dim, channels * 2)

def forward(self, x, condition):
# x 的形状: [B, C, H, W] (以CNN为例)
# condition 的形状: [B, D]

# 1. 通过条件网络生成 gamma 和 beta
gamma_beta = self.cond_net(condition) # [B, C*2]

# 2. 将输出切分为 gamma 和 beta
gamma, beta = torch.chunk(gamma_beta, 2, dim=1) # 均为 [B, C]

# 3. 调整形状以匹配特征 x
# unsqueeze 用于增加维度以进行广播 (broadcast)
gamma = gamma.unsqueeze(2).unsqueeze(3) # [B, C, 1, 1]
beta = beta.unsqueeze(2).unsqueeze(3) # [B, C, 1, 1]

# 4. 应用 FiLM 调制
return x * gamma + beta

FiLM vs AdaLN

FiLM 经常与 AdaLN (Adaptive Layer Normalization) 相提并论。它们的核心思想相似(都是动态生成仿射变换参数),但有一个关键区别:

  • FiLM: 直接对特征进行仿射变换 γx + β。

  • AdaLN: 先对特征进行归一化,然后再进行仿射变换 γ * Norm(x) + β。

可以认为,AdaLN 是 FiLM 思想在层归一化(Layer Normalization)上的一种特例化应用。AdaLN 在稳定训练的同时注入条件,在许多生成模型(如 DiT)中非常流行。