关于ViT中pos embed的可视化
在ViT中有一个position embedding部分,为什么要有这一部分呢?
在NLP中,不同词转化为Token之后有一个位置编码的模块,这是因为不同词汇之间是有顺序的,但是在视觉领域,图像与图像之间是没有顺序的,ViT将每一幅图划分为一个个patch,如下图所示,每一个patch就对应于NLP中的一个Token,而且从图中也可以直观的感受到每一个patch都是有位置的,所以在每一个特征维度上都加入了一个position embedding模块,最后我们可视化一下Google预训练后position embedding的结果
可视化左上角的patch
可视化所有patch
这是一幅图中所有patch可视化的结果,但是因为patch太多,不是很清晰,但是还是可以看出大体的位置效果
注意 假设我们的Patch一共是576个,那么计算出来的每一个可视化图都是576维也就是 24 × 24 24 imes 24 24×24,每一维度都是计算余弦相似度。以左上角的第一幅图为例,先计算第一维与自己的余弦相似度,在计算他与其他575维的余弦相似度,最后得到576个值,reshape为 ( 24 , 24 ) (24, 24) (24,24)可视化即可,通过可视化结果我们可以发现他与自己的相似度最高,与他同行或者同列的相似度次之,其余的相似度最小。下面是余弦相似度的计算公式
s i m i l a r i t y = c o s ( θ ) = A ⋅ B ∣ ∣ A ∣ ∣ ∣ ∣ B ∣ ∣ = ∑ i = 1 n A i B i ∑ i = 1 n ( A i ) 2 ∑ i = 1 n ( B i ) 2 similarity = cos( heta) = frac{Acdot B}{||A||||B||} = frac{sum_{i=1}^{n}A_{i}B_{i}}{sqrt{sum_{i=1}^{n}(A_{i})^2}sqrt{sum_{i=1}^{n}(B_{i})^2}} similarity=cos(θ)=∣∣A∣∣∣∣B∣∣A⋅B=∑i=1n(Ai)2 ∑i=1n(Bi)2 ∑i=1nAiBi
下面直接给出代码
首先下载预训练模型,,然后放到py相同的文件夹下运行即可
# show position embedding picture import numpy as np import matplotlib.pyplot as plt from tqdm import tqdm def bit_product_sum(x, y): return sum([item[0] * item[1] for item in zip(x, y)]) def cosine_similarity(x, y, norm=False): """ 计算两个向量x和y的余弦相似度 """ assert len(x) == len(y), "len(x) != len(y)" xy = x.dot(y) x2y2 = np.linalg.norm(x, ord=2) * np.linalg.norm(x, ord=2) sim = xy/x2y2 return sim data = np.load("imagenet21k+imagenet2012_ViT-B_16.npz") pos = data[Transformer/posembed_input/pos_embedding].reshape(577, 768)[1:, :] # 576, 768 cos = np.zeros((576, 576)) # 只计算左上角的值 # for i in tqdm(range(1)): # for j in range(576): # cos[i, j] = cosine_similarity(pos[i, :], pos[j, :]) # cos = cos[0, :].reshape(24, 24) # plt.imshow(cos) # plt.show() # 计算所有 for i in tqdm(range(576)): for j in range(576): cos[i, j] = cosine_similarity(pos[i, :], pos[j, :]) fig, axs = plt.subplots(nrows=24, ncols=24, figsize=(24, 24), subplot_kw={ xticks: [], yticks: []}) i=0 cos = cos.reshape(576, 24, 24) for ax in axs.flat: ax.imshow(cos[i, :, :], cmap=viridis) i+=1 plt.tight_layout() plt.show()