ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

使用camera在tensorflow/slim下调用pb文件进行图像识别的预测

2021-03-25 12:33:31  阅读:200  来源: 互联网

标签:图像识别 graph image slim pb depth np import self


建立demo_cam.py文件,python代码如下:
代码中的camera使用的是realsenseD435i

import tensorflow as tf
import numpy as np
import cv2
from datasets import dataset_utils
#from IPython import display
#import pylab
#import PIL
import time
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFont
import matplotlib.font_manager as fm
import pyrealsense2 as rs

pipeline = rs.pipeline()
config = rs.config()
config.enable_stream(rs.stream.depth, 640, 480, rs.format.z16, 30)
config.enable_stream(rs.stream.color, 640, 480, rs.format.bgr8, 30)
profile = pipeline.start(config)
align_to = rs.stream.color
align = rs.align(align_to)


#image_dir='./data/flower_photos/daisy/5547758_eea9edfd54_n.jpg'
dataset_dir='./data/flower_photos'
model_dir ='./output_model_pb/frozen_graph.pb'


def get_aligned_images():
    frames = pipeline.wait_for_frames()
    aligned_frames = align.process(frames)
    aligned_depth_frame = aligned_frames.get_depth_frame()
    color_frame = aligned_frames.get_color_frame()
    depth_image = np.asanyarray(aligned_depth_frame.get_data())
    depth_image_8bit = cv2.convertScaleAbs(depth_image, alpha=0.03)
    depth_org = depth_image_8bit
    depth_image_8bit = 255 - depth_image_8bit
    pos=np.where(depth_image_8bit==255)
    depth_image_8bit[pos]=0
    depth_medianBlur = cv2.medianBlur(depth_image_8bit, 5)  # 中值滤波
    depth_max = np.max(depth_medianBlur)
    #print(depth_max)
    color_image = np.asanyarray(color_frame.get_data())
    depth_image_3d = np.dstack((depth_image_8bit,depth_image_8bit,depth_image_8bit)) #depth image is 1 channel, color is 3 channels
    depth_image_3d_org = np.dstack((depth_org, depth_org, depth_org))
    #视差图
    depth_colormap = cv2.applyColorMap(cv2.convertScaleAbs(depth_image, alpha=0.03), cv2.COLORMAP_JET)
    return color_image, depth_medianBlur, depth_image_3d_org


#opencv
class TOD(object):
  def __init__(self):
    self.PATH_TO_CKPT = './output_model_pb/frozen_graph.pb'
    self.NUM_CLASSES = 5
    self.detection_graph = tf.Graph()
    self.label_map = dataset_utils.read_label_file(dataset_dir)
    with self.detection_graph.as_default():
      od_graph_def = tf.GraphDef()
      with tf.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')
    #return detection_graph
    with self.detection_graph.as_default():
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
        self.sess = tf.Session(graph=self.detection_graph, config=config)
        self.windowNotSet = True

  def visualization(self, image, str):
      image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
      draw = ImageDraw.Draw(image_pil)
      font = ImageFont.truetype(fm.findfont(fm.FontProperties(family='DejaVu Sans')), 15)  # 设置字体DejaVu Sans
      draw.text((10, 10), str, 'red', font)  # 'fuchsia'
      np.copyto(image, np.array(image_pil))
      return image

  def classify(self,image,resized):
    # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
    image_np_expanded = np.expand_dims(resized, axis=0)
    inp = self.detection_graph.get_tensor_by_name('input:0')
    #predictions = self.detection_graph.get_tensor_by_name('InceptionResnetV2/Predictions/Reshape_1:0')
    predictions = self.detection_graph.get_tensor_by_name('InceptionResnetV2/Logits/Predictions:0')
    start_time = time.time()
    pred = self.sess.run(
        predictions,
        feed_dict={inp: image_np_expanded})
    elapsed_time = time.time() - start_time
    #print(pred)
    print('inference time cost: {}'.format(elapsed_time))
    font1 = str(self.label_map[pred.argmax()])
    font2 = str(pred.max())
    font3 = font1 + ":" + font2
    img = self.visualization(image,font3)
    #return pred
    #print("Top 1 Prediction: ", x.argmax(), self.label_map[x.argmax()], x.max())
    cv2.namedWindow("classification", cv2.WINDOW_NORMAL)
    cv2.imshow("classification", img)



if __name__ == '__main__':
  width = 299
  height = 299
  dim = (width, height)
  # resize image to [-1,1] Maps pixel values to the range [-1, 1]
  classifier = TOD()
  while 1:
      rgb, depth, depcol = get_aligned_images()
      #image = cv2.imread(image_dir)
      image = rgb
      resized = (cv2.resize(image, dim)).astype(np.float) / 128 - 1
      classifier.classify(image,resized)
      k = cv2.waitKey(1) & 0xff
      if k == ord('q') or k == 27:
          pipeline.stop()
          break
  cv2.destroyAllWindows()

其中用到的labels.txt文件的格式为:

0:daisy
1:dandelion
2:roses
3:sunflowers
4:tulips

运行

python demo_cam.py

标签:图像识别,graph,image,slim,pb,depth,np,import,self
来源: https://blog.csdn.net/gaoqing_dream163/article/details/115205714

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有