ICode9

精准搜索请尝试: 精确搜索
首页 > 编程语言> 文章详细

pytorch源码解析系列-yolov4最核心技巧代码详解(3)- 数据处理以及图像增强

2021-09-26 13:06:31  阅读:359  来源: 互联网

标签:yolov4 img cfg self cut 源码 bboxes shift 图像增强


输入是啥?

现在 我们知道了yolo的模型,知道了模型获取的推理输出,在进入训练代码之前,我们需要了解到yolo数据的处理方式,以及用了何种图像增强方法。(不然连输入是啥都不知道,你怎么看的懂训练过程呢)

数据载入

其实数据载入这块没有必要单独拿出来说的,但是学会如何写collate function是很重要的
我们先看下数据源

## 如果你多进程报错 num_workers要改成1 大部分CPU训练机器都会出这个问题
## batchsize = batch // subdivisions 所以你就知道超参数这块实际上最终的batchsize是由这两个数据决定的
train_loader = DataLoader(train_dataset, batch_size=config.batch // config.subdivisions, shuffle=True,
                              num_workers=8, pin_memory=True, drop_last=True, collate_fn=collate)

只需要在collate_fn指定自己写的方法就行了,自己写可以把很多数据预处理都搞定在这个方法里

def collate(batch):
    images = []
    bboxes = []
    for img, box in batch:
        images.append([img])
        bboxes.append([box])
    #拿torch和numpy写都行,反正都要最后转成torch的
    #这边就是数据的标准预处理了,把格式处理成[B,C,H,W] 再除以255
    images = np.concatenate(images, axis=0)
    images = images.transpose(0, 3, 1, 2)
    images = torch.from_numpy(images).div(255.0)
    bboxes = np.concatenate(bboxes, axis=0)
    bboxes = torch.from_numpy(bboxes)
    return images, bboxes

Mosaic图像增强

小目标的AP一般比中目标和大目标低很多,而且小目标的分布并不均匀,所以作者用Mosaic处理这种问题
CutMix数据增强:两张图片进行拼接
然后给他变形一下↓
Mosaic数据增强:4张图片,随机缩放、随机裁剪、随机排布的方式进行拼接。
好处就是,一次把4个图片拧一块,batch直接4倍了,岂不美哉?
看看代码如何实现的:(注:代码实现较老,没有用torchvision)

#yolo的Dataset写法不会的建议先学下基础 这里的代码实际上没必要,
#dataset的label文件是指定一个txt的 你如果跑过yolo4的训练就知道要自己制作train.txt
class Yolo_dataset(Dataset):
    def __init__(self, lable_path, cfg, train=True):
        super(Yolo_dataset, self).__init__()
    #这边参数的意思是 mixup=4是用 mosaic+cutmix, 2是只用mixup,3是mosaic 。
        if cfg.mixup == 2:
            print("cutmix=1 - isn't supported for Detector")
            raise
        elif cfg.mixup == 2 and cfg.letter_box:
            print("Combination: letter_box=1 & mosaic=1 - isn't supported, use only 1 of these parameters")
            raise

        self.cfg = cfg
        self.train = train
		# label载入 每张图片的里面分类都载进字典
        truth = {}
        f = open(lable_path, 'r', encoding='utf-8')
        for line in f.readlines():
            data = line.split(" ")
            truth[data[0]] = []
            for i in data[1:]:
                truth[data[0]].append([int(float(j)) for j in i.split(',')])

        self.truth = truth
        self.imgs = list(self.truth.keys())

