AttnGAN 代码复现 2024
文章目录
- AttnGAN 代码复现 2024
- 简介
- 环境
- python 依赖
- 数据集
- Training
- Pre-train DAMSM
- Train AttnGAN
- Sampling
- B_VALIDATION 为 False (默认)
- B_VALIDATION 为 True
- 参考博客
简介
论文地址: https://arxiv.org/pdf/1711.10485.pdf
代码 python2.7(论文源码):https://github.com/taoxugit/AttnGAN
python2.7基本废掉了,不建议
代码 python3:https://github.com/davidstap/AttnGAN
环境
名称 | 版本 | 备注 |
---|---|---|
NVIDIA-SMI 531.79 | CUDA Version: 12.1 | win11 |
torch | 2.3.1+cu121 | 最新 |
torchvision | 0.18.1+cu121 | |
Python | 3.10.14 | |
笔记本 4060 | 8G 显存 |
注意 python2.7版本已经无法下载 torch
参考博客:Torch 、torchvision 、Python 版本对应关系
python 依赖
python-dateutil
easydict
pandas
torchfile
nltk
scikit-image
pyyaml
数据集
- 鸟类预处理的元数据:https://drive.google.com/file/d/1O_LtUP9sch09QH3s_EBAgLEctBQ5JBSJ/view,保存到
data
- 下载鸟类图像数据:http://www.vision.caltech.edu/datasets/cub_200_2011/ 将它们提取到
data/birds/
Training
Pre-train DAMSM
运行命令
cd code
python pretrain_DAMSM.py --cfg cfg/DAMSM/bird.yml --gpu 0
配置文件
code/cfg/DAMSM/bird.yml
TRAIN:
FLAG: True
NET_E: ''
BATCH_SIZE: 48 # RuntimeError: CUDA out of memory,可减少batch_size
MAX_EPOCH: 600 # 训练轮数目
SNAPSHOT_INTERVAL: 50 # 每训练50轮保存模型
yml 文件中不要写中文,中文会出现读文件错误!
问题1
re_img = transforms.Resize(imsize[i])(img)
IndexError: list index out of range
code/datasets.py
# 旧代码
if i < (cfg.TREE.BRANCH_NUM - 1)
# 新代码 (如果你修改成 TREE.BRANCH_NUM -2 后续会导致很多问题)
if i < len(imsize)
问题2
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
code/pretrain_DAMSM.py
:提示我们将 loss[0]
改为 loss.item()
问题3
TypeError: load() missing 1 required positional argument: ‘Loader’
code/miscc/config.py
# 旧代码
yaml_cfg = edict(yaml.load(f))
# 新代码
yaml_cfg = edict(yaml.safe_load(f))
问题4
RuntimeError: masked_fill only supports boolean masks, but got dtype Byte
code/miscc/losses.py
# 旧代码(windows下不适用,但liunx下适用)
data.masked_fill_(masks, -float('inf'))
# 新代码(相反)
data.masked_fill_(masks.bool(), -float('inf'))
问题5
TypeError: pyramid_expand() got an unexpected keyword argument ‘multichannel’
code/miscc/utils.py
# 旧代码
skimage.transform.pyramid_expand(one_map, sigma=20,
upscale=vis_size // att_sze,
multichannel=True)
# 新代码
skimage.transform.pyramid_expand(one_map, sigma=20,
upscale=vis_size // att_sze,
channel_axis=-1)
问题6
OSError: cannot open resource
code/miscc/utils.py
# 旧代码
fnt = ImageFont.truetype('Pillow/Tests/fonts/FreeMono.ttf', 50)
# 新代码
script_dir = os.path.dirname(os.path.abspath(__file__))
font_path = os.path.join(script_dir, 'FreeMono.ttf')
fnt = ImageFont.truetype(font_path, 50)
将 eval/FreeMono.ttf
复制到 code/miscc/FreeMono.ttf
Train AttnGAN
运行命令
python main.py --cfg cfg/bird_attn2.yml --gpu 0
预训练模型下载
如果没有进行 Pre-train DAMSM 步骤,可直接下载 DAMSM 预训练模型 保存到 DAMSMencoders
配置文件
code/cfg/bird_attn2.yml
TRAIN:
FLAG: True
...
NET_E: '../DAMSMencoders/bird/text_encoder200.pth' # DAMSM 预训练模型存放位置,可自定义
注意:配置文件中不要进行中文注释
问题1
FileNotFoundError: [Errno 2] No such file or directory: '../DAMSMencoders/bird/image_encoder200.pth'
找不到.pth文件,比对路径发现是 …/DAMSMencoders/
birds
/image_encoder200.pth
- 修改后
问题2
AttributeError: ‘_MultiProcessingDataLoaderIter’ object has no attribute ‘next’
code/trainer.py
# 旧代码
data = data_iter.next()
# 新代码
data = next(data_iter)
问题3
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU
code/cfg/bird_attn2.yml
# 旧代码
BATCH_SIZE: 20 # 22
# 新代码
BATCH_SIZE: 10 # 20 22
警告
UserWarning: This overload of add_ is deprecated
code/trainer.py
# 修改前
avg_p.mul_(0.999).add_(0.001, p.data)
# 修改后
avg_p.mul_(0.999).add_(p.data, alpha=0.001)
注意:警告可以不进行修改
如果你不想进行Trian AttnGAN 步骤,可以下载已训练好的 AttnGAN 模型 保存到 models/
Sampling
运行命令
python main.py --cfg cfg/eval_bird.yml --gpu 0
B_VALIDATION 为 False (默认)
将
./data/birds/example_filenames.txt
中的文本(可自定义句子)作为输入生成样本图像,没啥用还会干扰评价指标测量
配置文件
code/cfg/eval_bird.yml
B_VALIDATION: False
BATCH_SIZE: 100
错误1
Length of all samples has to be greater than 0, but found an element in ‘lengths’ that is <= 0
- 进行控制变量法
- data/birds/example_filenames.txt 进行不断删减,检测存在问题的输入样本
正确运行
example_captions
text/180.Wilson_Warbler/Wilson_Warbler_0007_175618
text/180.Wilson_Warbler/Wilson_Warbler_0024_175278
text/180.Wilson_Warbler/Wilson_Warbler_0074_175645
text/180.Wilson_Warbler/Wilson_Warbler_0107_175320
text/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0001_163813
text/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0008_164001
text/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0016_164060
text/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0035_163587
text/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0101_164324
text/165.Chestnut_sided_Warbler/Chestnut_Sided_Warbler_0103_163669
运行报错
example_captions
text/138.Tree_Swallow/Tree_Swallow_0002_136792
text/138.Tree_Swallow/Tree_Swallow_0008_135352
text/138.Tree_Swallow/Tree_Swallow_0030_134942
text/138.Tree_Swallow/Tree_Swallow_0050_135104
text/138.Tree_Swallow/Tree_Swallow_0117_134925
text/098.Scott_Oriole/Scott_Oriole_0002_795829
text/098.Scott_Oriole/Scott_Oriole_0014_795827
text/098.Scott_Oriole/Scott_Oriole_0018_795840
text/098.Scott_Oriole/Scott_Oriole_0046_92371
反手查看源文件
data/birds/text/138.Tree_Swallow/Tree_Swallow_0030_134942.txt
the blue backed, white bellied baby bird has a very fat little belly
this is a bird with a white belly and breast and a blue back and head.
the bird has a small white body with blue and green colored crown and coverts.
a small round bird with a white and blue body.
this��bird��has��a��white��belly,��dark��blue��wings,��and��a��small,��short��bill.
this bird is blue with white and has a very short beak.
this is a blue bird with a white throat, breast, belly and abdomen and a small black pointed beak
this small bird has a white beast and blue crest and back.
this bird has wings that are blue and has a white belly
this bird has wings that are blue and has a white belly
偷真的🐕,将
��
替换成空格
后正常运行!!!
B_VALIDATION 为 True
以
data/birds/text
中文本作为输入,生成样本图像,用于评价指标测量
code/cfg/eval_bird.yml
# 新代码
B_VALIDATION: True
# 新代码(8G显存)
BATCH_SIZE: 20
运行命令
python main.py --cfg cfg/eval_bird.yml --gpu 0
结果
Make a new folder: ../models/bird_AttnGAN2/valid/single/156.White_eyed_Vireo
step: 100
Total time for training: 91.08317875862122
参考博客
AttnGAN代码复现(详细步骤+避坑指南)文本生成图像