426 lines
16 KiB
Python
426 lines
16 KiB
Python
import cv2
|
|
import numpy as np
|
|
import os
|
|
import pickle
|
|
import glob
|
|
from datetime import datetime
|
|
import logging
|
|
from collections import defaultdict
|
|
|
|
# 配置日志
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class SimpleFaceComparator:
|
|
def __init__(self):
|
|
# 模型路径
|
|
self.face_recognizer_model = './models/openface_nn4.small2.v1.t7'
|
|
|
|
# 文件夹路径
|
|
self.database_folder = './data/data_faces/person_mabinhao/' # 数据库人脸文件夹
|
|
self.query_folder = './faces_output/all_faces/' # 待识别人脸文件夹
|
|
|
|
# 输出路径
|
|
self.output_folder = './comparison_results/'
|
|
self.embeddings_file = './models/face_embeddings.pkl'
|
|
|
|
# 阈值设置
|
|
self.recognition_threshold = 0.4 # 人脸识别阈值
|
|
|
|
# 加载模型
|
|
self.load_models()
|
|
|
|
# 加载数据库
|
|
self.face_database = self.load_or_create_database()
|
|
|
|
def load_models(self):
|
|
"""加载人脸识别模型"""
|
|
try:
|
|
self.face_recognizer = cv2.dnn.readNetFromTorch(self.face_recognizer_model)
|
|
|
|
# 尝试使用GPU加速
|
|
if cv2.cuda.getCudaEnabledDeviceCount() > 0:
|
|
self.face_recognizer.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA)
|
|
self.face_recognizer.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA)
|
|
print("✅ 使用CUDA加速")
|
|
else:
|
|
print("✅ 使用CPU")
|
|
|
|
print("✅ 模型加载成功")
|
|
except Exception as e:
|
|
print(f"❌ 模型加载失败: {e}")
|
|
raise
|
|
|
|
def extract_face_embedding(self, image):
|
|
"""提取人脸特征向量"""
|
|
if image is None or image.size == 0:
|
|
return None
|
|
|
|
# 调整大小并预处理
|
|
face_blob = cv2.dnn.blobFromImage(
|
|
image,
|
|
1.0 / 255,
|
|
(96, 96),
|
|
(0, 0, 0),
|
|
swapRB=True,
|
|
crop=False
|
|
)
|
|
|
|
# 提取特征
|
|
self.face_recognizer.setInput(face_blob)
|
|
vec = self.face_recognizer.forward()
|
|
|
|
# 归一化
|
|
embedding = vec.flatten()
|
|
embedding /= np.linalg.norm(embedding)
|
|
|
|
return embedding
|
|
|
|
def load_or_create_database(self):
|
|
"""加载或创建人脸数据库"""
|
|
if os.path.exists(self.embeddings_file):
|
|
print("加载已有的人脸数据库...")
|
|
with open(self.embeddings_file, 'rb') as f:
|
|
database = pickle.load(f)
|
|
print(f"✅ 数据库加载成功: {len(database)} 个人物")
|
|
return database
|
|
else:
|
|
print("创建新的人脸数据库")
|
|
return self.build_database_from_folder()
|
|
|
|
def build_database_from_folder(self):
|
|
"""从数据库文件夹构建人脸数据库"""
|
|
print(f"从文件夹构建数据库: {self.database_folder}")
|
|
|
|
if not os.path.exists(self.database_folder):
|
|
os.makedirs(self.database_folder)
|
|
print(f"✅ 创建数据库文件夹: {self.database_folder}")
|
|
return {}
|
|
|
|
database = {}
|
|
|
|
# 获取所有图片文件
|
|
image_files = glob.glob(os.path.join(self.database_folder, "*.jpg")) + \
|
|
glob.glob(os.path.join(self.database_folder, "*.png")) + \
|
|
glob.glob(os.path.join(self.database_folder, "*.jpeg"))
|
|
|
|
if not image_files:
|
|
print("❌ 数据库文件夹中没有图片")
|
|
return database
|
|
|
|
# 按人物名称分组(文件名作为人物名)
|
|
person_images = defaultdict(list)
|
|
for image_file in image_files:
|
|
# 从文件名提取人物名字(去掉扩展名)
|
|
person_name = os.path.splitext(os.path.basename(image_file))[0]
|
|
person_images[person_name].append(image_file)
|
|
|
|
# 为每个人物提取特征
|
|
for person_name, image_paths in person_images.items():
|
|
print(f"处理数据库人物: {person_name} ({len(image_paths)} 张图片)")
|
|
|
|
embeddings = []
|
|
valid_count = 0
|
|
|
|
for image_path in image_paths:
|
|
image = cv2.imread(image_path)
|
|
if image is None:
|
|
print(f" ❌ 无法读取图片: {image_path}")
|
|
continue
|
|
|
|
# 提取特征
|
|
embedding = self.extract_face_embedding(image)
|
|
|
|
if embedding is not None:
|
|
embeddings.append(embedding)
|
|
valid_count += 1
|
|
print(f" ✅ 成功提取特征: {os.path.basename(image_path)}")
|
|
else:
|
|
print(f" ❌ 特征提取失败: {os.path.basename(image_path)}")
|
|
|
|
if embeddings:
|
|
# 计算平均特征向量
|
|
avg_embedding = np.mean(embeddings, axis=0)
|
|
avg_embedding /= np.linalg.norm(avg_embedding)
|
|
|
|
# 保存到数据库
|
|
database[person_name] = {
|
|
'embedding': avg_embedding,
|
|
'samples': len(embeddings),
|
|
'image_count': len(image_paths),
|
|
'image_paths': image_paths,
|
|
'registered_time': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
}
|
|
|
|
print(f"✅ {person_name} 注册成功! 使用 {valid_count}/{len(image_paths)} 张有效样本")
|
|
else:
|
|
print(f"❌ {person_name} 注册失败,无有效特征")
|
|
|
|
# 保存数据库
|
|
if database:
|
|
self.save_database(database)
|
|
|
|
return database
|
|
|
|
def save_database(self, database=None):
|
|
"""保存人脸数据库"""
|
|
if database is None:
|
|
database = self.face_database
|
|
|
|
with open(self.embeddings_file, 'wb') as f:
|
|
pickle.dump(database, f)
|
|
print(f"✅ 数据库已保存: {len(database)} 个人物")
|
|
|
|
def recognize_face(self, embedding):
|
|
"""识别人脸"""
|
|
if not self.face_database:
|
|
return "未知", 1.0, None
|
|
|
|
best_name = "未知"
|
|
best_distance = float('inf')
|
|
best_person_data = None
|
|
|
|
for name, data in self.face_database.items():
|
|
db_embedding = data['embedding']
|
|
|
|
# 计算余弦距离
|
|
distance = 1 - np.dot(embedding, db_embedding)
|
|
|
|
if distance < best_distance:
|
|
best_distance = distance
|
|
best_name = name
|
|
best_person_data = data
|
|
|
|
# 检查是否超过阈值
|
|
if best_distance > self.recognition_threshold:
|
|
return "未知", best_distance, None
|
|
else:
|
|
return best_name, best_distance, best_person_data
|
|
|
|
def compare_faces(self):
|
|
"""对比两个文件夹中的人脸"""
|
|
print(f"\n开始人脸对比检测...")
|
|
print(f"数据库文件夹: {self.database_folder}")
|
|
print(f"查询文件夹: {self.query_folder}")
|
|
|
|
# 检查文件夹是否存在
|
|
if not os.path.exists(self.query_folder):
|
|
print(f"❌ 查询文件夹不存在: {self.query_folder}")
|
|
return None
|
|
|
|
if not self.face_database:
|
|
print("❌ 数据库为空,无法进行对比")
|
|
return None
|
|
|
|
# 获取查询文件夹中的所有图片
|
|
query_files = glob.glob(os.path.join(self.query_folder, "*.jpg")) + \
|
|
glob.glob(os.path.join(self.query_folder, "*.png")) + \
|
|
glob.glob(os.path.join(self.query_folder, "*.jpeg"))
|
|
|
|
if not query_files:
|
|
print(f"❌ 查询文件夹中没有图片: {self.query_folder}")
|
|
return None
|
|
|
|
print(f"找到 {len(query_files)} 张待识别人脸图片")
|
|
|
|
# 创建输出文件夹
|
|
if not os.path.exists(self.output_folder):
|
|
os.makedirs(self.output_folder)
|
|
|
|
all_results = []
|
|
|
|
# 处理每张查询图片
|
|
for query_file in query_files:
|
|
print(f"\n处理查询图片: {os.path.basename(query_file)}")
|
|
|
|
# 读取图片
|
|
query_image = cv2.imread(query_file)
|
|
if query_image is None:
|
|
print(f" ❌ 无法读取图片: {query_file}")
|
|
continue
|
|
|
|
# 提取特征
|
|
query_embedding = self.extract_face_embedding(query_image)
|
|
if query_embedding is None:
|
|
print(f" ❌ 无法提取人脸特征: {os.path.basename(query_file)}")
|
|
continue
|
|
|
|
# 识别
|
|
name, distance, person_data = self.recognize_face(query_embedding)
|
|
recognition_confidence = max(0, 1 - distance)
|
|
|
|
# 保存结果
|
|
result = {
|
|
'query_file': os.path.basename(query_file),
|
|
'query_path': query_file,
|
|
'recognized_name': name,
|
|
'recognition_confidence': recognition_confidence,
|
|
'distance': distance,
|
|
'person_data': person_data,
|
|
'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
}
|
|
|
|
all_results.append(result)
|
|
|
|
# 输出结果
|
|
status = "✅" if name != "未知" else "❌"
|
|
print(f" {status} 识别结果: {name} (置信度: {recognition_confidence:.3f})")
|
|
|
|
# 可视化结果
|
|
self.visualize_comparison(result, query_image)
|
|
|
|
# 生成报告
|
|
if all_results:
|
|
self.generate_comparison_report(all_results)
|
|
|
|
return all_results
|
|
|
|
def visualize_comparison(self, result, query_image):
|
|
"""可视化对比结果"""
|
|
output_image = query_image.copy()
|
|
|
|
# 获取图片基本信息
|
|
height, width = output_image.shape[:2]
|
|
query_filename = result['query_file']
|
|
recognized_name = result['recognized_name']
|
|
confidence = result['recognition_confidence']
|
|
|
|
# 设置颜色
|
|
color = (0, 255, 0) if recognized_name != "未知" else (0, 0, 255)
|
|
|
|
# 添加边框
|
|
cv2.rectangle(output_image, (0, 0), (width, height), color, 8)
|
|
|
|
# 添加识别结果文本
|
|
result_text = f"{recognized_name} ({confidence:.3f})"
|
|
cv2.putText(output_image, result_text, (20, 40),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 3)
|
|
|
|
# 添加文件名
|
|
cv2.putText(output_image, f"File: {query_filename}", (20, height - 20),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
|
|
|
|
# 如果有匹配的数据库人物,显示数据库信息
|
|
if result['person_data']:
|
|
person_data = result['person_data']
|
|
sample_text = f"Samples: {person_data['samples']}"
|
|
cv2.putText(output_image, sample_text, (20, 80),
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
|
|
|
|
# 保存结果图片
|
|
output_filename = f"result_{recognized_name}_{confidence:.3f}_{query_filename}"
|
|
output_path = os.path.join(self.output_folder, output_filename)
|
|
cv2.imwrite(output_path, output_image)
|
|
|
|
print(f" 💾 结果图片已保存: {output_filename}")
|
|
|
|
def generate_comparison_report(self, results):
|
|
"""生成对比报告"""
|
|
report_path = os.path.join(self.output_folder, "face_comparison_report.txt")
|
|
|
|
# 统计结果
|
|
total_queries = len(results)
|
|
recognized_queries = [r for r in results if r['recognized_name'] != "未知"]
|
|
unrecognized_queries = [r for r in results if r['recognized_name'] == "未知"]
|
|
|
|
# 按人物统计
|
|
person_stats = defaultdict(lambda: {'count': 0, 'total_confidence': 0, 'files': []})
|
|
|
|
for result in recognized_queries:
|
|
name = result['recognized_name']
|
|
person_stats[name]['count'] += 1
|
|
person_stats[name]['total_confidence'] += result['recognition_confidence']
|
|
person_stats[name]['files'].append(result['query_file'])
|
|
|
|
# 计算平均置信度
|
|
for name in person_stats:
|
|
person_stats[name]['avg_confidence'] = person_stats[name]['total_confidence'] / person_stats[name]['count']
|
|
|
|
# 写入报告
|
|
with open(report_path, 'w', encoding='utf-8') as f:
|
|
f.write("人脸对比检测报告\n")
|
|
f.write("=" * 60 + "\n\n")
|
|
|
|
f.write(f"对比时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
f.write(f"数据库文件夹: {self.database_folder}\n")
|
|
f.write(f"查询文件夹: {self.query_folder}\n\n")
|
|
|
|
f.write("统计摘要:\n")
|
|
f.write("-" * 40 + "\n")
|
|
f.write(f"总查询图片数: {total_queries}\n")
|
|
f.write(f"成功识别: {len(recognized_queries)}\n")
|
|
f.write(f"识别失败: {len(unrecognized_queries)}\n")
|
|
f.write(f"识别成功率: {len(recognized_queries)/max(total_queries,1)*100:.1f}%\n\n")
|
|
|
|
f.write("识别出的人物统计:\n")
|
|
f.write("-" * 40 + "\n")
|
|
for name, stats in sorted(person_stats.items(), key=lambda x: x[1]['count'], reverse=True):
|
|
f.write(f"{name}:\n")
|
|
f.write(f" 出现次数: {stats['count']}\n")
|
|
f.write(f" 平均置信度: {stats['avg_confidence']:.3f}\n")
|
|
f.write(f" 出现的文件: {', '.join(stats['files'][:5])}")
|
|
if len(stats['files']) > 5:
|
|
f.write(f" ... (共{len(stats['files'])}个文件)")
|
|
f.write("\n\n")
|
|
|
|
f.write("详细识别记录:\n")
|
|
f.write("-" * 40 + "\n")
|
|
for result in results:
|
|
status = "成功" if result['recognized_name'] != "未知" else "失败"
|
|
f.write(f"文件: {result['query_file']}\n")
|
|
f.write(f" 识别结果: {result['recognized_name']} ({status})\n")
|
|
f.write(f" 置信度: {result['recognition_confidence']:.3f}\n")
|
|
f.write(f" 处理时间: {result['timestamp']}\n\n")
|
|
|
|
f.write("数据库信息:\n")
|
|
f.write("-" * 40 + "\n")
|
|
for name, data in self.face_database.items():
|
|
f.write(f"{name}: {data['samples']}个样本, {data['image_count']}张图片\n")
|
|
|
|
print(f"✅ 对比报告已保存: {report_path}")
|
|
|
|
# 控制台输出摘要
|
|
print(f"\n对比结果摘要:")
|
|
print(f"总查询图片: {total_queries}")
|
|
print(f"成功识别: {len(recognized_queries)}")
|
|
print(f"识别失败: {len(unrecognized_queries)}")
|
|
print(f"成功率: {len(recognized_queries)/max(total_queries,1)*100:.1f}%")
|
|
|
|
if person_stats:
|
|
print("\n识别出的人物:")
|
|
for name, stats in sorted(person_stats.items(), key=lambda x: x[1]['count'], reverse=True):
|
|
print(f" {name}: {stats['count']}次 (平均置信度: {stats['avg_confidence']:.3f})")
|
|
|
|
def main():
|
|
"""主函数"""
|
|
# 初始化对比器
|
|
comparator = SimpleFaceComparator()
|
|
|
|
while True:
|
|
print("\n简单人脸对比检测系统")
|
|
print("1. 开始人脸对比")
|
|
print("2. 重新构建数据库")
|
|
print("3. 退出")
|
|
|
|
choice = input("请选择操作 (1-5): ").strip()
|
|
|
|
if choice == '1':
|
|
# 开始对比
|
|
print("开始人脸对比...")
|
|
comparator.compare_faces()
|
|
|
|
elif choice == '2':
|
|
# 重新构建数据库
|
|
print("重新构建数据库...")
|
|
comparator.face_database = comparator.build_database_from_folder()
|
|
|
|
elif choice == '3':
|
|
print("再见!")
|
|
break
|
|
|
|
else:
|
|
print("无效选择,请重新输入!")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|