pytorch官方代码:https://github.com/lcy0604/EraseNet
论文:2010.EraseNet: End-to-End Text Removal in the Wild 网盘提取码:0719
一、图片文字去除效果
图10 SCUT-EnsText 真实数据集的去除
第一列原图带文字、第二列为去除后的标签,剩下的列都是不同的算法去除效果 (pix2pix, scennetextEraser ,EnsNet, 本文EraseNet)
图11 合成的
数据集文字图片去除效果比较
图12 与 inpanting方法比较去除效果
二、方法概述
模型设计了一个两阶段的从粗到细的(h a two-stage ·coarse-to-refine generator network)生成器
网络和一个局部全局鉴别器
网络(a local-global discriminator network.)。(本文中作者改进了SN-GAN,并提出名为 local-global SN-Patch-GAN
的架构
一个额外的语义分割网络
头与整个算法一体的,用于感知(perceive)文字区域。
同时,借助外部预训练好的VGG-16
网络抽取特征,用来监督生成的去除文字的图片(fake samples)与标签图片(ground-truths)的高级语义的差异(discrepancies of high-level semantics.)
图8 判别器架构
图9 不同算法效果对比
训练细节
单个NVIDIA 2080TI GPU
, batch size =4
用
数据集
SCUT-EnsText : 华南理工大学提出与搜集见抬头代码库
2016年提出的 Synthetic data for text localisation in natural images 用来合成数据集
三、本地自己数据集
实验结果
购物图转化
推理代码
# -*- coding: utf-8 -*-
# @Time : 2023/7/6 20:36
# @Author : XyZeng
import os
import math
import argparse
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from PIL import Image
import numpy as np
from torch.autograd import Variable
from torchvision.utils import save_image
from data.dataloader import ErasingData,ImageTransform
from models.sa_gan import STRnet2
parser = argparse.ArgumentParser()
parser.add_argument('--numOfWorkers', type=int, default=0,
help='workers for dataloader')
parser.add_argument('--modelsSavePath', type=str, default='',
help='path for saving models')
parser.add_argument('--logPath', type=str,
default='')
parser.add_argument('--batchSize', type=int, default=16)
parser.add_argument('--loadSize', type=int, default=512,
help='image loading size')
parser.add_argument('--dataRoot', type=str,
default='./')
parser.add_argument('--pretrained',type=str, default='./model.pth', help='pretrained models for finetuning')
parser.add_argument('--savePath', type=str, default='./output')
args = parser.parse_args()
cuda = torch.cuda.is_available()
if cuda:
print('Cuda is available!')
cudnn.benchmark = True
def visual(image):
im =(image).transpose(1,2).transpose(2,3).detach().cpu().numpy()
Image.fromarray(im[0].astype(np.uint8)).show()
batchSize = args.batchSize
loadSize = (args.loadSize, args.loadSize)
dataRoot = args.dataRoot
savePath = args.savePath
import torch.nn.functional as F
os.makedirs(savePath,exist_ok=True)
netG = STRnet2(3)
netG.load_state_dict(torch.load(args.pretrained))
if cuda:
netG = netG.cuda()
for param in netG.parameters():
param.requires_grad = False
print('OK!')
import time
start = time.time()
netG.eval()
ImgTrans=ImageTransform(args.loadSize)
def get_img_tensor(path):
img = Image.open(path)
Image.Resampling.BICUBIC (3), Image.Resampling.BOX (4) o
img=img.convert('RGB').resize((args.loadSize,args.loadSize) ,2)
inputImage = ImgTrans(img).unsqueeze(0)
# mask = ImgTrans(mask.convert('RGB'))
# inputImage = F.interpolate(inputImage, size=(512,512), mode='bilinear') # Adjust size to 115
print('inputImage',inputImage.size())
return inputImage
if __name__ == '__main__':
inpur_dir=r'example\all_images' # 改为'./你需要转换的图片目录'
for name in os.listdir(inpur_dir):
path=os.path.join(inpur_dir,name)
imgs=get_img_tensor(path)
if cuda:
imgs = imgs.cuda()
# masks = masks.cuda()
'''
看论文喝源码能发现5个输出的对应
'''
out1, out2, out3, g_images,mm = netG(imgs)
g_image = g_images.data.cpu()
mm = mm.data.cpu()
# save_image(g_image_with_mask, result_with_mask+path[0])
dir,name=os.path.split(path)
out_path=os.path.join(savePath,name)
mask_path= os.path.join(savePath,name+'_mask.png')
save_image(g_image, out_path)
save_image(mm,mask_path)
print(out_path,mask_path)
# break