class Evaluator:
"""Class to evaluate various pose and shape estimation algorithms."""
# ShapeNetV2 convention for all objects and datasets assumed
# for simplicity assume all cans, bowls and bottles to be rotation symmetric
SYMMETRY_AXIS_DICT = {
"mug": None,
"laptop": None,
"camera": None,
"can": 1,
"bowl": 1,
"bottle": 1,
}
def __init__(self, config: dict) -> None:
"""Initialize model wrappers and evaluator."""
self._parse_config(config)
def _parse_config(self, config: dict) -> None:
"""Read config and initialize method wrappers."""
self._init_dataset(config["dataset_config"])
self._visualize_input = config["visualize_input"]
self._visualize_prediction = config["visualize_prediction"]
self._visualize_gt = config["visualize_gt"]
self._fast_eval = config["fast_eval"]
self._store_visualization = config["store_visualization"]
self._run_name = (
f"{self._dataset_name}_eval_{config['run_name']}_"
f"{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
)
self._out_dir_path = config["out_dir"]
self._metrics = config["metrics"]
self._num_gt_points = config["num_gt_points"]
self._vis_camera_json = config["vis_camera_json"]
self._render_options_json = config["render_options_json"]
self._cam = camera_utils.Camera(**config["camera"])
self._init_wrappers(config["methods"])
self._config = config
def _init_dataset(self, dataset_config: dict) -> None:
"""Initialize reading of dataset.
This includes sanity checks whether the provided path is correct.
"""
self._dataset_name = dataset_config["name"]
print(f"Initializing {self._dataset_name} dataset...")
dataset_type = utils.str_to_object(dataset_config["type"])
self._dataset = dataset_type(config=dataset_config["config_dict"])
# Faster but probably only worth it if whole evaluation supports batches
# self._dataloader = DataLoader(self._dataset, 1, num_workers=8)
if len(self._dataset) == 0:
print(f"No images found for dataset {self._dataset_name}")
exit()
print(f"{len(self._dataset)} samples found for dataset {self._dataset_name}.")
def _init_wrappers(self, method_configs: dict) -> None:
"""Initialize method wrappers."""
self._wrappers = {}
for method_dict in method_configs.values():
method_name = method_dict["name"]
print(f"Initializing {method_name}...")
method_type = utils.str_to_object(method_dict["method_type"])
if method_type is None:
print(f"Could not find class {method_dict['method_type']}")
continue
self._wrappers[method_name] = method_type(
config=method_dict["config_dict"], camera=self._cam
)
def _eval_method(self, method_name: str, method_wrapper: CPASMethod) -> None:
"""Run and evaluate method on all samples."""
print(f"Run {method_name}...")
self._init_metrics()
indices = list(range(len(self._dataset)))
random.seed(0)
random.shuffle(indices)
for i in tqdm(indices):
if self._fast_eval and i % 10 != 0:
continue
sample = self._dataset[i]
if self._visualize_input:
_, ((ax1, ax2), (ax3, _)) = plt.subplots(2, 2)
ax1.imshow(sample["color"].numpy())
ax2.imshow(sample["depth"].numpy())
ax3.imshow(sample["mask"].numpy())
plt.show()
t_start = time.time()
prediction = method_wrapper.inference(
color_image=sample["color"],
depth_image=sample["depth"],
instance_mask=sample["mask"],
category_str=sample["category_str"],
)
inference_time = time.time() - t_start
self._runtime_data["total"] += inference_time
self._runtime_data["total_squared"] += inference_time**2
self._runtime_data["count"] += 1
self._runtime_data["min"] = min(self._runtime_data["min"], inference_time)
self._runtime_data["max"] = max(self._runtime_data["max"], inference_time)
if self._visualize_gt:
visualize_estimation(
color_image=sample["color"],
depth_image=sample["depth"],
local_cv_position=sample["position"],
local_cv_orientation_q=sample["quaternion"],
reconstructed_mesh=self._dataset.load_mesh(sample["obj_path"]),
extents=sample["scale"],
camera=self._cam,
vis_camera_json=self._vis_camera_json,
render_options_json=self._render_options_json,
) # GT estimate
if self._visualize_prediction:
visualize_estimation(
color_image=sample["color"],
depth_image=sample["depth"],
local_cv_position=prediction["position"],
local_cv_orientation_q=prediction["orientation"],
extents=prediction["extents"],
reconstructed_points=prediction["reconstructed_pointcloud"],
reconstructed_mesh=prediction["reconstructed_mesh"],
camera=self._cam,
vis_camera_json=self._vis_camera_json,
render_options_json=self._render_options_json,
)
if self._store_visualization:
vis_dir_path = os.path.join(
self._out_dir_path, self._run_name, "visualization"
)
os.makedirs(vis_dir_path, exist_ok=True)
vis_file_path = os.path.join(vis_dir_path, f"{i:06}_{method_name}.jpg")
visualize_estimation(
color_image=sample["color"],
depth_image=sample["depth"],
local_cv_position=prediction["position"],
local_cv_orientation_q=prediction["orientation"],
extents=prediction["extents"],
reconstructed_points=prediction["reconstructed_pointcloud"],
reconstructed_mesh=prediction["reconstructed_mesh"],
camera=self._cam,
vis_camera_json=self._vis_camera_json,
render_options_json=self._render_options_json,
vis_file_path=vis_file_path,
)
self._eval_prediction(prediction, sample)
self._finalize_metrics(method_name)
def _eval_prediction(self, prediction: PredictionDict, sample: dict) -> None:
"""Evaluate all metrics for a prediction."""
# correctness metric
for metric_name in self._metrics.keys():
self._eval_metric(metric_name, prediction, sample)
def _init_metrics(self) -> None:
"""Initialize metrics."""
self._metric_data = {}
self._runtime_data = {
"total": 0.0,
"total_squared": 0.0,
"count": 0.0,
"min": 1e10,
"max": 0.0,
}
for metric_name, metric_config_dict in self._metrics.items():
self._metric_data[metric_name] = self._init_metric_data(metric_config_dict)
def _init_metric_data(self, metric_config_dict: dict) -> dict:
"""Create data structure necessary to compute a metric."""
metric_data = {}
if "position_thresholds" in metric_config_dict:
pts = metric_config_dict["position_thresholds"]
dts = metric_config_dict["deg_thresholds"]
its = metric_config_dict["iou_thresholds"]
fts = metric_config_dict["f_thresholds"]
metric_data["correct_counters"] = np.zeros(
(
len(pts),
len(dts),
len(its),
len(fts),
self._dataset.num_categories + 1,
)
)
metric_data["total_counters"] = np.zeros(self._dataset.num_categories + 1)
elif "pointwise_f" in metric_config_dict:
metric_data["means"] = np.zeros(self._dataset.num_categories + 1)
metric_data["m2s"] = np.zeros(self._dataset.num_categories + 1)
metric_data["counts"] = np.zeros(self._dataset.num_categories + 1)
else:
raise NotImplementedError("Unsupported metric configuration.")
return metric_data
def _eval_metric(
self, metric_name: str, prediction: PredictionDict, sample: dict
) -> None:
"""Evaluate and update single metric for a single prediction.
Args:
metric_name: Name of metric to evaluate.
prediction: Dictionary containing prediction data.
sample: Sample containing ground truth information.
"""
metric_config_dict = self._metrics[metric_name]
if "position_thresholds" in metric_config_dict: # correctness metrics
self._eval_correctness_metric(metric_name, prediction, sample)
elif "pointwise_f" in metric_config_dict: # pointwise reconstruction metrics
self._eval_pointwise_metric(metric_name, prediction, sample)
else:
raise NotImplementedError(
f"Unsupported metric configuration with name {metric_name}."
)
def _eval_correctness_metric(
self, metric_name: str, prediction: PredictionDict, sample: dict
) -> None:
"""Evaluate and update single correctness metric for a single prediction.
Args:
metric_name: Name of metric to evaluate.
prediction: Dictionary containing prediction data.
sample: Sample containing ground truth information.
"""
metric_dict = self._metrics[metric_name]
correct_counters = self._metric_data[metric_name]["correct_counters"]
total_counters = self._metric_data[metric_name]["total_counters"]
category_id = sample["category_id"]
total_counters[category_id] += 1
total_counters[-1] += 1
gt_points, pred_points = self._get_points(sample, prediction, True)
for pi, p in enumerate(metric_dict["position_thresholds"]):
for di, d in enumerate(metric_dict["deg_thresholds"]):
for ii, i in enumerate(metric_dict["iou_thresholds"]):
for fi, f in enumerate(metric_dict["f_thresholds"]):
correct = metrics.correct_thresh(
position_gt=sample["position"].cpu().numpy(),
position_prediction=prediction["position"].cpu().numpy(),
orientation_gt=Rotation.from_quat(sample["quaternion"]),
orientation_prediction=Rotation.from_quat(
prediction["orientation"]
),
extent_gt=sample["scale"].cpu().numpy(),
extent_prediction=prediction["extents"].cpu().numpy(),
points_gt=gt_points,
points_prediction=pred_points,
position_threshold=p,
degree_threshold=d,
iou_3d_threshold=i,
fscore_threshold=f,
rotational_symmetry_axis=self.SYMMETRY_AXIS_DICT[
sample["category_str"]
],
)
correct_counters[pi, di, ii, fi, category_id] += correct
correct_counters[pi, di, ii, fi, -1] += correct # all
def _eval_pointwise_metric(
self, metric_name: str, prediction: PredictionDict, sample: dict
) -> None:
"""Evaluate and update single pointwise metric for a single prediction.
Args:
metric_name: Name of metric to evaluate.
prediction: Dictionary containing prediction data.
sample: Sample containing ground truth information.
"""
metric_config_dict = self._metrics[metric_name]
means = self._metric_data[metric_name]["means"]
m2s = self._metric_data[metric_name]["m2s"]
counts = self._metric_data[metric_name]["counts"]
category_id = sample["category_id"]
point_metric = utils.str_to_object(metric_config_dict["pointwise_f"])
gt_points, pred_points = self._get_points(
sample, prediction, metric_config_dict["posed"]
)
result = point_metric(
gt_points.numpy(), pred_points.numpy(), **metric_config_dict["kwargs"]
)
# Use Welfords algorithm to update mean and variance
# for category
counts[category_id] += 1
delta = result - means[category_id]
means[category_id] += delta / counts[category_id]
delta2 = result - means[category_id]
m2s[category_id] += delta * delta2
# for all
counts[-1] += 1
delta = result - means[-1]
means[-1] += delta / counts[-1]
delta2 = result - means[-1]
m2s[-1] += delta * delta2
def _get_points(
self, sample: dict, prediction: PredictionDict, posed: bool
) -> Tuple[np.ndarray]:
# load ground truth mesh
gt_mesh = self._dataset.load_mesh(sample["obj_path"])
gt_points = torch.from_numpy(
np.asarray(gt_mesh.sample_points_uniformly(self._num_gt_points).points)
)
pred_points = prediction["reconstructed_pointcloud"]
# transform points if posed
if posed:
gt_points = quaternion_utils.quaternion_apply(
sample["quaternion"], gt_points
)
gt_points += sample["position"]
pred_points = quaternion_utils.quaternion_apply(
prediction["orientation"], pred_points
)
pred_points += prediction["position"]
return gt_points, pred_points
def _finalize_runtime_metric(self) -> dict:
mean = self._runtime_data["total"] / self._runtime_data["count"]
mean_squared = self._runtime_data["total_squared"] / self._runtime_data["count"]
variance = mean_squared - mean**2
std = math.sqrt(variance)
return {
"mean": mean,
"variance": variance,
"std": std,
"min": self._runtime_data["min"],
"max": self._runtime_data["max"],
}
def _finalize_metrics(self, method_name: str) -> None:
"""Finalize metrics after all samples have been evaluated.
Also writes results to disk and create plot if applicable.
"""
results_dir_path = os.path.join(self._out_dir_path, self._run_name)
os.makedirs(results_dir_path, exist_ok=True)
yaml_file_path = os.path.join(results_dir_path, "results.yaml")
self._results_dict[method_name] = {}
self._runtime_results_dict[method_name] = self._finalize_runtime_metric()
for metric_name, metric_dict in self._metrics.items():
if "position_thresholds" in metric_dict: # correctness metrics
correct_counter = self._metric_data[metric_name]["correct_counters"]
total_counter = self._metric_data[metric_name]["total_counters"]
correct_percentage = correct_counter / total_counter
self._results_dict[method_name][
metric_name
] = correct_percentage.tolist()
self._create_metric_plot(
method_name,
metric_name,
metric_dict,
correct_percentage,
results_dir_path,
)
elif "pointwise_f" in metric_dict: # pointwise reconstruction metrics
counts = self._metric_data[metric_name]["counts"]
m2s = self._metric_data[metric_name]["m2s"]
means = self._metric_data[metric_name]["means"]
variances = m2s / counts
stds = np.sqrt(variances)
self._results_dict[method_name][metric_name] = {
"means": means.tolist(),
"variances": variances.tolist(),
"std": stds.tolist(),
}
else:
raise NotImplementedError(
f"Unsupported metric configuration with name {metric_name}."
)
results_dict = {
**self._config,
"results": self._results_dict,
"runtime_results": self._runtime_results_dict,
}
yoco.save_config_to_file(yaml_file_path, results_dict)
print(f"Results saved to: {yaml_file_path}")
def _create_metric_plot(
self,
method_name: str,
metric_name: str,
metric_dict: dict,
correct_percentage: np.ndarray,
out_dir: str,
) -> None:
"""Create metric plot if applicable.
Applicable means only one of the thresholds has multiple values.
Args:
correct_percentage:
Array holding the percentage of correct predictions.
Shape (NUM_POS_THRESH,NUM_DEG_THRESH,NUM_IOU_THRESH,NUM_CATEGORIES + 1).
"""
axis = None
for i, s in enumerate(correct_percentage.shape[:4]):
if s != 1 and axis is None:
axis = i
elif s != 1: # multiple axis with != 1 size
return
if axis is None:
return
axis_to_threshold_key = {
0: "position_thresholds",
1: "deg_thresholds",
2: "iou_thresholds",
3: "f_thresholds",
}
threshold_key = axis_to_threshold_key[axis]
x_values = metric_dict[threshold_key]
for category_id in range(self._dataset.num_categories + 1):
y_values = correct_percentage[..., category_id].flatten()
if category_id in self._dataset.category_id_to_str:
label = self._dataset.category_id_to_str[category_id]
else:
label = "all"
plt.plot(x_values, y_values, label=label)
figure_file_path = os.path.join(out_dir, f"{method_name}_{metric_name}")
plt.xlabel(threshold_key)
plt.ylabel("Correct")
plt.legend()
plt.grid()
tikzplotlib.save(figure_file_path + ".tex")
plt.savefig(figure_file_path + ".png")
plt.close()
def run(self) -> None:
"""Run the evaluation."""
self._results_dict = {}
self._runtime_results_dict = {}
for method_name, method_wrapper in self._wrappers.items():
self._eval_method(method_name, method_wrapper)