TensorFlow Lite模型转换为C语言代码

TensorFlow Lite模型转换为C语言代码

tflite2c.py脚本概述

tools/tflite2c.py脚本旨在将TensorFlow Lite模型文件(.tflite)转换为C或C++代码文件,方便在C/C++项目中使用。该脚本具备以下特点:

  1. 输入输出管理:通过命令行参数接收输入的TensorFlow Lite模型文件路径、输出目录、模型名称等信息,并根据这些参数生成对应的C或C++文件。
  2. 模型解析与转换:能够读取TensorFlow Lite模型文件内容,对文件格式进行验证,并将模型数据转换为C语言可识别的格式。
  3. 类别信息处理:支持将模型的类别名称作为参数输入,并在生成的代码中进行定义,便于后续的模型推理与结果处理。

技术原理与实现过程

在实现过程中,首先利用argparse模块解析命令行参数,获取输入模型文件路径、输出目录、模型名称、是否生成 C++ 文件及类别名称等信息;接着检查输入文件是否存在,若未指定模型名称则从输入文件路径提取;然后确定输出文件路径,以二进制模式读取输入文件并验证其是否为有效的 TensorFlow Lite 文件,通过binascii.hexlify将二进制数据转为十六进制字符串;最后分别生成头文件和源文件。

gwgf9j0q_

转换过程

将 TensorFlow Lite 模型的二进制数据转换为适合在 C 语言中表示的十六进制数组形式。首先,脚本读取 TFLite 模型的二进制数据,接着使用binascii.hexlify函数将二进制数据转换为十六进制字节串,再通过decode('utf-8')方法将字节串转换为十六进制字符串。之后,在生成 C 语言源文件时,通过循环按步长 2 遍历该十六进制字符串,每次取两个字符,在其前面添加0x,组成 C 语言数组的一个元素值。

iftdte2z_

命令行参数解析

脚本使用argparse模块解析命令行参数,获取输入文件路径--input、输出目录--output_dir、模型名称--name、是否输出C++文件--cpp以及类别名称--classes

def parse_args():
parser = argparse.ArgumentParser(
description='Convert tflite to c or cpp file')

parser.add_argument('--input', help='input tflite file')
parser.add_argument('--output_dir', help='output directory')
parser.add_argument('--name', help='model name')
parser.add_argument('--cpp', action='store_true',
default=True, help='output cpp file')
parser.add_argument('--classes', type=str, help='classes name')

args = parser.parse_args()

return args

输入文件检查与处理

  1. 文件存在性检查:验证输入的TensorFlow Lite模型文件是否存在,若不存在则输出错误信息并退出程序。
if not os.path.exists(input):
print('input file not exist')
sys.exit(1)
  1. 模型名称处理:若未指定模型名称,则从输入文件路径中提取文件名作为模型名称。
if name == None:
name = input.split('/')[-1].split('.')[0]
  1. 类别名称处理:若指定了类别名称,则将其按逗号分隔为列表。
if classes != None:
classes = list(classes.split(','))

输出文件路径确定

根据输出目录、模型名称和是否输出C++文件确定输出的头文件和源文件路径。

output_h = os.path.join(output_dir, name + '_model_data.h')
if args.cpp:
output_c = os.path.join(output_dir, name + '_model_data.cpp')
else:
output_c = os.path.join(output_dir, name + '_model_data.c')

模型文件读取与验证

  1. 文件读取:以二进制模式读取输入的TensorFlow Lite文件。
with open(input, 'rb') as f_input:
data = f_input.read()
  1. 格式验证:检查文件的魔数是否为TFL3,若不是则输出错误信息并退出程序,确保输入文件是有效的TensorFlow Lite模型文件。
if data[4:8] != b'TFL3':
print('input file is not tflite')
sys.exit(1)

数据转换与文件生成

  1. 数据转换:将读取的二进制数据转换为十六进制字符串,以便在C语言中表示。
data = binascii.hexlify(data)
data = data.decode('utf-8')
  1. 头文件生成:生成的头文件包含头文件保护宏、引入<stdint.h>头文件、声明模型数据数组和其长度,若指定了类别名称,还会声明类别名称数组和其数量。
with open(output_h, 'w') as f_output_h:
f_output_h.write('#ifndef __%s_MODEL_DATA_H__\r\n' % name.upper())
f_output_h.write('#define __%s_MODEL_DATA_H__\r\n' % name.upper())
f_output_h.write('\r\n//this file is generated by tflite2c.py\r\n')
f_output_h.write('\r\n#include <stdint.h>\r\n')
f_output_h.write('extern const unsigned char g_%s_model_data[];\r\n' % name)
f_output_h.write('extern const unsigned int g_%s_model_data_len;\r\n' % name)
if classes != None:
f_output_h.write('extern const char* g_%s_model_classes[];\r\n' % name)
f_output_h.write('extern const unsigned int g_%s_model_classes_num;\r\n' % name)
f_output_h.write('\r\n#endif\r\n')
  1. 源文件生成:源文件引入<stdint.h>头文件和对应的头文件,定义模型数据数组,将十六进制字符串按每两个字符一组转换为十六进制数值写入数组,并定义模型数据数组的长度和类别名称数组(若有)。
with open(output_c, 'w') as f_output_c:
f_output_c.write('#include <stdint.h>\r\n')
f_output_c.write('\r\n#include "%s_model_data.h"\r\n\r\n' % name)
f_output_c.write('const unsigned char g_%s_model_data[] = {\r\n' % name)
for i in range(0, len(data), 2):
f_output_c.write('0x')
f_output_c.write(data[i])
f_output_c.write(data[i+1])
f_output_c.write(', ')
if i % 36 == 34:
f_output_c.write('\r\n')
f_output_c.write('};\r\n\r\n')
f_output_c.write('const unsigned int g_%s_model_data_len = %d;\r\n' % (name, len(data) // 2))
if classes != None:
f_output_c.write('const char* g_%s_model_classes[] = {' % name)
for i in range(len(classes)):
f_output_c.write('"%s", ' % classes[i])
f_output_c.write('};\r\n\r\n')
f_output_c.write('const unsigned int g_%s_model_classes_num = %d;\r\n' % (name, len(classes)))
else:
f_output_c.write('const char* g_%s_model_classes[] = {};\r\n' % name)
f_output_c.write('const unsigned int g_%s_model_classes_num = 0;\r\n' % name)