当前位置:
首页 > Python基础教程 >
-
30秒轻松实现TensorFlow物体检测(2)
下载模型
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
opener = urllib.request.URLopener() opener.retrieve(DOWNLOAD_BASE + MODEL_FILE, MODEL_FILE) tar_file = tarfile. open (MODEL_FILE) for file in tar_file.getmembers(): file_name = os.path.basename( file .name) if 'frozen_inference_graph.pb' in file_name: tar_file.extract( file , os.getcwd()) 将(frozen)TensorFlow模型载入内存 detection_graph = tf.Graph() with detection_graph.as_default(): od_graph_def = tf.GraphDef() with tf.gfile.GFile(PATH_TO_CKPT, 'rb' ) as fid: serialized_graph = fid.read() od_graph_def.ParseFromString(serialized_graph) tf.import_graph_def(od_graph_def, name = '') |
载入标签图
标签图将索引映射到类名称,当我们的卷积预测5时,我们知道它对应飞机。这里我们使用内置函数,但是任何返回将整数映射到恰当字符标签的字典都适用。
1
2
3
|
label_map = label_map_util.load_labelmap(PATH_TO_LABELS) categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes = NUM_CLASSES, use_display_name = True ) category_index = label_map_util.create_category_index(categories) |
辅助代码
1
2
3
|
def load_image_into_numpy_array(image): (im_width, im_height) = image.size return np.array(image.getdata()).reshape( (im_height, im_width, 3 )).astype(np.uint8) |
检测
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
|
PATH_TO_TEST_IMAGES_DIR = 'test_images' TEST_IMAGE_PATHS = [ os.path.join(PATH_TO_TEST_IMAGES_DIR, 'image{}.jpg' . format (i)) for i in range ( 1 , 3 ) ] IMAGE_SIZE = ( 12 , 8 ) with detection_graph.as_default(): with tf.Session(graph = detection_graph) as sess: for image_path in TEST_IMAGE_PATHS: image = Image. open (image_path) # 这个array在之后会被用来准备为图片加上框和标签 image_np = load_image_into_numpy_array(image) # 扩展维度,应为模型期待: [1, None, None, 3] image_np_expanded = np.expand_dims(image_np, axis = 0 ) image_tensor = detection_graph.get_tensor_by_name( 'image_tensor:0' ) # 每个框代表一个物体被侦测到. boxes = detection_graph.get_tensor_by_name( 'detection_boxes:0' ) # 每个分值代表侦测到物体的可信度. scores = detection_graph.get_tensor_by_name( 'detection_scores:0' ) classes = detection_graph.get_tensor_by_name( 'detection_classes:0' ) num_detections = detection_graph.get_tensor_by_name( 'num_detections:0' ) # 执行侦测任务. (boxes, scores, classes, num_detections) = sess.run( [boxes, scores, classes, num_detections], feed_dict = {image_tensor: image_np_expanded}) # 图形化. vis_util.visualize_boxes_and_labels_on_image_array( image_np, np.squeeze(boxes), np.squeeze(classes).astype(np.int32), np.squeeze(scores), category_index, use_normalized_coordinates = True , line_thickness = 8 ) plt.figure(figsize = IMAGE_SIZE) plt.imshow(image_np) |
栏目列表
最新更新
nodejs爬虫
Python正则表达式完全指南
爬取豆瓣Top250图书数据
shp 地图文件批量添加字段
爬虫小试牛刀(爬取学校通知公告)
【python基础】函数-初识函数
【python基础】函数-返回值
HTTP请求:requests模块基础使用必知必会
Python初学者友好丨详解参数传递类型
如何有效管理爬虫流量?
SQL SERVER中递归
2个场景实例讲解GaussDB(DWS)基表统计信息估
常用的 SQL Server 关键字及其含义
动手分析SQL Server中的事务中使用的锁
openGauss内核分析:SQL by pass & 经典执行
一招教你如何高效批量导入与更新数据
天天写SQL,这些神奇的特性你知道吗?
openGauss内核分析:执行计划生成
[IM002]Navicat ODBC驱动器管理器 未发现数据
初入Sql Server 之 存储过程的简单使用
这是目前我见过最好的跨域解决方案!
减少回流与重绘
减少回流与重绘
如何使用KrpanoToolJS在浏览器切图
performance.now() 与 Date.now() 对比
一款纯 JS 实现的轻量化图片编辑器
关于开发 VS Code 插件遇到的 workbench.scm.
前端设计模式——观察者模式
前端设计模式——中介者模式
创建型-原型模式