重点在下面mosiac的实现:

	# 每次从dataset拿数据的时候,都会随机抽4张图片做mosiac
    def __getitem__(self, index):
        if not self.train:
            return self._get_val_item(index)
        img_path = self.imgs[index]
        #从label里拿到类别bboxes
        bboxes = np.array(self.truth.get(img_path), dtype=np.float)
        img_path = os.path.join(self.cfg.dataset_dir, img_path)
        use_mixup = self.cfg.mixup
        # 一半概率不用增强
        if random.randint(0, 1):
            use_mixup = 0
		# mixup==3是使用mosaic 
        if use_mixup == 3:
            min_offset = 0.2 #指定剪切率之后 随机从宽高剪切一个区域
            cut_x = random.randint(int(self.cfg.w * min_offset), int(self.cfg.w * (1 - min_offset)))
            cut_y = random.randint(int(self.cfg.h * min_offset), int(self.cfg.h * (1 - min_offset)))

        r1, r2, r3, r4, r_scale = 0, 0, 0, 0, 0
        dhue, dsat, dexp, flip, blur = 0, 0, 0, 0, 0
        gaussian_noise = 0

        out_img = np.zeros([self.cfg.h, self.cfg.w, 3])
        out_bboxes = []
		#这边就知道 use_mixup里数字的意思了,0就是不用 只循环1次,3是mosiac,循环4次
        for i in range(use_mixup + 1):
            if i != 0:
            	#不是第一张图片 就从图片库里随便抽一张做拼接
                img_path = random.choice(list(self.truth.keys()))
                bboxes = np.array(self.truth.get(img_path), dtype=np.float)
                img_path = os.path.join(self.cfg.dataset_dir, img_path)
            img = cv2.imread(img_path) 
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #bgr2rgb处理 你不会这个都不知道吧?
            if img is None:
                continue
            oh, ow, oc = img.shape
            # 默认jitter=0.2 hue = 0.1 sat = 1.5 exp = 1.5
            # 这边三个增强方法写在最后,有兴趣看看,主要是随机获取HSV中三个随机增强参数
            dh, dw, dc = np.array(np.array([oh, ow, oc]) * self.cfg.jitter, dtype=np.int)
            dhue = rand_uniform_strong(-self.cfg.hue, self.cfg.hue)
            dsat = rand_scale(self.cfg.saturation)
            dexp = rand_scale(self.cfg.exposure)
			# 下面很长 有兴趣的可以看看 主要是图像增强的其他一系列旋转、滤波、拼接操作
            pleft = random.randint(-dw, dw)
            pright = random.randint(-dw, dw)
            ptop = random.randint(-dh, dh)
            pbot = random.randint(-dh, dh)
			#旋转
            flip = random.randint(0, 1) if self.cfg.flip else 0
			#高斯滤波
            if (self.cfg.blur):
                tmp_blur = random.randint(0, 2)  # 0 - disable, 1 - blur background, 2 - blur the whole image
                if tmp_blur == 0:
                    blur = 0
                elif tmp_blur == 1:
                    blur = 1
                else:
                    blur = self.cfg.blur
			#高斯
            if self.cfg.gaussian and random.randint(0, 1):
                gaussian_noise = self.cfg.gaussian
            else:
                gaussian_noise = 0
			#补灰(类似pad)
            if self.cfg.letter_box:
                img_ar = ow / oh
                net_ar = self.cfg.w / self.cfg.h
                result_ar = img_ar / net_ar
                if result_ar > 1:  # 补sheight
                    oh_tmp = ow / net_ar
                    delta_h = (oh_tmp - oh) / 2
                    ptop = ptop - delta_h
                    pbot = pbot - delta_h
                else:  # 补swidth
                    ow_tmp = oh * net_ar
                    delta_w = (ow_tmp - ow) / 2
                    pleft = pleft - delta_w
                    pright = pright - delta_w

            swidth = ow - pleft - pright
            sheight = oh - ptop - pbot
			#这个代码贴在最后 有兴趣看一下 主要是拿剪切后的label
            truth, min_w_h = fill_truth_detection(bboxes, self.cfg.boxes, self.cfg.classes, flip, pleft, ptop, swidth,
                                                  sheight, self.cfg.w, self.cfg.h)
            if (min_w_h / 8) < blur and blur > 1:  # 如果太小就不用滤波了
                blur = min_w_h / 8
			# 主要图片增强处理 核心内容是图片转成HSV格式,用参数对其增强,有兴趣看最后的代码
            ai = image_data_augmentation(img, self.cfg.w, self.cfg.h, pleft, ptop, swidth, sheight, flip,
                                         dhue, dsat, dexp, gaussian_noise, blur, truth)

            if use_mixup == 0:
                out_img = ai
                out_bboxes = truth
            ## 1是cutmix 拼接两个图像
            if use_mixup == 1:
                if i == 0:
                    old_img = ai.copy()
                    old_truth = truth.copy()
                elif i == 1:
                    out_img = cv2.addWeighted(ai, 0.5, old_img, 0.5)
                    out_bboxes = np.concatenate([old_truth, truth], axis=0)
             ## 3 是mosaic:旋转、滤波
            elif use_mixup == 3:
                if flip:
                    tmp = pleft
                    pleft = pright
                    pright = tmp

                left_shift = int(min(cut_x, max(0, (-int(pleft) * self.cfg.w / swidth))))
                top_shift = int(min(cut_y, max(0, (-int(ptop) * self.cfg.h / sheight))))

                right_shift = int(min((self.cfg.w - cut_x), max(0, (-int(pright) * self.cfg.w / swidth))))
                bot_shift = int(min(self.cfg.h - cut_y, max(0, (-int(pbot) * self.cfg.h / sheight))))
				#blend_truth_mosaic这个代码也贴在最后 有兴趣看看 主要是拿剪切后的图片和label
                out_img, out_bbox = blend_truth_mosaic(out_img, ai, truth.copy(), self.cfg.w, self.cfg.h, cut_x,
                                                       cut_y, i, left_shift, right_shift, top_shift, bot_shift)
                out_bboxes.append(out_bbox)
        #后面就是拼接图片和label了
        if use_mixup == 3:
            out_bboxes = np.concatenate(out_bboxes, axis=0)
        out_bboxes1 = np.zeros([self.cfg.boxes, 5])
        out_bboxes1[:min(out_bboxes.shape[0], self.cfg.boxes)] = out_bboxes[:min(out_bboxes.shape[0], self.cfg.boxes)]
        return out_img, out_bboxes1

