当前位置:
首页 > temp > 简明python教程 >
-
文字检测模型EAST应用详解 ckpt pb的tf加载,opencv加载
参考链接:https://github.com/argman/EAST (项目来源)
https://github.com/opencv/opencv/issues/12491 (遇到的问题)
https://www.pyimagesearch.com/2018/08/20/opencv-text-detection-east-text-detector/ (opencv加载)
文字检测有很多比较好的现成的模型比如yolov3,pesnet,pennet,east。不一一赘述,讲一下自己跑通east的过程。
在https://github.com/argman/EAST链接中下载项目,windows下,各种包的版本要正确否则会出一些乱七八糟的错误。
运行EAST/eval.py。没有什么特别的问题要说,我在cpu下单张640*480的图能够达到每张0.4秒左右,还是非常优秀的。中英文数字都可。
但是源代码是ckpt,非常大,转成pb会稍微小点。添加:
##生成pb模型,但需要修改model.py output_graph_def = tf.graph_util.convert_variables_to_constants(self.sess, # The session is used to retrieve the weights tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes ["feature_fusion/Conv_7/Sigmoid", "feature_fusion/concat_3"] ) output_graph='D:\\work\\video\\hand_tracking_no_op\\hand_tracking\\EAST\\east_icdar2015_resnet_v1_50_rbox\\out.pb' with tf.gfile.GFile(output_graph, "wb") as f: f.write(output_graph_def.SerializeToString()) print("%d ops in the final graph." % len(output_graph_def.node)) 位置在eval.py中的 saver.restore(self.sess, model_path)后面。注意如果你想要opencv加载pb还要修改model.py中的内容,这个在后面一篇文章中会讲到。 生成后用tf加载,方法跟加载ckpt相似:
import os os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_list try: os.makedirs(FLAGS.output_dir) except OSError as e: if e.errno != 17: raise print("load_graph") graph = load_graph(FLAGS.checkpoint_path) input_images = graph.get_tensor_by_name( 'import/input_images:0') f_score = graph.get_tensor_by_name('import/feature_fusion/Conv_7/Sigmoid:0') f_geometry = graph.get_tensor_by_name( 'import/feature_fusion/concat_3:0') with tf.Session(graph=graph) as sess: im_fn_list = get_images() for im_fn in im_fn_list: im = cv2.imread(im_fn)[:, :, ::-1] start_time = time.time() im_resized, (ratio_h, ratio_w) = resize_image(im) timer = {'net': 0, 'restore': 0, 'nms': 0} start = time.time() #file_writer = tf.summary.FileWriter('tmp/log', sess.graph) score, geometry = sess.run([f_score, f_geometry], feed_dict={ input_images: [im_resized]}) timer['net'] = time.time() - start boxes, timer = detect(score_map=score, geo_map=geometry, timer=timer) print('{} : net {:.0f}ms, restore {:.0f}ms, nms {:.0f}ms'.format( im_fn, timer['net']*1000, timer['restore']*1000, timer['nms']*1000)) if boxes is not None: boxes = boxes[:, :8].reshape((-1, 4, 2)) boxes[:, :, 0] /= ratio_w boxes[:, :, 1] /= ratio_h duration = time.time() - start_time print('[timing] {}'.format(duration)) # save to file if boxes is not None: res_file = os.path.join( FLAGS.output_dir, '{}.txt'.format( os.path.basename(im_fn).split('.')[0])) with open(res_file, 'w') as f: for box in boxes: # to avoid submitting errors box = sort_poly(box.astype(np.int32)) if np.linalg.norm(box[0] - box[1]) < 5 or np.linalg.norm(box[3]-box[0]) < 5: continue f.write('{},{},{},{},{},{},{},{}\r\n'.format( box[0, 0], box[0, 1], box[1, 0], box[1, 1], box[2, 0], box[2, 1], box[3, 0], box[3, 1], )) cv2.polylines(im[:, :, ::-1], [box.astype(np.int32).reshape((-1, 1, 2))], True, color=(255, 255, 0), thickness=1) if not FLAGS.no_write_images: img_path = os.path.join(FLAGS.output_dir, os.path.basename(im_fn)) cv2.imwrite(img_path, im[:, :, ::-1]) 以上就是EAST的ckpt转pb用tf加载啦。 下一篇讲opencv加载east的pb。
栏目列表
最新更新
nodejs爬虫
Python正则表达式完全指南
爬取豆瓣Top250图书数据
shp 地图文件批量添加字段
爬虫小试牛刀(爬取学校通知公告)
【python基础】函数-初识函数
【python基础】函数-返回值
HTTP请求:requests模块基础使用必知必会
Python初学者友好丨详解参数传递类型
如何有效管理爬虫流量?
2个场景实例讲解GaussDB(DWS)基表统计信息估
常用的 SQL Server 关键字及其含义
动手分析SQL Server中的事务中使用的锁
openGauss内核分析:SQL by pass & 经典执行
一招教你如何高效批量导入与更新数据
天天写SQL,这些神奇的特性你知道吗?
openGauss内核分析:执行计划生成
[IM002]Navicat ODBC驱动器管理器 未发现数据
初入Sql Server 之 存储过程的简单使用
SQL Server -- 解决存储过程传入参数作为s
关于JS定时器的整理
JS中使用Promise.all控制所有的异步请求都完
js中字符串的方法
import-local执行流程与node模块路径解析流程
检测数据类型的四种方法
js中数组的方法,32种方法
前端操作方法
数据类型
window.localStorage.setItem 和 localStorage.setIte
如何完美解决前端数字计算精度丢失与数