class Evaluator:
    """Class to evaluate various pose and shape estimation algorithms."""
    # ShapeNetV2 convention for all objects and datasets assumed
    # for simplicity assume all cans, bowls and bottles to be rotation symmetric
    SYMMETRY_AXIS_DICT = {
        "mug": None,
        "laptop": None,
        "camera": None,
        "can": 1,
        "bowl": 1,
        "bottle": 1,
    }
    def __init__(self, config: dict) -> None:
        """Initialize model wrappers and evaluator."""
        self._parse_config(config)
    def _parse_config(self, config: dict) -> None:
        """Read config and initialize method wrappers."""
        self._init_dataset(config["dataset_config"])
        self._visualize_input = config["visualize_input"]
        self._visualize_prediction = config["visualize_prediction"]
        self._visualize_gt = config["visualize_gt"]
        self._fast_eval = config["fast_eval"]
        self._store_visualization = config["store_visualization"]
        self._run_name = (
            f"{self._dataset_name}_eval_{config['run_name']}_"
            f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
        )
        self._out_dir_path = config["out_dir"]
        self._metrics = config["metrics"]
        self._num_gt_points = config["num_gt_points"]
        self._vis_camera_json = config["vis_camera_json"]
        self._render_options_json = config["render_options_json"]
        self._cam = camera_utils.Camera(**config["camera"])
        self._init_wrappers(config["methods"])
        self._config = config
    def _init_dataset(self, dataset_config: dict) -> None:
        """Initialize reading of dataset.
        This includes sanity checks whether the provided path is correct.
        """
        self._dataset_name = dataset_config["name"]
        print(f"Initializing {self._dataset_name} dataset...")
        dataset_type = utils.str_to_object(dataset_config["type"])
        self._dataset = dataset_type(config=dataset_config["config_dict"])
        # Faster but probably only worth it if whole evaluation supports batches
        # self._dataloader = DataLoader(self._dataset, 1, num_workers=8)
        if len(self._dataset) == 0:
            print(f"No images found for dataset {self._dataset_name}")
            exit()
        print(f"{len(self._dataset)} samples found for dataset {self._dataset_name}.")
    def _init_wrappers(self, method_configs: dict) -> None:
        """Initialize method wrappers."""
        self._wrappers = {}
        for method_dict in method_configs.values():
            method_name = method_dict["name"]
            print(f"Initializing {method_name}...")
            method_type = utils.str_to_object(method_dict["method_type"])
            if method_type is None:
                print(f"Could not find class {method_dict['method_type']}")
                continue
            self._wrappers[method_name] = method_type(
                config=method_dict["config_dict"], camera=self._cam
            )
    def _eval_method(self, method_name: str, method_wrapper: CPASMethod) -> None:
        """Run and evaluate method on all samples."""
        print(f"Run {method_name}...")
        self._init_metrics()
        indices = list(range(len(self._dataset)))
        random.seed(0)
        random.shuffle(indices)
        for i in tqdm(indices):
            if self._fast_eval and i % 10 != 0:
                continue
            sample = self._dataset[i]
            if self._visualize_input:
                _, ((ax1, ax2), (ax3, _)) = plt.subplots(2, 2)
                ax1.imshow(sample["color"].numpy())
                ax2.imshow(sample["depth"].numpy())
                ax3.imshow(sample["mask"].numpy())
                plt.show()
            t_start = time.time()
            prediction = method_wrapper.inference(
                color_image=sample["color"],
                depth_image=sample["depth"],
                instance_mask=sample["mask"],
                category_str=sample["category_str"],
            )
            inference_time = time.time() - t_start
            self._runtime_data["total"] += inference_time
            self._runtime_data["total_squared"] += inference_time**2
            self._runtime_data["count"] += 1
            self._runtime_data["min"] = min(self._runtime_data["min"], inference_time)
            self._runtime_data["max"] = max(self._runtime_data["max"], inference_time)
            if self._visualize_gt:
                visualize_estimation(
                    color_image=sample["color"],
                    depth_image=sample["depth"],
                    local_cv_position=sample["position"],
                    local_cv_orientation_q=sample["quaternion"],
                    reconstructed_mesh=self._dataset.load_mesh(sample["obj_path"]),
                    extents=sample["scale"],
                    camera=self._cam,
                    vis_camera_json=self._vis_camera_json,
                    render_options_json=self._render_options_json,
                )  # GT estimate
            if self._visualize_prediction:
                visualize_estimation(
                    color_image=sample["color"],
                    depth_image=sample["depth"],
                    local_cv_position=prediction["position"],
                    local_cv_orientation_q=prediction["orientation"],
                    extents=prediction["extents"],
                    reconstructed_points=prediction["reconstructed_pointcloud"],
                    reconstructed_mesh=prediction["reconstructed_mesh"],
                    camera=self._cam,
                    vis_camera_json=self._vis_camera_json,
                    render_options_json=self._render_options_json,
                )
            if self._store_visualization:
                vis_dir_path = os.path.join(
                    self._out_dir_path, self._run_name, "visualization"
                )
                os.makedirs(vis_dir_path, exist_ok=True)
                vis_file_path = os.path.join(vis_dir_path, f"{i:06}_{method_name}.jpg")
                visualize_estimation(
                    color_image=sample["color"],
                    depth_image=sample["depth"],
                    local_cv_position=prediction["position"],
                    local_cv_orientation_q=prediction["orientation"],
                    extents=prediction["extents"],
                    reconstructed_points=prediction["reconstructed_pointcloud"],
                    reconstructed_mesh=prediction["reconstructed_mesh"],
                    camera=self._cam,
                    vis_camera_json=self._vis_camera_json,
                    render_options_json=self._render_options_json,
                    vis_file_path=vis_file_path,
                )
            self._eval_prediction(prediction, sample)
        self._finalize_metrics(method_name)
    def _eval_prediction(self, prediction: PredictionDict, sample: dict) -> None:
        """Evaluate all metrics for a prediction."""
        # correctness metric
        for metric_name in self._metrics.keys():
            self._eval_metric(metric_name, prediction, sample)
    def _init_metrics(self) -> None:
        """Initialize metrics."""
        self._metric_data = {}
        self._runtime_data = {
            "total": 0.0,
            "total_squared": 0.0,
            "count": 0.0,
            "min": 1e10,
            "max": 0.0,
        }
        for metric_name, metric_config_dict in self._metrics.items():
            self._metric_data[metric_name] = self._init_metric_data(metric_config_dict)
    def _init_metric_data(self, metric_config_dict: dict) -> dict:
        """Create data structure necessary to compute a metric."""
        metric_data = {}
        if "position_thresholds" in metric_config_dict:
            pts = metric_config_dict["position_thresholds"]
            dts = metric_config_dict["deg_thresholds"]
            its = metric_config_dict["iou_thresholds"]
            fts = metric_config_dict["f_thresholds"]
            metric_data["correct_counters"] = np.zeros(
                (
                    len(pts),
                    len(dts),
                    len(its),
                    len(fts),
                    self._dataset.num_categories + 1,
                )
            )
            metric_data["total_counters"] = np.zeros(self._dataset.num_categories + 1)
        elif "pointwise_f" in metric_config_dict:
            metric_data["means"] = np.zeros(self._dataset.num_categories + 1)
            metric_data["m2s"] = np.zeros(self._dataset.num_categories + 1)
            metric_data["counts"] = np.zeros(self._dataset.num_categories + 1)
        else:
            raise NotImplementedError("Unsupported metric configuration.")
        return metric_data
    def _eval_metric(
        self, metric_name: str, prediction: PredictionDict, sample: dict
    ) -> None:
        """Evaluate and update single metric for a single prediction.
        Args:
            metric_name: Name of metric to evaluate.
            prediction: Dictionary containing prediction data.
            sample: Sample containing ground truth information.
        """
        metric_config_dict = self._metrics[metric_name]
        if "position_thresholds" in metric_config_dict:  # correctness metrics
            self._eval_correctness_metric(metric_name, prediction, sample)
        elif "pointwise_f" in metric_config_dict:  # pointwise reconstruction metrics
            self._eval_pointwise_metric(metric_name, prediction, sample)
        else:
            raise NotImplementedError(
                f"Unsupported metric configuration with name {metric_name}."
            )
    def _eval_correctness_metric(
        self, metric_name: str, prediction: PredictionDict, sample: dict
    ) -> None:
        """Evaluate and update single correctness metric for a single prediction.
        Args:
            metric_name: Name of metric to evaluate.
            prediction: Dictionary containing prediction data.
            sample: Sample containing ground truth information.
        """
        metric_dict = self._metrics[metric_name]
        correct_counters = self._metric_data[metric_name]["correct_counters"]
        total_counters = self._metric_data[metric_name]["total_counters"]
        category_id = sample["category_id"]
        total_counters[category_id] += 1
        total_counters[-1] += 1
        gt_points, pred_points = self._get_points(sample, prediction, True)
        for pi, p in enumerate(metric_dict["position_thresholds"]):
            for di, d in enumerate(metric_dict["deg_thresholds"]):
                for ii, i in enumerate(metric_dict["iou_thresholds"]):
                    for fi, f in enumerate(metric_dict["f_thresholds"]):
                        correct = metrics.correct_thresh(
                            position_gt=sample["position"].cpu().numpy(),
                            position_prediction=prediction["position"].cpu().numpy(),
                            orientation_gt=Rotation.from_quat(sample["quaternion"]),
                            orientation_prediction=Rotation.from_quat(
                                prediction["orientation"]
                            ),
                            extent_gt=sample["scale"].cpu().numpy(),
                            extent_prediction=prediction["extents"].cpu().numpy(),
                            points_gt=gt_points,
                            points_prediction=pred_points,
                            position_threshold=p,
                            degree_threshold=d,
                            iou_3d_threshold=i,
                            fscore_threshold=f,
                            rotational_symmetry_axis=self.SYMMETRY_AXIS_DICT[
                                sample["category_str"]
                            ],
                        )
                        correct_counters[pi, di, ii, fi, category_id] += correct
                        correct_counters[pi, di, ii, fi, -1] += correct  # all
    def _eval_pointwise_metric(
        self, metric_name: str, prediction: PredictionDict, sample: dict
    ) -> None:
        """Evaluate and update single pointwise metric for a single prediction.
        Args:
            metric_name: Name of metric to evaluate.
            prediction: Dictionary containing prediction data.
            sample: Sample containing ground truth information.
        """
        metric_config_dict = self._metrics[metric_name]
        means = self._metric_data[metric_name]["means"]
        m2s = self._metric_data[metric_name]["m2s"]
        counts = self._metric_data[metric_name]["counts"]
        category_id = sample["category_id"]
        point_metric = utils.str_to_object(metric_config_dict["pointwise_f"])
        gt_points, pred_points = self._get_points(
            sample, prediction, metric_config_dict["posed"]
        )
        result = point_metric(
            gt_points.numpy(), pred_points.numpy(), **metric_config_dict["kwargs"]
        )
        # Use Welfords algorithm to update mean and variance
        # for category
        counts[category_id] += 1
        delta = result - means[category_id]
        means[category_id] += delta / counts[category_id]
        delta2 = result - means[category_id]
        m2s[category_id] += delta * delta2
        # for all
        counts[-1] += 1
        delta = result - means[-1]
        means[-1] += delta / counts[-1]
        delta2 = result - means[-1]
        m2s[-1] += delta * delta2
    def _get_points(
        self, sample: dict, prediction: PredictionDict, posed: bool
    ) -> Tuple[np.ndarray]:
        # load ground truth mesh
        gt_mesh = self._dataset.load_mesh(sample["obj_path"])
        gt_points = torch.from_numpy(
            np.asarray(gt_mesh.sample_points_uniformly(self._num_gt_points).points)
        )
        pred_points = prediction["reconstructed_pointcloud"]
        # transform points if posed
        if posed:
            gt_points = quaternion_utils.quaternion_apply(
                sample["quaternion"], gt_points
            )
            gt_points += sample["position"]
            pred_points = quaternion_utils.quaternion_apply(
                prediction["orientation"], pred_points
            )
            pred_points += prediction["position"]
        return gt_points, pred_points
    def _finalize_runtime_metric(self) -> dict:
        mean = self._runtime_data["total"] / self._runtime_data["count"]
        mean_squared = self._runtime_data["total_squared"] / self._runtime_data["count"]
        variance = mean_squared - mean**2
        std = math.sqrt(variance)
        return {
            "mean": mean,
            "variance": variance,
            "std": std,
            "min": self._runtime_data["min"],
            "max": self._runtime_data["max"],
        }
    def _finalize_metrics(self, method_name: str) -> None:
        """Finalize metrics after all samples have been evaluated.
        Also writes results to disk and create plot if applicable.
        """
        results_dir_path = os.path.join(self._out_dir_path, self._run_name)
        os.makedirs(results_dir_path, exist_ok=True)
        yaml_file_path = os.path.join(results_dir_path, "results.yaml")
        self._results_dict[method_name] = {}
        self._runtime_results_dict[method_name] = self._finalize_runtime_metric()
        for metric_name, metric_dict in self._metrics.items():
            if "position_thresholds" in metric_dict:  # correctness metrics
                correct_counter = self._metric_data[metric_name]["correct_counters"]
                total_counter = self._metric_data[metric_name]["total_counters"]
                correct_percentage = correct_counter / total_counter
                self._results_dict[method_name][
                    metric_name
                ] = correct_percentage.tolist()
                self._create_metric_plot(
                    method_name,
                    metric_name,
                    metric_dict,
                    correct_percentage,
                    results_dir_path,
                )
            elif "pointwise_f" in metric_dict:  # pointwise reconstruction metrics
                counts = self._metric_data[metric_name]["counts"]
                m2s = self._metric_data[metric_name]["m2s"]
                means = self._metric_data[metric_name]["means"]
                variances = m2s / counts
                stds = np.sqrt(variances)
                self._results_dict[method_name][metric_name] = {
                    "means": means.tolist(),
                    "variances": variances.tolist(),
                    "std": stds.tolist(),
                }
            else:
                raise NotImplementedError(
                    f"Unsupported metric configuration with name {metric_name}."
                )
        results_dict = {
            **self._config,
            "results": self._results_dict,
            "runtime_results": self._runtime_results_dict,
        }
        yoco.save_config_to_file(yaml_file_path, results_dict)
        print(f"Results saved to: {yaml_file_path}")
    def _create_metric_plot(
        self,
        method_name: str,
        metric_name: str,
        metric_dict: dict,
        correct_percentage: np.ndarray,
        out_dir: str,
    ) -> None:
        """Create metric plot if applicable.
        Applicable means only one of the thresholds has multiple values.
        Args:
            correct_percentage:
                Array holding the percentage of correct predictions.
                Shape (NUM_POS_THRESH,NUM_DEG_THRESH,NUM_IOU_THRESH,NUM_CATEGORIES + 1).
        """
        axis = None
        for i, s in enumerate(correct_percentage.shape[:4]):
            if s != 1 and axis is None:
                axis = i
            elif s != 1:  # multiple axis with != 1 size
                return
        if axis is None:
            return
        axis_to_threshold_key = {
            0: "position_thresholds",
            1: "deg_thresholds",
            2: "iou_thresholds",
            3: "f_thresholds",
        }
        threshold_key = axis_to_threshold_key[axis]
        x_values = metric_dict[threshold_key]
        for category_id in range(self._dataset.num_categories + 1):
            y_values = correct_percentage[..., category_id].flatten()
            if category_id in self._dataset.category_id_to_str:
                label = self._dataset.category_id_to_str[category_id]
            else:
                label = "all"
            plt.plot(x_values, y_values, label=label)
        figure_file_path = os.path.join(out_dir, f"{method_name}_{metric_name}")
        plt.xlabel(threshold_key)
        plt.ylabel("Correct")
        plt.legend()
        plt.grid()
        tikzplotlib.save(figure_file_path + ".tex")
        plt.savefig(figure_file_path + ".png")
        plt.close()
    def run(self) -> None:
        """Run the evaluation."""
        self._results_dict = {}
        self._runtime_results_dict = {}
        for method_name, method_wrapper in self._wrappers.items():
            self._eval_method(method_name, method_wrapper)