上面的内容 用一句话概括:
50%的概率使用图像增强,增强方式是对4张图片采用随机缩放,裁剪,旋转,滤波,最后拼在一起的方式搞成一张图片
在这里插入图片描述

选看

下面的代码可以看,可以不看,主要是之前代码里面图像增强的具体实现方式,如果你真的有手撕代码的需要,可以用来参考

从简单的来看,HSE三个随机参数:
没啥东西 就是随机增强

def rand_uniform_strong(min, max):
    if min > max:
        swap = min
        min = max
        max = swap
    return random.random() * (max - min) + min
def rand_scale(s):
    scale = rand_uniform_strong(1, s)
    if random.randint(0, 1) % 2:
        return scale
    return 1. / scale
def rand_precalc_random(min, max, random_part):
    if max < min:
        swap = min
        min = max
        max = swap
    return (random_part * (max - min)) + min

然后是处理图片裁剪和缩放后的label:

#truth, min_w_h = fill_truth_detection(bboxes, self.cfg.boxes, self.cfg.classes, flip, pleft, ptop, swidth,sheight, self.cfg.w, self.cfg.h)
#这边的dx,dy,sx,sy对应之前的pleft,ptop,swidth,sheight 了解了这个再往后看
def fill_truth_detection(bboxes, num_boxes, classes, flip, dx, dy, sx, sy, net_w, net_h):
    if bboxes.shape[0] == 0:
        return bboxes, 10000
    np.random.shuffle(bboxes)
    bboxes[:, 0] -= dx
    bboxes[:, 2] -= dx
    bboxes[:, 1] -= dy
    bboxes[:, 3] -= dy
	#上面减去随机位置后,用下面的参数保证长宽落在[0,sx]中间
    bboxes[:, 0] = np.clip(bboxes[:, 0], 0, sx)
    bboxes[:, 2] = np.clip(bboxes[:, 2], 0, sx)
	#同理
    bboxes[:, 1] = np.clip(bboxes[:, 1], 0, sy)
    bboxes[:, 3] = np.clip(bboxes[:, 3], 0, sy)
	#过滤掉out_box,也就是说过滤掉长宽都同时落在边缘(如果图片裁剪区域里没有目标)
    out_box = list(np.where(((bboxes[:, 1] == sy) & (bboxes[:, 3] == sy)) |
                            ((bboxes[:, 0] == sx) & (bboxes[:, 2] == sx)) |
                            ((bboxes[:, 1] == 0) & (bboxes[:, 3] == 0)) |
                            ((bboxes[:, 0] == 0) & (bboxes[:, 2] == 0)))[0])
    list_box = list(range(bboxes.shape[0]))
    for i in out_box:
        list_box.remove(i)
    bboxes = bboxes[list_box]

    if bboxes.shape[0] == 0:
        return bboxes, 10000
	#我感觉这一步多余的 类别判断为啥在这里
    bboxes = bboxes[np.where((bboxes[:, 4] < classes) & (bboxes[:, 4] >= 0))[0]]

    if bboxes.shape[0] > num_boxes:
        bboxes = bboxes[:num_boxes]
	#最短宽高
    min_w_h = np.array([bboxes[:, 2] - bboxes[:, 0], bboxes[:, 3] - bboxes[:, 1]]).min()
	#这个net参数是自己cfg设的,控制缩放比率
    bboxes[:, 0] *= (net_w / sx)
    bboxes[:, 2] *= (net_w / sx)
    bboxes[:, 1] *= (net_h / sy)
    bboxes[:, 3] *= (net_h / sy)
	#是否翻转
    if flip:
        temp = net_w - bboxes[:, 0]
        bboxes[:, 0] = net_w - bboxes[:, 2]
        bboxes[:, 2] = temp
	#最后是返回处理后的图片label和最短宽高
    return bboxes, min_w_h

