风格矩阵代表图像自身各通道之间的相关性,代码如下:
def gram_matrix(input_tensor):
"""
获取风格矩阵
"""
# 爱因斯坦求和约定(Einstein summation convention) b:批次大小,i:高、j:宽、channels:特征通道数,bijc和bijd表示沿着高、宽维度进行,即对同一位置的特征向量相乘
# 结果矩阵形状为(batch_size, channels, channels),因为每个通道做点积,再求和得到1个值。 而channel1与channel2..channel_x, channel2与channel1...channel_x、...最终会得到一个矩阵。
result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
input_shape = tf.shape(input_tensor)
# 获取height * weight = 位置数
num_locations = tf.cast(input_shape[1] * input_shape[2], tf.float32)
# 标准化处理,确保输出的Gram矩阵不会因为图像大小不同而由较大的数值差值。
return result / num_locations
示例
假设input_tensor
有 3 个通道(channels=3),并且为了简化,假设输入张量的形状是 (1, 2, 2, 3)
,即:
batch_size=1
height=2
width=2
channels=3
假设输入张量 input_tensor 是:
[
[
[[1, 2, 3], [4, 5, 6]], # 第一个位置 (i=0, j=0) 和 第二个位置 (i=0, j=1)
[[7, 8, 9], [10, 11, 12]] # 第三个位置 (i=1, j=0) 和 第四个位置 (i=1, j=1)
]
]
这个张量形状为 (1, 2, 2, 3)
,即 (batch_size=1, height=2, width=2, channels=3)。
每个位置(i, j)
的通道 c
上有一个值,如 (1, 2, 3)
表示在(i=0, j=0)
位置上的三个通道的值。
Gram矩阵的生成
第一步:通道之间的乘积
einsum
表达式 'bijc,bijd->bcd'
将两个input_tensor
相乘,并在i 和 j
上进行求和。我们来看具体的操作。
在每个 (i, j) 位置上,我们取出通道 c 和 d 的值相乘,并对所有 (i, j) 位置的乘积求和。
通道 0 和 通道 0 (c=0, d=0)
在位置 (0, 0):1 * 1 = 1
在位置 (0, 1):4 * 4 = 16
在位置 (1, 0):7 * 7 = 49
在位置 (1, 1):10 * 10 = 100
求和:1 + 16 + 49 + 100 = 166
通道 0 和 通道 1 (c=0, d=1)
在位置 (0, 0):1 * 2 = 2
在位置 (0, 1):4 * 5 = 20
在位置 (1, 0):7 * 8 = 56
在位置 (1, 1):10 * 11 = 110
求和:2 + 20 + 56 + 110 = 188
通道 0 和 通道 2 (c=0, d=2)
在位置 (0, 0):1 * 3 = 3
在位置 (0, 1):4 * 6 = 24
在位置 (1, 0):7 * 9 = 63
在位置 (1, 1):10 * 12 = 120
求和:3 + 24 + 63 + 120 = 210
通道 1 和 通道 1 (c=1, d=1)
在位置 (0, 0):2 * 2 = 4
在位置 (0, 1):5 * 5 = 25
在位置 (1, 0):8 * 8 = 64
在位置 (1, 1):11 * 11 = 121
求和:4 + 25 + 64 + 121 = 214
通道 1 和 通道 2 (c=1, d=2)
在位置 (0, 0):2 * 3 = 6
在位置 (0, 1):5 * 6 = 30
在位置 (1, 0):8 * 9 = 72
在位置 (1, 1):11 * 12 = 132
求和:6 + 30 + 72 + 132 = 240
通道 2 和 通道 2 (c=2, d=2)
在位置 (0, 0):3 * 3 = 9
在位置 (0, 1):6 * 6 = 36
在位置 (1, 0):9 * 9 = 81
在位置 (1, 1):12 * 12 = 144
求和:9 + 36 + 81 + 144 = 270
第二步:构建 Gram 矩阵
根据上面的计算结果,我们可以构建最终的 Gram 矩阵。矩阵的形状是 (channels, channels),在本例中是 (3, 3):
[
[166, 188, 210], # 通道 0 和其他通道的相关性
[188, 214, 240], # 通道 1 和其他通道的相关性
[210, 240, 270] # 通道 2 和其他通道的相关性
]
最后一步:除以位置数
根据代码中的 num_locations = tf.cast(input_shape[1] * input_shape[2], tf.float32)
,位置数为 height * width = 2 * 2 = 4
。
最终的 Gram 矩阵需要除以这个位置数:
[
[166/4, 188/4, 210/4], # => [41.5, 47.0, 52.5]
[188/4, 214/4, 240/4], # => [47.0, 53.5, 60.0]
[210/4, 240/4, 270/4] # => [52.5, 60.0, 67.5]
]
最终 Gram 矩阵为:
[
[41.5, 47.0, 52.5],
[47.0, 53.5, 60.0],
[52.5, 60.0, 67.5]
]
这个 Gram 矩阵显示了每个通道之间的相关性。例如:
(0, 1)
位置的47.0
表示通道 0 和通道 1 在所有位置上的特征值相关性。
对角线上的元素(如 41.5、53.5、67.5
)表示通道自身的相关性,即通道内的特征强度。
。