LISHUZUOXUN_yangjiang/MNPoseDetection/pose_detection.py

78 lines
2.8 KiB
Python
Raw Normal View History

2024-09-23 14:54:15 +08:00
import os
import cv2
import numpy as np
import tensorflow as tf
from MNPoseDetection.consensus import *
from MNPoseDetection.valid_test import valid_skeleton_judge
SCRIPT_PATH = os.path.split(os.path.realpath(__file__))[0]
MODEL_PATH = os.path.join(SCRIPT_PATH, "lite-model_movenet_singlepose_lightning_tflite_int8_4.tflite")
LABEL_PATH = os.path.join(SCRIPT_PATH, 'coco_labels.txt')
class PoseDetection:
def __init__(
self,
filter_confidence=0.1,
shoulder_hip_tolerance=50,
shin_thigh_tolerance=55,
arm_forearm_tolerance=60
) -> None:
super().__init__()
self.arm_forearm_tolerance = arm_forearm_tolerance
self.shin_thigh_tolerance = shin_thigh_tolerance
self.shoulder_hip_tolerance = shoulder_hip_tolerance
# 过滤器置信度
self.filter_confidence = filter_confidence
# 加载MoveNet模型
self.interpreter = tf.lite.Interpreter(model_path=MODEL_PATH)
self.interpreter.allocate_tensors()
# 获取输入和输出张量的索引
self.input_details = self.interpreter.get_input_details()
# 获得输入图像的大小
self.scale = self.input_details[0]["shape"].tolist()[1: 3]
self.output_details = self.interpreter.get_output_details()
# 加载标签
with open(LABEL_PATH, 'r') as f:
self.labels = f.read().splitlines()
# 去除标签中的序号
self.labels = [label.split(" ")[1] for label in self.labels]
def detect(self, frame):
# 重新缩放图像并将数据类型转换为 UINT8
input_data = cv2.resize(frame, self.scale)
input_data = input_data.astype(np.uint8)
input_data = input_data[np.newaxis, ...]
# 设置输入张量
self.interpreter.set_tensor(self.input_details[0]['index'], input_data)
# 运行推理
self.interpreter.invoke()
# 获取输出张量
keypoints_with_scores = self.interpreter.get_tensor(self.output_details[0]['index'])
# 获取输入图像的高度和宽度
height, width, _ = frame.shape
# 绘制关键点
result = {}
for i in range(len(self.labels)):
keypoint_y = int(keypoints_with_scores[0][0][i][0] * height)
keypoint_x = int(keypoints_with_scores[0][0][i][1] * width)
confidence = keypoints_with_scores[0][0][i][2]
keypoint = (keypoint_x, keypoint_y)
label = self.labels[i]
result.update({
label: {
CONFIDENCE: confidence,
KEY_POINTS: keypoint,
RELIABLE: True if confidence >= self.filter_confidence else False,
PASS_TEST: True
}
})
return valid_skeleton_judge(result)