博客
关于我
用tensorflow object detection api做手势识别
阅读量:742 次
发布时间:2019-03-22

本文共 5587 字,大约阅读时间需要 18 分钟。

使用TensorFlow进行目标检测:完整步骤指南

安装TensorFlow Object Detection API

TensorFlow Object Detection API 是一个强大的工具,用于实现目标检测任务。首先需要从GitHub克隆TensorFlow的模型仓库:

git clone https://github.com/tensorflow/models.git

安装完成后,按照以下步骤进行数据准备和模型训练。

数据准备

为了训练目标检测模型,我们需要摄像头捕捉图像并存储到特定目录。以下是一个简单的Python脚本示例:

import cv2
cap = cv2.VideoCapture(0)
idx = 0
while True:
ret, frame = cap.read()
if ret is True:
cv2.imshow('frame', frame)
if idx % 5 == 0:
cv2.imwrite('gesture_data/VOC2012/JPEGImages/{}.jpg'.format(idx), frame)
cv2.waitKey(50)
idx += 1
else:
break
cv2.destroyAllWindows()

模型准备

接下来,我们需要标注图片以训练模型。使用 labelImg.exe 工具进行数据标注。由于标注数据量较大,这一步可能会比较耗时。

将数据转换为TFRecord格式

TFRecord是TensorFlow中常用的数据格式,用于高效存储和加载数据集。以下是将PASCAL VOC数据转换为TFRecord的命令:

cd /path/to/tensorflow/models/research
python object_detection/dataset_tools/create_pascal_tf_record.py \
--label_map_path=/path/to/label_map.pbtxt \
--data_dir=/path/to/data \
--year=VOC2012 \
--set=train \
--output_path=/path/to/output.record \
--category=hand

注意:确保修改 create_pascal_tf_record.py 文件,使其支持自定义分类。

迁移训练

使用预训练模型进行迁移训练可以加速模型收敛。以下是迁移训练的命令示例:

cd /path/to/tensorflow/models/research
python object_detection/model_main.py \
--pipeline_config_path=/path/to/config-file \
--model_dir=/path/to/training-data \
--num_train_steps=1000 \
--num_eval_steps=15

导出模型

训练完成后,需要将模型导出为可推理的格式:

cd /path/to/tensorflow/models/research
python object_detection/export_inference_graph.py \
--input_type=image_tensor \
--pipeline_config_path=/path/to/config-file \
--trained_checkpoint-prefix=/path/to/training-model.ckpt \
--output-directory=/path/to/exported-model

使用模型

导出后的模型文件可以用于推理,以下是一个使用模型的示例代码:

import pathlib
import cv2 as cv
import numpy as np
import os
import tarfile
import tensorflow as tf
import zipfile
from collections import defaultdict
from io import StringIO
from matplotlib import pyplot as plt
from PIL import Image
from IPython.display import display
from object_detection.utils import ops as utils_ops
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as vis_util
# TensorFlow 1.x兼容性修复
utils_ops.tf = tf.compat.v1
tf.gfile = tf.io.gfile
# 加载标签映射文件
PATH_TO_LABELS = '/path/to/label_map.pbtxt'
category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)
# 导入已训练模型
PATH_TO_FROZEN_GRAPH = '/path/to/exported-model/frozen_inference_graph.pb'
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
def run_inference_for_single_image(image, graph):
with graph.as_default():
with tf.Session() as sess:
ops = tf.get_default_graph().get_operations()
all_tensor_names = {output.name for op in ops for output in op.outputs}
tensor_dict = {}
for key in ['num_detections', 'detection_boxes', 'detection_scores', 'detection_classes', 'detection_masks']:
tensor_name = key + ':0'
if tensor_name in all_tensor_names:
tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(tensor_name)
if 'detection_masks' in tensor_dict:
# 简化处理,适用于单张图片
detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
detection_masks, detection_boxes, image.shape[0], image.shape[1])
detection_masks_reframed = tf.cast(tf.greater(detection_masks_reframed, 0.5), tf.uint8)
tensor_dict['detection_masks'] = tf.expand_dims(detection_masks_reframed, 0)
image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
output_dict = sess.run(tensor_dict,
feed_dict={image_tensor: np.expand_dims(image, 0)})
output_dict['num_detections'] = int(output_dict['num_detections'][0])
output_dict['detection_classes'] = output_dict['detection_classes'][0].astype(np.uint8)
output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
output_dict['detection_scores'] = output_dict['detection_scores'][0]
if 'detection_masks' in output_dict:
output_dict['detection_masks'] = output_dict['detection_masks'][0]
return output_dict
# 加载示例图片
image_path = 'data/test_images/hand/two.jpg'
image = cv.imread(image_path)
# 进行推理
output_dict = run_inference_for_single_image(image, detection_graph)
# 可视化结果
vis_util.visualize_boxes_and_labels_on_image_array(
image,
output_dict['detection_boxes'],
output_dict['detection_classes'],
output_dict['detection_scores'],
category_index,
instance_masks=output_dict.get('detection_masks'),
min_score_thresh=0.5,
use_normalized_coordinates=True,
line_thickness=4)
# 保存结果
cv.imwrite('data/test_images/hand/two-result.jpg', image)
cv.destroyAllWindows()

转载地址:http://mjkwk.baihongyu.com/

你可能感兴趣的文章
MySQL一个表A中多个字段关联了表B的ID,如何关联查询?
查看>>
MYSQL一直显示正在启动
查看>>
MySQL一站到底!华为首发MySQL进阶宝典,基础+优化+源码+架构+实战五飞
查看>>
MySQL万字总结!超详细!
查看>>
Mysql下载以及安装(新手入门,超详细)
查看>>
MySQL不会性能调优?看看这份清华架构师编写的MySQL性能优化手册吧
查看>>
MySQL不同字符集及排序规则详解:业务场景下的最佳选
查看>>
Mysql不同官方版本对比
查看>>
MySQL与Informix数据库中的同义表创建:深入解析与比较
查看>>
mysql与mem_细说 MySQL 之 MEM_ROOT
查看>>
MySQL与Oracle的数据迁移注意事项,另附转换工具链接
查看>>
mysql丢失更新问题
查看>>
MySQL两千万数据优化&迁移
查看>>
MySql中 delimiter 详解
查看>>
MYSQL中 find_in_set() 函数用法详解
查看>>
MySQL中auto_increment有什么作用?(IT枫斗者)
查看>>
MySQL中B+Tree索引原理
查看>>
mysql中cast() 和convert()的用法讲解
查看>>
mysql中datetime与timestamp类型有什么区别
查看>>
MySQL中DQL语言的执行顺序
查看>>