PrecisionRecallFScore Class

This class provides common statistical metrics for determining how well a source separation algorithm in nussl was able to create a binary mask compared to a known binary mask. The metrics used here are Precision, Recall, F-Score (sometimes called F-measure or F1-score), and Accuracy (though this is not reflected in the name of the class, it is simply # correct / total).

Example:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
# Run Repet
repet = nussl.Repet(mixture, mask_type=nussl.BinaryMask)  # it's important to specify BinaryMask!
repet_masks = repet()

# Get Ideal Binary Masks
ideal_mask = nussl.IdealMask(mixture, [drums, flute], mask_type=nussl.BinaryMask)  # BinaryMask here, too!
ideal_masks = ideal_mask()

# Compare Repet to Ideal Binary Mask
prf_repet = nussl.PrecisionRecallFScore(ideal_masks, repet_masks)
prf_repet_scores = prf_repet.evaluate()

Scores for each source are stored in a nested dictionary aptly named scores. This is a dictionary of dictionaries where the key is the source label, and the value is another dictionary with scores for each of the metrics for that source. So, for instance, the format of the prf_repet_scores dictionary from above is shown below:

{'Source 0' : {'Accuracy': 0.83,
               'Precision': 0.78,
               'Recall': 0.81,
               'F1-Score': 0.77 },
 'Source 1' : {'Accuracy': 0.22,
               'Precision': 0.12,
               'Recall': 0.15,
               'F1-Score': 0.19 }
}

Notes

  • PrecisionRecallFScore can only be run using BinaryMask objects. The constructor expects a list of

BinaryMask objects for both the ground truth sources and the estimated sources. * PrecisionRecallFScore does not calculate the correct permutation of the estimated and ground truth sources; they are expected to be in the correct order when they are passed into PrecisionRecallFScore.

See also

  • EvaluationBase Class for more information about derived properties that this class has.
  • IdealMask for information about how to get an array of ground truth binary masks.
class nussl.evaluation.precision_recall_fscore.PrecisionRecallFScore(true_sources_mask_list, estimated_sources_mask_list, source_labels=None)

Bases: nussl.evaluation.evaluation_base.EvaluationBase

Parameters:
  • true_sources_mask_list (list) – List of BinaryMask objects representing the ground truth sources.
  • estimated_sources_mask_list (list) – List of BinaryMask objects representing the estimates from a source separation object
  • source_labels (list) (Optional) – List of str with labels for each source. If no labels are provided, sources will be labeled Source 0, Source 1, ... etc.
scores

Dictionary storing the precision, recall, F1-Score, and accuracy. See nussl.PrecisionRecallFScore.evaluate below.

Type:dict
ACCURACY_KEY = 'Accuracy'
PRECISION_KEY = 'Precision'
RECALL_KEY = 'Recall'
FSCORE_KEY = 'F1-Score'
evaluate()

Determines the precision, recall, f-score, and accuracy of each BinaryMask object in true_sources_mask_list and estimated_sources_mask_list. Returns a dictionary of results that is formatted like so:

{'Source 0' : {'Accuracy': 0.83,
               'Precision': 0.78,
               'Recall': 0.81,
               'F1-Score': 0.77 },
 'Source 1' : {'Accuracy': 0.22,
               'Precision': 0.12,
               'Recall': 0.15,
               'F1-Score': 0.19 }
}

This dictionary is stored as e keys to this dictionary

Returns:self.scores (dict) – A dictionary of scores that contains accuracy, precision, recall, and F1-score of between the list of BinaryMask objects in both true_sources_mask_list and estimated_sources_mask_list.
scores

A dictionary that stores all scores from the evaluation method. Gets populated when evaluate() gets run.