ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

yolov4-deepsort目标跟踪、ROI计数、绘制轨迹

2021-04-15 15:32:26  阅读:683  来源: 互联网

标签:ROI deepsort yolov4 int object cv2 track bbox class


项目环境部署参看我的另一篇博客《windows下yolov4-deepsort项目tensorflowGPU版本配置+项目实战
linux下部署直接参考源码github
效果展示:
在这里插入图片描述

目录

1 介绍

今天整理了下之前做的一些工作,在原有yolov4-deepsort功能上增加了一些别的小功能。
增加功能如下:

  1. 绘制ROI区域,并记录roi内目标数量
  2. 目标进入roi区域显示‘enter’
  3. 打印各类目标跟踪数
  4. 绘制目标运动轨迹
    其余类似任务功能做法原理类似。

2 object_tracker.py解读与重写

上述效果的object_tracker.py文件如下(其余文件是作者源码),这里需要做的是提供自己的ROI区域坐标点,取点以后要拟合roi边界,可参考我的另一篇《ROI区域提取(图上直接利用鼠标事件提取坐标点,可视化显示)》,将对应位置的判别条件换为自己的数据即可。

import os
# comment out below line to enable tensorflow logging outputs
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import time
import tensorflow as tf
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
from absl import app, flags, logging
from absl.flags import FLAGS
import core.utils as utils
from core.yolov4 import filter_boxes
from tensorflow.python.saved_model import tag_constants
from core.config import cfg
from PIL import Image
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
# deep sort imports
from deep_sort import preprocessing, nn_matching
from deep_sort.detection import Detection
from deep_sort.tracker import Tracker
from tools import generate_detections as gdet
flags.DEFINE_string('framework', 'tf', '(tf, tflite, trt')
flags.DEFINE_string('weights', './checkpoints/yolov4-416',
                    'path to weights file')
flags.DEFINE_integer('size', 416, 'resize images to')
flags.DEFINE_boolean('tiny', False, 'yolo or yolo-tiny')
flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4')
flags.DEFINE_string('video', './data/video/test.mp4', 'path to input video or set to 0 for webcam')
flags.DEFINE_string('output','./outputs/test2.avi', 'path to output video')
flags.DEFINE_string('output_format', 'XVID', 'codec used in VideoWriter when saving video to file')
flags.DEFINE_float('iou', 0.45, 'iou threshold')
flags.DEFINE_float('score', 0.50, 'score threshold')
flags.DEFINE_boolean('dont_show', False, 'dont show video output')
flags.DEFINE_boolean('info', False, 'show detailed info of tracked objects')
flags.DEFINE_boolean('count', True, 'count objects being tracked on screen')
flags.DEFINE_boolean('roi',True,'Draw ROI and count')


