卷积的计算 - im2col 2
flyfish
import numpy as np
np.set_printoptions(linewidth=200)
# F = filter = kernel
def im2col(images, kernel_size, stride=1, padding=0):
#process images
if images.ndim == 2:
images = images.reshape(1, 1, *images.shape)
elif images.ndim == 3:
N, I_h, I_w = images.shape
images = images.reshape(N, 1, I_h, I_w)
N, C, I_h, I_w = images.shape
#process kernel
if isinstance(kernel_size, tuple):
if len(kernel_size) == 2:
kernel_size = (1, 1, *kernel_size)
elif len(kernel_size) == 3:
M, k_h, k_w = kernel_size
kernel_size = (M, 1, k_h, k_w)
_, _, k_h, k_w = kernel_size
else:
if kernel_size.ndim == 2:
kernel_size = kernel_size.reshape(1, 1, *kernel_size.shape)
elif kernel_size.ndim == 3:
M, k_h, k_w = kernel_size.shape
kernel_size = kernel_size.reshape(M, 1, k_h, k_w)
_, _, k_h, k_w = kernel_size.shape
#process stride
if isinstance(stride, tuple):
stride_ud, stride_lr = stride
else:
stride_ud = stride
stride_lr = stride
#process padding
if isinstance(padding, tuple):
pad_ud, pad_lr = padding
elif isinstance(padding, int):
pad_ud = padding
pad_lr = padding
elif padding == "same":
pad_ud = 0.5*((I_h - 1)*stride_ud - I_h + k_h)
pad_lr = 0.5*((I_w - 1)*stride_lr - I_w + k_w)
pad_zero = (0, 0)
O_h = int((I_h - k_h + 2*pad_ud)//stride_ud + 1)
O_w = int((I_w - k_w + 2*pad_lr)//stride_lr + 1)
result_pad = (pad_ud, pad_lr)
pad_ud = int(np.ceil(pad_ud))
pad_lr = int(np.ceil(pad_lr))
pad_ud = (pad_ud, pad_ud)
pad_lr = (pad_lr, pad_lr)
images = np.pad(images, [pad_zero, pad_zero, pad_ud, pad_lr], "constant")
cols = np.empty((N, C, k_h, k_w, O_h, O_w))
for h in range(k_h):
h_lim = h + stride_ud*O_h
for w in range(k_w):
w_lim = w + stride_lr*O_w
cols[:, :, h, w, :, :] \
= images[:, :, h:h_lim:stride_ud, w:w_lim:stride_lr]
cols = cols.transpose(1, 2, 3, 0, 4, 5).reshape(C*k_h*k_w, N*O_h*O_w)
print("H:",O_h)
print("W:",O_w)
return cols
image = np.arange(1, 109).reshape(1, 3, 6, 6) #NCHW
print("input:",image)
r=im2col(image,(3,3,3),1,0)
print(r)
print(r.shape)
运行结果
图中 Input的值是
input: [[[[ 1 2 3 4 5 6]
[ 7 8 9 10 11 12]
[ 13 14 15 16 17 18]
[ 19 20 21 22 23 24]
[ 25 26 27 28 29 30]
[ 31 32 33 34 35 36]]
[[ 37 38 39 40 41 42]
[ 43 44 45 46 47 48]
[ 49 50 51 52 53 54]
[ 55 56 57 58 59 60]
[ 61 62 63 64 65 66]
[ 67 68 69 70 71 72]]
[[ 73 74 75 76 77 78]
[ 79 80 81 82 83 84]
[ 85 86 87 88 89 90]
[ 91 92 93 94 95 96]
[ 97 98 99 100 101 102]
[103 104 105 106 107 108]]]]
输出的形状
H: 4
W: 4
卷积核是3*3 ,K =3
通道是3 ,C=3
(K*K*C,H*W,)=(3*3*3,4*4)=(27, 16)
经过im2col结果是
[[ 1. 2. 3. 4. 7. 8. 9. 10. 13. 14. 15. 16. 19. 20. 21. 22.]
[ 2. 3. 4. 5. 8. 9. 10. 11. 14. 15. 16. 17. 20. 21. 22. 23.]
[ 3. 4. 5. 6. 9. 10. 11. 12. 15. 16. 17. 18. 21. 22. 23. 24.]
[ 7. 8. 9. 10. 13. 14. 15. 16. 19. 20. 21. 22. 25. 26. 27. 28.]
[ 8. 9. 10. 11. 14. 15. 16. 17. 20. 21. 22. 23. 26. 27. 28. 29.]
[ 9. 10. 11. 12. 15. 16. 17. 18. 21. 22. 23. 24. 27. 28. 29. 30.]
[ 13. 14. 15. 16. 19. 20. 21. 22. 25. 26. 27. 28. 31. 32. 33. 34.]
[ 14. 15. 16. 17. 20. 21. 22. 23. 26. 27. 28. 29. 32. 33. 34. 35.]
[ 15. 16. 17. 18. 21. 22. 23. 24. 27. 28. 29. 30. 33. 34. 35. 36.]
[ 37. 38. 39. 40. 43. 44. 45. 46. 49. 50. 51. 52. 55. 56. 57. 58.]
[ 38. 39. 40. 41. 44. 45. 46. 47. 50. 51. 52. 53. 56. 57. 58. 59.]
[ 39. 40. 41. 42. 45. 46. 47. 48. 51. 52. 53. 54. 57. 58. 59. 60.]
[ 43. 44. 45. 46. 49. 50. 51. 52. 55. 56. 57. 58. 61. 62. 63. 64.]
[ 44. 45. 46. 47. 50. 51. 52. 53. 56. 57. 58. 59. 62. 63. 64. 65.]
[ 45. 46. 47. 48. 51. 52. 53. 54. 57. 58. 59. 60. 63. 64. 65. 66.]
[ 49. 50. 51. 52. 55. 56. 57. 58. 61. 62. 63. 64. 67. 68. 69. 70.]
[ 50. 51. 52. 53. 56. 57. 58. 59. 62. 63. 64. 65. 68. 69. 70. 71.]
[ 51. 52. 53. 54. 57. 58. 59. 60. 63. 64. 65. 66. 69. 70. 71. 72.]
[ 73. 74. 75. 76. 79. 80. 81. 82. 85. 86. 87. 88. 91. 92. 93. 94.]
[ 74. 75. 76. 77. 80. 81. 82. 83. 86. 87. 88. 89. 92. 93. 94. 95.]
[ 75. 76. 77. 78. 81. 82. 83. 84. 87. 88. 89. 90. 93. 94. 95. 96.]
[ 79. 80. 81. 82. 85. 86. 87. 88. 91. 92. 93. 94. 97. 98. 99. 100.]
[ 80. 81. 82. 83. 86. 87. 88. 89. 92. 93. 94. 95. 98. 99. 100. 101.]
[ 81. 82. 83. 84. 87. 88. 89. 90. 93. 94. 95. 96. 99. 100. 101. 102.]
[ 85. 86. 87. 88. 91. 92. 93. 94. 97. 98. 99. 100. 103. 104. 105. 106.]
[ 86. 87. 88. 89. 92. 93. 94. 95. 98. 99. 100. 101. 104. 105. 106. 107.]
[ 87. 88. 89. 90. 93. 94. 95. 96. 99. 100. 101. 102. 105. 106. 107. 108.]]