最近在接手一个遗留数据库项目时,我遇到了一个棘手问题:代码库中散落着数百个SQL脚本文件,其中部分文件包含DROP TABLE语句。这些语句如果被误执行,将导致生产环境数据表被意外删除。为了快速定位风险点,我需要开发一个Python工具来自动扫描目录下所有SQL文件,提取其中的DROP TABLE语句及其相关信息。
这个需求在数据库迁移、版本控制和代码审计等场景中非常常见。比如:
解决方案的核心流程可以分为三个步骤:
选择Python实现主要基于以下考虑:
python复制import os
def find_sql_files(directory):
sql_files = []
for root, _, files in os.walk(directory):
for file in files:
if file.lower().endswith('.sql'):
sql_files.append(os.path.join(root, file))
return sql_files
这个函数使用os.walk递归遍历目录,收集所有.sql后缀的文件路径。注意:
python复制import re
def parse_sql_file(file_path):
drop_table_pattern = re.compile(
r'DROP\s+TABLE\s+(IF\s+EXISTS\s+)?(`?.+?`?|\".+?\"|\[.+?\])',
re.IGNORECASE
)
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
matches = drop_table_pattern.finditer(content)
results = []
for match in matches:
full_statement = match.group(0)
table_name = match.group(2).strip('`"[]')
if_exists = bool(match.group(1))
results.append({
'file': file_path,
'line': content[:match.start()].count('\n') + 1,
'statement': full_statement,
'table': table_name,
'if_exists': if_exists
})
return results
关键点解析:
正则表达式设计:
行号计算技巧:
结果结构化:
python复制def generate_report(results, output_file=None):
if not results:
print("未发现DROP TABLE语句")
return
report = ["发现DROP TABLE语句汇总:"]
for item in results:
report.append(
f"文件: {item['file']}\n"
f"行号: {item['line']}\n"
f"表名: {item['table']}\n"
f"条件执行: {'是' if item['if_exists'] else '否'}\n"
f"完整语句: {item['statement']}\n"
f"{'-'*50}"
)
if output_file:
with open(output_file, 'w', encoding='utf-8') as f:
f.write('\n'.join(report))
else:
print('\n'.join(report))
输出设计考虑:
python复制import os
import re
import argparse
def main():
parser = argparse.ArgumentParser(
description='扫描SQL文件中的DROP TABLE语句'
)
parser.add_argument('directory', help='要扫描的目录路径')
parser.add_argument('-o', '--output', help='输出报告文件路径')
args = parser.parse_args()
sql_files = find_sql_files(args.directory)
all_results = []
for sql_file in sql_files:
all_results.extend(parse_sql_file(sql_file))
generate_report(all_results, args.output)
if __name__ == '__main__':
main()
使用说明:
实际项目中可能需要检测其他危险操作:
python复制dangerous_patterns = {
'DROP_TABLE': r'DROP\s+TABLE\s+(IF\s+EXISTS\s+)?(.+)',
'TRUNCATE': r'TRUNCATE\s+(TABLE\s+)?(.+)',
'DROP_DATABASE': r'DROP\s+DATABASE\s+(IF\s+EXISTS\s+)?(.+)'
}
处理大型SQL文件时:
python复制def safe_read_file(file_path):
size = os.path.getsize(file_path)
if size > 10 * 1024 * 1024: # 10MB
raise ValueError(f"文件过大: {file_path} ({size/1024/1024:.2f}MB)")
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
示例GitHub Actions配置:
yaml复制name: SQL安全检查
on: [push, pull_request]
jobs:
scan-sql:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: 设置Python
uses: actions/setup-python@v2
with:
python-version: '3.9'
- name: 执行扫描
run: |
python scan_drop_tables.py ./sql --output sql_scan_report.txt
if [ -s sql_scan_report.txt ]; then
echo "发现危险SQL语句!"
cat sql_scan_report.txt
exit 1
fi
不同SQL文件可能使用不同编码:
python复制def detect_encoding(file_path):
import chardet
with open(file_path, 'rb') as f:
result = chardet.detect(f.read(1024))
return result['encoding']
对于跨行语句,需要特殊处理:
python复制def preprocess_sql(content):
# 移除单行注释
content = re.sub(r'--.*$', '', content, flags=re.MULTILINE)
# 移除多行注释
content = re.sub(r'/\*.*?\*/', '', content, flags=re.DOTALL)
# 合并跨行语句
content = ' '.join(line.strip() for line in content.splitlines())
return content
可能误判的情况:
解决方案:
在某次数据库迁移项目中,使用此脚本发现了以下问题:
修复措施:
关键经验:自动化检查不能完全替代人工审核,但可以显著提高效率和发现率。建议将此类检查工具集成到开发流程的以下环节:
- 代码提交前钩子
- CI流水线
- 部署前检查清单