Token-Critic

作用

Token-Critic 用来判断 which visual tokens belong to the original image and which were sampled by the generative transformer. 在非自回归的生成中知道哪些 tokens 需要保留哪些 tokens 需要被替换是一件比较困难的事情,所以 token-critic (一个 transformer 模型) 被提出用来解决这个问题。
token-critic

Token-Critic 解决了三个关键的限制问题:

  1. token 掩码委托给 Token-Critic 模型

    • Token-Critic 模型经过训练,能够区分哪些 token 属于真实分布,哪些不属于
    • 它学习判断每个 visual token 的质量和真实性
  2. 迭代采样过程中的精确掩码

    • 在每次迭代中,Token-Critic 不是简单地丢弃采样的 token,而是能够精确地识别出质量较差的 token
    • 这使得模型能够保留高质量的生成 token,同时替换掉不符合真实分布的 token
  3. 改进的采样程序

    • 传统方法可能会在迭代解码过程中错误地丢弃已经正确采样的 token
    • Token-Critic 通过精确的质量评估,避免了这种误删除,从而提升了整体生成质量

Token-Critic 的结构

Transformer 结构,输出的结果为每个 token 的得分,表示在根据 condition 的情况下生成的 token 的质量。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


def exists(val):
return val is not None


class SelfCritic(nn.Module):
def __init__(self, net):
super().__init__()
self.net = net
self.to_pred = nn.Linear(net.dim, 1)

def forward_with_cond_scale(self, x, *args, **kwargs):
_, embeds = self.net.forward_with_cond_scale(
x, *args, return_embed=True, **kwargs)
return self.to_pred(embeds)

def forward_with_neg_prompt(self, x, *args, **kwargs):
_, embeds = self.net.forward_with_neg_prompt(
x, *args, return_embed=True, **kwargs)
return self.to_pred(embeds)

def forward(self, x, *args, labels=None, **kwargs):
_, embeds = self.net(x, *args, return_embed=True, **kwargs)
logits = self.to_pred(embeds)

if not exists(labels):
return logits

logits = rearrange(logits, '... 1 -> ...')
return F.binary_cross_entropy_with_logits(logits, labels)