利用tree-sitter提取代码文件中的函数和注释
- 1. 需求
- 2. 工具
- 3. 实现
1. 需求
提取.c或.cpp文件中的带有注释的函数,作为训练数据喂给大语言模型。要求是能够批量处理,提取函数前带有注释的函数和注释,并将函数中的注释同样提取出来作为辅助训练数据,结果保存在JSON文件中。
2. 工具
tree-sitter
如何配置使用环境见https://blog.csdn.net/sluck_0430/article/details/134194493
pycharm
如何将conda的虚拟python环境添加到pycharm中见https://blog.csdn.net/weixin_62783109/article/details/129962054?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522171402346916800178588080%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fall.%2522%257D&request_id=171402346916800178588080&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2allfirst_rank_ecpm_v1~rank_v31_ecpm-4-129962054-null-null.142%5Ev100%5Econtrol&utm_term=conda%E5%90%8E%E7%9A%84%E7%8E%AF%E5%A2%83%E5%A6%82%E4%BD%95%E6%B7%BB%E5%8A%A0%E5%88%B0pycharm%E4%B8%AD&spm=1018.2226.3001.4187
3. 实现
from tree_sitter import Language, Parser
import json
import os
import re
# 加载C语言模块
Language.build_library(
'build/my-languages.so',
[
'vendor/tree-sitter-c'
]
)
C_LANGUAGE = Language('build/my-languages.so', 'c')
parser = Parser()
parser.set_language(C_LANGUAGE)
# 提取代码信息
def extract_code_information(node, code):
functions = [] # 存放最终的代码提取结果
comment = '' # 存放函数前的注释
in_comment = '' # 存放函数中的注释
function = '' # 存放函数
for child in node.children:
# 只保存函数前存在注释的函数及其注释
if child.type == 'function_definition' and child.prev_sibling and child.prev_sibling.type == 'comment':
# 首先处理函数
function = extract_node_information(child, code)
# 然后处理函数中的注释
in_comment = traverse_children(child, code)
# 最后处理函数前的注释
temp_node = child.prev_sibling
while temp_node.type == 'comment':
comment += extract_node_information(temp_node, code)
if temp_node.prev_sibling:
temp_node = temp_node.prev_sibling
else:
break
# 将函数和其注释保存到最终的结果中
functions.append({
'comment_before_function': comment,
'comment_in_function': in_comment,
'function': function
})
comment = ''
in_comment = ''
function = ''
return functions
# 深度优先遍历节点的全部孩子节点
def traverse_children(node, code):
if node is None:
return ''
comment = ''
if node.type == 'comment':
comment += extract_node_information(node, code)
for child in node.children:
comment += traverse_children(child, code)
return comment
# 提取节点信息
def extract_node_information(node, code):
try:
start_row, start_col = node.start_point
end_row, end_col = node.end_point
# 将源代码按行进行拆分
code_lines = code.split('\n')
# 如果起始行和结束行在同一行
if start_row == end_row:
extracted_code = code_lines[start_row][start_col:end_col]
else:
# 提取起始行到结束行中的内容
extracted_code = code_lines[start_row][start_col:]
for i in range(start_row + 1, end_row):
extracted_code += code_lines[i] + '\n'
extracted_code += code_lines[end_row][:end_col]
return extracted_code
except AttributeError as e:
return ''
# 查找文件夹中的.c和.cpp文件
def get_c_files(folder):
c_files = []
for root, dirs, files in os.walk(folder):
for file in files:
if re.search(r'\.c$|\.cpp$', file):
c_files.append(os.path.join(root, file))
return c_files
# 处理文件夹中的.c和.cpp文件
def pipeline(folder_path):
c_files = get_c_files(folder_path)
functions = []
for c_file in c_files:
print(c_file)
temp = []
try:
try:
with open(c_file, 'r', encoding='gbk') as file:
code = file.read()
tree = parser.parse(bytes(code, 'gbk'))
root_node = tree.root_node
temp = extract_code_information(root_node, code)
functions.append(temp)
except UnicodeDecodeError as e:
with open(c_file, 'r', encoding='utf8') as file:
code = file.read()
tree = parser.parse(bytes(code, 'utf8'))
root_node = tree.root_node
temp = extract_code_information(root_node, code)
functions.append(temp)
except UnicodeDecodeError as e:
print("UnicodeDecodeError!")
# 将结果保存在functions.json中
with open('functions.json', 'w', encoding='utf8') as json_file:
json.dump(functions, json_file, indent=4, ensure_ascii=False)
if __name__ == '__main__':
folder_path = '文件夹的绝对路径'
pipeline(folder_path)