以下是一张图片,数据增广之后的示意图:
问题是这样的,当数据增广后,我们怎么知道图片变成什么样了呢,或者说我们输入到网络中的图片长什么样?对,解法很简单,就是在图片输入到网络时,把图片保存一下。
这里特把此过程记录一下,大家也可以参考下,快速展示训练input图片,在训练代码中修改3处即可,这里举例是基于timm训练代码框架,其他的也是一样。三处如下:
- 创建路径以及保存图片的函数
count = 0
save_path = r"./output/train_picture"
if os.path.exists(save_path):
shutil.rmtree(save_path)
os.makedirs(save_path, exist_ok=True)
def save_train_picture(input_batch):
global count, save_path
from torchvision import transforms
from PIL import Image
input_len = len(input_batch)
for input in input_batch:
tensor_to_pil = transforms.ToPILImage()(input)
tensor_to_pil.save(os.path.join(save_path, f"saved_image_{count}.png"))
count += 1
- 训练时保存图片
训练取数据时加上两行代码,表示保存并不经过以下训练,即只是保存训练图片。
3. 结束后跳出程序
循环之后结束程序即可。
以上就是三步搞定可视化训练图片,后续会继续研究可视化,如基于wandb的特征可视化,均为简易可操作版,敬请期待~
∼ O n e p e r s o n g o f a s t e r , a g r o u p o f p e o p l e c a n g o f u r t h e r ∼ \sim_{One\ person\ go\ faster,\ a\ group\ of\ people\ can\ go\ further}\sim ∼One person go faster, a group of people can go further∼