这类代码写起来比看要容易,因为当你有明确的目标后,你再去写就有思路,光看是很难看出思路的

上面获取偏转缩放后的label,经过mosaic还要处理一次,核心思想就是4张图片分别用何种处理获得mosiac后的label区间

def blend_truth_mosaic(out_img, img, bboxes, w, h, cut_x, cut_y, i_mixup,
                       left_shift, right_shift, top_shift, bot_shift):
    left_shift = min(left_shift, w - cut_x)
    top_shift = min(top_shift, h - cut_y)
    right_shift = min(right_shift, cut_x)
    bot_shift = min(bot_shift, cut_y)
	#i_mixup 对应外面4张图片的4个循环
    if i_mixup == 0:
    #filter_truth这个代码不写了,主要就是fill_truth_detection里OUTBOX做的事情,过滤掉处理后没有目标的图片
        bboxes = filter_truth(bboxes, left_shift, top_shift, cut_x, cut_y, 0, 0)
        out_img[:cut_y, :cut_x] = img[top_shift:top_shift + cut_y, left_shift:left_shift + cut_x]
    if i_mixup == 1:
        bboxes = filter_truth(bboxes, cut_x - right_shift, top_shift, w - cut_x, cut_y, cut_x, 0)
        out_img[:cut_y, cut_x:] = img[top_shift:top_shift + cut_y, cut_x - right_shift:w - right_shift]
    if i_mixup == 2:
        bboxes = filter_truth(bboxes, left_shift, cut_y - bot_shift, cut_x, h - cut_y, 0, cut_y)
        out_img[cut_y:, :cut_x] = img[cut_y - bot_shift:h - bot_shift, left_shift:left_shift + cut_x]
    if i_mixup == 3:
        bboxes = filter_truth(bboxes, cut_x - right_shift, cut_y - bot_shift, w - cut_x, h - cut_y, cut_x, cut_y)
        out_img[cut_y:, cut_x:] = img[cut_y - bot_shift:h - bot_shift, cut_x - right_shift:w - right_shift]

    return out_img, bboxes

最后的代码是图片的处理,主要是转成HSV处理增强 再转回来


