% LABELS is a vector of binary class labels. Positive indicates the "positive" class, negative
% or zero indicates "negative."
% SCORES is a vector of scores, with larger values indicating greater likelihood of the positive
% class. If the values
% To plot an ROC curve:
% [tp, fp] = roc(labels, scores);
% plot(fp, tp);
function [tp, fp, threshes, eer, eeThresh] = roc(labels, scores)
% Convert labels to strictly 1s (positive) and 0s (negative).
labels = (labels > 0);
% Sort from most likely to be postive to most likely to be negative.
[sscores, sscoreixes] = sort(scores, 'descend');
slabels = labels(sscoreixes);
% True positive and false positive rates at each threshold.
tp = cumsum(slabels) ./ sum(slabels);
fp = cumsum(~slabels) ./ sum(~slabels);
% Note that if we have multiple samples with the same score, we cannot really classify some of
% them as positive and others as negative. So only the last (tp, fp) pair for each score should
% be included (based on the idea than anything with score greater than *or equal to* the
% threshold is classified positive).
oldscore = sscores(end);
isgood = true(size(sscores));
for ix = length(sscores)-1:-1:1
if sscores(ix) == oldscore
isgood(ix) = false;
else
oldscore = sscores(ix);
end
end
goodpoints = find(isgood);
threshes = sscores(goodpoints);
tp = tp(goodpoints);
fp = fp(goodpoints);
% Equal error mean false_positive_rate = false_negative_rate.
% false_negative_rate = 1 - true_positive_rate, so we want to find where fp = 1 - tp, or
% where 1 - tp - fp = 0. We may not be able to hit that exactly, but get as close as we can,
% taking the mean in the case of a tie.
d = 1 - tp - fp;
iee = find(abs(d) == min(abs(d)));
assert(length(iee) >= 1);
eeThresh = mean(threshes(iee));
eetp = sum((scores >= eeThresh) & labels) / sum(labels);
eefp = sum((scores >= eeThresh) & ~labels) / sum(~labels);
eer = mean([eetp 1-eefp]);
assert(tp(1) > 0 || fp(1) > 0);
if fp(1) ~= 0
if size(tp, 1) > 1
tp = [0; tp]; fp = [0; fp]; threshes = [inf; threshes];
else
tp = [0 tp]; fp = [0 fp]; threshes = [inf threshes];
end
end
assert(tp(end) == 1 && fp(end) == 1);