# -*- coding: utf-8 -*-
"""This module provides stages that can compute metrics like overlaps between
groundtruth and detections, or simply count detections per bag.
"""
import math
import os
import pickle
import sys
from multiprocessing import Pool
from typing import Any, Dict, List, Tuple, Union
import jinja2 # type: ignore
import matplotlib # type: ignore
import matplotlib.pyplot as plt # type: ignore
import numpy as np
import pandas as pd # type: ignore
from rich.progress import track
from uval.stages.stage import uval_stage # type: ignore
from uval.stages.stage_data import SupportedDatasetSpecificationData
from uval.utils.log import logger
matplotlib.use("Agg")
[docs]class DetectionEntry:
"""Any detection entry wether it is from GT or detections will be converted in this form."""
def __init__(
self, volume_id: str, class_name: str, confidence_score: float, roi_start: Tuple[int], roi_shape: Tuple[int]
):
self.volume_id = volume_id
self.class_name = class_name
self.confidence_score = confidence_score
self.roi_start = roi_start
self.roi_shape = roi_shape
[docs]class Metrics:
"""This class implements all the evaluation metrics."""
def __init__(self, dataset: SupportedDatasetSpecificationData, metrics_settings, output_settings):
self.dataset = dataset
self.data = None
self.iou_threshold = metrics_settings.IOU_THRESHOLD
self.template_file = "template.html"
if len(metrics_settings.IOU_RANGE) == 3:
a, b, c = metrics_settings.IOU_RANGE
self.iou_range = np.linspace(a, b, int(np.round((b - a) / c)) + 1, endpoint=True).tolist()
self.iou_range = [round(iou, 2) for iou in self.iou_range]
self.template_file = "template_range.html"
else:
self.iou_range = None
self.confidence_threshold = metrics_settings.CONFIDENCE_THRESHOLD
self.output_path = output_settings.PATH
self.title = output_settings.TITLE or self.output_path.split("/")[-1]
self.report_file = output_settings.REPORT_FILE
self.metrics_file = output_settings.METRICS_FILE
self.max_workers = metrics_settings.MAX_PROCESSES
self.templates_path = output_settings.TEMPLATES_PATH
self.ap_method = metrics_settings.AP_METHOD
self.factor = metrics_settings.FACTOR
[docs] def evaluate(self):
metrics_output = dict()
basics = self.basic_metric(iou_threshold=self.iou_threshold, confidence_threshold=self.confidence_threshold)
ap_metrics = self.get_average_precision(basics, method=self.ap_method)
fscore_metrics = self.get_fscore(ap_metrics)
metrics_output["title"] = self.title
metrics_output["iou_threshold"] = self.iou_threshold
metrics_output["single_threshold"] = fscore_metrics
metrics_output["confidence_threshold"] = self.confidence_threshold
if self.iou_range:
aps, rs, ars, map, mar = self.evaluate_range(iou_range=self.iou_range)
metrics_output["AP"] = aps
metrics_output["rs"] = rs
metrics_output["ars"] = ars
metrics_output["map"] = map
metrics_output["mar"] = mar
metrics_output["iou_range"] = self.iou_range
with open(os.path.join(self.output_path, self.metrics_file), "wb") as f:
pickle.dump(metrics_output, f)
logger.info(f"metrics saved to {os.path.join(self.output_path, self.metrics_file)}.")
self.plot_roc_curves(fscore_metrics)
self.plot_precision_recall_curve(fscore_metrics)
self.generate_report(metrics_output)
return metrics_output
[docs] def worker(self, iou_threshold):
basics = self.basic_metric(iou_threshold=iou_threshold, confidence_threshold=0.1)
rs = [result["Single_Recall"] for result in basics]
output_metrics = self.get_average_precision(basics)
aps = [result["AP"] for result in output_metrics]
map = sum(aps) / len(aps)
classes = [basic["Class"] for basic in basics]
return rs, aps, map, classes
[docs] def evaluate_range(self, iou_range: List[float], confidence_threshold: float = None):
# if not iou_range:
# iou_range = self.iou_range
if not confidence_threshold:
confidence_threshold = self.confidence_threshold
aps = dict()
rs: Dict[float, List[float]] = dict()
map = dict()
with Pool(self.max_workers) as pool:
result = pool.map(self.worker, iou_range)
for iou, res in zip(iou_range, result):
aps[iou] = res[0]
rs[iou] = res[1]
map[iou] = res[2]
classes = res[3]
assert len(aps) == len(iou_range)
ars = self.get_average_recall(rs, iou_range)
self.plot_recall_iou_curve(rs, iou_range, classes)
mar = sum(ars) / len(ars)
return aps, rs, ars, map, mar
[docs] def plot_recall_iou_curve(
self, recalls: Dict[float, List[float]], iou_thresholds: List[float], classes: List[str], show_ar: bool = True
) -> None:
"""Plot the Recall x IOU curve for a given class.
Args:
recalls (Dict[float:list]): keys are iou thresholds. value is a list of recall for each class.
iou_thresholds: list of iou thresholds.
classes: list containing names of all classes.
"""
# Each result represents a class
for index, class_id in enumerate(classes):
recall_vector = [recalls[iou][index] for iou in iou_thresholds]
plt.close()
plt.plot(iou_thresholds, recall_vector, label=f"Confidence:{self.confidence_threshold}")
plt.xlabel("IOU")
plt.ylabel("Recall")
if show_ar:
ap_str = "{0:.2f}%".format(sum(recall_vector) / len(recall_vector) * 100)
plt.title("Recall x IOU curve \nClass: %s, AR: %s" % (str(class_id), ap_str))
else:
plt.title("Recall x IOU curve \nClass: %s" % str(class_id))
plt.legend(shadow=True)
plt.grid()
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.savefig(os.path.join(self.output_path, class_id + "_recall_iou.png"))
plt.close()
@property
def templates_path(self):
return self._templates_path
@templates_path.setter
def templates_path(self, path):
os.makedirs(os.path.abspath(path), exist_ok=True)
self._templates_path = os.path.abspath(path)
@property
def output_path(self):
return self._output_path
@output_path.setter
def output_path(self, path):
os.makedirs(os.path.abspath(path), exist_ok=True)
self._output_path = os.path.abspath(path)
@property
def iou_threshold(self):
return self._iou_threshold
@iou_threshold.setter
def iou_threshold(self, value):
if value < 0 or value > 1:
raise ValueError
self._iou_threshold = value
@property
def confidence_threshold(self):
return self._confidence_threshold
@confidence_threshold.setter
def confidence_threshold(self, value):
if value < 0 or value > 1:
raise ValueError
self._confidence_threshold = value
[docs] @uval_stage
def data_preparations(self):
ground_truths = []
detections = []
# Get all classes
classes = []
det_classes = []
volumes_soc = {}
for row_num in track(range(len(self.dataset.table)), "Preparing..."):
row = self.dataset.table.loc[row_num]
volume_id = row["volume_id"]
if volume_id not in volumes_soc:
volumes_soc[volume_id] = 0
det = row["hdf5_detection"]
gt = row["hdf5_groundtruth"]
if gt:
volumes_soc.pop(volume_id, None)
for gt_item in gt.values():
class_name = gt_item["class_name"]
ground_truths.append(
DetectionEntry(volume_id, class_name, 1.0, gt_item["roi_start"], gt_item["roi_shape"])
)
if class_name not in classes:
classes.append(class_name)
# Loop through all bounding boxes and separate them into GTs and
# detections
for det_item in det:
# if det_item["class_name"]=="bluntobject":
# class_name = "bluntobjects"
if det_item["class_name"] == "iCMORE":
class_name = "grenade"
else:
class_name = det_item["class_name"]
det_classes.append(class_name)
detections.append(
DetectionEntry(
volume_id, class_name, det_item["score"], det_item["roi_start"], det_item["roi_shape"]
)
)
logger.info(f"detected classes are:{set(det_classes)}")
logger.info(f"ground truth classes are:{set(classes)}")
total_negative = len(volumes_soc)
return classes, volumes_soc, ground_truths, detections, total_negative
[docs] @uval_stage
def basic_metric(self, iou_threshold: float = None, confidence_threshold: float = None) -> List[dict]:
"""Get the TP, FP, Precision and recall.
Args:
iou_threshold (float, optional): Threshold for IOU. Defaults to None.
confidence_threshold (float, optional): Threshold for confidence. Defaults to None.
Returns:
List[dict]: A list of dictionaries. Each dictionary contains information and
metrics of each class.
"""
if not self.data:
self.data = self.data_preparations()
classes, volumes_soc, ground_truths, detections, total_negative = self.data
if iou_threshold is None:
iou_threshold = self.iou_threshold
if confidence_threshold is None:
confidence_threshold = self.confidence_threshold
ret = []
# Loop through by classes
for c in classes:
volumes_negative_current = {v: 0 for v in volumes_soc}
# Get only detection of class c
dects = [d for d in detections if d.class_name == c]
# dects_image_level = [d.volume_id for d in detections if d.class_name == c]
# Get only ground truths of class c, use filename as key
gts: Dict[Any, Any] = {}
nneg = 0.0
npos = 0.0
for g in ground_truths:
if g.class_name == c:
# volumes that do not contain class c
npos += 1.0
gts[g.volume_id] = gts.get(g.volume_id, []) + [g]
else:
nneg += 1.0
# sort detections by decreasing confidence
dects = sorted(dects, key=lambda detected_entry: detected_entry.confidence_score, reverse=True)
total_tp = 0
total_fp = 0
fp_image_level = np.zeros(len(dects))
tp = np.zeros(len(dects))
fp = np.zeros(len(dects))
tp_soft = 0
fp_soft = 0
# create dictionary with amount of gts for each image
det = {key: np.zeros(len(gts[key])) for key in gts}
# Loop through detections
single_recall = 0.0
for d, dect in enumerate(dects):
if dect.confidence_score < confidence_threshold:
single_recall = float(np.sum(tp)) / npos
if dect.volume_id in volumes_negative_current:
fp_image_level[d] = 1
volumes_negative_current.pop(dect.volume_id, None)
# Find ground truth image
gt = gts[dect.volume_id] if dect.volume_id in gts else []
iou_max = sys.float_info.min
for j, gtj in enumerate(gt):
iou = Metrics.iou(dect.roi_start, dect.roi_shape, gtj.roi_start, gtj.roi_shape)
if iou > iou_max:
iou_max = iou
jmax = j
# Assign detection as true positive/don't care/false positive
if iou_max >= iou_threshold:
if det[dect.volume_id][jmax] == 0:
tp[d] = 1 # count as true positive
total_tp += 1 if dect.confidence_score > confidence_threshold else 0
tp_soft += dect.confidence_score if dect.confidence_score > confidence_threshold else 0
det[dect.volume_id][jmax] = 1 # flag as already 'seen'
else:
fp[d] = 1 # count as false positive
total_fp += 1 if dect.confidence_score > confidence_threshold else 0
fp_soft += dect.confidence_score if dect.confidence_score > confidence_threshold else 0
# - A detected "cat" is overlaped with a GT "cat" with IOU >= iou_threshold.
else:
fp[d] = 1 # count as false positive
total_fp += 1 if dect.confidence_score > confidence_threshold else 0
fp_soft += dect.confidence_score if dect.confidence_score > confidence_threshold else 0
# compute precision, recall and average precision
acc_fp = np.cumsum(fp)
acc_tp = np.cumsum(tp)
acc_fp_image_level = np.cumsum(fp_image_level)
fpr = acc_fp_image_level / total_negative
rec = acc_tp / npos
prec = np.divide(acc_tp, (acc_fp + acc_tp))
# Depending on the method, call the right implementation
# add class result in the dictionary to be returned
r = {
"Class": c,
"precision": prec,
"recall": rec,
"Single_Recall": single_recall,
"Total Positives": npos,
"Total Negatives": nneg,
"Total TP": total_tp,
"Total FP": total_fp,
"Total FN": npos - total_tp,
"Total TP soft": tp_soft,
"Total FP soft": fp_soft,
"Total FN soft": npos - tp_soft,
"fpr": fpr,
}
ret.append(r)
return ret
[docs] def get_pascal_voc2012_metric(self, confidence_threshold=None) -> list:
basics = self.basic_metric(iou_threshold=0.5, confidence_threshold=confidence_threshold)
return self.get_average_precision(basics, method="EveryPointInterpolation")
[docs] def get_pascal_voc2007_metric(self, confidence_threshold=None) -> list:
basics = self.basic_metric(iou_threshold=0.5, confidence_threshold=confidence_threshold)
return self.get_average_precision(basics, method="ElevenPointInterpolation")
[docs] @uval_stage
def get_average_recall(self, recalls: Dict[float, list], iou_range: List) -> List:
class_num = len(recalls[iou_range[0]])
average_recalls = [0] * class_num
for c in range(class_num):
area = 0
for idx in range(len(iou_range) - 1):
iou_step = iou_range[idx + 1] - iou_range[idx]
area += iou_step * (
recalls[iou_range[idx]][c] + 0.5 * (recalls[iou_range[idx]][c] - recalls[iou_range[idx + 1]][c])
)
average_recalls[c] = area
return average_recalls
[docs] @uval_stage
def get_average_precision(self, basic_metrics: List[dict], method: str = None) -> List[dict]:
"""Get the average precision. This will be used in multiple other metrics such as
COCO or pascal voc.
Args:
basic_metrics (List[dict]): [description]
method (str, optional): choice between precise (EveryPointInterpolation or None)
or estimation (ElevenPointInterpolation). Defaults to None.
Returns:
List[dict]: adds ap to the each class of the output dictionaries.
"""
if method is None:
method = self.ap_method
ret = []
for err in basic_metrics:
# Depending on the method, call the right implementation
if method == "EveryPointInterpolation":
[ap, mpre, mrec] = Metrics.calculate_average_precision(err["recall"], err["precision"])
else:
[ap, mpre, mrec] = Metrics.eleven_point_interpolated_ap(err["recall"], err["precision"])
# add class result in the dictionary to be returned
r = {key: value for key, value in err.items()}
r["AP"] = ap
r["interpolated precision"] = mpre
r["interpolated recall"] = mrec
ret.append(r)
return ret
[docs] @uval_stage
def get_fscore(self, basic_metrics: List[dict]) -> List[dict]:
"""Get the f score metrics.
Args:
basic_metrics (List[dict]): output of basic_metric method.
needs to be called before this method.
Returns:
List[dict]: adds dict['F score'] and dict['F score soft'] to the inputs.
"""
ret = []
for err in basic_metrics:
fp = err["Total FP"]
fn = err["Total FN"]
tp = err["Total TP"]
f_score = (1 + self.factor**2) * tp / ((1 + self.factor**2) * tp + (self.factor**2) * fn + fp)
fp_soft = err["Total FP soft"]
fn_soft = err["Total FN soft"]
tp_soft = err["Total TP soft"]
f_score_soft = (
(1 + self.factor**2)
* tp_soft
/ ((1 + self.factor**2) * tp_soft + (self.factor**2) * fn_soft + fp_soft)
)
# add class result in the dictionary to be returned
r = {key: value for key, value in err.items()}
r["F score"] = f_score
r["F score soft"] = f_score_soft
ret.append(r)
return ret
[docs] @uval_stage
def generate_report(self, results_cluttered: dict) -> None:
# Sample DataFrame
range_results = dict(results_cluttered)
single_results = range_results.pop("single_threshold")
def func(row):
highlight = "background-color: darkorange;"
default = ""
return [default] * (len(row) - 1) + [highlight]
classes = []
for idx, res in enumerate(single_results):
res.pop("precision")
res.pop("recall")
res.pop("fpr")
res.pop("interpolated precision")
res.pop("interpolated recall")
# res["ar"]=range_results["ars"][idx]
classes.append(res["Class"])
res.pop("Class")
# for key, value in kwargs.items():
high_level = dict()
range_results.pop("ap", None)
rs = range_results.pop("rs", None)
range_results.pop("ars", None)
range_results.pop("iou_range", None)
high_level["mar"] = range_results.pop("mar", None)
cell_hover = { # for row hover use <tr> instead of <td>
"selector": "td:hover",
"props": [("background-color", "#ffffb3")],
}
row_hover = { # for row hover use <tr> instead of <td>
"selector": "tr:hover",
"props": [("background-color", "#ffffb3")],
}
sorted_idx = [i[0] for i in sorted(enumerate(classes), key=lambda x: x[1])]
single_results_sorted = [single_results[i] for i in sorted_idx]
df = pd.DataFrame(single_results_sorted, index=pd.Index(sorted(classes)))
styler = (
df.style.set_caption(f"Calculated metrics for iou:{self.iou_threshold}")
# .set_precision(2)
.format(precision=2).set_table_styles([row_hover])
)
env = jinja2.Environment(loader=jinja2.FileSystemLoader(searchpath=self.templates_path))
template = env.get_template(self.template_file)
total_images = []
for cls in sorted(classes):
img_names = []
img_names.append("." + "/" + cls + "_roc.png")
img_names.append("." + "/" + cls + "_precision_recall.png")
if self.iou_range:
img_names.append("." + "/" + cls + "_recall_iou.png")
total_images.append(img_names)
if self.iou_range:
high_level["map"] = sum(range_results["map"].values()) / len(range_results["map"])
range_results["map"]["Total"] = high_level["map"]
df2 = pd.DataFrame(
[range_results["map"].values()], index=pd.Index(["map"]), columns=range_results["map"].keys()
)
rs_sorted = {key: [val[i] for i in sorted_idx] for key, val in rs.items()}
df_rs = pd.DataFrame(rs_sorted, index=pd.Index(sorted(classes)))
styler_rs = (
df_rs.style.set_caption("Recall values for all classes and all IOU thresholds")
.format(precision=2)
.set_table_styles([row_hover])
)
styler2 = (
df2.style.set_caption("Mean average precision for various IOU levels.")
.format(precision=2)
.set_table_styles([cell_hover])
.apply(func, subset=["Total"], axis=1)
)
# Template handling
html = template.render(
range_table=styler2.to_html(),
single_table=styler.to_html(),
rs_table=styler_rs.to_html(),
total_images=total_images,
mar=round(high_level["mar"], 2),
map=round(high_level["map"], 2),
title=self.title,
)
else:
html = template.render(single_table=styler.to_html(), total_images=total_images, title=self.title)
# Template handling
# Write the HTML file
with open(os.path.join(self.output_path, self.report_file), "w") as f:
f.write(html)
logger.info(f"Report saved to {os.path.join(self.output_path, self.report_file)}.")
[docs] @uval_stage
def plot_precision_recall_curve(
self,
pascal_voc_metrics: List[dict],
show_ap: bool = True,
show_interpolated_precision: bool = True,
show_graphic: bool = False,
) -> None:
"""Plot the Precision x Recall curve for a given class.
Args:
pascal_voc_metrics (List[dict]): Output of some pascal voc metric. needs to be
called before this method.
show_ap (bool, optional): if True, the average precision value will be shown
in the title of the graph. Defaults to False.
show_interpolated_precision (bool, optional): if True, it will show in the plot
the interpolated precision. Defaults to False.
show_graphic (bool, optional): if True, the plot will be shown. Defaults to False.
Raises:
IOError: [description]
"""
result = None
# Each result represents a class
for result in pascal_voc_metrics:
if result is None:
raise IOError("Error: No data for a class was found.")
class_id = result["Class"]
precision = result["precision"]
recall = result["recall"]
average_precision = result["AP"]
mpre = result["interpolated precision"]
mrec = result["interpolated recall"]
plt.close()
if show_interpolated_precision:
if self.ap_method == "EveryPointInterpolation":
plt.plot(mrec, mpre, "--r", label="Interpolated precision (every point)")
elif self.ap_method == "ElevenPointInterpolation":
# Uncomment the line below if you want to plot the area
# plt.plot(mrec, mpre, 'or', label='11-point interpolated precision')
# Remove duplicates, getting only the highest precision of
# each recall value
nrec = []
nprec = []
for idx in range(len(mrec)):
r = mrec[idx]
if r not in nrec:
idx_eq = np.argwhere(mrec == r)
nrec.append(r)
nprec.append(max([mpre[int(idx)] for idx in idx_eq]))
plt.plot(nrec, nprec, "or", label="11-point interpolated precision")
else:
raise NotImplementedError(
"plot_precision_recall_curve() without show_interpolated_precision is not implemented yet!"
)
plt.plot(recall, precision, label=f"Precision for IOU:{self.iou_threshold}")
plt.xlabel("recall")
plt.ylabel("precision")
if show_ap:
ap_str = "{0:.2f}%".format(average_precision * 100)
plt.title("Precision x Recall curve \nClass: %s, AP: %s" % (str(class_id), ap_str))
else:
plt.title("Precision x Recall curve \nClass: %s" % str(class_id))
plt.legend(shadow=True)
plt.grid()
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
if self.output_path is not None:
plt.savefig(os.path.join(self.output_path, class_id + "_precision_recall.png"))
if show_graphic is True:
plt.show(block=False)
plt.pause(0.05)
plt.close()
[docs] @staticmethod
def calculate_average_precision(rec: List[float], prec: List[float]) -> List[Any]:
assert len(rec) == len(prec)
mrec = [0.0] + list(rec) + [1.0]
mpre = [1.0] + list(prec) + [0.0]
for i in range(len(mpre) - 1, 0, -1):
mpre[i - 1] = max(mpre[i - 1], mpre[i])
ii = []
for i in range(len(mrec) - 1):
if mrec[1 + i] != mrec[i]:
ii.append(i + 1)
ap: Union[float, Any] = 0.0
for i in ii:
ap = ap + np.sum((mrec[i] - mrec[i - 1]) * mpre[i])
# return [ap, mpre[1:len(mpre)-1], mrec[1:len(mpre)-1], ii]
return [ap, mpre[0 : len(mpre) - 1], mrec[0 : len(mpre) - 1]]
[docs] @staticmethod
# 11-point interpolated average precision
def eleven_point_interpolated_ap(rec: List[float], prec: List[float]) -> List[Any]:
mrec = list(rec)
mpre = list(prec)
recall_values_np = np.linspace(0, 1, 11)
recall_values = list(recall_values_np[::-1])
rho_interp = []
recall_valid = []
# For each recall_values (0, 0.1, 0.2, ... , 1)
for r in recall_values:
# Obtain all recall values higher or equal than r
arg_greater_recalls = np.argwhere(mrec[:] >= r)
pmax = 0.0
# If there are recalls above r
if arg_greater_recalls.size != 0:
pmax = max(mpre[int(arg_greater_recalls.min()) :])
recall_valid.append(r)
rho_interp.append(pmax)
# By definition ap = sum(max(precision whose recall is above r))/11
ap = sum(rho_interp) / 11
# Generating values for the plot
rvals = [recall_valid[0]] + list(recall_valid) + [0.0]
pvals = [0.0] + list(rho_interp) + [0.0]
# rho_interp = rho_interp[::-1]
cc = []
for i in range(len(rvals)):
p = (rvals[i], pvals[i - 1])
if p not in cc:
cc.append(p)
p = (rvals[i], pvals[i])
if p not in cc:
cc.append(p)
recall_values_out = [i[0] for i in cc]
rho_interp = [i[1] for i in cc]
return [ap, rho_interp, recall_values_out]
[docs] @staticmethod
def iou(start_a: List[float], shape_a: List[float], start_b: List[float], shape_b: List[float]) -> float:
"""Calculates the intersection over union of the two cubes A and B.
Args:
start_a (List[float]): bottom left corner of the cube A.
shape_a (List[float]): size of each dimension in the cube A.
start_b (List[float]): bottom left corner of the cube B.
shape_b (List[float]): size of each dimension in the cube B.
Returns:
float: 3D IOU of these cubes.
"""
if (
np.any(np.array(start_a) < 0)
or np.any(np.array(start_b) < 0)
or np.any(np.array(shape_a) < 0)
or np.any(np.array(shape_b) < 0)
):
logger.warning(f"bounding box coordinates are negative!{start_a}{shape_a}{start_b}{shape_b}")
return 0
if Metrics._boxes_intersect(start_a, shape_a, start_b, shape_b) is False:
return 0
inter_area = Metrics._get_intersection_area(start_a, shape_a, start_b, shape_b)
union = Metrics._get_union_areas(start_a, shape_a, start_b, shape_b)
# intersection over union
iou = inter_area / union
assert iou >= 0
return iou
@staticmethod
def _boxes_intersect(
start_a: List[float], shape_a: List[float], start_b: List[float], shape_b: List[float]
) -> bool:
"""Check if the two cubes intersect or not.
Args:
start_a (List[float]): bottom left corner of the cube A.
shape_a (List[float]): size of each dimension in the cube A.
start_b (List[float]): bottom left corner of the cube B.
shape_b (List[float]): size of each dimension in the cube B.
Returns:
bool: True if the two cubes intersect. otherwise False.
"""
if start_a[0] > start_b[0] + shape_b[0]:
return False
if start_b[0] > start_a[0] + shape_a[0]:
return False
if start_a[1] > start_b[1] + shape_b[1]:
return False
if start_b[1] > start_a[1] + shape_a[1]:
return False
if start_a[2] > start_b[2] + shape_b[2]:
return False
if start_b[2] > start_a[2] + shape_a[2]:
return False
return True
@staticmethod
def _get_union_areas(
start_a: List[float], shape_a: List[float], start_b: List[float], shape_b: List[float]
) -> float:
"""Calculates the Union of the areas of the two cubes A and B.
Args:
start_a (List[float]): bottom left corner of the cube A.
shape_a (List[float]): size of each dimension in the cube A.
start_b (List[float]): bottom left corner of the cube B.
shape_b (List[float]): size of each dimension in the cube B.
Returns:
float: union of the two cubes.
"""
area_a = Metrics._get_area(shape_a)
area_b = Metrics._get_area(shape_b)
inter_area = Metrics._get_intersection_area(start_a, shape_a, start_b, shape_b)
return float(area_a + area_b - inter_area)
@staticmethod
def _get_area(shape: List[float]) -> float:
"""calculates the area of a cube.
Args:
shape (List[float]): size of each dimension in the cube.
Returns:
float: area of the cube.
"""
return math.prod(shape)
@staticmethod
def _get_intersection_area(
start_a: List[float], shape_a: List[float], start_b: List[float], shape_b: List[float]
) -> float:
"""Calculates the intersection of the areas of the two cubes A and B.
Args:
start_a (List[float]): bottom left corner of the cube A.
shape_a (List[float]): size of each dimension in the cube A.
start_b (List[float]): bottom left corner of the cube B.
shape_b (List[float]): size of each dimension in the cube B.
Returns:
float: intersection of the two cubes.
"""
x_a = max(start_a[0], start_b[0])
y_a = max(start_a[1], start_b[1])
z_a = max(start_a[2], start_b[2])
x_b = min(start_a[0] + shape_a[0], start_b[0] + shape_b[0])
y_b = min(start_a[1] + shape_a[1], start_b[1] + shape_b[1])
z_b = min(start_a[2] + shape_a[2], start_b[2] + shape_b[2])
# intersection area
return (x_b - x_a) * (y_b - y_a) * (z_b - z_a)
[docs] @uval_stage
def plot_roc_curves(self, roc_metrics: List[dict], show_graphic: bool = False) -> None:
"""Plot the ROC curve for every class.
Args:
roc_metrics (List[dict]): Output of some basic_metric. needs to be
called before this method.
show_graphic (bool, optional): if True, the plot will be shown. Defaults to False.
Raises:
IOError: [description]
"""
result = None
# Each resut represents a class
for result in roc_metrics:
if result is None:
raise IOError("Error:No data for this class could be found.")
class_id = result["Class"]
recall = result["recall"]
fpr = result["fpr"]
plt.close()
plt.plot(fpr, recall, label=f"ROC for IOU:{self.iou_threshold}")
plt.xlabel("FP Rate")
plt.ylabel("TP Rate")
plt.title("ROC curve \nClass: %s" % str(class_id))
plt.legend(shadow=True)
plt.grid()
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
if self.output_path is not None:
plt.savefig(os.path.join(self.output_path, class_id + "_roc.png"))
if show_graphic is True:
plt.show(block=False)
plt.pause(0.05)
plt.close()