For semantic segmentation we can use the classification module to compute the desired metrics.
The metrics have practically the same configuration parameters (to be sure, check the docs of each metric). They are:
average=
‘none’ : output will be the metric for each category
‘macro’: the metric is calculate for each class separately, and average the metrics across classes (with equal weights for each class).
mdmc_reduce or mdmc_average=
‘global’: will flatten the inputs, and them apply the average as usual.
None: Should be left unchanged if your data is not multi-dimensional multi-class.
ignore_index=
Integer specifying a target class to ignore.
num_classes=
Number of classes.
threshold=
Threshold for transforming probability or logit predictions to binary (0,1) predictions, in the case of binary or multi-label inputs. Default value of 0.5 corresponds to input being probabilities.
Compute the confusion matrix parameter (tp, fp, tn, fn)
The reduce parameter will define the reduction which will be applied. The “macro” value will compute the statistics for each class separately.
The mdmc_reduce will be required because we are working with a tensor which represent an image. The “global” value will flatten the inputs, and them apply the reduce as usual.
The num_classes, the number/quantity of classes is necessary for multicategorical data.
With reduce='macro' and mdmc_reduce='global' the output will be in the shape: (num_classes, 5). Where this 5 values will be TP, FP, TN, FN, sup (sup stands for support and equals to: TP + FN).
stat = torchmetrics.functional.stat_scores( pred, target, reduce="macro", mdmc_reduce="global", num_classes=num_classes)stat
We can do a pretty print of this matrix for demonstration only.
out_sequence = ["TP", "FP", "TN", "FN", "sup"]print(" "*10+"\t"+"\t | \t".join([f"{t}"for t in out_sequence]))for idx, name inenumerate(id2label.values()): txt ="\t | \t".join([f"{v:<5}"for v in stat[idx]])print(f"{name:<10} |\t{txt}")
Using average='macro' the metric is calculate for each class separately, and average the metrics across classes (with equal weights for each class).
The mdmc_reduce or mdmc_average will be required because we are working with a tensor which represent an image. The “global” value will flatten the inputs, and them apply the average as usual.
The lapixdl package calculates the confusion matrix first (on the CPU), which this will be slower than calculating using torchmetrics which uses pytorch tensors. So a trick here, to not calculate each metric separately in torchmetrics, is to calculate a confusion matrix using torchmetrics and then calculate all the metrics at once using lapixdl.
First, compute the confusion matrix with torch metrics