Spaces:
Runtime error
Runtime error
| """ | |
| DeepLabCut Toolbox (deeplabcut.org) | |
| © A. & M. Mathis Labs | |
| Licensed under GNU Lesser General Public License v3.0 | |
| """ | |
| import os | |
| import ruamel.yaml | |
| import glob | |
| import warnings | |
| import numpy as np | |
| import tensorflow as tf | |
| import typing | |
| from pathlib import Path | |
| from typing import Optional, Tuple, List | |
| try: | |
| TFVER = [int(v) for v in tf.__version__.split(".")] | |
| if TFVER[1] < 14: | |
| from tensorflow.contrib.tensorrt import trt_convert as trt | |
| else: | |
| from tensorflow.python.compiler.tensorrt import trt_convert as trt | |
| except Exception: | |
| pass | |
| from dlclive.graph import ( | |
| read_graph, | |
| finalize_graph, | |
| get_output_nodes, | |
| get_output_tensors, | |
| extract_graph, | |
| ) | |
| from dlclive.pose import extract_cnn_output, argmax_pose_predict, multi_pose_predict | |
| from dlclive.display import Display | |
| from dlclive import utils | |
| from dlclive.exceptions import DLCLiveError, DLCLiveWarning | |
| if typing.TYPE_CHECKING: | |
| from dlclive.processor import Processor | |
| class DLCLive(object): | |
| """ | |
| Object that loads a DLC network and performs inference on single images (e.g. images captured from a camera feed) | |
| Parameters | |
| ----------- | |
| path : string | |
| Full path to exported model directory | |
| model_type: string, optional | |
| which model to use: 'base', 'tensorrt' for tensorrt optimized graph, 'lite' for tensorflow lite optimized graph | |
| precision : string, optional | |
| precision of model weights, only for model_type='tensorrt'. Can be 'FP16' (default), 'FP32', or 'INT8' | |
| cropping : list of int | |
| cropping parameters in pixel number: [x1, x2, y1, y2] | |
| dynamic: triple containing (state, detectiontreshold, margin) | |
| If the state is true, then dynamic cropping will be performed. That means that if an object is detected (i.e. any body part > detectiontreshold), | |
| then object boundaries are computed according to the smallest/largest x position and smallest/largest y position of all body parts. This window is | |
| expanded by the margin and from then on only the posture within this crop is analyzed (until the object is lost, i.e. <detectiontreshold). The | |
| current position is utilized for updating the crop window for the next frame (this is why the margin is important and should be set large | |
| enough given the movement of the animal). | |
| resize : float, optional | |
| Factor to resize the image. | |
| For example, resize=0.5 will downsize both the height and width of the image by a factor of 2. | |
| processor: dlc pose processor object, optional | |
| User-defined processor object. Must contain two methods: process and save. | |
| The 'process' method takes in a pose, performs some processing, and returns processed pose. | |
| The 'save' method saves any valuable data created by or used by the processor | |
| Processors can be used for two main purposes: | |
| i) to run a forward predicting model that will predict the future pose from past history of poses (history can be stored in the processor object, but is not stored in this DLCLive object) | |
| ii) to trigger external hardware based on pose estimation (e.g. see 'TeensyLaser' processor) | |
| convert2rgb : bool, optional | |
| boolean flag to convert frames from BGR to RGB color scheme | |
| display : bool, optional | |
| Display frames with DeepLabCut labels? | |
| This is useful for testing model accuracy and cropping parameters, but it is very slow. | |
| display_lik : float, optional | |
| Likelihood threshold for display | |
| display_raidus : int, optional | |
| radius for keypoint display in pixels, default=3 | |
| """ | |
| PARAMETERS = ( | |
| "path", | |
| "cfg", | |
| "model_type", | |
| "precision", | |
| "cropping", | |
| "dynamic", | |
| "resize", | |
| "processor", | |
| ) | |
| def __init__( | |
| self, | |
| model_path:str, | |
| model_type:str="base", | |
| precision:str="FP32", | |
| tf_config=None, | |
| cropping:Optional[List[int]]=None, | |
| dynamic:Tuple[bool, float, float]=(False, 0.5, 10), | |
| resize:Optional[float]=None, | |
| convert2rgb:bool=True, | |
| processor:Optional['Processor']=None, | |
| display:typing.Union[bool, Display]=False, | |
| pcutoff:float=0.5, | |
| display_radius:int=3, | |
| display_cmap:str="bmy", | |
| ): | |
| self.path = model_path | |
| self.cfg = None # type: typing.Optional[dict] | |
| self.model_type = model_type | |
| self.tf_config = tf_config | |
| self.precision = precision | |
| self.cropping = cropping | |
| self.dynamic = dynamic | |
| self.dynamic_cropping = None | |
| self.resize = resize | |
| self.processor = processor | |
| self.convert2rgb = convert2rgb | |
| if isinstance(display, Display): | |
| self.display = display | |
| elif display: | |
| self.display = Display(pcutoff=pcutoff, radius=display_radius, cmap=display_cmap) | |
| else: | |
| self.display = None | |
| self.sess = None | |
| self.inputs = None | |
| self.outputs = None | |
| self.tflite_interpreter = None | |
| self.pose = None | |
| self.is_initialized = False | |
| # checks | |
| if self.model_type == "tflite" and self.dynamic[0]: | |
| self.dynamic = (False, *self.dynamic[1:]) | |
| warnings.warn( | |
| "Dynamic cropping is not supported for tensorflow lite inference. Dynamic cropping will not be used...", | |
| DLCLiveWarning, | |
| ) | |
| self.read_config() | |
| def read_config(self): | |
| """ Reads configuration yaml file | |
| Raises | |
| ------ | |
| FileNotFoundError | |
| error thrown if pose configuration file does nott exist | |
| """ | |
| cfg_path = Path(self.path).resolve() / "pose_cfg.yaml" | |
| if not cfg_path.exists(): | |
| raise FileNotFoundError( | |
| f"The pose configuration file for the exported model at {str(cfg_path)} was not found. Please check the path to the exported model directory" | |
| ) | |
| ruamel_file = ruamel.yaml.YAML() | |
| self.cfg = ruamel_file.load(open(str(cfg_path), "r")) | |
| def parameterization(self) -> dict: | |
| """ | |
| Return | |
| Returns | |
| ------- | |
| """ | |
| return {param: getattr(self, param) for param in self.PARAMETERS} | |
| def process_frame(self, frame): | |
| """ | |
| Crops an image according to the object's cropping and dynamic properties. | |
| Parameters | |
| ----------- | |
| frame :class:`numpy.ndarray` | |
| image as a numpy array | |
| Returns | |
| ---------- | |
| frame :class:`numpy.ndarray` | |
| processed frame: convert type, crop, convert color | |
| """ | |
| if frame.dtype != np.uint8: | |
| frame = utils.convert_to_ubyte(frame) | |
| if self.cropping: | |
| frame = frame[ | |
| self.cropping[2] : self.cropping[3], self.cropping[0] : self.cropping[1] | |
| ] | |
| if self.dynamic[0]: | |
| if self.pose is not None: | |
| detected = self.pose[:, 2] > self.dynamic[1] | |
| if np.any(detected): | |
| x = self.pose[detected, 0] | |
| y = self.pose[detected, 1] | |
| x1 = int(max([0, int(np.amin(x)) - self.dynamic[2]])) | |
| x2 = int(min([frame.shape[1], int(np.amax(x)) + self.dynamic[2]])) | |
| y1 = int(max([0, int(np.amin(y)) - self.dynamic[2]])) | |
| y2 = int(min([frame.shape[0], int(np.amax(y)) + self.dynamic[2]])) | |
| self.dynamic_cropping = [x1, x2, y1, y2] | |
| frame = frame[y1:y2, x1:x2] | |
| else: | |
| self.dynamic_cropping = None | |
| if self.resize != 1: | |
| frame = utils.resize_frame(frame, self.resize) | |
| if self.convert2rgb: | |
| frame = utils.img_to_rgb(frame) | |
| return frame | |
| def init_inference(self, frame=None, **kwargs): | |
| """ | |
| Load model and perform inference on first frame -- the first inference is usually very slow. | |
| Parameters | |
| ----------- | |
| frame :class:`numpy.ndarray` | |
| image as a numpy array | |
| Returns | |
| -------- | |
| pose :class:`numpy.ndarray` | |
| the pose estimated by DeepLabCut for the input image | |
| """ | |
| # get model file | |
| model_file = glob.glob(os.path.normpath(self.path + "/*.pb"))[0] | |
| if not os.path.isfile(model_file): | |
| raise FileNotFoundError( | |
| "The model file {} does not exist.".format(model_file) | |
| ) | |
| # process frame | |
| if frame is None and (self.model_type == "tflite"): | |
| raise DLCLiveError( | |
| "No image was passed to initialize inference. An image must be passed to the init_inference method" | |
| ) | |
| if frame is not None: | |
| if frame.ndim == 2: | |
| self.convert2rgb = True | |
| processed_frame = self.process_frame(frame) | |
| # load model | |
| if self.model_type == "base": | |
| graph_def = read_graph(model_file) | |
| graph = finalize_graph(graph_def) | |
| self.sess, self.inputs, self.outputs = extract_graph( | |
| graph, tf_config=self.tf_config | |
| ) | |
| elif self.model_type == "tflite": | |
| ### | |
| # the frame size needed to initialize the tflite model as | |
| # tflite does not support saving a model with dynamic input size | |
| ### | |
| # get input and output tensor names from graph_def | |
| graph_def = read_graph(model_file) | |
| graph = finalize_graph(graph_def) | |
| output_nodes = get_output_nodes(graph) | |
| output_nodes = [on.replace("DLC/", "") for on in output_nodes] | |
| tf_version_2 = tf.__version__[0] == '2' | |
| if tf_version_2: | |
| converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph( | |
| model_file, | |
| ["Placeholder"], | |
| output_nodes, | |
| input_shapes={"Placeholder": [1, processed_frame.shape[0], processed_frame.shape[1], 3]}, | |
| ) | |
| else: | |
| converter = tf.lite.TFLiteConverter.from_frozen_graph( | |
| model_file, | |
| ["Placeholder"], | |
| output_nodes, | |
| input_shapes={"Placeholder": [1, processed_frame.shape[0], processed_frame.shape[1], 3]}, | |
| ) | |
| try: | |
| tflite_model = converter.convert() | |
| except Exception: | |
| raise DLCLiveError( | |
| ( | |
| "This model cannot be converted to tensorflow lite format. " | |
| "To use tensorflow lite for live inference, " | |
| "make sure to set TFGPUinference=False " | |
| "when exporting the model from DeepLabCut" | |
| ) | |
| ) | |
| self.tflite_interpreter = tf.lite.Interpreter(model_content=tflite_model) | |
| self.tflite_interpreter.allocate_tensors() | |
| self.inputs = self.tflite_interpreter.get_input_details() | |
| self.outputs = self.tflite_interpreter.get_output_details() | |
| elif self.model_type == "tensorrt": | |
| graph_def = read_graph(model_file) | |
| graph = finalize_graph(graph_def) | |
| output_tensors = get_output_tensors(graph) | |
| output_tensors = [ot.replace("DLC/", "") for ot in output_tensors] | |
| if (TFVER[0] > 1) | (TFVER[0] == 1 & TFVER[1] >= 14): | |
| converter = trt.TrtGraphConverter( | |
| input_graph_def=graph_def, | |
| nodes_blacklist=output_tensors, | |
| is_dynamic_op=True, | |
| ) | |
| graph_def = converter.convert() | |
| else: | |
| graph_def = trt.create_inference_graph( | |
| input_graph_def=graph_def, | |
| outputs=output_tensors, | |
| max_batch_size=1, | |
| precision_mode=self.precision, | |
| is_dynamic_op=True, | |
| ) | |
| graph = finalize_graph(graph_def) | |
| self.sess, self.inputs, self.outputs = extract_graph( | |
| graph, tf_config=self.tf_config | |
| ) | |
| else: | |
| raise DLCLiveError( | |
| "model_type = {} is not supported. model_type must be 'base', 'tflite', or 'tensorrt'".format( | |
| self.model_type | |
| ) | |
| ) | |
| # get pose of first frame (first inference is often very slow) | |
| if frame is not None: | |
| pose = self.get_pose(frame, **kwargs) | |
| else: | |
| pose = None | |
| self.is_initialized = True | |
| return pose | |
| def get_pose(self, frame=None, **kwargs): | |
| """ | |
| Get the pose of an image | |
| Parameters | |
| ----------- | |
| frame :class:`numpy.ndarray` | |
| image as a numpy array | |
| Returns | |
| -------- | |
| pose :class:`numpy.ndarray` | |
| the pose estimated by DeepLabCut for the input image | |
| """ | |
| if frame is None: | |
| raise DLCLiveError("No frame provided for live pose estimation") | |
| frame = self.process_frame(frame) | |
| if self.model_type in ["base", "tensorrt"]: | |
| pose_output = self.sess.run( | |
| self.outputs, feed_dict={self.inputs: np.expand_dims(frame, axis=0)} | |
| ) | |
| elif self.model_type == "tflite": | |
| self.tflite_interpreter.set_tensor( | |
| self.inputs[0]["index"], | |
| np.expand_dims(frame, axis=0).astype(np.float32), | |
| ) | |
| self.tflite_interpreter.invoke() | |
| if len(self.outputs) > 1: | |
| pose_output = [ | |
| self.tflite_interpreter.get_tensor(self.outputs[0]["index"]), | |
| self.tflite_interpreter.get_tensor(self.outputs[1]["index"]), | |
| ] | |
| else: | |
| pose_output = self.tflite_interpreter.get_tensor( | |
| self.outputs[0]["index"] | |
| ) | |
| else: | |
| raise DLCLiveError( | |
| "model_type = {} is not supported. model_type must be 'base', 'tflite', or 'tensorrt'".format( | |
| self.model_type | |
| ) | |
| ) | |
| # check if using TFGPUinference flag | |
| # if not, get pose from network output | |
| if len(pose_output) > 1: | |
| scmap, locref = extract_cnn_output(pose_output, self.cfg) | |
| num_outputs = self.cfg.get("num_outputs", 1) | |
| if num_outputs > 1: | |
| self.pose = multi_pose_predict( | |
| scmap, locref, self.cfg["stride"], num_outputs | |
| ) | |
| else: | |
| self.pose = argmax_pose_predict(scmap, locref, self.cfg["stride"]) | |
| else: | |
| pose = np.array(pose_output[0]) | |
| self.pose = pose[:, [1, 0, 2]] | |
| # display image if display=True before correcting pose for cropping/resizing | |
| if self.display is not None: | |
| self.display.display_frame(frame, self.pose) | |
| # if frame is cropped, convert pose coordinates to original frame coordinates | |
| if self.resize is not None: | |
| self.pose[:, :2] *= 1 / self.resize | |
| if self.cropping is not None: | |
| self.pose[:, 0] += self.cropping[0] | |
| self.pose[:, 1] += self.cropping[2] | |
| if self.dynamic_cropping is not None: | |
| self.pose[:, 0] += self.dynamic_cropping[0] | |
| self.pose[:, 1] += self.dynamic_cropping[2] | |
| # process the pose | |
| if self.processor: | |
| self.pose = self.processor.process(self.pose, **kwargs) | |
| return self.pose | |
| def close(self): | |
| """ Close tensorflow session | |
| """ | |
| self.sess.close() | |
| self.sess = None | |
| self.is_initialized = False | |
| if self.display is not None: | |
| self.display.destroy() | |