已解决
对paddleOCR中的字符识别模型转ONNX
来自网友在路上 163863提问 提问时间:2023-10-15 08:53:57阅读次数: 63
最佳答案 问答题库638位专家为你答疑解惑
对paddle OCR中的模型转换成ONNX。
转换代码:
import os
import sys
import yaml
import numpy as np
import cv2
import argparse
import paddle
from paddle import nnfrom argparse import ArgumentParser, RawDescriptionHelpFormatter
import paddle.distributed as dist
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import init_model
from ppocr.modeling.architectures import build_modelclass AttrDict(dict):"""Single level attribute dict, NOT recursive"""def __init__(self, **kwargs):super(AttrDict, self).__init__()super(AttrDict, self).update(kwargs)def __getattr__(self, key):if key in self:return self[key]raise AttributeError("object has no attribute '{}'".format(key))global_config = AttrDict()
default_config = {'Global': {'debug': False, }}class ArgsParser(ArgumentParser):def __init__(self):super(ArgsParser, self).__init__(formatter_class=RawDescriptionHelpFormatter)# self.add_argument("-c", "--config", default='./configs/ch_PP-OCRv2_rec_idcard.yml',# help="configuration file to use")self.add_argument("-c", "--config", default='./configs/ch_PP-OCRv2_rec.yml',help="configuration file to use")self.add_argument("-o", "--opt", nargs='+', help="set configuration options")def parse_args(self, argv=None):args = super(ArgsParser, self).parse_args(argv)assert args.config is not None, \"Please specify --config=configure_file_path."args.opt = self._parse_opt(args.opt)return argsdef _parse_opt(self, opts):config = {}if not opts:return configfor s in opts:s = s.strip()k, v = s.split('=')config[k] = yaml.load(v, Loader=yaml.Loader)return configdef merge_config(config):"""Merge config into global config.Args:config (dict): Config to be merged.Returns: global config"""for key, value in config.items():if "." not in key:if isinstance(value, dict) and key in global_config:global_config[key].update(value)else:global_config[key] = valueelse:sub_keys = key.split('.')assert (sub_keys[0] in global_config), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(global_config.keys(), sub_keys[0])cur = global_config[sub_keys[0]]for idx, sub_key in enumerate(sub_keys[1:]):if idx == len(sub_keys) - 2:cur[sub_key] = valueelse:cur = cur[sub_key]def load_config(file_path):"""Load config from yml/yaml file.Args:file_path (str): Path of the config file to be loaded.Returns: global config"""merge_config(default_config)_, ext = os.path.splitext(file_path)assert ext in ['.yml', '.yaml'], "only support yaml files for now"merge_config(yaml.load(open(file_path, 'rb'), Loader=yaml.Loader))return global_configdef check_device(use_gpu, use_xpu=False):"""Log error and exit when set use_gpu=true in paddlepaddlecpu version."""err = "Config {} cannot be set as true while your paddle " \"is not compiled with {} ! \nPlease try: \n" \"\t1. Install paddlepaddle to run model on {} \n" \"\t2. Set {} as false in config file to run " \"model on CPU"try:if use_gpu and use_xpu:print("use_xpu and use_gpu can not both be ture.")if use_gpu and not paddle.is_compiled_with_cuda():print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))sys.exit(1)if use_xpu and not paddle.device.is_compiled_with_xpu():print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))sys.exit(1)except Exception as e:passdef getArgs(is_train=False):FLAGS = ArgsParser().parse_args()config = load_config(FLAGS.config)merge_config(FLAGS.opt)# check if set use_gpu=True in paddlepaddle cpu versionuse_gpu = config['Global']['use_gpu']use_xpu = Falsealg = config['Architecture']['algorithm']assert alg in ['EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN','CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE','SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE','SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN','Gestalt', 'SLANet', 'RobustScanner']device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'check_device(use_gpu, use_xpu)device = paddle.set_device(device)config['Global']['distributed'] = dist.get_world_size() != 1return config, deviceclass CRNN(nn.Layer):def __init__(self, config, device):super(CRNN, self).__init__()# 定义预处理参数mean = (0.5, 0.5, 0.5)std = (0.5, 0.5, 0.5)self.mean = paddle.to_tensor(mean).reshape([1, 3, 1, 1])self.std = paddle.to_tensor(std).reshape([1, 3, 1, 1])self.config = config# build post processself.post_process_class = build_post_process(config['PostProcess'],config['Global'])# build modelif hasattr(self.post_process_class, 'character'):char_num = len(getattr(self.post_process_class, 'character'))if self.config['Architecture']["algorithm"] in ["Distillation",]: # distillation modelfor key in self.config['Architecture']["Models"]:if self.config['Architecture']['Models'][key]['Head']['name'] == 'MultiHead': # for multi headout_channels_list = {}if self.config['PostProcess']['name'] == 'DistillationSARLabelDecode':char_num = char_num - 2out_channels_list['CTCLabelDecode'] = char_numout_channels_list['SARLabelDecode'] = char_num + 2self.config['Architecture']['Models'][key]['Head']['out_channels_list'] = out_channels_listelse:self.config['Architecture']["Models"][key]["Head"]['out_channels'] = char_numelif self.config['Architecture']['Head']['name'] == 'MultiHead': # for multi headout_channels_list = {}if self.config['PostProcess']['name'] == 'SARLabelDecode':char_num = char_num - 2out_channels_list['CTCLabelDecode'] = char_numout_channels_list['SARLabelDecode'] = char_num + 2self.config['Architecture']['Head']['out_channels_list'] = out_channels_listelse: # base rec modelself.config['Architecture']["Head"]['out_channels'] = char_num# 加载模型self.model = build_model(config['Architecture'])# load_model(config, self.model)init_model(self.config, self.model)self.model.eval()def forward(self, x):# x = paddle.transpose(x, [0,3,1,2])# x = x / 255.0# x = (x - self.mean) / self.stdmodel_out = self.model(x)# return model_outpreds_idx = model_out.argmax(axis=2, name='class').astype('float32')# preds_idx = model_out.argmax(axis=2, name='class')preds_prob = model_out.max(axis=2, name='score').astype('float32')return preds_idx, preds_probEXPORT_ONNX = True
DYNAMIC = Falseif __name__ == '__main__':config, device = getArgs()model_crnn = CRNN(config, device=device)# 构建输入数据images:image_path = "1.jpg"img = cv2.imread(image_path)img = cv2.resize(img, (320, 32))print('input data:', img.shape)img = img.astype(np.float32)img = img.transpose((2, 0, 1)) / 255input_data = img[np.newaxis, :]print('input data:', input_data.shape)x = paddle.to_tensor(input_data)print('input data:', x.shape)output_idx, output_prob = model_crnn(x)print('output_idx: ', output_idx)print('output_prob: ', output_prob)input_spec = paddle.static.InputSpec.from_tensor(x, name='input')onnx_save_path = "./export_onnx"if EXPORT_ONNX:onnx_model_name = onnx_save_path + "/char_recognize_20230526_v1"if DYNAMIC:input_spec = paddle.static.InputSpec(shape=[None, 32, 320, 3], dtype='float32', name='input')# ONNX模型导出paddle.onnx.export(model_crnn, onnx_model_name, input_spec=[input_spec], opset_version=11,enable_onnx_checker=True, output_spec=[output_idx, output_prob])
转换后的网络结构绘制出来,绘制使用的工具Netron
绘制出的起始和末尾的网络结构:
测试ONNX的代码:
'''
测试转出的onnx模型
'''
import cv2
import numpy as npimport torch
import onnxruntime as rt
import math
import osclass TestOnnx:def __init__(self, onnx_file, character_dict_path, use_space_char=True):self.sess = rt.InferenceSession(onnx_file)# 获取输入节点名称self.input_names = [input.name for input in self.sess.get_inputs()]# 获取输出节点名称self.output_names = [output.name for output in self.sess.get_outputs()]self.character = []self.character.append("blank")with open(character_dict_path, "rb") as fin:lines = fin.readlines()for line in lines:line = line.decode('utf-8').strip("\n").strip("\r\n")self.character.append(line)if use_space_char:self.character.append(" ")def resize_norm_img(self, img, image_shape=[3, 32, 320]):imgC, imgH, imgW = image_shapeh = img.shape[0]w = img.shape[1]ratio = w / float(h)if math.ceil(imgH * ratio) > imgW:resized_w = imgWelse:resized_w = int(math.ceil(imgH * ratio))resized_image = cv2.resize(img, (resized_w, imgH))resized_image = resized_image.astype('float32')if image_shape[0] == 1:resized_image = resized_image / 255resized_image = resized_image[np.newaxis, :]else:resized_image = resized_image.transpose((2, 0, 1)) / 255resized_image -= 0.5resized_image /= 0.5padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)padding_im[:, :, 0:resized_w] = resized_imagereturn padding_im# # 准备模型运行的feed_dictdef process(self, input_names, image):feed_dict = dict()for input_name in input_names:feed_dict[input_name] = imagereturn feed_dictdef get_ignored_tokens(self):return [0]def decode(self, text_index, text_prob=None, is_remove_duplicate=False):""" convert text-index into text-label. """result_list = []ignored_tokens = self.get_ignored_tokens()batch_size = len(text_index)for batch_idx in range(batch_size):selection = np.ones(len(text_index[batch_idx]), dtype=bool)if is_remove_duplicate:selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]for ignored_token in ignored_tokens:selection &= text_index[batch_idx] != ignored_tokenchar_list = [self.character[int(text_id)].replace('\n', '')for text_id in text_index[batch_idx][selection]]if text_prob is not None:conf_list = text_prob[batch_idx][selection]else:conf_list = [1] * len(selection)if len(conf_list) == 0:conf_list = [0]text = ''.join(char_list)result_list.append((text, np.mean(conf_list).tolist()))return result_listdef test(self, image_path):img_onnx = cv2.imread(image_path)# img_onnx = cv2.resize(img_onnx, (320, 32))# img_onnx = img_onnx.transpose((2, 0, 1)) / 255img_onnx = self.resize_norm_img(img_onnx)onnx_indata = img_onnx[np.newaxis, :, :, :]onnx_indata = torch.from_numpy(onnx_indata)# print('diff:', onnx_indata - input_data)print('image shape: ', onnx_indata.shape)onnx_indata = np.array(onnx_indata, dtype=np.float32)feed_dict = self.process(self.input_names, onnx_indata)output_onnx = self.sess.run(self.output_names, feed_dict)# print('output1 shape: ', output_onnx[0].shape)# print('output1: ', output_onnx[0])# print('output2 shape: ', output_onnx[1].shape)# print('output2: ', output_onnx[1])preds_idx = output_onnx[0]preds_prob = output_onnx[1]post_result = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)if isinstance(post_result, dict):rec_info = dict()for key in post_result:if len(post_result[key][0]) >= 2:rec_info[key] = {"label": post_result[key][0][0],"score": float(post_result[key][0][1]),}print(image_path, rec_info)else:if len(post_result[0]) >= 2:# info = post_result[0][0] + "\t" + str(post_result[0][1])info = post_result[0][0]print(image_path, info)if __name__=='__main__':image_dir = "./sample/img"onnx_file = './export_onnx/char_recognize_20230526_v1.onnx'character_dict_path = './all_label_num_20230517.txt'testobj = TestOnnx(onnx_file, character_dict_path)files = os.listdir(image_dir)for file in files:image_path = os.path.join(image_dir, file)result = testobj.test(image_path)
模型转换结束。
查看全文
99%的人还看了
相似问题
- 最新AIGC创作系统ChatGPT系统源码,支持最新GPT-4-Turbo模型,支持DALL-E3文生图,图片对话理解功能
- 思维模型 等待效应
- FinGPT:金融垂类大模型架构
- 人工智能基础_机器学习044_使用逻辑回归模型计算逻辑回归概率_以及_逻辑回归代码实现与手动计算概率对比---人工智能工作笔记0084
- Pytorch完整的模型训练套路
- Doris数据模型的选择建议(十三)
- python自动化标注工具+自定义目标P图替换+深度学习大模型(代码+教程+告别手动标注)
- ChatGLM2 大模型微调过程中遇到的一些坑及解决方法(更新中)
- Python实现WOA智能鲸鱼优化算法优化随机森林分类模型(RandomForestClassifier算法)项目实战
- 扩散模型实战(十一):剖析Stable Diffusion Pipeline各个组件
猜你感兴趣
版权申明
本文"对paddleOCR中的字符识别模型转ONNX":http://eshow365.cn/6-19596-0.html 内容来自互联网,请自行判断内容的正确性。如有侵权请联系我们,立即删除!