def image_data_augmentation(mat, w, h, pleft, ptop, swidth, sheight, flip, dhue, dsat, dexp, gaussian_noise, blur,
                            truth):
    try:
        img = mat
        oh, ow, _ = img.shape
        #剪切缩放
        pleft, ptop, swidth, sheight = int(pleft), int(ptop), int(swidth), int(sheight)
        # 剪切目标位置
        src_rect = [pleft, ptop, swidth + pleft, sheight + ptop]  # x1,y1,x2,y2
        img_rect = [0, 0, ow, oh]
        new_src_rect = rect_intersection(src_rect, img_rect)  # 交集 就是获取[minx, miny, maxx, maxy]

        dst_rect = [max(0, -pleft), max(0, -ptop), max(0, -pleft) + new_src_rect[2] - new_src_rect[0],
                    max(0, -ptop) + new_src_rect[3] - new_src_rect[1]]
        # cv2.Mat sized

        if (src_rect[0] == 0 and src_rect[1] == 0 and src_rect[2] == img.shape[0] and src_rect[3] == img.shape[1]):
            sized = cv2.resize(img, (w, h), cv2.INTER_LINEAR)
        else:
            cropped = np.zeros([sheight, swidth, 3])
            cropped[:, :, ] = np.mean(img, axis=(0, 1))

            cropped[dst_rect[1]:dst_rect[3], dst_rect[0]:dst_rect[2]] = \
                img[new_src_rect[1]:new_src_rect[3], new_src_rect[0]:new_src_rect[2]]

            # 线性差值resize
            sized = cv2.resize(cropped, (w, h), cv2.INTER_LINEAR)

        # 翻转
        if flip:
            sized = cv2.flip(sized, 1)  # 0 - x-axis, 1 - y-axis, -1 - both axes (x & y)

        # HSV增强,主要用下面机种方式把图转成HSV
        # cv2.COLOR_BGR2HSV, cv2.COLOR_RGB2HSV, cv2.COLOR_HSV2BGR, cv2.COLOR_HSV2RGB
        if dsat != 1 or dexp != 1 or dhue != 0:
            if img.shape[2] >= 3:
                hsv_src = cv2.cvtColor(sized.astype(np.float32), cv2.COLOR_RGB2HSV) 
                hsv = cv2.split(hsv_src)
                #之前的三个参数增强对hsv增强
                hsv[1] *= dsat
                hsv[2] *= dexp
                hsv[0] += 179 * dhue
                hsv_src = cv2.merge(hsv)
                sized = np.clip(cv2.cvtColor(hsv_src, cv2.COLOR_HSV2RGB), 0, 255)  # 转回RGB,记得合法范围[0,255]
            else:
                sized *= dexp
		#高斯滤波 平滑
        if blur:
            if blur == 1:
                dst = cv2.GaussianBlur(sized, (17, 17), 0)
                # cv2.bilateralFilter(sized, dst, 17, 75, 75)
            else:
                ksize = (blur / 2) * 2 + 1
                dst = cv2.GaussianBlur(sized, (ksize, ksize), 0)

            if blur == 1:
                img_rect = [0, 0, sized.cols, sized.rows]
                for b in truth:
                    left = (b.x - b.w / 2.) * sized.shape[1]
                    width = b.w * sized.shape[1]
                    top = (b.y - b.h / 2.) * sized.shape[0]
                    height = b.h * sized.shape[0]
                    roi(left, top, width, height)
                    roi = roi & img_rect
                    dst[roi[0]:roi[0] + roi[2], roi[1]:roi[1] + roi[3]] = sized[roi[0]:roi[0] + roi[2],
                                                                          roi[1]:roi[1] + roi[3]]

            sized = dst

        if gaussian_noise:
            noise = np.array(sized.shape)
            gaussian_noise = min(gaussian_noise, 127)
            gaussian_noise = max(gaussian_noise, 0)
            cv2.randn(noise, 0, gaussian_noise)  # mean and variance
            sized = sized + noise
    except:
        print("OpenCV can't augment image: " + str(w) + " x " + str(h))
        sized = mat

    return sized

标签:yolov4,img,cfg,self,cut,源码,bboxes,shift,图像增强
来源: https://blog.csdn.net/weixin_48174100/article/details/120482685

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有