博客
关于我
用tensorflow object detection api做手势识别
阅读量:742 次
发布时间:2019-03-22

本文共 5478 字,大约阅读时间需要 18 分钟。

使用TensorFlow进行目标检测:完整步骤指南

安装TensorFlow Object Detection API

TensorFlow Object Detection API 是一个强大的工具,用于实现目标检测任务。首先需要从GitHub克隆TensorFlow的模型仓库:

git clone https://github.com/tensorflow/models.git

安装完成后,按照以下步骤进行数据准备和模型训练。

数据准备

为了训练目标检测模型,我们需要摄像头捕捉图像并存储到特定目录。以下是一个简单的Python脚本示例:

import cv2cap = cv2.VideoCapture(0)idx = 0while True:    ret, frame = cap.read()    if ret is True:        cv2.imshow('frame', frame)        if idx % 5 == 0:            cv2.imwrite('gesture_data/VOC2012/JPEGImages/{}.jpg'.format(idx), frame)        cv2.waitKey(50)        idx += 1    else:        breakcv2.destroyAllWindows()

模型准备

接下来,我们需要标注图片以训练模型。使用 labelImg.exe 工具进行数据标注。由于标注数据量较大,这一步可能会比较耗时。

将数据转换为TFRecord格式

TFRecord是TensorFlow中常用的数据格式,用于高效存储和加载数据集。以下是将PASCAL VOC数据转换为TFRecord的命令:

cd /path/to/tensorflow/models/researchpython object_detection/dataset_tools/create_pascal_tf_record.py \    --label_map_path=/path/to/label_map.pbtxt \    --data_dir=/path/to/data \    --year=VOC2012 \    --set=train \    --output_path=/path/to/output.record \    --category=hand

注意:确保修改 create_pascal_tf_record.py 文件,使其支持自定义分类。

迁移训练

使用预训练模型进行迁移训练可以加速模型收敛。以下是迁移训练的命令示例:

cd /path/to/tensorflow/models/researchpython object_detection/model_main.py \    --pipeline_config_path=/path/to/config-file \    --model_dir=/path/to/training-data \    --num_train_steps=1000 \    --num_eval_steps=15

导出模型

训练完成后,需要将模型导出为可推理的格式:

cd /path/to/tensorflow/models/researchpython object_detection/export_inference_graph.py \    --input_type=image_tensor \    --pipeline_config_path=/path/to/config-file \    --trained_checkpoint-prefix=/path/to/training-model.ckpt \    --output-directory=/path/to/exported-model

使用模型

导出后的模型文件可以用于推理,以下是一个使用模型的示例代码:

import pathlibimport cv2 as cvimport numpy as npimport osimport tarfileimport tensorflow as tfimport zipfilefrom collections import defaultdictfrom io import StringIOfrom matplotlib import pyplot as pltfrom PIL import Imagefrom IPython.display import displayfrom object_detection.utils import ops as utils_opsfrom object_detection.utils import label_map_utilfrom object_detection.utils import visualization_utils as vis_util# TensorFlow 1.x兼容性修复utils_ops.tf = tf.compat.v1tf.gfile = tf.io.gfile# 加载标签映射文件PATH_TO_LABELS = '/path/to/label_map.pbtxt'category_index = label_map_util.create_category_index_from_labelmap(PATH_TO_LABELS, use_display_name=True)# 导入已训练模型PATH_TO_FROZEN_GRAPH = '/path/to/exported-model/frozen_inference_graph.pb'detection_graph = tf.Graph()with detection_graph.as_default():    od_graph_def = tf.GraphDef()    with tf.gfile.GFile(PATH_TO_FROZEN_GRAPH, 'rb') as fid:        serialized_graph = fid.read()        od_graph_def.ParseFromString(serialized_graph)        tf.import_graph_def(od_graph_def, name='')def run_inference_for_single_image(image, graph):    with graph.as_default():        with tf.Session() as sess:            ops = tf.get_default_graph().get_operations()            all_tensor_names = {output.name for op in ops for output in op.outputs}            tensor_dict = {}            for key in ['num_detections', 'detection_boxes', 'detection_scores', 'detection_classes', 'detection_masks']:                tensor_name = key + ':0'                if tensor_name in all_tensor_names:                    tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(tensor_name)            if 'detection_masks' in tensor_dict:                # 简化处理,适用于单张图片                detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])                detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])                real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)                detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])                detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])                detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(                    detection_masks, detection_boxes, image.shape[0], image.shape[1])                detection_masks_reframed = tf.cast(tf.greater(detection_masks_reframed, 0.5), tf.uint8)                tensor_dict['detection_masks'] = tf.expand_dims(detection_masks_reframed, 0)            image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')            output_dict = sess.run(tensor_dict,                                   feed_dict={image_tensor: np.expand_dims(image, 0)})            output_dict['num_detections'] = int(output_dict['num_detections'][0])            output_dict['detection_classes'] = output_dict['detection_classes'][0].astype(np.uint8)            output_dict['detection_boxes'] = output_dict['detection_boxes'][0]            output_dict['detection_scores'] = output_dict['detection_scores'][0]            if 'detection_masks' in output_dict:                output_dict['detection_masks'] = output_dict['detection_masks'][0]            return output_dict# 加载示例图片image_path = 'data/test_images/hand/two.jpg'image = cv.imread(image_path)# 进行推理output_dict = run_inference_for_single_image(image, detection_graph)# 可视化结果vis_util.visualize_boxes_and_labels_on_image_array(    image,    output_dict['detection_boxes'],    output_dict['detection_classes'],    output_dict['detection_scores'],    category_index,    instance_masks=output_dict.get('detection_masks'),    min_score_thresh=0.5,    use_normalized_coordinates=True,    line_thickness=4)# 保存结果cv.imwrite('data/test_images/hand/two-result.jpg', image)cv.destroyAllWindows()

转载地址:http://mjkwk.baihongyu.com/

你可能感兴趣的文章
NFS的常用挂载参数
查看>>
NFS网络文件系统
查看>>
NFS远程目录挂载
查看>>
nft文件传输_利用remoting实现文件传输-.NET教程,远程及网络应用
查看>>
NFV商用可行新华三vBRAS方案实践验证
查看>>
ng build --aot --prod生成文件报错
查看>>
ng 指令的自定义、使用
查看>>
ng6.1 新特性:滚回到之前的位置
查看>>
nghttp3使用指南
查看>>
Nginx
查看>>
nginx + etcd 动态负载均衡实践(一)—— 组件介绍
查看>>
nginx + etcd 动态负载均衡实践(三)—— 基于nginx-upsync-module实现
查看>>
nginx + etcd 动态负载均衡实践(二)—— 组件安装
查看>>
nginx + etcd 动态负载均衡实践(四)—— 基于confd实现
查看>>
Nginx + Spring Boot 实现负载均衡
查看>>
Nginx + Tomcat + SpringBoot 部署项目
查看>>
Nginx + uWSGI + Flask + Vhost
查看>>
Nginx - Header详解
查看>>
nginx - thinkphp 如何实现url的rewrite
查看>>
Nginx - 反向代理、负载均衡、动静分离、底层原理(案例实战分析)
查看>>