分享 TensorFlow Lite模型转换为C语言代码 lihuibear 2025-04-27 2025-04-27 TensorFlow Lite模型转换为C语言代码 tflite2c.py脚本概述 tools/tflite2c.py
脚本旨在将TensorFlow Lite模型文件(.tflite
)转换为C或C++代码文件,方便在C/C++项目中使用。该脚本具备以下特点:
输入输出管理 :通过命令行参数接收输入的TensorFlow Lite模型文件路径、输出目录、模型名称等信息,并根据这些参数生成对应的C或C++文件。
模型解析与转换 :能够读取TensorFlow Lite模型文件内容,对文件格式进行验证,并将模型数据转换为C语言可识别的格式。
类别信息处理 :支持将模型的类别名称作为参数输入,并在生成的代码中进行定义,便于后续的模型推理与结果处理。
技术原理与实现过程 在实现过程中,首先利用argparse
模块解析命令行参数,获取输入模型文件路径、输出目录、模型名称、是否生成 C++ 文件及类别名称等信息;接着检查输入文件是否存在,若未指定模型名称则从输入文件路径提取;然后确定输出文件路径,以二进制模式读取输入文件并验证其是否为有效的 TensorFlow Lite 文件,通过binascii.hexlify
将二进制数据转为十六进制字符串;最后分别生成头文件和源文件。
转换过程 将 TensorFlow Lite 模型的二进制数据转换为适合在 C 语言中表示的十六进制数组形式。首先,脚本读取 TFLite 模型的二进制数据,接着使用binascii.hexlify
函数将二进制数据转换为十六进制字节串,再通过decode('utf-8')
方法将字节串转换为十六进制字符串。之后,在生成 C 语言源文件时,通过循环按步长 2 遍历该十六进制字符串,每次取两个字符,在其前面添加0x
,组成 C 语言数组的一个元素值。
命令行参数解析 脚本使用argparse
模块解析命令行参数,获取输入文件路径--input
、输出目录--output_dir
、模型名称--name
、是否输出C++文件--cpp
以及类别名称--classes
。
def parse_args (): parser = argparse.ArgumentParser( description='Convert tflite to c or cpp file' ) parser.add_argument('--input' , help ='input tflite file' ) parser.add_argument('--output_dir' , help ='output directory' ) parser.add_argument('--name' , help ='model name' ) parser.add_argument('--cpp' , action='store_true' , default=True , help ='output cpp file' ) parser.add_argument('--classes' , type =str , help ='classes name' ) args = parser.parse_args() return args
输入文件检查与处理
文件存在性检查 :验证输入的TensorFlow Lite模型文件是否存在,若不存在则输出错误信息并退出程序。
if not os.path.exists(input ): print ('input file not exist' ) sys.exit(1 )
模型名称处理 :若未指定模型名称,则从输入文件路径中提取文件名作为模型名称。
if name == None : name = input .split('/' )[-1 ].split('.' )[0 ]
类别名称处理 :若指定了类别名称,则将其按逗号分隔为列表。
if classes != None : classes = list (classes.split(',' ))
输出文件路径确定 根据输出目录、模型名称和是否输出C++文件确定输出的头文件和源文件路径。
output_h = os.path.join(output_dir, name + '_model_data.h' ) if args.cpp: output_c = os.path.join(output_dir, name + '_model_data.cpp' ) else : output_c = os.path.join(output_dir, name + '_model_data.c' )
模型文件读取与验证
文件读取 :以二进制模式读取输入的TensorFlow Lite文件。
with open (input , 'rb' ) as f_input: data = f_input.read()
格式验证 :检查文件的魔数是否为TFL3
,若不是则输出错误信息并退出程序,确保输入文件是有效的TensorFlow Lite模型文件。
if data[4 :8 ] != b'TFL3' : print ('input file is not tflite' ) sys.exit(1 )
数据转换与文件生成
数据转换 :将读取的二进制数据转换为十六进制字符串,以便在C语言中表示。
data = binascii.hexlify(data) data = data.decode('utf-8' )
头文件生成 :生成的头文件包含头文件保护宏、引入<stdint.h>
头文件、声明模型数据数组和其长度,若指定了类别名称,还会声明类别名称数组和其数量。
with open (output_h, 'w' ) as f_output_h: f_output_h.write('#ifndef __%s_MODEL_DATA_H__\r\n' % name.upper()) f_output_h.write('#define __%s_MODEL_DATA_H__\r\n' % name.upper()) f_output_h.write('\r\n//this file is generated by tflite2c.py\r\n' ) f_output_h.write('\r\n#include <stdint.h>\r\n' ) f_output_h.write('extern const unsigned char g_%s_model_data[];\r\n' % name) f_output_h.write('extern const unsigned int g_%s_model_data_len;\r\n' % name) if classes != None : f_output_h.write('extern const char* g_%s_model_classes[];\r\n' % name) f_output_h.write('extern const unsigned int g_%s_model_classes_num;\r\n' % name) f_output_h.write('\r\n#endif\r\n' )
源文件生成 :源文件引入<stdint.h>
头文件和对应的头文件,定义模型数据数组,将十六进制字符串按每两个字符一组转换为十六进制数值写入数组,并定义模型数据数组的长度和类别名称数组(若有)。
with open (output_c, 'w' ) as f_output_c: f_output_c.write('#include <stdint.h>\r\n' ) f_output_c.write('\r\n#include "%s_model_data.h"\r\n\r\n' % name) f_output_c.write('const unsigned char g_%s_model_data[] = {\r\n' % name) for i in range (0 , len (data), 2 ): f_output_c.write('0x' ) f_output_c.write(data[i]) f_output_c.write(data[i+1 ]) f_output_c.write(', ' ) if i % 36 == 34 : f_output_c.write('\r\n' ) f_output_c.write('};\r\n\r\n' ) f_output_c.write('const unsigned int g_%s_model_data_len = %d;\r\n' % (name, len (data) // 2 )) if classes != None : f_output_c.write('const char* g_%s_model_classes[] = {' % name) for i in range (len (classes)): f_output_c.write('"%s", ' % classes[i]) f_output_c.write('};\r\n\r\n' ) f_output_c.write('const unsigned int g_%s_model_classes_num = %d;\r\n' % (name, len (classes))) else : f_output_c.write('const char* g_%s_model_classes[] = {};\r\n' % name) f_output_c.write('const unsigned int g_%s_model_classes_num = 0;\r\n' % name)