78 lines
2.8 KiB
Python
78 lines
2.8 KiB
Python
import os
|
|
import cv2
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from MNPoseDetection.consensus import *
|
|
from MNPoseDetection.valid_test import valid_skeleton_judge
|
|
|
|
SCRIPT_PATH = os.path.split(os.path.realpath(__file__))[0]
|
|
MODEL_PATH = os.path.join(SCRIPT_PATH, "lite-model_movenet_singlepose_lightning_tflite_int8_4.tflite")
|
|
LABEL_PATH = os.path.join(SCRIPT_PATH, 'coco_labels.txt')
|
|
|
|
|
|
class PoseDetection:
|
|
|
|
def __init__(
|
|
self,
|
|
filter_confidence=0.1,
|
|
shoulder_hip_tolerance=50,
|
|
shin_thigh_tolerance=55,
|
|
arm_forearm_tolerance=60
|
|
) -> None:
|
|
super().__init__()
|
|
self.arm_forearm_tolerance = arm_forearm_tolerance
|
|
self.shin_thigh_tolerance = shin_thigh_tolerance
|
|
self.shoulder_hip_tolerance = shoulder_hip_tolerance
|
|
# 过滤器置信度
|
|
self.filter_confidence = filter_confidence
|
|
# 加载MoveNet模型
|
|
self.interpreter = tf.lite.Interpreter(model_path=MODEL_PATH)
|
|
self.interpreter.allocate_tensors()
|
|
# 获取输入和输出张量的索引
|
|
self.input_details = self.interpreter.get_input_details()
|
|
# 获得输入图像的大小
|
|
self.scale = self.input_details[0]["shape"].tolist()[1: 3]
|
|
self.output_details = self.interpreter.get_output_details()
|
|
# 加载标签
|
|
with open(LABEL_PATH, 'r') as f:
|
|
self.labels = f.read().splitlines()
|
|
# 去除标签中的序号
|
|
self.labels = [label.split(" ")[1] for label in self.labels]
|
|
|
|
def detect(self, frame):
|
|
# 重新缩放图像并将数据类型转换为 UINT8
|
|
input_data = cv2.resize(frame, self.scale)
|
|
input_data = input_data.astype(np.uint8)
|
|
input_data = input_data[np.newaxis, ...]
|
|
# 设置输入张量
|
|
self.interpreter.set_tensor(self.input_details[0]['index'], input_data)
|
|
|
|
# 运行推理
|
|
self.interpreter.invoke()
|
|
|
|
# 获取输出张量
|
|
keypoints_with_scores = self.interpreter.get_tensor(self.output_details[0]['index'])
|
|
|
|
# 获取输入图像的高度和宽度
|
|
height, width, _ = frame.shape
|
|
|
|
# 绘制关键点
|
|
result = {}
|
|
for i in range(len(self.labels)):
|
|
keypoint_y = int(keypoints_with_scores[0][0][i][0] * height)
|
|
keypoint_x = int(keypoints_with_scores[0][0][i][1] * width)
|
|
confidence = keypoints_with_scores[0][0][i][2]
|
|
keypoint = (keypoint_x, keypoint_y)
|
|
label = self.labels[i]
|
|
result.update({
|
|
label: {
|
|
CONFIDENCE: confidence,
|
|
KEY_POINTS: keypoint,
|
|
RELIABLE: True if confidence >= self.filter_confidence else False,
|
|
PASS_TEST: True
|
|
}
|
|
})
|
|
|
|
return valid_skeleton_judge(result)
|