import os import cv2 import numpy as np import tensorflow as tf from MNPoseDetection.consensus import * from MNPoseDetection.valid_test import valid_skeleton_judge SCRIPT_PATH = os.path.split(os.path.realpath(__file__))[0] MODEL_PATH = os.path.join(SCRIPT_PATH, "lite-model_movenet_singlepose_lightning_tflite_int8_4.tflite") LABEL_PATH = os.path.join(SCRIPT_PATH, 'coco_labels.txt') class PoseDetection: def __init__( self, filter_confidence=0.1, shoulder_hip_tolerance=50, shin_thigh_tolerance=55, arm_forearm_tolerance=60 ) -> None: super().__init__() self.arm_forearm_tolerance = arm_forearm_tolerance self.shin_thigh_tolerance = shin_thigh_tolerance self.shoulder_hip_tolerance = shoulder_hip_tolerance # 过滤器置信度 self.filter_confidence = filter_confidence # 加载MoveNet模型 self.interpreter = tf.lite.Interpreter(model_path=MODEL_PATH) self.interpreter.allocate_tensors() # 获取输入和输出张量的索引 self.input_details = self.interpreter.get_input_details() # 获得输入图像的大小 self.scale = self.input_details[0]["shape"].tolist()[1: 3] self.output_details = self.interpreter.get_output_details() # 加载标签 with open(LABEL_PATH, 'r') as f: self.labels = f.read().splitlines() # 去除标签中的序号 self.labels = [label.split(" ")[1] for label in self.labels] def detect(self, frame): # 重新缩放图像并将数据类型转换为 UINT8 input_data = cv2.resize(frame, self.scale) input_data = input_data.astype(np.uint8) input_data = input_data[np.newaxis, ...] # 设置输入张量 self.interpreter.set_tensor(self.input_details[0]['index'], input_data) # 运行推理 self.interpreter.invoke() # 获取输出张量 keypoints_with_scores = self.interpreter.get_tensor(self.output_details[0]['index']) # 获取输入图像的高度和宽度 height, width, _ = frame.shape # 绘制关键点 result = {} for i in range(len(self.labels)): keypoint_y = int(keypoints_with_scores[0][0][i][0] * height) keypoint_x = int(keypoints_with_scores[0][0][i][1] * width) confidence = keypoints_with_scores[0][0][i][2] keypoint = (keypoint_x, keypoint_y) label = self.labels[i] result.update({ label: { CONFIDENCE: confidence, KEY_POINTS: keypoint, RELIABLE: True if confidence >= self.filter_confidence else False, PASS_TEST: True } }) return valid_skeleton_judge(result)