# -*- coding: utf-8 -*-
"""A StageData object acts like a data container that manages results generated by stages and read by subsequent
stages.
"""
from os import path
from uval.stages.stage_data import DatasetSpecificationData
from uval.utils.label_naming import label_short_to_long
from uval.utils.log import logger
from uval.utils.yaml_io import load_yaml_data
# @uval_stage
[docs]def load_datasplit(file_path: str, subsets=None, output=None) -> DatasetSpecificationData:
"""This stage will load a data split file in YAML format and keep the id lists for
train, val and test set.
Args:
file_path (str): path to the yaml file.
subsets ([str], optional): which subsets of the yaml files are to be loaded,
if set to None, all subsets are used. Defaults to None.
output: The output file to be generated, if needed
Returns:
DatasetSpecificationData: returns the dataset.
"""
# Set up result data store
stage_results = DatasetSpecificationData()
logger.debug("Reading YAML file started")
data_split_dict = load_yaml_data(file_path)
logger.debug("Reading YAML file finished.")
if not data_split_dict or not data_split_dict.get("split"):
logger.error("No data could be read from data split file '{}'".format(file_path))
return DatasetSpecificationData()
available_subsets = list()
if data_split_dict.get("split"):
available_subsets = [subset_name for subset_name in data_split_dict.get("split").keys()] # type: ignore
if subsets is None:
subsets = available_subsets
else:
for subset in subsets:
if subset not in available_subsets:
raise ValueError(f"The requested split subset {subset} not found in the YAML file.")
logger.debug(subsets)
# frame_columns = ["volume_id", "label_id", "is_negative", "subset", "class_name"]
dic = dict()
for subset_name, subset_data in data_split_dict.get("split").items(): # type: ignore
if subset_name not in subsets:
continue
for class_name, image_list in subset_data.items():
is_negative = True if class_name == "negative" else False
logger.debug("Reading list of '{}' images for split {}/{}".format(len(image_list), subset_name, class_name))
for volume_id, labels in image_list.items():
labels_list = labels if isinstance(labels, list) else [labels]
for label_id in labels_list:
try:
item_key = label_short_to_long(volume_id, label_id) if label_id else volume_id
dic[item_key] = {
"volume_id": volume_id,
"label_id": label_id,
"is_negative": is_negative,
"subset": subset_name,
"class_name": class_name,
}
except KeyError:
logger.warning(f"Unable to find proper keys in the YAML file: '{file_path}'")
stage_results.from_dict_as_rows(dic)
if output:
if output.DATASET_OVERVIEW_FILE.split(".")[-1] == "html":
stage_results.to_html(path.join(output.PATH, output.DATASET_OVERVIEW_FILE))
elif output.DATASET_OVERVIEW_FILE.split(".")[-1] == "csv":
stage_results.to_csv(path.join(output.PATH, output.DATASET_OVERVIEW_FILE))
else:
raise NotImplementedError("Other output formats rather than 'html' and 'csv' are not implemented yet!")
logger.info(f"dataset overview saved to {path.join(output.PATH, output.DATASET_OVERVIEW_FILE)}.")
return stage_results