IA-SEG项目源自论文Improving Nighttime Driving-Scene Segmentation via Dual Image-adaptive Learnable Filters,其核心就是在原有的语义分割模型上添加了DIAL-Filters。而,DIAL-Filters由两部分组成,包括一个图像自适应处理模块(IAPM,即IA-YOLO中的CNN-PP+DIF模块)和一个可学习的引导滤波器(LGF)。其项目代码使用pytorch实现,为能在pytorch下实现域自适应的检测算法,故对该项目进行分析。IA-SEG项目为针对夜间环境下的语义分割项目,其包含监督学习和非监督学习部分,这里只讨论其核心部分IAPM模块(CNN-PP与DIF)+LGF模块的使用。
在本文的第三章和第四章有相关的代码使用案例。
除DIAL-Filters外,IA-SEG论文还提出了一种非监督学习框架,在博文最后面描述,感兴趣的朋友可以去查阅论文原文,或者看我的IA-SEG论文翻译讲解。
IA-SEG项目地址:https://github.com/wenyyu/IA-Seg#arxiv
1、CNN-PP模块
1.1 基本介绍
CNN-PP模块为DIP模块优化图像提供filter参数,其本质是一个简洁的卷积神经网络,其输入部分为低分辨的原始图,其输出为DIP模块的优化参数。
在IA-SEG中,CNN-PP模块的参数(预测4个filter参数,278K)比在IA-YOLO(预测15个filter参数,165k)要多。注:IA-SEG与IA-YOLO均为同一作者实现
1.2 实现代码
代码地址:https://github.com/wenyyu/IA-Seg/blob/main/network/dip.py
代码全文如下,其中涉及到一个外部对象cfg,该对象为配置文件,包含num_filter_parameters和cfg.filters在CNN_PP中被用到。
#! /usr/bin/env python
# coding=utf-8
import torch
import torch.nn as nn
import numpy as np
from configs.train_config import cfg
import time
def conv_downsample(in_filters, out_filters, normalization=False):
layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)]
layers.append(nn.LeakyReLU(0.2))
if normalization:
layers.append(nn.InstanceNorm2d(out_filters, affine=True))
return layers
class CNN_PP(nn.Module):
def __init__(self, in_channels=3):
super(CNN_PP, self).__init__()
self.model = nn.Sequential(
nn.Upsample(size=(256,256),mode='bilinear'),
nn.Conv2d(3, 16, 3, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.InstanceNorm2d(16, affine=True),
*conv_downsample(16, 32, normalization=True),
*conv_downsample(32, 64, normalization=True),
*conv_downsample(64, 128, normalization=True),
*conv_downsample(128, 128),
#*discriminator_block(128, 128, normalization=True),
nn.Dropout(p=0.5),
nn.Conv2d(128, cfg.num_filter_parameters, 8, padding=0),
)
def forward(self, img_input):
self.Pr = self.model(img_input)
self.filtered_image_batch = img_input
filters = cfg.filters
filters = [x(img_input, cfg) for x in filters]
self.filter_parameters = []
self.filtered_images = []
for j, filter in enumerate(filters):
# with tf.variable_scope('filter_%d' % j):
# print(' creating filter:', j, 'name:', str(filter.__class__), 'abbr.',
# filter.get_short_name())
# print(' filter_features:', self.Pr.shape)
self.filtered_image_batch, filter_parameter = filter.apply(
self.filtered_image_batch, self.Pr)
self.filter_parameters.append(filter_parameter)
self.filtered_images.append(self.filtered_image_batch)
# print(' output:', self.filtered_image_batch.shape)
return self.filtered_image_batch, self.filtered_images, self.Pr, self.filter_parameters
def DIP():
model = CNN_PP()
return model
1.3 其他关联代码
DIP模块设计到了cfg对象,其代码地址为:
https://github.com/wenyyu/IA-Seg/blob/main/configs/train_config.py
这里与CNN-PP及CNN-DIP相关的代码如下:
import argparse
from network.filters import *
cfg.filters = [ExposureFilter, GammaFilter, ContrastFilter, UsmFilter]
# cfg.filters = []
cfg.num_filter_parameters = 4
#这里的配置均被用于DIF模块的滤波操作
cfg.exposure_begin_param = 0
cfg.gamma_begin_param = 1
cfg.contrast_begin_param = 2
cfg.usm_begin_param = 3
# Gamma = 1/x ~ x
cfg.curve_steps = 8
cfg.gamma_range = 3
cfg.exposure_range = 3.5
cfg.wb_range = 1.1
cfg.color_curve_range = (0.90, 1.10)
cfg.lab_curve_range = (0.90, 1.10)
cfg.tone_curve_range = (0.5, 2)
cfg.defog_range = (0.1, 1.0)
cfg.usm_range = (0.0, 5)
cfg.cont_range = (0.0, 1.0)
此外,其还关联到DIF的实现代码,后续会描述.
CNN-PP模块作为一个即插即用的头部模块,可以不用添加到模型结构中,在train函数补齐其流程即可。IA-SEG对CNN-PP的使用如下,从中可以看出输入CNNPP的是归一化的图片(但并未进行标准化), 同时CNN-PP的输出也并未与其他图像计算loss,CNN-PP的优化全靠forword流程结束后的loss,这与IA-YOLO中的设计不同
更多使用细节可以查看原作者代码
CNNPP = dip.DIP().to(device)
optimizer.zero_grad()
CNNPP.train()
model= PSPNet(num_classes=args.num_classes, dgf=args.DGF_FLAG).to(device)
model.train()
optimizer = optim.SGD(list(model.parameters())+list(CNNPP.parameters()),
lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
for i_iter in range(args.num_steps):
for sub_i in range(args.iter_size):
_, batch = trainloader_iter.__next__()
images, labels, _, _ = batch
images = images.to(device)
labels = labels.long().to(device)
enhanced_images_pre, ci_map, Pr, filter_parameters = CNNPP(images)
enhanced_images = enhanced_images_pre
enhanced_images[i_pre,...] = standard_transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(
enhanced_images_pre[i_pre,...])
pred_c = model(enhanced_images)
2、DIF模块
2.1 基本介绍
DIF模块全程Differentiable Image Filters,其由几个具有可调超参数的可微滤波器组成,包括曝光度、伽玛度、对比度和锐度。在IA-SEG中的DIF代码其实是根据IA-YOLO中的DIP代码修改,将原先的TensorFlow实现修改为PyTorch语法,并注释了一些在IA-SEG中不需要用到的Filter模块(Tone Filter 和 Defog Filter
)。
2.2 实现代码
代码地址:https://github.com/wenyyu/IA-Seg/blob/main/network/filters.py
其实现代码如下,这里滤除了一下被注释的代码(即原来用tensorflow实现的Tone Filter 和 Defog Filter等).
这里需要注意的是,所有的可微滤波器均继承自Filter,在构建Filter时的参数net,cfg仅有cfg起到作用见1.3章中的代码注释
。rgb2lum, tanh_range, lerp函数被引入,为FIlter对象提供数据操作能力。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from network.util_filters import rgb2lum, tanh_range, lerp
from network.util_filters import *
import cv2
import math
# device = torch.device("cuda")
class Filter(nn.Module):
def __init__(self, net, cfg):
super(Filter, self).__init__()
self.cfg = cfg
# self.height, self.width, self.channels = list(map(int, net.get_shape()[1:]))
# Specified in child classes
self.num_filter_parameters = None
self.short_name = None
self.filter_parameters = None
def get_short_name(self):
assert self.short_name
return self.short_name
def get_num_filter_parameters(self):
assert self.num_filter_parameters
return self.num_filter_parameters
def get_begin_filter_parameter(self):
return self.begin_filter_parameter
def extract_parameters(self, features):
# output_dim = self.get_num_filter_parameters(
# ) + self.get_num_mask_parameters()
# features = ly.fully_connected(
# features,
# self.cfg.fc1_size,
# scope='fc1',
# activation_fn=lrelu,
# weights_initializer=tf.contrib.layers.xavier_initializer())
# features = ly.fully_connected(
# features,
# output_dim,
# scope='fc2',
# activation_fn=None,
# weights_initializer=tf.contrib.layers.xavier_initializer())
return features[:, self.get_begin_filter_parameter():(self.get_begin_filter_parameter() + self.get_num_filter_parameters())], \
features[:, self.get_begin_filter_parameter():(self.get_begin_filter_parameter() + self.get_num_filter_parameters())]
# Should be implemented in child classes
def filter_param_regressor(self, features):
assert False
# Process the whole image, without masking
# Should be implemented in child classes
def process(self, img, param, defog, IcA):
assert False
def debug_info_batched(self):
return False
def no_high_res(self):
return False
# Apply the whole filter with masking
def apply(self,
img,
img_features=None,
defog_A=None,
IcA=None,
specified_parameter=None,
high_res=None):
assert (img_features is None) ^ (specified_parameter is None)
if img_features is not None:
filter_features, mask_parameters = self.extract_parameters(img_features)
filter_parameters = self.filter_param_regressor(filter_features)
else:
assert not self.use_masking()
filter_parameters = specified_parameter
if high_res is not None:
# working on high res...
pass
debug_info = {}
# We only debug the first image of this batch
if self.debug_info_batched():
debug_info['filter_parameters'] = filter_parameters
else:
debug_info['filter_parameters'] = filter_parameters[0]
# self.mask_parameters = mask_parameters
# self.mask = self.get_mask(img, mask_parameters)
# debug_info['mask'] = self.mask[0]
#low_res_output = lerp(img, self.process(img, filter_parameters), self.mask)
low_res_output = self.process(img, filter_parameters, defog_A, IcA)
if high_res is not None:
if self.no_high_res():
high_res_output = high_res
else:
self.high_res_mask = self.get_mask(high_res, mask_parameters)
# high_res_output = lerp(high_res,
# self.process(high_res, filter_parameters, defog, IcA),
# self.high_res_mask)
else:
high_res_output = None
#return low_res_output, high_res_output, debug_info
return low_res_output, filter_parameters
def use_masking(self):
return self.cfg.masking
def get_num_mask_parameters(self):
return 6
# Input: no need for tanh or sigmoid
# Closer to 1 values are applied by filter more strongly
# no additional TF variables inside
def get_mask(self, img, mask_parameters):
if not self.use_masking():
print('* Masking Disabled')
return tf.ones(shape=(1, 1, 1, 1), dtype=tf.float32)
else:
print('* Masking Enabled')
with tf.name_scope(name='mask'):
# Six parameters for one filter
filter_input_range = 5
assert mask_parameters.shape[1] == self.get_num_mask_parameters()
mask_parameters = tanh_range(
l=-filter_input_range, r=filter_input_range,
initial=0)(mask_parameters)
size = list(map(int, img.shape[1:3]))
grid = np.zeros(shape=[1] + size + [2], dtype=np.float32)
shorter_edge = min(size[0], size[1])
for i in range(size[0]):
for j in range(size[1]):
grid[0, i, j,
0] = (i + (shorter_edge - size[0]) / 2.0) / shorter_edge - 0.5
grid[0, i, j,
1] = (j + (shorter_edge - size[1]) / 2.0) / shorter_edge - 0.5
grid = tf.constant(grid)
# Ax + By + C * L + D
inp = grid[:, :, :, 0, None] * mask_parameters[:, None, None, 0, None] + \
grid[:, :, :, 1, None] * mask_parameters[:, None, None, 1, None] + \
mask_parameters[:, None, None, 2, None] * (rgb2lum(img) - 0.5) + \
mask_parameters[:, None, None, 3, None] * 2
# Sharpness and inversion
inp *= self.cfg.maximum_sharpness * mask_parameters[:, None, None, 4,
None] / filter_input_range
mask = tf.sigmoid(inp)
# Strength
mask = mask * (
mask_parameters[:, None, None, 5, None] / filter_input_range * 0.5 +
0.5) * (1 - self.cfg.minimum_strength) + self.cfg.minimum_strength
print('mask', mask.shape)
return mask
# def visualize_filter(self, debug_info, canvas):
# # Visualize only the filter information
# assert False
def visualize_mask(self, debug_info, res):
return cv2.resize(
debug_info['mask'] * np.ones((1, 1, 3), dtype=np.float32),
dsize=res,
interpolation=cv2.cv2.INTER_NEAREST)
def draw_high_res_text(self, text, canvas):
cv2.putText(
canvas,
text, (30, 128),
cv2.FONT_HERSHEY_SIMPLEX,
0.8, (0, 0, 0),
thickness=5)
return canvas
class ExposureFilter(Filter):
def __init__(self, net, cfg):
Filter.__init__(self, net, cfg)
self.short_name = 'E'
self.begin_filter_parameter = cfg.exposure_begin_param
self.num_filter_parameters = 1
def filter_param_regressor(self, features):#param is in (-self.cfg.exposure_range, self.cfg.exposure_range)
return tanh_range(
-self.cfg.exposure_range, self.cfg.exposure_range, initial=0)(features)
def process(self, img, param, defog, IcA):
# print(' param:', param)
# print(' param:', torch.exp(param * np.log(2)))
# return img * torch.exp(torch.tensor(3.31).cuda() * np.log(2))
return img * torch.exp(param * np.log(2))
class UsmFilter(Filter):#Usm_param is in [Defog_range]
def __init__(self, net, cfg):
Filter.__init__(self, net, cfg)
self.short_name = 'UF'
self.begin_filter_parameter = cfg.usm_begin_param
self.num_filter_parameters = 1
def filter_param_regressor(self, features):
return tanh_range(*self.cfg.usm_range)(features)
def process(self, img, param, defog_A, IcA):
self.channels = 3
kernel = [[0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633],
[0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],
[0.01330373, 0.11098164, 0.22508352, 0.11098164, 0.01330373],
[0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],
[0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633]]
kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
kernel = np.repeat(kernel, self.channels, axis=0)
# print(' param:', param)
kernel = kernel.to(img.device)
# self.weight = nn.Parameter(data=kernel, requires_grad=False)
# self.weight.to(device)
output = F.conv2d(img, kernel, padding=2, groups=self.channels)
img_out = (img - output) * param + img
# img_out = (img - output) * torch.tensor(0.043).cuda() + img
return img_out
class ContrastFilter(Filter):
def __init__(self, net, cfg):
Filter.__init__(self, net, cfg)
self.short_name = 'Ct'
self.begin_filter_parameter = cfg.contrast_begin_param
self.num_filter_parameters = 1
def filter_param_regressor(self, features):
# return tf.sigmoid(features)
# return torch.tanh(features)
return tanh_range(*self.cfg.cont_range)(features)
def process(self, img, param, defog, IcA):
# print(' param.shape:', param.shape)
# luminance = torch.minimum(torch.maximum(rgb2lum(img), 0.0), 1.0)
luminance = rgb2lum(img)
zero = torch.zeros_like(luminance)
one = torch.ones_like(luminance)
luminance = torch.where(luminance < 0, zero, luminance)
luminance = torch.where(luminance > 1, one, luminance)
contrast_lum = -torch.cos(math.pi * luminance) * 0.5 + 0.5
contrast_image = img / (luminance + 1e-6) * contrast_lum
return lerp(img, contrast_image, param)
# return lerp(img, contrast_image, torch.tensor(0.015).cuda())
class ToneFilter(Filter):
def __init__(self, net, cfg):
Filter.__init__(self, net, cfg)
self.curve_steps = cfg.curve_steps
self.short_name = 'T'
self.begin_filter_parameter = cfg.tone_begin_param
self.num_filter_parameters = cfg.curve_steps
def filter_param_regressor(self, features):
# tone_curve = tf.reshape(
# features, shape=(-1, 1, self.cfg.curve_steps))[:, None, None, :]
tone_curve = tanh_range(*self.cfg.tone_curve_range)(features)
return tone_curve
def process(self, img, param, defog, IcA):
# img = tf.minimum(img, 1.0)
# param = tf.constant([[0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6], [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6],
# [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6], [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6],
# [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6], [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6]])
# param = tf.constant([[0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6]])
# param = tf.reshape(
# param, shape=(-1, 1, self.cfg.curve_steps))[:, None, None, :]
param = torch.unsqueeze(param, 3)
# print(' param.shape:', param.shape)
tone_curve = param
tone_curve_sum = torch.sum(tone_curve, axis=1) + 1e-30
# print(' tone_curve_sum.shape:', tone_curve_sum.shape)
total_image = img * 0
for i in range(self.cfg.curve_steps):
total_image += torch.clamp(img - 1.0 * i / self.cfg.curve_steps, 0, 1.0 / self.cfg.curve_steps) \
* param[:, i, :, :]
# p_cons = [0.52, 0.53, 0.55, 1.9, 1.8, 1.7, 0.7, 0.6]
# for i in range(self.cfg.curve_steps):
# total_image += tf.clip_by_value(img - 1.0 * i / self.cfg.curve_steps, 0, 1.0 / self.cfg.curve_steps) \
# * p_cons[i]
total_image *= self.cfg.curve_steps / tone_curve_sum
img = total_image
return img
# def visualize_filter(self, debug_info, canvas):
# curve = debug_info['filter_parameters']
# height, width = canvas.shape[:2]
# values = np.array([0] + list(curve[0][0][0]))
# values /= sum(values) + 1e-30
# for j in range(0, self.curve_steps):
# values[j + 1] += values[j]
# for j in range(self.curve_steps):
# p1 = tuple(
# map(int, (width / self.curve_steps * j, height - 1 -
# values[j] * height)))
# p2 = tuple(
# map(int, (width / self.curve_steps * (j + 1), height - 1 -
# values[j + 1] * height)))
# cv2.line(canvas, p1, p2, (0, 0, 0), thickness=1)
class GammaFilter(Filter): #gamma_param is in [1/gamma_range, gamma_range]
def __init__(self, net, cfg):
Filter.__init__(self, net, cfg)
self.short_name = 'G'
self.begin_filter_parameter = cfg.gamma_begin_param
self.num_filter_parameters = 1
def filter_param_regressor(self, features):
log_gamma_range = np.log(self.cfg.gamma_range)
# return tf.exp(tanh_range(-log_gamma_range, log_gamma_range)(features))
return torch.exp(tanh_range(-log_gamma_range, log_gamma_range)(features))
def process(self, img, param, defog_A, IcA):
# print(' param:', param)
# param_1 = param.repeat(1, 3)
zero = torch.zeros_like(img) + 0.00001
img = torch.where(img <= 0, zero, img)
# print("GAMMMA", param)
return torch.pow(img, param)
# return torch.pow(img, torch.tensor(0.51).cuda())
# param_1 = tf.tile(param, [1, 3])
# return tf.pow(tf.maximum(img, 0.0001), param_1[:, None, None, :])
# return img
2.3 其他关联代码
util_filters为filter提供了一些基础功能函数,如rgb2lum, tanh_range, lerp.
完整代码为: https://github.com/wenyyu/IA-Seg/blob/main/network/util_filters.py
主要代码如下:
import math
import cv2
import torch
import torch.nn as nn
def rgb2lum(image):
image = 0.27 * image[:, :, :, 0] + 0.67 * image[:, :, :,
1] + 0.06 * image[:, :, :, 2]
return image[:, :, :, None]
def tanh01(x):
# return tf.tanh(x) * 0.5 + 0.5
return torch.tanh(x) * 0.5 + 0.5
def tanh_range(l, r, initial=None):
def get_activation(left, right, initial):
def activation(x):
if initial is not None:
bias = math.atanh(2 * (initial - left) / (right - left) - 1)
else:
bias = 0
return tanh01(x + bias) * (right - left) + left
return activation
return get_activation(l, r, initial)
def lerp(a, b, l):
return (1 - l) * a + l * b
3、IPAM模块使用
IPAM模块实则为上文中CNN-PP与IDF模块的组合,这里在拎出来将使用,实则是为了将代码冲IA-SEG项目中剥离出来,单独使用。
在IA-SEG中,实质上已经将DIF模块嵌入到了CNN-PP的模型中,构成了IPAM模块。但是相关函数代码分离在多个py文件中,不便于使用,故此进行整合
3.1 整合代码
安装依赖项:pip install easydict
整合后的代码如下所示,仅需要修改最底部的cfg即可。这里构建了IPAM类,可以通过IPAM类直接进行图像域适应。
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import math
#-----Filter相关的基础函数------
def rgb2lum(image):
image = 0.27 * image[:, :, :, 0] + 0.67 * image[:, :, :,
1] + 0.06 * image[:, :, :, 2]
return image[:, :, :, None]
def tanh01(x):
# return tf.tanh(x) * 0.5 + 0.5
return torch.tanh(x) * 0.5 + 0.5
def tanh_range(l, r, initial=None):
def get_activation(left, right, initial):
def activation(x):
if initial is not None:
bias = math.atanh(2 * (initial - left) / (right - left) - 1)
else:
bias = 0
return tanh01(x + bias) * (right - left) + left
return activation
return get_activation(l, r, initial)
def lerp(a, b, l):
return (1 - l) * a + l * b
#-----Filter的相关实现------
class Filter(nn.Module):
def __init__(self, net, cfg):
super(Filter, self).__init__()
self.cfg = cfg
self.num_filter_parameters = None
self.short_name = None
self.filter_parameters = None
def get_short_name(self):
assert self.short_name
return self.short_name
def get_num_filter_parameters(self):
assert self.num_filter_parameters
return self.num_filter_parameters
def get_begin_filter_parameter(self):
return self.begin_filter_parameter
def extract_parameters(self, features):
return features[:, self.get_begin_filter_parameter():(self.get_begin_filter_parameter() + self.get_num_filter_parameters())], \
features[:, self.get_begin_filter_parameter():(self.get_begin_filter_parameter() + self.get_num_filter_parameters())]
# Should be implemented in child classes
def filter_param_regressor(self, features):
assert False
# Process the whole image, without masking
# Should be implemented in child classes
def process(self, img, param, defog, IcA):
assert False
def debug_info_batched(self):
return False
def no_high_res(self):
return False
# Apply the whole filter with masking
def apply(self,
img,
img_features=None,
defog_A=None,
IcA=None,
specified_parameter=None,
high_res=None):
assert (img_features is None) ^ (specified_parameter is None)
if img_features is not None:
filter_features, mask_parameters = self.extract_parameters(img_features)
filter_parameters = self.filter_param_regressor(filter_features)
else:
assert not self.use_masking()
filter_parameters = specified_parameter
if high_res is not None:
# working on high res...
pass
debug_info = {}
# We only debug the first image of this batch
if self.debug_info_batched():
debug_info['filter_parameters'] = filter_parameters
else:
debug_info['filter_parameters'] = filter_parameters[0]
# self.mask_parameters = mask_parameters
# self.mask = self.get_mask(img, mask_parameters)
# debug_info['mask'] = self.mask[0]
#low_res_output = lerp(img, self.process(img, filter_parameters), self.mask)
low_res_output = self.process(img, filter_parameters, defog_A, IcA)
if high_res is not None:
if self.no_high_res():
high_res_output = high_res
else:
self.high_res_mask = self.get_mask(high_res, mask_parameters)
# high_res_output = lerp(high_res,
# self.process(high_res, filter_parameters, defog, IcA),
# self.high_res_mask)
else:
high_res_output = None
#return low_res_output, high_res_output, debug_info
return low_res_output, filter_parameters
def use_masking(self):
return self.cfg.masking
def get_num_mask_parameters(self):
return 6
# Input: no need for tanh or sigmoid
# Closer to 1 values are applied by filter more strongly
# no additional TF variables inside
def get_mask(self, img, mask_parameters):
if not self.use_masking():
print('* Masking Disabled')
return tf.ones(shape=(1, 1, 1, 1), dtype=tf.float32)
else:
print('* Masking Enabled')
with tf.name_scope(name='mask'):
# Six parameters for one filter
filter_input_range = 5
assert mask_parameters.shape[1] == self.get_num_mask_parameters()
mask_parameters = tanh_range(
l=-filter_input_range, r=filter_input_range,
initial=0)(mask_parameters)
size = list(map(int, img.shape[1:3]))
grid = np.zeros(shape=[1] + size + [2], dtype=np.float32)
shorter_edge = min(size[0], size[1])
for i in range(size[0]):
for j in range(size[1]):
grid[0, i, j,
0] = (i + (shorter_edge - size[0]) / 2.0) / shorter_edge - 0.5
grid[0, i, j,
1] = (j + (shorter_edge - size[1]) / 2.0) / shorter_edge - 0.5
grid = tf.constant(grid)
# Ax + By + C * L + D
inp = grid[:, :, :, 0, None] * mask_parameters[:, None, None, 0, None] + \
grid[:, :, :, 1, None] * mask_parameters[:, None, None, 1, None] + \
mask_parameters[:, None, None, 2, None] * (rgb2lum(img) - 0.5) + \
mask_parameters[:, None, None, 3, None] * 2
# Sharpness and inversion
inp *= self.cfg.maximum_sharpness * mask_parameters[:, None, None, 4,
None] / filter_input_range
mask = tf.sigmoid(inp)
# Strength
mask = mask * (
mask_parameters[:, None, None, 5, None] / filter_input_range * 0.5 +
0.5) * (1 - self.cfg.minimum_strength) + self.cfg.minimum_strength
print('mask', mask.shape)
return mask
# def visualize_filter(self, debug_info, canvas):
# # Visualize only the filter information
# assert False
def visualize_mask(self, debug_info, res):
return cv2.resize(
debug_info['mask'] * np.ones((1, 1, 3), dtype=np.float32),
dsize=res,
interpolation=cv2.cv2.INTER_NEAREST)
def draw_high_res_text(self, text, canvas):
cv2.putText(
canvas,
text, (30, 128),
cv2.FONT_HERSHEY_SIMPLEX,
0.8, (0, 0, 0),
thickness=5)
return canvas
class ExposureFilter(Filter):
def __init__(self, net, cfg):
Filter.__init__(self, net, cfg)
self.short_name = 'E'
self.begin_filter_parameter = cfg.exposure_begin_param
self.num_filter_parameters = 1
def filter_param_regressor(self, features):#param is in (-self.cfg.exposure_range, self.cfg.exposure_range)
return tanh_range(
-self.cfg.exposure_range, self.cfg.exposure_range, initial=0)(features)
def process(self, img, param, defog, IcA):
return img * torch.exp(param * np.log(2))
class UsmFilter(Filter):#Usm_param is in [Defog_range]
def __init__(self, net, cfg):
Filter.__init__(self, net, cfg)
self.short_name = 'UF'
self.begin_filter_parameter = cfg.usm_begin_param
self.num_filter_parameters = 1
def filter_param_regressor(self, features):
return tanh_range(*self.cfg.usm_range)(features)
def process(self, img, param, defog_A, IcA):
self.channels = 3
kernel = [[0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633],
[0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],
[0.01330373, 0.11098164, 0.22508352, 0.11098164, 0.01330373],
[0.00655965, 0.05472157, 0.11098164, 0.05472157, 0.00655965],
[0.00078633, 0.00655965, 0.01330373, 0.00655965, 0.00078633]]
kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
kernel = np.repeat(kernel, self.channels, axis=0)
# print(' param:', param)
kernel = kernel.to(img.device)
# self.weight = nn.Parameter(data=kernel, requires_grad=False)
# self.weight.to(device)
output = F.conv2d(img, kernel, padding=2, groups=self.channels)
img_out = (img - output) * param + img
# img_out = (img - output) * torch.tensor(0.043).cuda() + img
return img_out
class ContrastFilter(Filter):
def __init__(self, net, cfg):
Filter.__init__(self, net, cfg)
self.short_name = 'Ct'
self.begin_filter_parameter = cfg.contrast_begin_param
self.num_filter_parameters = 1
def filter_param_regressor(self, features):
return tanh_range(*self.cfg.cont_range)(features)
def process(self, img, param, defog, IcA):
# print(' param.shape:', param.shape)
# luminance = torch.minimum(torch.maximum(rgb2lum(img), 0.0), 1.0)
luminance = rgb2lum(img)
zero = torch.zeros_like(luminance)
one = torch.ones_like(luminance)
luminance = torch.where(luminance < 0, zero, luminance)
luminance = torch.where(luminance > 1, one, luminance)
contrast_lum = -torch.cos(math.pi * luminance) * 0.5 + 0.5
contrast_image = img / (luminance + 1e-6) * contrast_lum
return lerp(img, contrast_image, param)
# return lerp(img, contrast_image, torch.tensor(0.015).cuda())
class ToneFilter(Filter):
def __init__(self, net, cfg):
Filter.__init__(self, net, cfg)
self.curve_steps = cfg.curve_steps
self.short_name = 'T'
self.begin_filter_parameter = cfg.tone_begin_param
self.num_filter_parameters = cfg.curve_steps
def filter_param_regressor(self, features):
tone_curve = tanh_range(*self.cfg.tone_curve_range)(features)
return tone_curve
def process(self, img, param, defog, IcA):
param = torch.unsqueeze(param, 3)
# print(' param.shape:', param.shape)
tone_curve = param
tone_curve_sum = torch.sum(tone_curve, axis=1) + 1e-30
# print(' tone_curve_sum.shape:', tone_curve_sum.shape)
total_image = img * 0
for i in range(self.cfg.curve_steps):
total_image += torch.clamp(img - 1.0 * i / self.cfg.curve_steps, 0, 1.0 / self.cfg.curve_steps) \
* param[:, i, :, :]
total_image *= self.cfg.curve_steps / tone_curve_sum
img = total_image
return img
class GammaFilter(Filter): #gamma_param is in [1/gamma_range, gamma_range]
def __init__(self, net, cfg):
Filter.__init__(self, net, cfg)
self.short_name = 'G'
self.begin_filter_parameter = cfg.gamma_begin_param
self.num_filter_parameters = 1
def filter_param_regressor(self, features):
log_gamma_range = np.log(self.cfg.gamma_range)
# return tf.exp(tanh_range(-log_gamma_range, log_gamma_range)(features))
return torch.exp(tanh_range(-log_gamma_range, log_gamma_range)(features))
def process(self, img, param, defog_A, IcA):
# print(' param:', param)
# param_1 = param.repeat(1, 3)
zero = torch.zeros_like(img) + 0.00001
img = torch.where(img <= 0, zero, img)
# print("GAMMMA", param)
return torch.pow(img, param)
#----------Filter模块的参数------------
from easydict import EasyDict as edict
cfg=edict()
cfg.num_filter_parameters = 4
#这里的配置均被用于DIF模块的滤波操作
cfg.exposure_begin_param = 0
cfg.gamma_begin_param = 1
cfg.contrast_begin_param = 2
cfg.usm_begin_param = 3
# Gamma = 1/x ~ x
cfg.curve_steps = 8
cfg.gamma_range = 3
cfg.exposure_range = 3.5
cfg.wb_range = 1.1
cfg.color_curve_range = (0.90, 1.10)
cfg.lab_curve_range = (0.90, 1.10)
cfg.tone_curve_range = (0.5, 2)
cfg.defog_range = (0.1, 1.0)
cfg.usm_range = (0.0, 5)
cfg.cont_range = (0.0, 1.0)
#----------DIF模块------------
class DIF(nn.Module):
def __init__(self, Filters):
super(DIF, self).__init__()
self.Filters=Filters
def forward(self, img_input,Pr):
self.filtered_image_batch = img_input
filters = [x(img_input, cfg) for x in self.Filters]
self.filter_parameters = []
self.filtered_images = []
for j, filter in enumerate(filters):
self.filtered_image_batch, filter_parameter = filter.apply(
self.filtered_image_batch, Pr)
self.filter_parameters.append(filter_parameter)
self.filtered_images.append(self.filtered_image_batch)
return self.filtered_image_batch, self.filtered_images, Pr, self.filter_parameters
#----------IPAM模块------------
def conv_downsample(in_filters, out_filters, normalization=False):
layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)]
layers.append(nn.LeakyReLU(0.2))
if normalization:
layers.append(nn.InstanceNorm2d(out_filters, affine=True))
return layers
class IPAM(nn.Module):
def __init__(self):
super(IPAM, self).__init__()
self.CNN_PP = nn.Sequential(
nn.Upsample(size=(256,256),mode='bilinear'),
nn.Conv2d(3, 16, 3, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.InstanceNorm2d(16, affine=True),
*conv_downsample(16, 32, normalization=True),
*conv_downsample(32, 64, normalization=True),
*conv_downsample(64, 128, normalization=True),
*conv_downsample(128, 128),
#*discriminator_block(128, 128, normalization=True),
nn.Dropout(p=0.5),
nn.Conv2d(128, cfg.num_filter_parameters, 8, padding=0),
)
Filters=[ExposureFilter, GammaFilter, ContrastFilter, UsmFilter]
self.dif=DIF(Filters)
def forward(self, img_input):
self.Pr = self.CNN_PP(img_input)
out = self.dif(img_input,self.Pr)
return out
3.2 使用代码
使用代码如下
model = IPAM()
print(model)
x=torch.rand((1,3,256,256))
filtered_image_batch,filtered_images,Pr,filter_parameters=model(x)
代码输出如下,其中filtered_image_batch是优化后的图像,filtered_images是一个长度为4的list,其包含了4个图像增强过程的图像,Pr为DNN-PP的输出,filter_parameters为实际上的DIF参数
IPAM(
(CNN_PP): Sequential(
(0): Upsample(size=(256, 256), mode='bilinear')
(1): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(2): LeakyReLU(negative_slope=0.2)
(3): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(4): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(5): LeakyReLU(negative_slope=0.2)
(6): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(8): LeakyReLU(negative_slope=0.2)
(9): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(10): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(11): LeakyReLU(negative_slope=0.2)
(12): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(13): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(14): LeakyReLU(negative_slope=0.2)
(15): Dropout(p=0.5, inplace=False)
(16): Conv2d(128, 4, kernel_size=(8, 8), stride=(1, 1))
)
(dif): DIF()
)
3.3 使用说明
这里由于IPAM的参数未经过训练,故生成的图像随机性比较强。
其中,ImgUilt的代码在: python工具方法 28 中。需要注意的是IPAM模块输入的图像时需要进行归一化的
,这里可以通过检验IA-SEG作者dataset源码
import cv2,torch
from ImgUilt import *
import numpy as np
p=r'D:\YOLO_seq\helmet_yolo\images\train\000092.jpg'
im_tensor,img=read_img_as_tensor(p)
model = IPAM().cuda()
im_tensor=im_tensor/255
filtered_image_batch,filtered_images,Pr,filter_parameters=model(im_tensor)
new_img=tensor2img(filtered_image_batch.detach()*255)
myimshows([img,new_img])
执行效果如下所示,可见img在进过IPAM处理后,得到了随机增强,下图效果表明了局部的边缘增强效果。
按照IA-SEG作者的用法,IPAM模块的参数优化不需要额外loss,仅需将其与正常模型的forword流程相连接即可。具体训练代码如下,CNNPP的输出仅与model的输入有关,与任何loss不存在直接关联。
enhanced_images_pre, ci_map, Pr, filter_parameters = CNNPP(images)
enhanced_images = enhanced_images_pre
for i_pre in range(enhanced_images_pre.shape[0]):
enhanced_images[i_pre,...] = standard_transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(
enhanced_images_pre[i_pre,...])
if args.model == 'RefineNet' or args.model.startswith('deeplabv3'):
pred_c = model(enhanced_images)
else:
_, pred_c = model(enhanced_images)
pred_c = interp(pred_c)
loss_seg = seg_loss(pred_c, labels)
loss = loss_seg #+ loss_seg_dark_dynamic + loss_seg_mix #+ loss_seg_dark_dynamic #+ loss_enhance
loss_s = loss / args.iter_size
loss_s.backward(retain_graph=True)
loss_seg_value += loss_seg.item() / args.iter_size
同时在使用中,也可以参考IA-YOLO中的用法,将加噪声后的图像传给IPAM,将原始清晰图像与IPAM优化后的图像计算loss
4、LGF模块
4.1 模块简介
引导滤波器是一种边缘保持和梯度保持的图像操作,它利用引导图像中的对象边界来检测对象的显著性。它能够抑制目标外的显著性,提高下行检测或分割性能。从效果上看其就是对输出的feature map的微调
。LGF模块的伪代码如下所示,其中fmean表示一个窗口半径为r的平均滤波器。相关性(corr)、方差(var)和协方差(cov)的缩写代表了这些变量的原始含义。其更多详细说明可以查看相关论文。
4.2 实现代码
代码地址:https://github.com/wenyyu/IA-Seg/blob/d6393cc87e5ca95ab3b27dee4ec31293256ab9a4/network/guided_filter.py
代码原文如下,可见guided_filter没有依赖任何外部函数。其中有GuidedFilter和FastGuidedFilter两个类,在IA-SEG中并没有使用FastGuidedFilter(当输入其中的三个参数lr_x, lr_y, hr_x,lr_x与hr_x相同时,其与GuidedFilter效果一模一样)。
以下代码的亮点在于实现了可微的方框滤波
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
def diff_x(input, r):
assert input.dim() == 4
left = input[:, :, r:2 * r + 1]
middle = input[:, :, 2 * r + 1: ] - input[:, :, :-2 * r - 1]
right = input[:, :, -1: ] - input[:, :, -2 * r - 1: -r - 1]
output = torch.cat([left, middle, right], dim=2)
return output
def diff_y(input, r):
assert input.dim() == 4
left = input[:, :, :, r:2 * r + 1]
middle = input[:, :, :, 2 * r + 1: ] - input[:, :, :, :-2 * r - 1]
right = input[:, :, :, -1: ] - input[:, :, :, -2 * r - 1: -r - 1]
output = torch.cat([left, middle, right], dim=3)
return output
class BoxFilter(nn.Module):
def __init__(self, r):
super(BoxFilter, self).__init__()
self.r = r
def forward(self, x):
assert x.dim() == 4
return diff_y(diff_x(x.cumsum(dim=2), self.r).cumsum(dim=3), self.r)
class FastGuidedFilter(nn.Module):
def __init__(self, r, eps=1e-8):
super(FastGuidedFilter, self).__init__()
self.r = r
self.eps = eps
self.boxfilter = BoxFilter(r)
def forward(self, lr_x, lr_y, hr_x):
n_lrx, c_lrx, h_lrx, w_lrx = lr_x.size()
n_lry, c_lry, h_lry, w_lry = lr_y.size()
n_hrx, c_hrx, h_hrx, w_hrx = hr_x.size()
assert n_lrx == n_lry and n_lry == n_hrx
assert c_lrx == c_hrx and (c_lrx == 1 or c_lrx == c_lry)
assert h_lrx == h_lry and w_lrx == w_lry
assert h_lrx > 2*self.r+1 and w_lrx > 2*self.r+1
## N
N = self.boxfilter(Variable(lr_x.data.new().resize_((1, 1, h_lrx, w_lrx)).fill_(1.0)))
## mean_x
mean_x = self.boxfilter(lr_x) / N
## mean_y
mean_y = self.boxfilter(lr_y) / N
## cov_xy
cov_xy = self.boxfilter(lr_x * lr_y) / N - mean_x * mean_y
## var_x
var_x = self.boxfilter(lr_x * lr_x) / N - mean_x * mean_x
## A
A = cov_xy / (var_x + self.eps)
## b
b = mean_y - A * mean_x
## mean_A; mean_b
mean_A = F.interpolate(A, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
mean_b = F.interpolate(b, (h_hrx, w_hrx), mode='bilinear', align_corners=True)
return mean_A*hr_x+mean_b
class GuidedFilter(nn.Module):
def __init__(self, r, eps=1e-8):
super(GuidedFilter, self).__init__()
self.r = r
self.eps = eps
self.boxfilter = BoxFilter(r)
def forward(self, x, y):
n_x, c_x, h_x, w_x = x.size()
n_y, c_y, h_y, w_y = y.size()
assert n_x == n_y
assert c_x == 1 or c_x == c_y
assert h_x == h_y and w_x == w_y
assert h_x > 2 * self.r + 1 and w_x > 2 * self.r + 1
# N
N = self.boxfilter(Variable(x.data.new().resize_((1, 1, h_x, w_x)).fill_(1.0)))
# mean_x
mean_x = self.boxfilter(x) / N
# mean_y
mean_y = self.boxfilter(y) / N
# cov_xy
cov_xy = self.boxfilter(x * y) / N - mean_x * mean_y
# var_x
var_x = self.boxfilter(x * x) / N - mean_x * mean_x
# A
A = cov_xy / (var_x + self.eps)
# b
b = mean_y - A * mean_x
# mean_A; mean_b
mean_A = self.boxfilter(A) / N
mean_b = self.boxfilter(b) / N
return mean_A * x + mean_b
以上代码可以保存为guided_filter.py
4.3 使用代码
暂时没有语义分割项目开展需求,故仅分析IA-SEG项目中的用法。
GuideFilter需要两个输入(边缘图和原始图),故需要额外的网络结构获取边缘图。
以下代码即是将普通的语义分割模型封装成一个包含LGF的模型,模型返回x1和x2,x1为正常语义分割的预测结果,x2为LGF优化后的结果。
class LGFModel(nn.Module):
def __init__(self, dgf, dgf_r, dgf_eps):
self.inplanes = 64
super(LGFModel, self).__init__()
self.model=SegModel()
if self.dgf:
self.guided_map_conv1 = nn.Conv2d(3, 64, 1)
self.guided_map_relu1 = nn.ReLU(inplace=True)
self.guided_map_conv2 = nn.Conv2d(64, num_classes, 1)
self.guided_filter = GuidedFilter(dgf_r, dgf_eps)
def forward(self, x1):
im = x1
x1 = self.model(x1)
if self.dgf:
g = self.guided_map_relu1(self.guided_map_conv1(im))
g = self.guided_map_conv2(g)
x2 = F.interpolate(x1, im.size()[2:], mode='bilinear', align_corners=True)
x2 = self.guided_filter(g, x2)
return x1, x2
使用LGFModel,通常只需要对x2计算loss,可以不对x1的计算loss进行反向传播。如若模型收敛速度较慢,可以对x1计算loss进行反向传播。