def main(_argv):
    # Definition of the parameters  (参数的定义)
    max_cosine_distance = 0.4
    nn_budget = None
    nms_max_overlap = 1.0
    
    # initialize deep sort  (初始化深度排序)
    model_filename = 'model_data/mars-small128.pb'
    encoder = gdet.create_box_encoder(model_filename, batch_size=1)
    # calculate cosine distance metric  (计算余弦距离度规)
    metric = nn_matching.NearestNeighborDistanceMetric("cosine", max_cosine_distance, nn_budget)
    # initialize tracker  (初始化跟踪器)
    tracker = Tracker(metric)

    # load configuration for object detector  (为对象检测器加载配置)
    config = ConfigProto()
    config.gpu_options.allow_growth = True
    session = InteractiveSession(config=config)
    STRIDES, ANCHORS, NUM_CLASS, XYSCALE = utils.load_config(FLAGS)
    input_size = FLAGS.size
    video_path = FLAGS.video

    # load tflite model if flag is set  (如果设置了标记,则加载tflite模型)
    if FLAGS.framework == 'tflite':
        interpreter = tf.lite.Interpreter(model_path=FLAGS.weights)
        interpreter.allocate_tensors()
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()
        print(input_details)
        print(output_details)
    # otherwise load standard tensorflow saved model   (否则加载标准tensorflow保存模型)
    else:
        saved_model_loaded = tf.saved_model.load(FLAGS.weights, tags=[tag_constants.SERVING])
        infer = saved_model_loaded.signatures['serving_default']

    # begin video capture  (开始视频捕捉)
    try:
        vid = cv2.VideoCapture(int(video_path))
    except:
        vid = cv2.VideoCapture(video_path)

    out = None

    # get video ready to save locally if flag is set  (如果设置了标记,准备在本地保存视频)
    if FLAGS.output:
        # by default VideoCapture returns float instead of int   (默认情况下,VideoCapture返回float而不是int)
        width = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = int(vid.get(cv2.CAP_PROP_FPS))
        codec = cv2.VideoWriter_fourcc(*FLAGS.output_format)
        out = cv2.VideoWriter(FLAGS.output, codec, fps, (width, height))

    frame_num = 0

    # 设置一个用来存放对象的字典
    object_dic = {}

    # while video is running   (视频运行时)
    while True:
        return_value, frame = vid.read()
        if return_value:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            image = Image.fromarray(frame)
        else:
            print('Video has ended or failed, try a different video format!')
            break
        frame_num +=1
        print('Frame #: ', frame_num)
        frame_size = frame.shape[:2]
        image_data = cv2.resize(frame, (input_size, input_size))
        image_data = image_data / 255.
        image_data = image_data[np.newaxis, ...].astype(np.float32)
        start_time = time.time()

        # run detections on tflite if flag is set   (如果设置了标记,运行tflite检测)
        if FLAGS.framework == 'tflite':
            interpreter.set_tensor(input_details[0]['index'], image_data)
            interpreter.invoke()
            pred = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))]
            # run detections using yolov3 if flag is set   (如果设置了标记,则使用yolov3运行检测)
            if FLAGS.model == 'yolov3' and FLAGS.tiny == True:
                boxes, pred_conf = filter_boxes(pred[1], pred[0], score_threshold=0.25,
                                                input_shape=tf.constant([input_size, input_size]))
            else:
                boxes, pred_conf = filter_boxes(pred[0], pred[1], score_threshold=0.25,
                                                input_shape=tf.constant([input_size, input_size]))
        else:
            batch_data = tf.constant(image_data)
            pred_bbox = infer(batch_data)
            for key, value in pred_bbox.items():
                boxes = value[:, :, 0:4]
                pred_conf = value[:, :, 4:]

        boxes, scores, classes, valid_detections = tf.image.combined_non_max_suppression(
            boxes=tf.reshape(boxes, (tf.shape(boxes)[0], -1, 1, 4)),
            scores=tf.reshape(
                pred_conf, (tf.shape(pred_conf)[0], -1, tf.shape(pred_conf)[-1])),
            max_output_size_per_class=50,
            max_total_size=50,
            iou_threshold=FLAGS.iou,
            score_threshold=FLAGS.score
        )

        # convert data to numpy arrays and slice out unused elements   (将数据转换为numpy数组并分割出未使用的元素)
        num_objects = valid_detections.numpy()[0]  #检测出的所有目标个数  int型
        bboxes = boxes.numpy()[0]
        bboxes = bboxes[0:int(num_objects)]
        scores = scores.numpy()[0]
        scores = scores[0:int(num_objects)]
        classes = classes.numpy()[0]
        classes = classes[0:int(num_objects)]     #检测出的类的index,是个np

        # format bounding boxes from normalized ymin, xmin, ymax, xmax ---> xmin, ymin, width, height
        #  (格式化边界框从标准化ymin, xmin, ymax, xmax—> xmin, ymin,宽度,高度)
        original_h, original_w, _ = frame.shape
        bboxes = utils.format_boxes(bboxes, original_h, original_w)

        # store all predictions in one parameter for simplicity when calling functions
        # (在调用函数时,为了简单起见,将所有预测存储在一个参数中)
        pred_bbox = [bboxes, scores, classes, num_objects]

        # read in all class names from config  (从配置中读取所有类名)
        class_names = utils.read_class_names(cfg.YOLO.CLASSES)

        # by default allow all classes in .names file (认情况下允许.names文件中的所有类)
        allowed_classes = list(class_names.values())    # 默认的跟踪目标项
        
        # custom allowed classes (uncomment line below to customize tracker for only people)
        # (自定义允许的类(取消下面的注释行,只为人定制跟踪器))
        # allowed_classes = ['person','car']

        # loop through objects and use class index to get class name, allow only classes in allowed_classes list
        # (循环遍历对象并使用类索引来获取类名,只允许allowed_classes列表中的类)

        # 这里我用来存放当前帧下追踪到的各类的数量dict={'name':num} 然后后面将其打印显示出来
        dict={}
        for i in allowed_classes:
            class_indx=(list(class_names.keys()))[list(class_names.values()).index(i)]
            class_num=np.count_nonzero(classes == class_indx)
            dict[i]=class_num

        # 这里的做法和上面类似,主要是从上面检测出的所有目标中,筛选指定类
        names = []
        deleted_indx = []
        for i in range(num_objects):
            #类名的index
            class_indx = int(classes[i])
            #拿到类名
            class_name = class_names[class_indx]
            if class_name not in allowed_classes:
                deleted_indx.append(i)
            else:
                names.append(class_name)
        names = np.array(names)
        # 统计目前追踪的目标数
        count = len(names)
        # 如果要计数的话
        if FLAGS.count:
            #打印总个数
            cv2.putText(frame, "Objects being tracked: {}".format(count), (5, 35), cv2.FONT_HERSHEY_COMPLEX_SMALL, 2, (0, 255, 0), 2)
            y=70
            #打印各类个数
            for key, value in dict.items():
                if value !=0:
                    cv2.putText(frame, "{} being tracked: {}".format(key,value), (5, y), cv2.FONT_HERSHEY_COMPLEX_SMALL, 2,(0, 255, 0), 2)
                    y+=35

            print("Objects being tracked: {}".format(count))
        # delete detections that are not in allowed_classes  (删除不属于allowed_classes的检测)
        bboxes = np.delete(bboxes, deleted_indx, axis=0)
        scores = np.delete(scores, deleted_indx, axis=0)

        # encode yolo detections and feed to tracker  (编码yolo检测并提供给跟踪器)
        features = encoder(frame, bboxes)
        detections = [Detection(bbox, score, class_name, feature) for bbox, score, class_name, feature in zip(bboxes, scores, names, features)]

        # initialize color map  (初始化彩色地图)
        cmap = plt.get_cmap('tab20b')
        colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)]

        # run non-maxima supression  (运行non-maxima压制)
        boxs = np.array([d.tlwh for d in detections])
        scores = np.array([d.confidence for d in detections])
        classes = np.array([d.class_name for d in detections])
        indices = preprocessing.non_max_suppression(boxs, classes, nms_max_overlap, scores)
        detections = [detections[i] for i in indices]

        # Call the tracker  (调用跟踪器)
        tracker.predict()
        tracker.update(detections)

        # 创建一个目标数列,用来存放ROI区域内的目标
        target = []

        track_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0),
                        (0, 255, 255), (255, 0, 255), (255, 127, 255),
                        (127, 0, 255), (127, 0, 127)]

        # update tracks  (更新跟踪)
        for track in tracker.tracks:
            if not track.is_confirmed() or track.time_since_update > 1:
                continue
            bbox = track.to_tlbr()
            class_name = track.get_class()

            # draw bbox on screen  (在屏幕上绘制bbox)
            # 框索引:0-左上角点x,1-左上角点y,2-右下角点x,3-右下角点y
            color = colors[int(track.track_id) % len(colors)]
            color = [i * 255 for i in color]
            cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), color, 2)
            cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[0])+(len(class_name)+len(str(track.track_id)))*17, int(bbox[1]+20)), color, -1)
            cv2.putText(frame, class_name + "-" + str(track.track_id),(int(bbox[0]), int(bbox[1]+10)),0, 0.75, (255,255,255),2)

            # if enable info flag then print details about each track (如果启用信息标志,然后打印每个轨道的详细信息)
            if FLAGS.info:
                print("Tracker ID: {}, Class: {},  BBox Coords (xmin, ymin, xmax, ymax): {}".format(str(track.track_id), class_name, (int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]))))


            # 检测框中心位置
            center = [int((bbox[0] + bbox[2]) / 2), int((bbox[1] + bbox[3]) / 2), int(bbox[2] - bbox[0]),
                      int(bbox[3] - bbox[1])]

            #像字典中添加目标
            # 如果没有当前车的ID,则创建{'id':{'trace':[[],[],[],……[]],'trace_frames':num},'id':{'trace':[[],[],[],……[]]}}

            if not "%d" % track.track_id in object_dic:
                # 创建当前id的字典:key(ID):val{轨迹,丢帧计数器}   当丢帧数超过10帧就删除该对象
                object_dic["%d" % track.track_id] = {"trace": [],'traced_frames': 10}
                object_dic["%d" % track.track_id]["trace"].append(center)
                object_dic["%d" % track.track_id]["traced_frames"] += 1

            # 如果有,直接写入
            else:
                object_dic["%d" % track.track_id]["trace"].append(center)
                object_dic["%d" % track.track_id]["traced_frames"] += 1

            # 加坐标判断和roi区域设置及画轨迹
            if FLAGS.roi:
                # 这里提供roi的坐标   [173, 456], [966, 91], [1240, 122], [574, 515]
                pts1 = np.array([[173, 456], [966, 91], [1240, 122], [574, 515]], np.int32)
                pts1 = pts1.reshape((-1, 1, 2))
                cv2.polylines(frame, [pts1], True, (0, 255, 255), thickness=2)

                # 判断目标是否在roi区域内
                # 拿检测框的下边线中心进行判断
                x = int((bbox[0] + bbox[2]) / 2)
                y = int(bbox[3]) - 10  # 给个偏移
                # 这里是4条线,分别是Lad\Lbc\Lab\Lcd,a左下角、b左上角、c右上角、d右下角。
                yab = round(-0.46*x+535.63, 2)
                ybc = round(0.11*x+-18.29, 2)
                ycd = round(-0.59*x+853.710, 2)
                yda = round(0.15*x+430.55, 2)
                # 判断中心点是否落入roi内
                if (y > yab and y > ybc and y < ycd and y < yda):
                    target.append(x)
                    cv2.putText(frame, str('enter'), (int(bbox[2] - 65), int(bbox[3] - 5)),
                                cv2.FONT_HERSHEY_COMPLEX_SMALL, 1, (255, 255, 0), 1)
        cv2.putText(frame, "ROI count: {}".format(str(len(target))), (1500, 35), cv2.FONT_HERSHEY_COMPLEX_SMALL,2, (255, 0, 0), 2)

        # 绘制轨迹
        for s in object_dic:
            i = int(s)

            # 这里可以将目标的坐标存起来后面可以继续做目标速度,行驶方向的判断
            # xlist, ylist, wlist, hlist = [], [], [], []

            # 限制轨迹最大长度
            if len(object_dic["%d" % i]["trace"]) > 20:
                for k in range(len(object_dic["%d" % i]["trace"]) - 20):
                    del object_dic["%d" % i]["trace"][k]

            # # # 绘制轨迹
            if len(object_dic["%d" % i]["trace"]) > 2:
                for j in range(1, len(object_dic["%d" % i]["trace"]) - 1):
                    pot1_x = object_dic["%d" % i]["trace"][j][0]
                    pot1_y = object_dic["%d" % i]["trace"][j][1]
                    pot2_x = object_dic["%d" % i]["trace"][j + 1][0]
                    pot2_y = object_dic["%d" % i]["trace"][j + 1][1]
                    # if pot2_x == pot1_x and pot1_y == pot2_y:
                    #     del object_dic["%d" % i]
                    clr = i % 9  # 轨迹颜色随机
                    cv2.line(frame, (pot1_x, pot1_y), (pot2_x, pot2_y), track_colors[clr], 2)

        # 对已经消失的目标予以排除
        for s in object_dic:
            if object_dic["%d" % int(s)]["traced_frames"] > 0:
                object_dic["%d" % int(s)]["traced_frames"] -= 1
        for n in list(object_dic):
            if object_dic["%d" % int(n)]["traced_frames"] == 0:
                del object_dic["%d" % int(n)]

        fps = 1.0 / (time.time() - start_time)
        print("FPS: %.2f" % fps)
        result = np.asarray(frame)
        result = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)

        if not FLAGS.dont_show:
            cv2.imshow("Output Video", result)

        # if output flag is set, save video file  (如果设置了输出标志,保存视频文件)
        if FLAGS.output:
            out.write(result)
        if cv2.waitKey(1) & 0xFF == ord('q'): break
    cv2.destroyAllWindows()

if __name__ == '__main__':
    try:
        app.run(main)
    except SystemExit:
        pass

标签:ROI,deepsort,yolov4,int,object,cv2,track,bbox,class
来源: https://blog.csdn.net/weixin_48994268/article/details/115724165

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有